# Ground-truth performance and gradient norms

This notebook inspects the initialization schemes for LQGs and time-varying linear policies described in section 5.1 of the [paper](https://www.overleaf.com/read/cmbgmxxpxqzr).

### Checklist

- [x] Fix `n_state`, `n_ctrl`, `horizon`
- [x] Sample random LQGs
- [ ] Sample random policies
- [ ] Evaluate the expected return
- [ ] Evaluate the value gradient norm
- [ ] Search numpy, scipy for methods for visualizing the distributions

### Imports

In [1]:
import torch
from typing import Tuple
from lqsvg.envs import lqr
from lqsvg.envs.lqr.gym import LQGSpec

In [2]:
n_state = 2
n_ctrl = 2
horizon = 1000

In [3]:
def sample_lqg(n_state: int, n_ctrl: int, horizon: int) -> Tuple[lqr.LinSDynamics, lqr.QuadCost, lqr.GaussInit]:
    spec = LQGSpec(
        n_state=n_state,
        n_ctrl=n_ctrl,
        horizon=horizon,
        stationary=False,
        gen_seed=None,
        num_envs=1,  # No effect
    )
    return spec.make_lqg()

In [4]:
def test_sample_lqg():
    n_state, n_ctrl, n_horizon = 2, 2, 1000
    dynamics, cost, init = sample_lqg(n_state, n_ctrl, n_horizon)
    assert isinstance(dynamics, lqr.LinSDynamics)
    assert isinstance(cost, lqr.QuadCost)
    assert isinstance(init, lqr.GaussInit)

In [5]:
test_sample_lqg()

  return super(Tensor, self).refine_names(names)


In [6]:
def sample_policy(dynamics: lqr.LinSDynamics, cost: lqr.QuadCost) -> lqr.Linear:
    n_state = dynamics.F.size("R")
    n_ctrl = dynamics.F.size("C") - n_state
    horizon = dynamics.F.size("H")
    solver = lqr.NamedLQGControl(n_state, n_ctrl, horizon)
    pistar, _, _ = solver(dynamics, cost)
    
    K, k = (g + torch.randn_like(g)*0.5 for g in pistar)
    return (K, k)

In [7]:
def test_sample_policy():
    n_state, n_ctrl, n_horizon = 2, 2, 1000
    dynamics, cost, _ = sample_lqg(n_state, n_ctrl, n_horizon)
    K, k = sample_policy(dynamics, cost)
    print(f"""
        K: {K.shape}; {K.names}
        k: {k.shape}; {k.names}
    """)

In [8]:
test_sample_policy()


        K: torch.Size([1000, 2, 2]); ('H', 'R', 'C')
        k: torch.Size([1000, 2]); ('H', 'R')
    
