In [75]:
import jax.numpy as np
import tigercontrol
from tigercontrol.methods import Method
from jax import grad,jit
import jax.random as random
from tigercontrol.utils import generate_key
import jax
import scipy

class GPC(Method):
    """
    Description: Computes optimal set of actions using the Linear Quadratic Regulator
    algorithm.
    """
    
    compatibles = set([])

    def __init__(self):
        self.initialized = False

    def initialize(self, A, B, x, n, m, H, HH, K = None):
        """
        Description: Initialize the dynamics of the model
        Args:
            A,B (float/numpy.ndarray): system dynamics
            n (float/numpy.ndarray): dimension of the state
            m (float/numpy.ndarray): dimension of the controls
            H (postive int): history of the controller 
            HH history of the system 
            x (float/numpy.ndarray): current state
            past_w (float/numpy.ndarray)  previous perturbations 
        """
        self.initialized = True
        
        def _update_past(self_past, x):
            new_past = np.roll(self_past, self.n)
            new_past = jax.ops.index_update(new_past, 0, x)
            return new_past
        self._update_past = jit(_update_past)

        if(K is None):
            # solve the ricatti equation 
            X = scipy.linalg.solve_continuous_are(A, B, np.identity(n), np.identity(m))
            #compute LQR gain
            self.K = np.linalg.inv(B.T @ X @ B + np.identity(m)) @ (B.T @ X @ A)
        else:
            self.K = K

        self.x = np.zeros(n)        
        self.u = np.zeros(m)
        
        self.n = n   ## dimension of  the state x 
        self.m = m   ## dimension of the control u
        self.A = A
        self.B = B
        self.H = H   ## how many control matrices
        self.HH = HH ## how many times to unfold the recursion

        ## internal parmeters to the class 
        self.T = 1 ## keep track of iterations, for the learning rate
        self.learning_rate = 1
        self.M = np.zeros((H, m, n))
        #self.M = random.normal(generate_key(), shape=(H, m, n)) / np.sqrt(0.5*(n+m)) # Glorot CANNOT BE SET TO ZERO
        self.S = np.repeat(B.reshape(1, n, m), HH, axis=0) # previously [B for i in range(HH)]
        for i in range(1, HH):
            self.S = jax.ops.index_update(self.S, i, (A - B @ self.K) @ self.S[i-1]) 
        self.w_past = np.zeros((HH + H,n)) ## this are the previous perturbations, from most recent [0] to latest [HH-1]

        self.is_online = True

        def the_complicated_loss_function(M, w_past, S):
            """
            This is the counterfactual loss function, we prefer not to differentiate it and use JAX 
            """
            final = np.zeros(self.n)
            for i in range(self.HH):
                temp = np.tensordot(M, w_past[i:i+self.H], axes=([0,2],[0,1]))
                final = final + S[i] @ temp
            return np.sum(final ** 2)

        self.grad_fn = jit(grad(the_complicated_loss_function))  # compiled gradient evaluation function

    def plan(self,x_new):
        """
        Description: Updates internal parameters and then returns the estimated optimal action (only one)
        Args:
            None
        Returns:
            Estimated optimal action
        """

        self.T +=1
        self.learning_rate = 1 / np.sqrt(self.T + 1)
        w_new = x_new - np.dot(self.A , self.x)  - np.dot(self.B , self.u)
        if(self.T > 2):
            print("OUR w:", w_new)
        self.w_past = self._update_past(self.w_past, w_new)
        self.x = x_new

        self.u = - self.K @ x_new + np.tensordot(self.M, self.w_past[:self.H], axes=([0,2],[0,1]))
        
        self.M = self.M - self.learning_rate * self.grad_fn(self.M, self.w_past, self.S)
        
        return self.u

In [76]:
from tigercontrol.utils.random import set_key
set_key(0)

In [77]:
environment = tigercontrol.environment("LDS-v0")
x = environment.initialize(n = 1, m = 1, noise_magnitude = 0.3, noise_distribution = 'normal')



In [78]:
method = GPC()
method.initialize(A = environment.A, B = environment.B, x = x, n = 1, m = 1, H = 3, HH = 30, K = np.zeros((1,1)))

In [79]:
#x_new = environment.step(np.zeros(1))
#u = np.zeros(1)

u = method.plan(x)
x_new = environment.step(u)
for i in range(T):
    print("True w:", x_new - np.dot(environment.A , x)  - np.dot(environment.B , u))
    u = method.plan(x_new)
    x = x_new
    x_new = environment.step(u)
    print("Past 3 w values:", method.w_past[0:3].flatten())

True w: [0.06514227]
OUR w: [0.06514227]
Past 3 w values: [0.06514227 0.48648298 0.        ]
True w: [-0.19459005]
OUR w: [-0.19459005]
Past 3 w values: [-0.19459005  0.06514227  0.48648298]
True w: [0.40191883]
OUR w: [0.40191883]
Past 3 w values: [ 0.40191883 -0.19459005  0.06514227]
True w: [0.3115316]
OUR w: [0.3115316]
Past 3 w values: [ 0.3115316   0.40191883 -0.19459005]
True w: [-0.22628738]
OUR w: [-0.22628738]
Past 3 w values: [-0.22628738  0.3115316   0.40191883]
True w: [-0.18143398]
OUR w: [-0.18143398]
Past 3 w values: [-0.18143398 -0.22628738  0.3115316 ]
True w: [-0.70501757]
OUR w: [-0.70501757]
Past 3 w values: [-0.70501757 -0.18143398 -0.22628738]
True w: [-0.62629807]
OUR w: [-0.62629807]
Past 3 w values: [-0.62629807 -0.70501757 -0.18143398]
True w: [0.48998746]
OUR w: [0.48998746]
Past 3 w values: [ 0.48998746 -0.62629807 -0.70501757]
True w: [0.61472225]
OUR w: [0.61472225]
Past 3 w values: [ 0.61472225  0.48998746 -0.62629807]
