In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import numpy as np
import scipy.linalg as spl
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch

import lstd
import test_utils
import trajgen
import trajdata
import icnn
import valuefunc

## TrajDataset class

In [8]:
p = 5
q = 3
Q, R = np.eye(p), np.eye(q)
sigma = 0.1
gamma = 0.99
A, B = test_utils.random_env(p, q, Anorm=1.05)
Pstar = spl.solve_discrete_are(A, B, Q, R)
Kstar = -np.linalg.pinv(B.T @ Pstar @ B + R) @ (B.T @ Pstar @ A)

In [9]:
num_trajs = 3000
T = 10
ctrl = test_utils.linear_feedback_controller(Kstar)
xtraj, utraj, rtraj, xtraj_ = test_utils.sample_multiple_traj(
    A, B, Q, R, ctrl, T, x0=None, sigma=0, num_traj=num_trajs)
dataset = trajdata.TrajDataset(xtraj, utraj, rtraj, xtraj_)

In [44]:
icnnvalue = valuefunc.ICNNValueFunc(p, [50, 50, 1])
icnnvalue.learn(dataset, gamma, num_epoch=400, batch_size=64, verbose=True, print_interval=50)

Epoch: 1 	 Training loss: 16733.836750825434
Epoch: 51 	 Training loss: 0.06111288143204297
Epoch: 101 	 Training loss: 0.09405480843728838
Epoch: 151 	 Training loss: 0.008599661122318776
Epoch: 201 	 Training loss: 0.028353810501377977
Epoch: 251 	 Training loss: 0.008675627033403204
Epoch: 301 	 Training loss: 0.0037263999183238146
Epoch: 351 	 Training loss: 0.0065733710315347285


## Vectorize LSTD policy evaluation

In [297]:
p = 2
q = 1
sigma = 0
A, B = test_utils.random_env(p, q, Anorm=0.99)
Q, R = 100 * np.eye(p), np.eye(q)
Tref = 20
# Construct augmented system for tracking
At, Bt, Qt, Rt = lstd.nominal_to_tracking(A, B, Q, R, Tref)
# Solve for a static tracking controller by directly solving DARE
Pstar = spl.solve_discrete_are(At, Bt, Qt, Rt)
Kstar = -np.linalg.pinv(Bt.T @ Pstar @ Bt + Rt) @ (Bt.T @ Pstar @ At)
static_opt_ctrl = test_utils.linear_feedback_controller(Kstar)

In [303]:
num_trajs = 3000
T = 5
ctrl = test_utils.random_controller(q)
xtraj, utraj, rtraj, xtraj_ = test_utils.sample_multiple_traj(
    At, Bt, Qt, Rt, ctrl, T, x0=None, sigma=0, num_traj=num_trajs)

In [304]:
Pxu, Px = lstd.evaluate(xtraj, utraj, rtraj, xtraj_, Kstar, gamma, sigma=0)

In [305]:
test_utils.relerr(Pstar, Px)

0.07014251404796701