# Value gradient error for linear policies in LQG

Experiment description on [Overleaf](https://www.overleaf.com/read/cmbgmxxpxqzr).

**Versioning:** [CalVer](https://calver.org) `MM.DD.MICRO`

In [1]:
from __future__ import annotations

import logging
import os.path as osp

import lqsvg
import lqsvg.envs.lqr.utils as lqg_util
import lqsvg.experiment.utils as utils
import lqsvg.torch.named as nt
import pytorch_lightning as pl
import ray
from lqsvg.experiment.data import build_datamodule
from lqsvg.experiment.models import LightningModel
from lqsvg.experiment.worker import make_worker
from ray import tune
from raylab.policy.model_based.lightning import LightningTrainerSpec
from torch import Tensor

import wandb

In [2]:
class InputStatistics(pl.callbacks.Callback):
    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Tensor,
        batch: tuple[Tensor, Tensor, Tensor],
        batch_idx: int,
        dataloader_idx: int,
    ):
        del trainer, outputs, batch_idx, dataloader_idx
        obs, act, new_obs = batch
        pl_module.log("train/obs-mean", obs.mean())
        pl_module.log("train/obs-std", obs.std())
        pl_module.log("train/act-mean", act.mean())
        pl_module.log("train/act-std", act.std())
        pl_module.log("train/new_obs-mean", new_obs.mean())
        pl_module.log("train/new_obs-std", new_obs.std())

In [3]:
class Experiment(tune.Trainable):
    def setup(self, config: dict):
        self.run = wandb.init(
            name="SVG Prediction",
            config=config,
            project="LQG-SVG",
            entity="angelovtt",
            tags=[utils.calver()],
            reinit=True,
            mode="online",
            save_code=True,
        )

        self.make_worker()
        self.make_model()
        self.make_datamodule()
        self.make_lightning_trainer()
        self.make_artifact()
        utils.suppress_lightning_info_logging()

    @property
    def hparams(self):
        return self.run.config

    def make_worker(self):
        with nt.suppress_named_tensor_warning():
            self.worker = make_worker(
                env_config=self.hparams.env_config, log_level=logging.WARNING
            )

    def make_model(self):
        self.model = LightningModel(self.worker.get_policy(), self.worker.env)
        self.model.hparams.learning_rate = self.hparams.learning_rate
        self.model.hparams.mc_samples = self.hparams.mc_samples

    def make_datamodule(self):
        self.datamodule = build_datamodule(
            self.worker, total_trajs=self.hparams.total_trajs
        )
        self.datamodule.collect_trajectories(prog=False)

    def make_lightning_trainer(self):
        logger = pl.loggers.WandbLogger(
            save_dir=self.run.dir, log_model=False, experiment=self.run
        )

        early_stopping = pl.callbacks.EarlyStopping(
            monitor=LightningModel.early_stop_on,
            min_delta=float(self.hparams.improvement_delta),
            patience=int(self.hparams.patience),
            mode="min",
            strict=True,
        )
        checkpointing = pl.callbacks.ModelCheckpoint(
            dirpath=osp.join(self.run.dir, "checkpoints"),
            monitor=LightningModel.early_stop_on,
            save_top_k=-1,
            period=10,
            save_last=True,
        )
        self.trainer = pl.Trainer(
            default_root_dir=self.run.dir,
            logger=logger,
            num_sanity_val_steps=2,
            callbacks=[early_stopping, checkpointing, InputStatistics()],
            max_epochs=self.hparams.max_epochs,
            progress_bar_refresh_rate=0,  # don't show progress bar for model training
            weights_summary=None,  # don't print summary before training
        )

    def make_artifact(self):
        env = self.worker.env
        self.artifact = wandb.Artifact(
            f"svg_prediction-lqg{env.n_state}.{env.n_ctrl}.{env.horizon}", type="model"
        )

    def step(self) -> dict:
        self.log_env_info()
        with utils.suppress_dataloader_warning():
            self.trainer.fit(self.model, datamodule=self.datamodule)

            results = self.trainer.test(self.model, datamodule=self.datamodule)[0]
            self.run.summary.update(results)

        self.artifact.add_dir(self.trainer.checkpoint_callback.dirpath)
        self.run.log_artifact(self.artifact)
        return {tune.result.DONE: True, **results}

    def log_env_info(self):
        dynamics = self.worker.env.dynamics
        eigvals = lqg_util.stationary_eigvals(dynamics)
        tests = {
            "stability": lqg_util.isstable(eigvals=eigvals),
            "controllability": lqg_util.iscontrollable(dynamics),
        }
        self.run.summary.update(tests)
        self.run.summary.update({"Fs_eigvals": wandb.Histogram(eigvals)})

    def cleanup(self):
        self.run.finish()

In [4]:
ray.init(logging_level=logging.WARNING)
lqsvg.register_all()
utils.suppress_lightning_info_logging()

config = {
    "env_config": dict(
        n_state=2,
        n_ctrl=2,
        horizon=100,
        stationary=True,
        num_envs=100,
    ),
    "learning_rate": 1e-3,
    "mc_samples": 32,
    "total_trajs": 1000,
    "improvement_delta": 0.0,
    "patience": 3,
    "max_epochs": 200,
}

analysis = tune.run(Experiment, config=config, num_samples=2)

Trial name,status,loc
Experiment_14f74_00000,RUNNING,


[2m[36m(pid=9412)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=9410)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=9412)[0m wandb: Tracking run with wandb version 0.10.23
[2m[36m(pid=9412)[0m wandb: Syncing run SVG Prediction
[2m[36m(pid=9412)[0m wandb: ⭐️ View project at https://wandb.ai/angelovtt/LQG-SVG
[2m[36m(pid=9412)[0m wandb: 🚀 View run at https://wandb.ai/angelovtt/LQG-SVG/runs/3759xbb7
[2m[36m(pid=9412)[0m wandb: Run data is saved locally in /Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00001_1_2021-03-24_12-24-53/wandb/run-20210324_122501-3759xbb7
[2m[36m(pid=9412)[0m wandb: Run `wandb offline` to turn off syncing.
[2m[36m(pid=9410)[0m wandb: Tracking run with wandb version 0.10.23
[2m[36m(pid=9410)[0m wandb: Syncing run SVG Prediction
[2m[36m(pid=9410)[0m wandb: ⭐️ View project at https://wandb

[2m[36m(pid=9410)[0m 
[2m[36m(pid=9412)[0m 


[2m[36m(pid=9412)[0m GPU available: False, used: False
[2m[36m(pid=9412)[0m TPU available: None, using: 0 TPU cores
[2m[36m(pid=9412)[0m 2021-03-24 12:25:15,794	INFO trainable.py:100 -- Trainable.setup took 15.514 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
[2m[36m(pid=9410)[0m GPU available: False, used: False
[2m[36m(pid=9410)[0m TPU available: None, using: 0 TPU cores
[2m[36m(pid=9410)[0m 2021-03-24 12:25:15,824	INFO trainable.py:100 -- Trainable.setup took 15.543 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


[2m[36m(pid=9412)[0m --------------------------------------------------------------------------------
[2m[36m(pid=9412)[0m DATALOADER:0 TEST RESULTS
[2m[36m(pid=9412)[0m {'test/analytic_cossim': tensor(0.6994),
[2m[36m(pid=9412)[0m  'test/analytic_diff': tensor(134.5474),
[2m[36m(pid=9412)[0m  'test/analytic_svg_norm': tensor(963.7524),
[2m[36m(pid=9412)[0m  'test/analytic_value': tensor(-2066.0317),
[2m[36m(pid=9412)[0m  'test/loss': tensor(331.8568),
[2m[36m(pid=9412)[0m  'test/monte_carlo_cossim': tensor(0.6940),
[2m[36m(pid=9412)[0m  'test/monte_carlo_diff': tensor(70.6790),
[2m[36m(pid=9412)[0m  'test/monte_carlo_svg_norm': tensor(993.4412),
[2m[36m(pid=9412)[0m  'test/monte_carlo_value': tensor(-2129.9001),
[2m[36m(pid=9412)[0m  'true_svg_norm': tensor(319.8508),
[2m[36m(pid=9412)[0m  'true_value': tensor(-2200.5791)
[2m[36m(pid=9412)[0m }
[2m[36m(pid=9412)[0m ----------------------------------------------------------------------------

[2m[36m(pid=9412)[0m wandb: Adding directory to artifact (/Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00001_1_2021-03-24_12-24-53/wandb/run-20210324_122501-3759xbb7/files/checkpoints)... 
[2m[36m(pid=9412)[0m Done. 0.1s


[2m[36m(pid=9410)[0m --------------------------------------------------------------------------------
[2m[36m(pid=9410)[0m DATALOADER:0 TEST RESULTS
[2m[36m(pid=9410)[0m {'test/analytic_cossim': tensor(0.5825),
[2m[36m(pid=9410)[0m  'test/analytic_diff': tensor(-8.2647),
[2m[36m(pid=9410)[0m  'test/analytic_svg_norm': tensor(90.5345),
[2m[36m(pid=9410)[0m  'test/analytic_value': tensor(-231.6026),
[2m[36m(pid=9410)[0m  'test/loss': tensor(247.1650),
[2m[36m(pid=9410)[0m  'test/monte_carlo_cossim': tensor(0.5997),
[2m[36m(pid=9410)[0m  'test/monte_carlo_diff': tensor(-11.9832),
[2m[36m(pid=9410)[0m  'test/monte_carlo_svg_norm': tensor(99.0906),
[2m[36m(pid=9410)[0m  'test/monte_carlo_value': tensor(-235.3211),
[2m[36m(pid=9410)[0m  'true_svg_norm': tensor(249.2109),
[2m[36m(pid=9410)[0m  'true_value': tensor(-223.3379)}
[2m[36m(pid=9410)[0m --------------------------------------------------------------------------------
Result for Experiment_14f

[2m[36m(pid=9410)[0m wandb: Adding directory to artifact (/Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00000_0_2021-03-24_12-24-53/wandb/run-20210324_122501-2k5lzip2/files/checkpoints)... 
[2m[36m(pid=9410)[0m Done. 0.1s


Trial name,status,loc,iter,total time (s),test/loss,true_value,true_svg_norm
Experiment_14f74_00000,RUNNING,,,,,,
Experiment_14f74_00001,TERMINATED,,1.0,146.597,,-2200.58,319.851


[2m[36m(pid=9412)[0m wandb: Waiting for W&B process to finish, PID 9439
[2m[36m(pid=9412)[0m wandb: Program ended successfully.


Result for Experiment_14f74_00000:
  date: 2021-03-24_12-27-42
  done: true
  experiment_id: 86efad1a82214cc6a655e7d5f1eb7d66
  hostname: Angelos-MBP
  iterations_since_restore: 1
  node_ip: 192.168.15.8
  pid: 9410
  test/analytic_cossim: 0.5825487375259399
  test/analytic_diff: -8.264694213867188
  test/analytic_svg_norm: 90.5345458984375
  test/analytic_value: -231.6025848388672
  test/loss: 247.1649932861328
  test/monte_carlo_cossim: 0.5996895432472229
  test/monte_carlo_diff: -11.983184814453125
  test/monte_carlo_svg_norm: 99.0905990600586
  test/monte_carlo_value: -235.32107543945312
  time_since_restore: 146.70653319358826
  time_this_iter_s: 146.70653319358826
  time_total_s: 146.70653319358826
  timestamp: 1616599662
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 14f74_00000
  true_svg_norm: 249.21087646484375
  true_value: -223.337890625
  


[2m[36m(pid=9410)[0m wandb: Waiting for W&B process to finish, PID 9441
[2m[36m(pid=9410)[0m wandb: Program ended successfully.


Trial name,status,loc,iter,total time (s),test/loss,true_value,true_svg_norm
Experiment_14f74_00000,TERMINATED,,1,146.707,,-223.338,249.211
Experiment_14f74_00001,TERMINATED,,1,146.597,,-2200.58,319.851


[2m[36m(pid=9412)[0m wandb: ERROR Error while calling W&B API: Error 1062: Duplicate entry '140944-11' for key 'unique_artifact_collection_membership_version' (<Response [409]>)
[2m[36m(pid=9410)[0m wandb: - 0.94MB of 0.94MB uploaded (0.00MB deduped)
[2m[36m(pid=9412)[0m wandb: - 0.94MB of 0.94MB uploaded (0.00MB deduped)
wandb:                                                                                
[2m[36m(pid=9410)[0m wandb: Find user logs for this run at: /Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00000_0_2021-03-24_12-24-53/wandb/run-20210324_122501-2k5lzip2/logs/debug.log
[2m[36m(pid=9410)[0m wandb: Find internal logs for this run at: /Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00000_0_2021-03-24_12-24-53/wandb/run-20210324_122501-2k5lzip2/logs/debug-internal.log
[2m[36m(pid=9410)[0m wandb: Run summary:
[2m[36m(pid=9410)[0m wandb:                   stability True
[2m[36m(pid=9

[2m[36m(pid=9410)[0m 


wandb:                                                                                
[2m[36m(pid=9412)[0m wandb: Find user logs for this run at: /Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00001_1_2021-03-24_12-24-53/wandb/run-20210324_122501-3759xbb7/logs/debug.log
[2m[36m(pid=9412)[0m wandb: Find internal logs for this run at: /Users/angelolovatto/ray_results/Experiment_2021-03-24_12-24-53/Experiment_14f74_00001_1_2021-03-24_12-24-53/wandb/run-20210324_122501-3759xbb7/logs/debug-internal.log
[2m[36m(pid=9412)[0m wandb: Run summary:
[2m[36m(pid=9412)[0m wandb:                   stability True
[2m[36m(pid=9412)[0m wandb:             controllability True
[2m[36m(pid=9412)[0m wandb:                    val/loss 333.80878
[2m[36m(pid=9412)[0m wandb:       val/monte_carlo_value -2155.76733
[2m[36m(pid=9412)[0m wandb:    val/monte_carlo_svg_norm 1024.82434
[2m[36m(pid=9412)[0m wandb:          val/analytic_value -2066.03174
[2

In [5]:
ray.shutdown()