In [1]:
import tigercontrol
import jax
import jax.numpy as np
from tigercontrol.controllers import Controller



In [2]:
# Quadratic Loss
def quad_loss(x, u, Q = None, R = None):
	x_contrib = x.T @ x if Q is None else x.T @ Q @ x
	u_contrib = u.T @ u if R is None else u.T @ R @ u
	
	return np.sum(x_contrib + u_contrib)

# Policy Loss
def policy_loss(params, determine_action, w, look_back, env, cost_fn = quad_loss):
    """
    Description: 
    Args:
    """
    y = np.zeros((env.n, 1))
    for h in range(look_back, 0, -1):
        v = determine_action(params, y, w[:-h])
        y = env.dyn(y, v) + w[-h] 

    # Don't update state at the end    
    v = determine_action(params, y, w)
    return cost_fn(y, v) 

# Action Loss
def action_loss(actions, w, look_back, env, cost_fn = quad_loss):
    """
    Description: 
    Args:
    """
    y = np.zeros((env.n, 1))
    for h in range(look_back, 1, -1):
        y = env.dyn(y, actions[-h]) + w[-h-1]

    # Don't update state at the end    
    v = actions[-1]
    return cost_fn(y, v) 

In [3]:
# ---------- History Updates ----------

def update_noise(w, x, u, env):
    w = jax.ops.index_update(w, 0, x - env.dyn(x, u))
    w = np.roll(w, -1, axis = 0) 
    return w

In [44]:
class env():
    """
    Description: The base, master LDS class that all other LDS subenvironments inherit. 
        Simulates a linear dynamical system with a lot of flexibility and variety in
        terms of hyperparameter choices.
    """
    def __init__(self):
        self.n, self.m, self.A, self.B = 2, 1, np.array([[1., 1.], [0., 1.]]), np.array([[0.], [1.]])
        self.x = np.zeros((self.n,1))
        self.dyn = lambda x, u: self.A @ x + self.B @ u
        self.t = 0
        
    def step(self, u):
        self.t += 1
        return self.A @ self.x + self.B @ u + self.t % 7 // 100

In [45]:
env = env()

In [48]:

"""
Gradient Pertubation Controller
"""

import tigercontrol
from tigercontrol.controllers import Controller
from tigercontrol.controllers import LQR
from tigercontrol.controllers.boosting.core import quad_loss, policy_loss

import jax
import jax.numpy as np
from jax import grad

# GPC definition
class GPC_v2(Controller):

    def __init__(self, env, K = None, H = 3, look_back = 3, cost_fn = quad_loss, lr = 0.001):
        """
        Description: Initialize the dynamics of the model
        Args:
            env (object): environment
            K (float/numpy.ndarray): Starting policy (optional). 
            H (postive int): history of the controller 
            look_back (positive int): history (rollout) of the system 
            cost_fn (function): cost function
            lr (float/numpy.ndarray): learning rate(s)
        """

        self.n, self.m = env.n, env.m # State & Action Dimensions
        self.env = env

        self.t = 1 # Time Counter (for decaying learning rate)
        self.lr, self.H = lr, H # Model Hyperparameters
        self.look_back = look_back

        # Model Parameters: initial linear policy / perturbation contributions
        self.K = np.zeros((self.m, self.n)) if K is None else K
        self.params = np.zeros((H, self.m, self.n))

        # Past H + look_back noises
        self.w = np.zeros((H + look_back, self.n, 1))

        # past state and past action
        self.x, self.u = np.zeros((self.n, 1)), np.zeros((self.m, 1))

        self.determine_action = lambda params, x, w: -self.K @ x + \
                                        np.tensordot(params, w[-self.H:], axes = ([0, 2], [0, 1]))
        
        self.grad_policy = grad(policy_loss)

    def update_params(self, grad = None, cost_fn = quad_loss):
        """
        Description: Updates the parameters of the model
        Args:
            grad (float/numpy.ndarray): gradient of loss
            cost_fn (function): current loss function
            cost_val (float): current cost value
        """
        # 1. Update t
        self.t = self.t + 1 

        # 2. Get gradients if not provided
        delta_params = self.grad_policy(self.params, self.determine_action, self.w, \
                            self.look_back, self.env, cost_fn) if grad is None else grad

        # 3. Execute parameter updates
        self.params -= self.lr * delta_params

    def update_history(self, x = None):
        """
        Description: Updates the system history tracked by of the model
        Args:
            x (float/numpy.ndarray): observed state
        """
        self.w = update_noise(self.w, x, self.u, self.env)
        self.x = x
        self.u = self.determine_action(self.params, x, self.w)

    def get_action(self, x):
        """
        Description: Return the action chosen by the controller for state x. No side-effects.
        Args:
            x (float/numpy.ndarray): system state
        """

        return self.determine_action(self.params, x, self.w)

