# Value gradient error for linear policies in LQG

**Objective:** compare ground-truth vs. estimated gradients for linear policies to see if it's at all possible to learn a model that predicts the SVG with no concern for sample and computational efficiency.

**Procedure:** for several LQG dimensions, we collect $X$ trajectories from the sampled policy. These trajectories are then used to train a model by optimizing the model learning objective: 
1. we split the data into training, validation and test sets;
2. we choose a batch size and loss improvement threshold;
3. we train the model by stochastic optimization of the model parameters on the learning objective;
4. we stop learning when the validation loss stops improving (we use the PyTorch Lightning framework for this).

**Extensions:** interpolate between optimal and random policies while logging the SVG error.

**Versioning:** [CalVer](https://calver.org) `MM.DD.MICRO`

In [1]:
from __future__ import annotations
from datetime import date

import lqsvg.torch.named as nt
import pytorch_lightning as pl
import torch
from raylab.policy.model_based.lightning import LightningTrainerSpec
from torch import Tensor

from data import build_datamodule
from models import LightningModel
from policy import make_worker
from utils import suppress_dataloader_warning

In [2]:
def make_lightning_trainer(logger, spec: LightningTrainerSpec) -> pl.Trainer:
    early_stopping = pl.callbacks.EarlyStopping(
        monitor=LightningModel.early_stop_on,
        min_delta=spec.improvement_delta,
        patience=spec.patience,
        mode="min",
        strict=False,
    )
    trainer = pl.Trainer(
        logger=logger,
        num_sanity_val_steps=2,
        checkpoint_callback=False,
        callbacks=[early_stopping],
        max_epochs=spec.max_epochs,
        max_steps=spec.max_steps,
        progress_bar_refresh_rate=0,
    )
    return trainer

In [3]:
def calver() -> str:
    today = date.today()
    return f"{today.month}.{today.day}.0"

print("CalVer:", calver())

CalVer: 3.9.0


In [4]:
with nt.suppress_named_tensor_warning():
    env_config = dict(n_state=2, n_ctrl=2, horizon=100, num_envs=100)
    worker = make_worker(env_config)
    model = LightningModel(worker.get_policy(), worker.env)
    model.hparams.learning_rate = 1e-4
    datamodule = build_datamodule(worker, total_trajs=5000)

    logger = pl.loggers.WandbLogger(
        name="SVG Prediction",
        offline=False,
        project="LQG-SVG",
        log_model=False,
        entity="angelovtt",
        tags=[calver()],
    )
    spec = LightningTrainerSpec(max_epochs=1000, patience=3, improvement_delta=0.0)
    trainer = make_lightning_trainer(logger, spec)

    with suppress_dataloader_warning():
        trainer.fit(model, datamodule=datamodule)

Collecting:  98%|█████████▊| 4900/5000 [00:43<00:00, 113.65traj/s]
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
[34m[1mwandb[0m: Currently logged in as: [33mangelovtt[0m (use `wandb login --relogin` to force relogin)



  | Name        | Type           | Params
-----------------------------------------------
0 | actor       | TVLinearPolicy | 600   
1 | model       | LQGModule      | 3.6 K 
2 | mdp         | LQGModule      | 3.6 K 
3 | policy_loss | PolicyLoss     | 0     
-----------------------------------------------
7.8 K     Trainable params
0         Non-trainable params
7.8 K     Total params
