In [1]:
import functools
import os
from typing import Callable, Iterable, Sequence, Union

import numpy as np
import pytorch_lightning as pl
import torch
from lqsvg import analysis, data, estimator
from lqsvg.envs import lqr
from lqsvg.envs.lqr.generators import LQGGenerator
from lqsvg.torch.nn.value import QuadQValue
from torch import Tensor, autograd, nn

from model_training_static_policy import DataModule, DataSpec, Experiment, make_modules

In [2]:
def obs_only(sampler: data.StateDynamics) -> Callable[[Tensor, Tensor], Tensor]:
    """Removes the likelihood return from a transition function."""

    def sample_(obs: Tensor, act: Tensor) -> Tensor:
        return sampler(obs, act)[0].rename(None)

    return sample_

### Training

In [3]:
config = {
    "wandb": {"name": "Debug", "mode": "offline"},
    "learning_rate": 1e-3,
    "weight_decay": 1e-4,
    "seed": 124,
    "env_config": {
        "n_state": 2,
        "n_ctrl": 2,
        "horizon": 50,
        "passive_eigval_range": (0.9, 1.1),
    },
    "model": {"type": "linear"},
    "pred_horizon": 4,
    "zero_q": False,
    "datamodule": {
        "trajectories": 2000,
        "train_batch_size": 128,
        "val_loss_batch_size": 128,
        "val_grad_batch_size": 256,
        "seq_len": 4,
    },
    "trainer": dict(
        max_epochs=50,
        weights_summary="full",
    ),
}
experiment = Experiment(config)