In [49]:
### TEST GPC ###
T = 50
controller = GPC_v2(env)
for t in range(T):
    u = controller.get_action(env.x)
    controller.update_params()
    x = env.step(u)
    controller.update_history(x)

In [95]:
import tigercontrol
from tigercontrol.controllers import Controller
from tigercontrol.controllers.boosting.core import quad_loss, action_loss
from tigercontrol.controllers.boosting.core import update_noise

import jax
import jax.numpy as np
from jax import grad

class DynaBoost(Controller):
    """
    Description: 
    """

    def __init__(self, env, N = 3, H = 3, cost_fn = quad_loss):
        """
        Description: Initializes autoregressive controller parameters
        Args:
            controller_id (string): id of weak learner controller
            controller_params (dict): dict of params to pass controller
            N (int): default 3. Number of weak learners
        """
        self.initialized = True

        self.n, self.m = env.n, env.m # State & Action Dimensions
        self.env = env # System

        # 1. Maintain N copies of the algorithm 
        assert N > 0
        self.N, self.H = N, H
        self.controllers = []

        #past state
        self.x = np.zeros((self.n, 1))
        # Past 2H noises
        self.w = np.zeros((2 * H, self.n, 1))

        # 2. Initialize the N weak learners
        self.weak_controller = GPC_v2(env)
        for _ in range(N):
            new_controller = GPC_v2(env)
            self.controllers.append(new_controller)

        self.past_partial_actions = np.zeros((N+1, H, self.m, 1))

        # Extract the set of actions of previous learners
        def get_partial_actions(x):
            u = np.zeros((self.N + 1, self.m, 1))
            partial_u = np.zeros((self.m, 1))
            for i, controller_i in enumerate(self.controllers):
                eta_i = 2 / (i + 2)
                pred_u = controller_i.get_action(x)
                partial_u = (1 - eta_i) * partial_u + eta_i * pred_u
                u = jax.ops.index_update(u, i + 1, partial_u)
            return u

        self.get_partial_actions = get_partial_actions

        self.grad_action = grad(action_loss)

        # Extract the set of actions of previous learners
        def get_grads(partial_actions, w, cost_fn = quad_loss):
            v_list = [self.grad_action(partial_actions[i], w, self.H, self.env, cost_fn) for i in range(self.N)]
            return v_list
        
        self.get_grads = get_grads
        
        def linear_loss(controller_i_params, grad_i, w):
            linear_loss_i = 0

            y = np.zeros((n, 1))

            for h in range(self.H):
                v = self.weak_controller.determine_action(controller_i_params, y, w[:h+H])
                linear_loss_i += np.dot(grad_i[h], v)
                y = self.env.dyn(y, v) + w[h+H]

            v = self.weak_controller.determine_action(controller_i_params, y, w[:h+H])
            linear_loss_i += np.dot(grad_i[h], v)

            return np.sum(linear_loss_i)
        
        self.grad_linear = grad(linear_loss)

    def update_params(self, cost_fn = quad_loss, cost_val = None):
        grads = self.get_grads(self.past_partial_actions, self.w, cost_fn)
        for controller_i, grad_i in zip(self.controllers, grads):
            controller_i.update_params(grad = self.grad_linear(controller_i.params, grad_i, self.w))

    def update_history(self, x = None):
        self.w = update_noise(self.w, x, self.past_partial_actions[-1][-1], self.env)
        self.x = x
        self.past_partial_actions = jax.ops.index_update(self.past_partial_actions,\
                                    jax.ops.index[:, 0], self.get_partial_actions(x))
        self.past_partial_actions = np.roll(self.past_partial_actions, -1, axis = 1)

    def get_action(self, x):
        return self.get_partial_actions(x)[-1]

In [96]:
### TEST DynaBoost ###
T = 50
controller = DynaBoost(env)
for t in range(T):
    u = controller.get_action(env.x)
    controller.update_params()
    x = env.step(u)
    controller.update_history(x)