# SVG(inf) for LQG

## Introduction
In this notebook we implement a simplified version of SVG($\infty$) for the LQG problem. Our intention is to show how one can learn a parameterized policy in LQG via Stochastic Value Gradients. Note that even if we successfully learn a policy, we do not necessarily know what components of the SVG framework were instrumental in doing so. The focus of the rest of our work will be on analyzing the tenets of this framework based on **gradient estimation quality** and **performance curvature approximation**.

### Imports

In [1]:
from __future__ import annotations
import logging
import itertools
import os.path as osp
from numbers import Number
from typing import Optional

import pytorch_lightning as pl
import ray
import torch
import torch.nn as nn
import wandb
from ray.rllib import RolloutWorker
from ray.rllib import SampleBatch
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from raylab.policy import TorchPolicy
from raylab.policy.modules.actor import DeterministicPolicy
from torch.optim import Optimizer
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data import TensorDataset
from tqdm.auto import trange

import lqsvg.torch.named as nt
from lqsvg.envs import lqr
from lqsvg.envs.lqr.gym import TorchLQGMixin
from lqsvg.experiment.data import TrajectoryData
from lqsvg.experiment.models import MonteCarloSVG
from lqsvg.experiment.models import glorot_init_model
from lqsvg.experiment.policy import make_worker
from lqsvg.experiment.tqdm_util import collect_with_progress
from lqsvg.experiment.utils import calver
from lqsvg.experiment.utils import group_batch_episodes
from lqsvg.experiment.utils import linear_feedback_distance
from lqsvg.experiment.utils import suppress_dataloader_warning
from lqsvg.policy.time_varying_linear import LQGPolicy

---
## Algorithm

In [8]:
@ray.remote
class Experiment:
    def __init__(self, config: dict):
        ver = calver()
        self.run = wandb.init(
            job_type="train",
            config=config,
            project="LQG-SVG",
            entity="angelovtt",
            reinit=True,
            tags=[ver],
            name="SVG(inf)",
            mode="offline",
            allow_val_change=False,
            save_code=True,
        )

        print("CalVer:", ver)
        self.worker = None
        self.policy = None
        self.model = None
        self.artifact = None
        self.collect_metrics = None
        
    @property
    def dir(self) -> str:
        return self.run.dir
    
    @property
    def hparams(self) -> dict:
        return self.run.config
        
    def setup(self):
        with nt.suppress_named_tensor_warning():
            self.worker = make_worker(self.hparams.env_config)

        rllib_policy = self.worker.get_policy()
        self.policy = rllib_policy.module.actor
        self.model = EnvModel(rllib_policy, self.hparams)
        if self.hparams.true_model:
            self.model.model = self.worker.env.module
        
        self.artifact = self.create_artifact()
        self.save_mdp_to_artifact()
        self.collect_metrics = CollectMetrics()
        self.suppress_lightning_info_logging()
        
    @staticmethod
    def suppress_lightning_info_logging():
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/3431
        logging.getLogger('lightning').setLevel(logging.WARNING)
        
    def create_artifact(self) -> wandb.Artifact:
        env = self.worker.env
        return wandb.Artifact(f"svg_inf-lqg{env.n_state}.{env.n_ctrl}.{env.horizon}", type="model")
    
    def save_mdp_to_artifact(self):
        mdp = self.worker.env.module
        path = osp.join(self.dir, "mdp.pt")
        torch.save(mdp.state_dict(), path)
        self.artifact.add_file(path)        

    def execute(self):
        """Train and finish run."""
        self.train()
        self.finish()
    
    def train(self):
        """Main algorithm logic."""
        dataset = self.create_dataset()
        optimizer = self.create_optimizer(self.policy)

        for itr in trange(self.hparams.iterations, desc="SVG(inf)", unit="iteration", disable=True):
            if itr % 10 == 0:
                self.add_ckpt_to_artifact(itr)

            self.collect_trajs(self.worker, dataset)
            self.optimize_model(self.model, dataset)
            self.step_policy(self.policy, self.model, optimizer)

            # Logging
            self.log_iteration(itr)
        
    def create_dataset(self) -> pl.LightningDataModule:
        return DataModule(self.hparams)
    
    def create_optimizer(self, policy: nn.Module) -> Optimizer:
        return torch.optim.Adam(policy.parameters(), lr=self.hparams["policy_lr"])
    
    def collect_trajs(self, worker: RolloutWorker, dataset: pl.LightningDataModule):
        dataset.collect_trajectories(worker, n_trajs=self.hparams["trajs_per_iter"])

    def optimize_model(self, model: pl.LightningModule, dataset: pl.LightningDataModule):
        if self.hparams.true_model:
            return

        hparams = self.hparams

        validation_results = ValidationResults()
        early_stopping = pl.callbacks.EarlyStopping(
            model.early_stop_on,
            min_delta=hparams["improvement_delta"],
            patience=hparams["patience"],
            mode="min",
            strict=True,
        )
        trainer = pl.Trainer(
            default_root_dir=self.dir,
            callbacks=[early_stopping, validation_results], 
            max_epochs=1000, 
            progress_bar_refresh_rate=0,  # don't show progress bar for model training
            weights_summary=None,  # don't print summary before training
            checkpoint_callback=False,  # don't save last model
        )
        with suppress_dataloader_warning():
            trainer.fit(model, datamodule=dataset)

        # Logging
        results = {
            "model_epochs": trainer.current_epoch + 1, 
            "model_nll": validation_results.last_validation_loss().item(),
        }
        self.run.log(results, commit=False)

    def step_policy(self, policy: DeterministicPolicy, model: pl.LightningModule, optimizer: Optimizer):
        svg = MonteCarloSVG(policy, model.model)
        optimizer.zero_grad()
        loss = -svg.value(self.hparams["svg_samples"])
        loss.backward()
        optimizer.step()
        
    def log_iteration(self, itr: int):
        pistar, _, _ = self.worker.env.solution
        logs = {
            "iteration": itr, 
            "distance_to_optimal": linear_feedback_distance(self.policy.standard_form(), pistar), 
            **self.collect_metrics(self.worker)
        }
        self.run.log(logs)
        
    def finish(self):
        self.add_ckpt_to_artifact(self.hparams["iterations"])            
        self.run.log_artifact(self.artifact)
        self.run.finish()

    def add_ckpt_to_artifact(self, itr: int):
        policy: TorchPolicy = self.worker.get_policy()
        path = osp.join(self.dir, f"module-iter={itr}.pt")
        torch.save(policy.module.state_dict(), path)
        self.artifact.add_file(path)

