# Defining the LDS

In [1]:
import jax
import jax.numpy as np
import pandas as pd
import numpy as onp
import numpy.random as random
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.linalg import solve_discrete_are as dare
from jax import jit, grad
from tqdm import tqdm

In [2]:
# LDS specification
n, m, A, B = 2, 1, np.array([[1., 1.], [0., 1.]]), np.array([[0.], [1.]])
Q, R = np.eye(N = n), np.eye(N = m)
x0, T = np.zeros((n, 1)), 1000
alg_name = ['No Control', 'LQR/H2Control', 'HinfControl', 'GPC', 'BPC', 'OGRWControl']
color_code = {'No Control': 'orange', 'LQR/H2Control': 'blue', 
              'HinfControl': 'green', 'GPC': 'red', 'BPC': 'purple', 'OGRWControl': 'black'}

quad_cost = lambda x, u: np.sum(x.T @ Q @ x + u.T @ R @ u)

# Func: Evaluate a given policy
def evaluate(controller, W, cost_fn):
    x, loss = x0, [0. for _ in range(T)]
    for t in range(T):
        u = controller.act(x)
        loss[t] = cost_fn(x, u)
        x = A @ x + B @ u + W[t]
    return np.array(loss, dtype=np.float32)



# No Control, LQR, H-inf, GPC

In [3]:
# Run zero control
class ZeroControl:
    def __init__(self):
        pass
    def act(self,x):
        return np.zeros((m, 1))

In [4]:
# Solve H2 Control
class H2Control:
    def __init__(self, A, B, Q, R):
        P = dare(A, B, Q, R)
        self.K = np.linalg.inv(R + B.T @ P @ B) @ (B.T @ P @ A)
    def act(self, x):
        return -self.K @ x

In [None]:
# Solve the non-stationary/finite-horizon version for H2 Control
class H2ControlNonStat:
    def __init__(self, A, B, Q, R, T):
        n, m = B.shape
        P, self.K, self.t = [np.zeros((n,n)) for _ in range(T+1)], [np.zeros((m, n)) for _ in range(T)], 0
        P[T] = Q
        for t in range(T-1, -1, -1):
            P[t] = Q + A.T @ P[t+1] @ A - A.T @ P[t+1] @ B @ np.linalg.inv(R + B.T @ P[t+1] @ B) @ B.T @ P[t+1] @ A
            self.K[t] = np.linalg.inv(R + B.T @ P[t] @ B) @ B.T @ P[t] @ A
    def act(self, x):
        u = -self.K[self.t] @ x
        self.t += 1
        return u

In [None]:
# Solve H2 Control for Random Walk
class ExtendedH2Control:
    def __init__(self, A, B, Q, R, T):
        Aprime = onp.block([[A, np.eye(n)], [np.zeros((n,n)), np.eye(n)]])
        Bprime = onp.block([[B], [np.zeros((n,m))]])
        Qprime = onp.block([[Q, np.zeros((n,n))], [np.zeros((n,n)), np.zeros((n,n))]])
        Rprime = R
        self.A, self.B = A, B
        self.H2 = H2ControlNonStat(Aprime, Bprime, Qprime, Rprime, T)
        self.x, self.u = np.zeros((n,1)), np.zeros((m,1))
    def act(self, x):
        W = x - self.A @ self.x - self.B @ self.u
        self.x = x
        self.u = self.H2.act(onp.block([[x],[W]]))
        return self.u

In [None]:
# Solve Hinf Control
class HinfControl:
    def __init__(self, A, B, Q, R, T, gamma):
        P, self.K, self.W, self.t = [np.zeros((n, n)) for _ in range(T+1)], [np.zeros((m, n)) for _ in range(T)], [np.zeros((n,n)) for _ in range(T)], 0
        P[T] = Q
        for t in range(T-1, -1, -1):
            P[t] = Q + A.T @ np.linalg.inv(np.linalg.inv(P[t+1]) + B @ np.linalg.inv(R) @ B.T - gamma**2 * np.eye(n)) @ A
            Lambda = np.eye(n) + (B @ np.linalg.inv(R) @ B.T - gamma**2 * np.eye(n)) @ P[t+1]
            self.K[t] = np.linalg.inv(R) @ B.T @ P[t+1] @ np.linalg.inv(Lambda) @ A
            self.W[t] = (gamma**2)*P[t+1] @ np.linalg.inv(Lambda) @ A
    def act(self, x):
        u = self.K[self.t] @ x
        self.t += 1
        return u