[34m[1mwandb[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.


In [4]:
generator = LQGGenerator(
    stationary=True,
    controllable=True,
    rng=np.random.default_rng(experiment.hparams.seed),
    **experiment.hparams.env_config,
)
lqg, policy, model = make_modules(generator, experiment.hparams)

In [5]:
lqg2, policy2, model2 = make_modules(
    LQGGenerator(
        stationary=True,
        controllable=True,
        rng=np.random.default_rng(experiment.hparams.seed),
        **{**experiment.hparams.env_config, "passive_eigval_range": (0.5, 1.5)},
    ),
    experiment.hparams,
)

In [6]:
lqg.trans.standard_form().F.select("H", 0)

tensor([[ 1.0689,  0.0089,  0.9776,  0.9827],
        [ 0.0089,  1.0637,  0.2107, -0.1850]], names=('R', 'C'))

In [7]:
lqg2.trans.standard_form().F.select("H", 0)

tensor([[ 1.3447,  0.0446,  0.9776,  0.9827],
        [ 0.0446,  1.3186,  0.2107, -0.1850]], names=('R', 'C'))

In [8]:
lqg.reward.standard_form().C.select("H", 0)

tensor([[ 0.4530, -0.2178,  0.0000,  0.0000],
        [-0.2178,  3.5398,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.3662, -0.0420],
        [ 0.0000,  0.0000, -0.0420,  0.7659]], names=('R', 'C'))

In [9]:
lqg2.reward.standard_form().C.select("H", 0)

tensor([[ 0.4530, -0.2178,  0.0000,  0.0000],
        [-0.2178,  3.5398,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.3662, -0.0420],
        [ 0.0000,  0.0000, -0.0420,  0.7659]], names=('R', 'C'))

In [10]:
torch.linalg.eigh(lqg.reward.standard_form().C.select("H", 0).rename(None))

torch.return_types.linalg_eigh(
eigenvalues=tensor([0.4377, 0.7630, 1.3691, 3.5550]),
eigenvectors=tensor([[-0.9975, -0.0000,  0.0000, -0.0700],
        [-0.0700, -0.0000,  0.0000,  0.9975],
        [-0.0000, -0.0694, -0.9976,  0.0000],
        [-0.0000, -0.9976,  0.0694,  0.0000]]))

In [12]:
SAVE_PATH = "state_dict124.pt"

datamodule = DataModule(lqg, policy, DataSpec(**experiment.hparams.datamodule))
trainer = pl.Trainer(
    default_root_dir=experiment.run.dir,
    logger=False,
    callbacks=[pl.callbacks.EarlyStopping("val/loss")],
                            num_sanity_val_steps=0,  # avoid evaluating gradients in the beginning?
    checkpoint_callback=False,  # don't save last model checkpoint
    **experiment.hparams.trainer,
)

trainer.validate(model, datamodule=datamodule)
if os.path.exists(SAVE_PATH):
    model.model.load_state_dict(torch.load(SAVE_PATH))
else:
    trainer.fit(model, datamodule=datamodule)
    torch.save(model.model.state_dict(), SAVE_PATH)
final_eval = trainer.test(model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Validating: 0it [00:00, ?it/s]



   | Name                           | Type                    | Params
----------------------------------------------------------------------------
0  | _lqg                           | LQGModule               | 1.8 K 
1  | _lqg.trans                     | LinearDynamicsModule    | 800   
2  | _lqg.trans.params              | LinearNormalParams      | 800   
3  | _lqg.trans.params.cov_cholesky | CholeskyFactor          | 300   
4  | _lqg.trans.dist                | TVMultivariateNormal    | 0     
5  | _lqg.reward                    | QuadraticReward         | 1.0 K 
6  | _lqg.init                      | InitStateModule         | 8     
7  | _lqg.init.scale_tril           | CholeskyFactor          | 6     
8  | _lqg.init.dist                 | TVMultivariateNormal    | 0     
9  | _policy                        | TVLinearPolicy          | 300   
10 | _policy.encoder                | Identity                | 0     
11 | _policy.action_linear          | TVLinearFeedback        | 300   

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val/4/grad_acc': -0.3436862826347351,
 'val/4/relative_value_err': 0.058996159583330154,
 'val/4/relative_vval_err': 1.0137571096420288,
 'val/empirical_kl': 6.437851428985596,
 'val/loss': 4.444287300109863}
--------------------------------------------------------------------------------
DATALOADER:1 VALIDATE RESULTS
{'val/4/grad_acc': -0.3436862826347351,
 'val/4/relative_value_err': 0.058996159583330154,
 'val/4/relative_vval_err': 1.0137571096420288,
 'val/empirical_kl': 6.437851428985596,
 'val/loss': 4.444287300109863}
--------------------------------------------------------------------------------
Training: -1it [00:00, ?it/s]


  rank_zero_warn(


Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]


  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/4/grad_acc': -0.8137210011482239,
 'test/4/relative_value_err': 0.47244513034820557,
 'test/4/relative_vval_err': 0.011561265215277672,
 'test/empirical_kl': -0.0005883624544367194,
 'test/loss': 2.8502163887023926}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'test/4/grad_acc': -0.8137210011482239,
 'test/4/relative_value_err': 0.47244513034820557,
 'test/4/relative_vval_err': 0.011561265215277672,
 'test/empirical_kl': -0.0005883624544367194,
 'test/loss': 2.8502163887023926}
--------------------------------------------------------------------------------


### Get the validation data

In [13]:
_, val_loader = datamodule.val_dataloader()

In [14]:
(obs,) = next(iter(val_loader))

In [15]:
obs = obs.refine_names("B", "R")

### Get the true value and SVG

In [16]:
val, svg = model.true_val, lqr.Linear(model.true_svg_K, model.true_svg_k)

In [17]:
svg.K[0]

tensor([[ 0.8923, -1.0159],
        [ 0.5689,  1.1825]])

### Evaluate MAAC(K) with the true model

In [18]:
maac_true = estimator.maac_estimator(
    policy,
    data.markovian_state_sampler(lqg.trans, lqg.trans.rsample),
    lqg.reward,
    QuadQValue.from_policy(
        policy.standard_form(), lqg.trans.standard_form(), lqg.reward.standard_form()
    ),
)

In [19]:
mc_val, mc_svg = maac_true(obs, 4)

In [20]:
mc_val

tensor(-169.7705, grad_fn=<MeanBackward0>)

In [21]:
mc_svg.K[0]

tensor([[ 0.0710, -0.1295],
        [-0.0109,  0.1175]])

In [23]:
print(analysis.cosine_similarity(svg, mc_svg))

tensor(0.7966)


### Evaluate MAAC(K) with the learned model

In [24]:
for k in range(8):
    print(analysis.cosine_similarity(svg, model.estimator(obs, k)[1]))

tensor(0.7778)
tensor(-0.5398)
tensor(-0.7367)
tensor(-0.8126)
tensor(-0.8227)
tensor(-0.8463)
tensor(-0.8810)
tensor(-0.8550)


### Compare model predictions

In [25]:
state_action_dynamics = obs_only(
    data.markovian_state_sampler(lqg.trans, lqg.trans.rsample)
)
state_action_model = obs_only(
    data.markovian_state_sampler(model.model, model.model.dist.rsample)
)

In [26]:
def state_dynamics(obs: Tensor) -> Tensor:
    return state_action_dynamics(obs, policy(obs))

def state_model(obs: Tensor) -> Tensor:
    return state_action_model(obs, policy(obs))

In [27]:
def state_mean_dynamics(obs: Tensor) -> Tensor:
    return lqg.trans(obs, policy(obs))["loc"]

def state_mean_model(obs: Tensor) -> Tensor:
    return model.model(obs, policy(obs))["loc"]

In [28]:
nn.MSELoss()(state_mean_dynamics(obs).rename(None), state_mean_model(obs).rename(None))

tensor(8.3321e-05, grad_fn=<MseLossBackward>)

### Compare model Jacobians

In [29]:
true_jac_s, true_jac_a = autograd.functional.jacobian(
    state_action_dynamics, (obs.select("B", 0), policy(obs.select("B", 0)))
)

In [30]:
print("State-action Jacobian:\n", torch.cat((true_jac_s, true_jac_a), dim=-1))
print("Dynamics kernel:\n", lqg.trans.standard_form().F.select("H", 0))

State-action Jacobian:
 tensor([[ 1.0689,  0.0089,  0.0000,  0.9776,  0.9827],
        [ 0.0089,  1.0637,  0.0000,  0.2107, -0.1850],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])
Dynamics kernel:
 tensor([[ 1.0689,  0.0089,  0.9776,  0.9827],
        [ 0.0089,  1.0637,  0.2107, -0.1850]], names=('R', 'C'))


In [31]:
jac_s, jac_a = autograd.functional.jacobian(
    state_action_model,
    (obs.select("B", 0), policy(obs.select("B", 0))),
)

In [32]:
print("State-action Jacobian:\n", torch.cat((jac_s, jac_a), dim=-1))

State-action Jacobian:
 tensor([[-0.1072, -0.0347,  0.0000, -0.0238,  0.0202],
        [ 0.0403,  0.3481,  0.0000, -0.2826,  0.3444],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


**Cosine similarity between rows of the Jacobians**

In [33]:
for row in range(2):
    true_Js, true_Ja, model_Js, model_Ja = (
        t[..., row, :] for t in (true_jac_s, true_jac_a, jac_s, jac_a)
    )
    print(analysis.cosine_similarity((true_Js, true_Ja), (model_Js, model_Ja)))

tensor(-0.5780)
tensor(0.3968)


**Determinant of the on-policy state transition Jacobians**

In [34]:
autograd.functional.jacobian(state_dynamics, obs.select("B", 0))

tensor([[-1.0490e-01, -3.9116e-08,  0.0000e+00],
        [ 6.5193e-09,  7.8521e-01,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00]])

In [35]:
torch.linalg.det(autograd.functional.jacobian(state_dynamics, obs.select("B", 0))[:2, :2])

tensor(-0.0824)

In [36]:
torch.linalg.det(autograd.functional.jacobian(state_model, obs.select("B", 0))[:2, :2])

tensor(-0.0834)

In [37]:
obs.select("B", 0).align_to("B", ...).expand(obs.shape)

tensor([[-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
        [-2.2201,  0.5578,  0.0000],
 