In [113]:
import jax.numpy as np
import tigercontrol
from tigercontrol.methods.control import ControlMethod
from jax import grad,jit
import jax.random as random
from tigercontrol.utils import generate_key
import jax
import scipy

class GPC(ControlMethod):
    """
    Description: Computes optimal set of actions using the Linear Quadratic Regulator
    algorithm.
    """
    
    compatibles = set([])

    def __init__(self):
        self.initialized = False

    def initialize(self, A, B, x, n, m, H, HH, K = None):
        """
        Description: Initialize the dynamics of the model
        Args:
            A,B (float/numpy.ndarray): system dynamics
            n (float/numpy.ndarray): dimension of the state
            m (float/numpy.ndarray): dimension of the controls
            H (postive int): history of the controller 
            HH history of the system 
            x (float/numpy.ndarray): current state
            past_w (float/numpy.ndarray)  previous perturbations 
        """
        self.initialized = True
        
        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past
        self._update_past = jit(_update_past)

        if(K is None):
            # solve the ricatti equation 
            X = scipy.linalg.solve_continuous_are(A, B, np.identity(n), np.identity(m))
            #compute LQR gain
            self.K = np.linalg.inv(B.T @ X @ B + np.identity(m)) @ (B.T @ X @ A)
        else:
            self.K = K

        self.x = np.zeros(n)        
        self.u = np.zeros(m)
        
        self.n = n   ## dimension of  the state x 
        self.m = m   ## dimension of the control u
        self.A = A
        self.B = B
        self.H = H   ## how many control matrices
        self.HH = HH ## how many times to unfold the recursion

        ## internal parmeters to the class 
        self.T = 1 ## keep track of iterations, for the learning rate
        self.learning_rate = 1
        self.M = random.normal(generate_key(), shape=(H, m, n)) ## CANNOT BE SET TO ZERO
        self.S = [B for i in range(HH)]
        for i in range(1, HH):
            self.S[i] = (A + B @ self.K) @ self.S[i-1]
        self.w_past = np.zeros((HH + H,n)) ## this are the previous perturbations, from most recent [0] to latest [HH-1]

        self.is_online = True

    def the_complicated_loss_function(self, M, w_past, S):
        """
        This is the counterfactual loss function, we prefer not to differentiate it and use JAX 
        """
        final = np.zeros(self.n)
        for i in range(self.HH):
            #temp = np.zeros(self.m)
            #for j in range(self.H):
            #    temp = temp + np.dot( M[j] , w_past[i+j])
            temp = np.tensordot(M, w_past[i:i+self.H], axes=([0,2],[0,1]))
            final = final + S[i] @ temp
        return np.sum(final ** 2)

    def plan(self,x_new):
        """
        Description: Updates internal parameters and then returns the estimated optimal action (only one)
        Args:
            None
        Returns:
            Estimated optimal action
        """

        self.T +=1
        self.learning_rate = 1 / np.sqrt(self.T + 1)

        w_new = x_new - np.dot(self.A , self.x)  - np.dot(self.B , self.u)
        self.w_past = self._update_past(self.w_past, w_new)
        self.x = x_new

        #self.u = np.zeros(self.m)
        #for i in range(self.H):
        #    self.u += np.dot(self.M[i] , self.w_past[i])
        self.u = - self.K @ x_new + np.tensordot(self.M, self.w_past[:self.H], axes=([0,2],[0,1]))
        
        grad_fn = jit(grad(self.the_complicated_loss_function))  # compiled gradient evaluation function
        
        self.M = self.M - self.learning_rate * grad_fn(self.M, self.w_past, self.S)
        
        return self.u

In [125]:
from tigercontrol.utils.random import set_key
set_key(0)

In [126]:
problem = tigercontrol.problem("LDS-v0")
x = problem.initialize(n = 1, m = 1, noise_magnitude = 0.3, noise_distribution = 'normal')



In [128]:
method = GPC()
method.initialize(A = problem.A, B = problem.B, x = x, n = 1, m = 1, H = 3, HH = 30, K = np.zeros((1,1)))

In [129]:
import time

In [130]:
print(time.ctime())
for i in range(10):
    print(method.plan(x))
print(time.ctime())

Fri Oct 18 16:38:13 2019
[0.20083119]
[0.3693685]
[0.0526429]
[-0.06397448]
[0.02098359]
[0.02807364]
[0.02046058]
[-0.00052501]
[0.00111327]
[-0.000693]
Fri Oct 18 16:38:18 2019


In [112]:
print(time.ctime())
for i in range(10):
    print(method.plan(x))
print(time.ctime())

Fri Oct 18 16:34:53 2019
[0.20083119]
[0.3693685]
[0.15104842]
[-0.05962548]
[-0.00827806]
[0.00710458]
[0.03647117]
[-0.003102]
[0.007906]
[-0.02375939]
Fri Oct 18 16:34:58 2019