In [None]:
# GPC definition
class GPC:
    def __init__(self, A, B, Q, R, x0, M, H, lr, cost_fn):
        n, m = B.shape
        self.lr, self.A, self.B, self.M = lr, A, B, M
        self.x, self.u, self.off, self.t = x0, np.zeros((m, 1)), np.zeros((m, 1)), 1
        self.K, self.E, self.W = H2Control(A, B, Q, R).K, np.zeros((M, m, n)), np.zeros((H+M, n, 1))

        def counterfact_loss(E, W):
            y = np.zeros((n, 1))
            for h in range(H-1):
                v = -self.K @ y + np.tensordot(E, W[h : h + M], axes = ([0, 2], [0, 1]))
                y = A @ y + B @ v + W[h + M]
            v = -self.K @ y + np.tensordot(E, W[h : h + M], axes = ([0, 2], [0, 1]))
            cost = cost_fn(y, v)
            return cost

        self.grad = jit(grad(counterfact_loss))

    def act(self, x):
        # 1. Get new noise
        self.W = jax.ops.index_update(self.W, 0, x - self.A @ self.x - self.B @ self. u)
        self.W = np.roll(self.W, -1, axis = 0)

        # 2. Get gradients
        delta_E = self.grad(self.E, self.W)

        # 3. Execute updates
        self.E -= self.lr * delta_E
        #self.off -= self.lr * delta_off

        # 4. Update x & t and get action
        self.x, self.t = x, self.t + 1
        self.u = -self.K @ x + np.tensordot(self.E, self.W[-self.M:], axes = ([0, 2], [0, 1])) #+ self.off
        return self.u

In [None]:
# BPC definition
class BPC:
    def __init__(self, A, B, Q, R, x0, M, H, lr, delta, cost_fn):
        n, m = B.shape
        self.n, self.m = n, m
        self.lr, self.A, self.B, self.M = lr, A, B, M
        self.x, self.u, self.delta, self.t = x0, np.zeros((m, 1)), delta, 0
        self.K, self.E, self.W = H2Control(A, B, Q, R).K, np.zeros((M, m, n)), np.zeros((M, n, 1))
        self.cost_fn = cost_fn
        self.off = np.zeros((m, 1))

        def _generate_uniform(shape, norm=1.00):
            v = random.normal(size=shape)
            v = norm * v / np.linalg.norm(v)
            return v
        self._generate_uniform = _generate_uniform
        self.eps = self._generate_uniform((M, M, m, n))

        self.eps_off = self._generate_uniform((M, m, 1))

    def act(self, x):
        # 1. Get new noise
        self.W = jax.ops.index_update(self.W, 0, x - self.A @ self.x - self.B @ self. u)
        self.W = np.roll(self.W, -1, axis = 0)
        
        # 2. Get gradient estimates
        delta_E = self.cost_fn(self.x, self.u) * np.sum(self.eps, axis = 0)

        # 3. Execute updates
        self.E -= self.lr * delta_E

        # 3. Ensure norm is good
        norm = np.linalg.norm(self.E)
        if norm > (1-self.delta):
            self.E *= (1-self.delta) / norm
            
        # 4. Get new eps (after parameter update (4) or ...?)
        self.eps = jax.ops.index_update(self.eps, 0, self._generate_uniform(
                    shape = (self.M, self.m, self.n), norm = np.sqrt(1 - np.linalg.norm(self.eps[1:])**2)))
        self.eps = np.roll(self.eps, -1, axis = 0)

        # 5. Update x & t and get action
        self.x, self.t = x, self.t + 1
        self.u = -self.K @ x + np.tensordot(self.E + self.delta * self.eps[-1], \
                            self.W[-self.M:], axes = ([0, 2], [0, 1])) 
              
        return self.u

# Plot & repeat utils

In [None]:
def benchmark(M, W, cost_fn = quad_cost, lr = 0.001, delta = 0.001, no_control = False, gamma = None, grw = False):
    global A, B, Q, R, T
    loss_zero = evaluate(ZeroControl(), W, cost_fn) if no_control else onp.full(T, np.nan, dtype=float)
    loss_h2 = evaluate(H2Control(A, B, Q, R), W, cost_fn)
    loss_hinf = evaluate(HinfControl(A, B, Q, R, T, gamma), W, cost_fn) if gamma else onp.full(T, np.nan, dtype=np.float32)
    loss_ogrw = evaluate(ExtendedH2Control(A, B, Q, R, T), W, cost_fn) if grw else onp.full(T, np.nan, dtype=np.float32)

    H, M = 3, M
    loss_gpc = evaluate(GPC(A, B, Q, R, x0, M, H, lr, cost_fn), W, cost_fn)
    loss_bpc = evaluate(BPC(A, B, Q, R, x0, M, H, lr, delta, cost_fn), W, cost_fn)

    return loss_zero, loss_h2, loss_hinf, loss_gpc, loss_bpc, loss_ogrw