### Logging

In [3]:
class CollectMetrics:
    # Copied from https://github.com/ray-project/ray/blob/c409b5b63a6928e423428b700e528e35d791e8ea/rllib/execution/metric_ops.py#L47
    def __init__(self):
        self.episode_history = []
        self.to_be_collected = []

    def __call__(self, worker: RolloutWorker, min_history: int = 100, timeout_seconds: int = 180) -> dict:
        # Collect worker metrics.
        episodes, self.to_be_collected = collect_episodes(
            worker, to_be_collected=self.to_be_collected, timeout_seconds=timeout_seconds
        )
        orig_episodes = list(episodes)
        missing = min_history - len(episodes)
        if missing > 0:
            episodes.extend(self.episode_history[-missing:])
            assert len(episodes) <= min_history
        self.episode_history.extend(orig_episodes)
        self.episode_history = self.episode_history[-min_history:]
        return {k: v for k, v in summarize_episodes(episodes, orig_episodes).items() if isinstance(v, Number)}

In [4]:
class ValidationResults(pl.callbacks.Callback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.saved_outputs = None

    def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self.saved_outputs: list[Tensor] = []
            
    def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Tensor, *args, **kwargs):
        self.saved_outputs += [outputs]
        
    def last_validation_loss(self) -> Tensor:
        return torch.stack(self.saved_outputs, dim=0).mean(dim=0)

### Modules

