# Biased SVG estimation

In [None]:
from __future__ import annotations

import textwrap
from typing import Callable

import lqsvg.torch.named as nt
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ray
import seaborn as sns
import torch
import torch.nn as nn
from lqsvg.envs import lqr
from lqsvg.envs.lqr.generators import LQGGenerator
from lqsvg.envs.lqr.modules import (
    InitStateDynamics,
    LinearDynamicsModule,
    LQGModule,
    QuadraticReward,
)
from lqsvg.experiment.estimators import AnalyticSVG, BootstrappedSVG, MonteCarloSVG
from lqsvg.experiment.plot import default_figsize, plot_surface
from lqsvg.np_util import RNG
from lqsvg.policy.modules import TVLinearPolicy
from tqdm.auto import tqdm, trange

In [5]:
sns.set()
ray.init()

In [2]:
def make_generator(seed: int) -> LQGGenerator:
    return LQGGenerator(
        n_state=2,
        n_ctrl=2,
        horizon=20,
        stationary=True,
        passive_eigval_range=(0.5, 1.5),
        controllable=True,
        seed=seed,
    )

In [3]:
def make_lqg_module(generator: LQGGenerator) -> LQGModule:
    dynamics, cost, init = generator()
    return LQGModule.from_existing(dynamics, cost, init)

In [6]:
def make_stabilizing_policy(module: LQGModule, rng: RNG) -> TVLinearPolicy:
    policy = TVLinearPolicy(module.n_state, module.n_ctrl, module.horizon)
    policy.stabilize_(module.trans.standard_form(), rng=rng)
    return policy

In [18]:
def grad_estimates(
    estimator,
    sample_sizes: list[int],
    estimates_per_sample_size: int = 10,
    pbar: bool = False,
) -> list[list[lqr.Linear]]:
    progress = tqdm(
        sample_sizes, desc="Computing SVG by sample size", leave=False, disable=not pbar
    )
    svgs_by_sample_size = []
    for size in progress:
        svgs = [estimator(samples=size)[1] for _ in range(estimates_per_sample_size)]
        svgs_by_sample_size += [svgs]
    return svgs_by_sample_size

## Unbiased estimator (DPG theorem)

In [25]:
class DPGEstimator(nn.Module):
    def __init__(self, policy: TVLinearPolicy, model: LQGModule):
        super().__init__()
        self.policy = policy
        self.model = model
        self.qvalue = self.make_qvalue(policy, model)
        self.estimator = BootstrappedSVG(policy, model.trans, model.reward, self.qvalue)
        self.state_dataset = None
        self.n_steps: int = 0

    def sample_starting_obs(self, samples: int) -> Tensor:
        idxs = torch.randint(low=0, high=self.state_dataset.size("B"), size=(samples,))
        return nt.index_select(self.state_dataset, dim="B", index=idxs)

    def surrogate(self, samples: int = 1) -> Tensor:
        obs = self.sample_starting_obs(samples)
        return self.estimator.surrogate(obs, n_steps=self.n_steps)

    def forward(self, samples: int = 1) -> tuple[Tensor, lqr.Linear]:
        obs = self.sample_starting_obs(samples)
        return self.estimator(obs, n_steps=self.n_steps)

    @staticmethod
    def make_qvalue(policy: TVLinearPolicy, model: LQGModule) -> QuadQValue:
        return QuadQValue.from_policy(
            policy.standard_form(),
            model.trans.standard_form(),
            model.reward.standard_form(),
        )

In [32]:
class EstimatorComparison:
    def __init__(self, seed: int, total_states: int = 1000):
        self.generator = make_generator(seed)
        dynamics, cost, init = self.generator()
        self.model = LQGModule.from_existing(dynamics, cost, init)
        self.policy = TVLinearPolicy(
            self.model.n_state, self.model.n_ctrl, self.model.horizon
        )
        self.policy.stabilize_(dynamics, rng=seed)
        self.qvalue = self.make_qvalue(policy, model)

        self.state_dataset = self.starting_states(self.policy, self.model, total_states)
        self.estimator = DPGEstimator(self.policy, self.model)
        self.estimator.setup(total_states)

        self.estimator = BootstrappedSVG(policy, model.trans, model.reward, self.qvalue)
        self.state_dataset = None
        self.n_steps: int = 0

    def setup(self, total_states: int = 1000):

    @staticmethod
    @torch.no_grad()
    def starting_states(policy: TVLinearPolicy, model: LQGModule, num: int) -> Tensor:
        rollout_module = MonteCarloSVG(policy, model)
        n_trajs = num // model.horizon
        obs, _, _, _, _ = rollout_module.rsample_trajectory(torch.Size([n_trajs]))
        obs = obs.flatten(["H", "B1"], "B")
        return obs
    
    def plot_optim_surface(self, samples: int, n_step: int):
        self.estimator.n_steps = n_step
        _, svg = self.estimator(samples)
        direction = tutil.tensors_to_vector(svg).numpy()

        plt.figure(figsize=default_figsize(2, 2))
        X, Y, Z = analysis.optimization_surface(
            self.delta_to_return(),
            direction=direction,
            max_scaling=3.0,
            steps=20,
            rng=self.generator._rng,
        )
        ax = plot_surface(X, Y, Z, invert_xaxis=True)
        ax.set_xlabel("Random direction")
        ax.set_ylabel("SVG direction")
        ax.set_zlabel("Policy return")
        plt.show()

    def delta_to_return(self) -> Callable[[np.ndarray], np.ndarray]:
        policy = self.policy.standard_form()
        dynamics, cost, init = self.model.standard_form()
        return analysis.delta_to_return(policy, dynamics, cost, init)

## Biased estimator (MAAC)