In [None]:
cummean = lambda x: np.cumsum(x)/(np.arange(T)+1)

def to_dataframe(alg, loss, avg_loss):
    global T
    return pd.DataFrame(data = {'Algorithm': alg, 'Time': np.arange(T, dtype=np.float32),
                                'Instantaneous Cost': loss, 'Average Cost': avg_loss})

def repeat_benchmark(M, Wgen, rep, cost_fn = quad_cost, lr = 0.001, 
                     delta = 0.001, no_control = False, gamma = None, grw = False):
    all_data = []
    for r in range(rep):
        loss = benchmark(M, Wgen(), cost_fn, lr, delta, no_control, gamma, grw)
        avg_loss = list(map(cummean, loss))
        data = pd.concat(list(map(lambda x: to_dataframe(*x), list(zip(alg_name, loss, avg_loss)))))
        all_data.append(data)
    all_data = pd.concat(all_data)
    return all_data[all_data['Instantaneous Cost'].notnull()]

In [None]:
def plot(title, data, scale = 'linear'):
    fig, axs = plt.subplots(ncols=2, figsize=(15,4))
    axs[0].set_yscale(scale)
    sns.lineplot(x = 'Time', y = 'Instantaneous Cost', hue = 'Algorithm', 
                 data = data, ax = axs[0], ci = 'sd', palette = color_code).set_title(title)
    axs[1].set_yscale(scale)
    sns.lineplot(x = 'Time', y = 'Average Cost', hue = 'Algorithm', 
                 data = data, ax = axs[1], ci = 'sd', palette = color_code).set_title(title)

# Experiments

In [None]:
# Sine perturbations
Wgen = lambda: (np.sin(np.arange(T*m)/(2*np.pi)).reshape(T,m) @ np.ones((m, n))).reshape(T, n, 1)
quad_cost = lambda x, u: np.sum(x.T @ Q @ x + u.T @ R @ u)

# Time steps & Number of seeds/repetitions to test each method on!
T = 1000
rep = 25
for M in [3, 6]:
    for lr in [0.007, 0.003, 0.001]:
        for delta in [0.5, 0.3, 0.1, 0.05, 0.01]:
            print("running M = {}, lr = {}, delta = {}".format(M, lr, delta))
            data = repeat_benchmark(M, Wgen, rep=rep, cost_fn=quad_cost, lr = lr, delta = delta)
            plot('Sinusoidal Perturbations', data)
            specs = str(T) + "_" + str(M) + "_" + str(lr) + "_" + str(delta)
            plt.savefig("sin_quad_" + specs + ".pdf") 

running M = 3, lr = 0.007, delta = 0.5
running M = 3, lr = 0.007, delta = 0.3
running M = 3, lr = 0.007, delta = 0.1
running M = 3, lr = 0.007, delta = 0.05
running M = 3, lr = 0.007, delta = 0.01
running M = 3, lr = 0.003, delta = 0.5


In [None]:
""" # DONE!
# gaussian random walk
def Wgen():
    W = random.normal(size = (T, n, 1), scale = 1/T**(0.5))
    for i in range(1, T):
        W[i] = W[i] + W[i-1]
    return W
    
T = 1000
for M in [3, 6, 10]:
  for lr in [0.007, 0.003, 0.001]:
    for delta in [0.05, 0.03, 0.01, 0.005, 0.001]: # gaussian random walk requires smaller deltas
        data = repeat_benchmark(M, Wgen, lr = lr, delta = delta)
        plot('Gaussian Random Walk Perturbations', data)
        specs = str(T) + "_" + str(M) + "_" + str(lr) + "_" + str(delta)
        plt.savefig("random_walk_quad_" + specs + ".pdf") 
"""

In [None]:
# Defining non-quadratic hinge loss with sine noise
Wgen = lambda: (np.sin(np.arange(T*m)/(2*np.pi)).reshape(T,m) @ np.ones((m, n))).reshape(T, n, 1)
hinge_loss = lambda x, u: np.sum(np.abs(x)) + np.sum(np.abs(u))

T = 1000
rep = 25
for M in [3, 6, 10]:
    for lr in [0.007, 0.003, 0.001]:
        for delta in [0.5, 0.3, 0.1, 0.05, 0.01]:
            data = repeat_benchmark(M, Wgen, rep=rep, cost_fn=hinge_loss, lr = lr, delta = delta)
            plot('Sinusoidal Perturbations - Hinge Loss', data)
            specs = str(T) + "_" + str(M) + "_" + str(lr) + "_" + str(delta)
            plt.savefig("sin_hinge_" + specs + ".pdf") 