In [5]:
class DataModule(pl.LightningDataModule):
    def __init__(self, hparams: dict):
        super().__init__()
        
        self.batch_size: float = hparams["dataset_batch_size"]
        self.train_val_split: tuple[float, float] = hparams["train_val_split"]

        self.full_dataset = None
        self.train_dataset, self.val_dataset = None, None
        self._itr_datasets: list[TensorDataset] = []
        
    def collect_trajectories(self, rollout_worker: RolloutWorker, n_trajs: int):
        worker = rollout_worker
        self._check_rollout_worker(worker)
        
        sample_batch = collect_with_progress(worker, n_trajs, prog=False)
        sample_batch = group_batch_episodes(sample_batch)
        trajs = sample_batch.split_by_episode()
        self._check_collected_trajs(trajs, worker, n_trajs)
        
        traj_dataset = TrajectoryData.trajectory_dataset(trajs)
        self._itr_datasets += [traj_dataset]
        self.full_dataset = ConcatDataset(self._itr_datasets)
        
    def setup(self, stage: Optional[str] = None):
        del stage
        train_frac, _ = self.train_val_split
        train_trajs = int(train_frac * len(self.full_dataset))
        val_trajs = len(self.full_dataset) - train_trajs

        self.train_dataset, self.val_dataset = random_split(self.full_dataset, (train_trajs, val_trajs))

    def train_dataloader(self) -> DataLoader:
        # pylint:disable=arguments-differ
        return DataLoader(self.train_dataset, shuffle=True, batch_size=self.batch_size)

    def val_dataloader(self) -> DataLoader:
        # pylint:disable=arguments-differ
        return DataLoader(self.val_dataset, shuffle=False, batch_size=self.batch_size)

    @staticmethod
    def _check_rollout_worker(worker: RolloutWorker):
        assert worker.rollout_fragment_length == worker.env.horizon * worker.num_envs
        assert worker.batch_mode == "truncate_episodes"
        
    @staticmethod
    def _check_collected_trajs(trajs: list[SampleBatch], worker: RolloutWorker, n_trajs: int):
        traj_counts = [t.count for t in trajs]
        assert all(c == worker.env.horizon for c in traj_counts), traj_counts
        total_ts = sum(t.count for t in trajs)
        assert total_ts == n_trajs * worker.env.horizon, total_ts        

In [6]:
class EnvModel(pl.LightningModule):
    early_stop_on: str = "val/loss"

    def __init__(self, policy: LQGPolicy, hparams: dict):
        super().__init__()
        self.model = policy.module.model

        self.hparams.learning_rate = hparams["model_lr"]
        glorot_init_model(self.model)

    def configure_optimizers(self):
        params = nn.ParameterList(
            itertools.chain(self.model.trans.parameters(), self.model.init.parameters())
        )
        optim = torch.optim.Adam(params, lr=self.hparams.learning_rate)
        return optim

    def forward(self, obs: Tensor, act: Tensor, new_obs: Tensor) -> Tensor:
        """Batched trajectory log prob."""
        # pylint:disable=arguments-differ
        return self.model.log_prob(obs, act, new_obs)

    def _compute_loss_on_batch(
        self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
    ) -> Tensor:
        del batch_idx
        obs, act, new_obs = (x.refine_names("B", "H", "R") for x in batch)
        return -self(obs, act, new_obs).mean()

    def training_step(
        self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
    ) -> Tensor:
        # pylint:disable=arguments-differ
        loss = self._compute_loss_on_batch(batch, batch_idx)
        self.log("train/loss", loss)
        return loss

    def validation_step(
        self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
    ) -> Tensor:
        # pylint:disable=arguments-differ
        loss = self._compute_loss_on_batch(batch, batch_idx)
        self.log("val/loss", loss)
        return loss

---
## Run

In [7]:
ray.init(
    logging_level=logging.WARNING
)

small_exps = [
    Experiment.remote(
        {
            "iterations": 200,
            "trajs_per_iter": 20,
            "policy_lr": 3e-4,
            "improvement_delta": 0.0,
            "patience": 3,
            "svg_samples": 32,
            "dataset_batch_size": 32,
            "train_val_split": (0.8, 0.2),
            "model_lr": 1e-3,
            "env_config": dict(n_state=2, n_ctrl=2, horizon=100, num_envs=20),
            "true_model": False,
        }
    )
    for _ in range(4)
]
for exp in small_exps:
    exp.setup.remote()
    
ray.get([e.execute.remote() for e in small_exps])

