# Test loading up checkpoint

### Imports

In [1]:
from __future__ import annotations

import os.path as osp
import torch
import torch.nn as nn
import wandb
from tqdm.auto import trange

import lqsvg.experiment.utils as utils
from lqsvg.envs import lqr
from lqsvg.experiment.policy import make_worker
from lqsvg.experiment.models import MonteCarloSVG
from lqsvg.experiment.models import AnalyticSVG

In [2]:
CONFIG = dict(n_state=8, n_ctrl=8, horizon=1000, svg_samples=10, monte_carlo_svg=32, iterations=200)

run = wandb.init(
    job_type="eval",
    config=CONFIG,
    project="LQG-SVG",
    entity="angelovtt",
    reinit=False,
    tags=["eval", utils.calver()],
    name="Eval Grads",
    mode="online",
    allow_val_change=False,
    save_code=True,
)
hparams = run.config
artifact = run.use_artifact(f"angelovtt/LQG-SVG/svg_inf-lqg{hparams.n_state}.{hparams.n_ctrl}.{hparams.horizon}:latest", type="model")
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mangelovtt[0m (use `wandb login --relogin` to force relogin)


In [3]:
print("Artifact dir:", artifact_dir)
worker = make_worker(dict(n_state=hparams.n_state, n_ctrl=hparams.n_ctrl, horizon=hparams.horizon, num_envs=1))
key_mismatch = worker.env.module.load_state_dict(torch.load(osp.join(artifact_dir, "mdp.pt")))
assert not any(key_mismatch), key_mismatch

Artifact dir: ./artifacts/svg_inf-lqg8.8.1000:v6


  return super(Tensor, self).refine_names(names)


In [4]:
mdp = worker.env.module
module = worker.get_policy().module

In [5]:
def restore_iteration(module: nn.Module, iteration: int):
    mismatch = module.load_state_dict(torch.load(osp.join(artifact_dir, f"module-iter={iteration}.pt")))
    assert not any(mismatch), mismatch

In [6]:
def generate_svg_samples() -> list[lqr.Linear]:
    svg_estimator = MonteCarloSVG(module.actor, module.model)
    return [svg_estimator(hparams.monte_carlo_svg)[1] for _ in range(hparams.svg_samples)]

In [7]:
def log_gradient_error(svg_samples: list[lqr.Linear]):
    solver = AnalyticSVG(module.actor, mdp)
    true_value, true_svg = solver()

    cossims = [utils.linear_feedback_cossim(g, true_svg) for g in svg_samples]
    run.log({"cossim with true grad": torch.stack(cossims).mean().item()}, commit=False)

In [8]:
def log_empirical_variance(svg_samples: list[lqr.Linear]):
    cossims = []
    for i, gi in enumerate(svg_samples):
        for gj in svg_samples[i+1:]:
            cossims += [utils.linear_feedback_cossim(gi, gj)]
            
    run.log({"avg. pairwise cossim": torch.stack(cossims).mean().item()}, commit=False)

In [9]:
def log_iteration(iteration: int):
    restore_iteration(module, iteration)
    svg_samples = generate_svg_samples()
    log_gradient_error(svg_samples)
    log_empirical_variance(svg_samples)
    run.log({"iteration": iteration})

In [10]:
for iteration in trange(0, hparams.iterations + 1, 10, desc="Evaluating", unit="ckpt"):
    log_iteration(iteration)
    
run.finish()

Evaluating:   0%|          | 0/21 [00:00<?, ?ckpt/s]

VBox(children=(Label(value=' 0.32MB of 0.32MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
cossim with true grad,-0.05521
avg. pairwise cossim,0.59789
iteration,200.0
_runtime,987.0
_timestamp,1616166872.0
_step,20.0


0,1
cossim with true grad,▂█▇▆▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
avg. pairwise cossim,▇█▇▇▆▅▅▅▄▄▄▃▃▃▃▂▂▁▁▁▁
iteration,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
_runtime,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
_timestamp,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
_step,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