2021-03-19 09:18:44,363	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[36m(pid=3001)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=3003)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=3002)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=3000)[0m wandb: Currently logged in as: angelovtt (use `wandb login --relogin` to force relogin)
[2m[36m(pid=3001)[0m wandb: ERROR Error while calling W&B API: Error 1213: Deadlock found when trying to get lock; try restarting transaction (<Response [500]>)
[2m[36m(pid=3003)[0m wandb: Tracking run with wandb version 0.10.22
[2m[36m(pid=3003)[0m wandb: Syncing run SVG(inf)
[2m[36m(pid=3003)[0m wandb: ⭐️ View project at https://wandb.ai/angelovtt/LQG-SVG
[2m[36m(pid=3003)[0m wandb: 🚀 View run at https://wandb.ai/an

[2m[36m(pid=3003)[0m 
[2m[36m(pid=3003)[0m CalVer: 3.19.0
[2m[36m(pid=3000)[0m 
[2m[36m(pid=3000)[0m CalVer: 3.19.0
[2m[36m(pid=3002)[0m 
[2m[36m(pid=3002)[0m CalVer: 3.19.0


[2m[36m(pid=3003)[0m 2021-03-19 09:18:56,492	INFO sampler.py:1295 -- Outputs of compute_actions():
[2m[36m(pid=3003)[0m 
[2m[36m(pid=3003)[0m { 'default_policy': ( np.ndarray((20, 2), dtype=float32, min=-4.32, max=-0.305, mean=-1.847),
[2m[36m(pid=3003)[0m                       [],
[2m[36m(pid=3003)[0m                       {})}
[2m[36m(pid=3003)[0m 
[2m[36m(pid=3002)[0m 2021-03-19 09:18:56,480	INFO rollout_worker.py:1114 -- Built policy map: {'default_policy': LQGPolicy(
[2m[36m(pid=3002)[0m   Box(-inf, inf, (3,), float32),
[2m[36m(pid=3002)[0m   Box(-inf, inf, (2,), float32),
[2m[36m(pid=3002)[0m   {
[2m[36m(pid=3002)[0m     compile: false
[2m[36m(pid=3002)[0m     env_config: {}
[2m[36m(pid=3002)[0m     exploration_config:
[2m[36m(pid=3002)[0m       pure_exploration_steps: 0
[2m[36m(pid=3002)[0m       type: raylab.utils.exploration.GaussianNoise
[2m[36m(pid=3002)[0m     explore: true
[2m[36m(pid=3002)[0m     framework: torch
[2m[36

[2m[36m(pid=3001)[0m 
[2m[36m(pid=3001)[0m CalVer: 3.19.0


[2m[36m(pid=3001)[0m 2021-03-19 09:18:58,469	INFO sample_batch_builder.py:209 -- Trajectory fragment after postprocess_trajectory():
[2m[36m(pid=3001)[0m 
[2m[36m(pid=3001)[0m { 'agent0': { 'data': { 'actions': np.ndarray((100, 2), dtype=float32, min=-6.2, max=7.0, mean=0.079),
[2m[36m(pid=3001)[0m                         'agent_index': np.ndarray((100,), dtype=int64, min=0.0, max=0.0, mean=0.0),
[2m[36m(pid=3001)[0m                         'dones': np.ndarray((100,), dtype=bool, min=0.0, max=1.0, mean=0.01),
[2m[36m(pid=3001)[0m                         'eps_id': np.ndarray((100,), dtype=int64, min=652789120.0, max=652789120.0, mean=652789120.0),
[2m[36m(pid=3001)[0m                         'infos': np.ndarray((100,), dtype=object, head={}),
[2m[36m(pid=3001)[0m                         'new_obs': np.ndarray((100, 3), dtype=float32, min=-7.543, max=100.0, mean=16.809),
[2m[36m(pid=3001)[0m                         'obs': np.ndarray((100, 3), dtype=float32, min=

KeyboardInterrupt: 

In [9]:
big_exps = [
    Experiment.remote(
        {
            "iterations": 200,
            "trajs_per_iter": 20,
            "policy_lr": 3e-4,
            "improvement_delta": 0.0,
            "patience": 3,
            "svg_samples": 32,
            "dataset_batch_size": 32,
            "train_val_split": (0.8, 0.2),
            "model_lr": 1e-3,
            "env_config": dict(n_state=8, n_ctrl=8, horizon=1000, num_envs=20),
            "true_model": False,
        }
    )
    for _ in range(4)
]
for exp in big_exps:
    exp.setup.remote()
    
ray.get([e.execute.remote() for e in big_exps])

[2m[36m(pid=3411)[0m CalVer: 3.19.0
[2m[36m(pid=3410)[0m CalVer: 3.19.0
[2m[36m(pid=3412)[0m CalVer: 3.19.0
[2m[36m(pid=3409)[0m CalVer: 3.19.0


[2m[36m(pid=3410)[0m wandb: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[2m[36m(pid=3411)[0m wandb: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[2m[36m(pid=3409)[0m wandb: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[2m[36m(pid=3412)[0m wandb: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[2m[36m(pid=3411)[0m 2021-03-19 09:37:48,548	INFO rollout_worker.py:1114 -- Built policy map: {'default_policy': LQGPolicy(
[2m[36m(pid=3411)[0m   Box(-inf, inf, (9,), float32),
[2m[36m(pid=3411)[0m   Box(-inf, inf, (8,), float32),
[2m[36m(pid=3411)[0m   {
[2m[36m(pid=3411)[0m     compile: false
[2m[36m(pid=3411)[0m     env_config: {}
[2m[36m(pid=3411)[0m     explorati

[None, None, None, None]

In [10]:
ray.shutdown()