# Exercise 1.1
## This *.ipynb file contains:

- `LQRsolver`: a class required to be written, which can be __initialised__ with default time horizon T = 1 (which can also take value from users' input) and with the matrices specifying the LQR problem which are:
    - `H` (`torch.Size([n,n]) torch.tensor`): the linear relations of dynamics between the total `n` state processes.
    - `M` (`torch.Size([n,m]) torch.tensor`): the influences from `m` control variables to `n` state processes. 
    - `sigma`(`torch.Size([n,d]) torch.tensor`): the diffusion matrix from `d` Wiener processes to `n` state processes.
    - `C` (`torch.Size([n,n]) torch.tensor`): the contribution matrix from state processes to runnning reward.
    - `D` (`torch.Size([m,m]) torch.tensor`): the contribution matrix from the final value of state processes to runnning reward.
    - `R` (`torch.Size([n,n]) torch.tensor`): the contribution matrix from the final values of state processes to terminal reward.
    
    __Declaration for dimensions of matrices__:
    Here `n` should be compatible with the dimension of state variable space and m is the dimension of control variable space.
    Note that `n = m = 2` and `d = 1` in this exercise but can be extended to a higher dimension.

    There are __3 main methods__ built in `LQRsolver` including

    - `solve_riccati_ode`: __a numerical solver for Ricatti ODE__ which requires
    
        __input__ of
    
        - `time_grids`  `torch.Size([batch_size,l]) torch.Tensor`
        
        and __returns__
        
        - the values of solution function $S(t)$  `torch.Size([batch_size,l,n,n]) torch.tensor`
        
        (Two numerical methods are tried to be provided as options: Euler scheme and 4th order Runge-Kutta scheme)
    
    - `value_function`: __a computation of value function__ which requires

        __inputs__ (in sequence) of 
        
        - `t_batch`  `torch.Size([batch_size]) torch.tensor` whose entries took value initially from [0,1] but would be scaled by the given T in the further calculation
        - `x_batch`  `torch.Size([batch_size,1,n]) torch.tensor`
        
        and __returns__
        
        - values of value_function `torch.Size([batch_size]) torch.tensor`
        
        
    - `markov_control`: __a computation of Markov control function__ which requires

        __inputs__ (in sequence) of 
        
        - `t_batch`  `torch.Size([batch_size]) torch.tensor` whose entries took value initially from [0,1] but would be scaled by the given T in the further calculation
        - `x_batch`  `torch.Size([batch_size,1,n]) torch.tensor`
        
        and __returns__
        
        - values of Markov_control_function  `torch.Size([batch_size,n]) torch.tensor`
        
        
        
- __A runnable sample__ including: 

    - a whole set of matrices for __initialisation__
    - an example of __calculation of value function__ with given t_batch and x_batch
    - an example of __calculation of Markov control function__ with given t_batch and x_batch


In [76]:
import torch
import numpy as np
from scipy.interpolate import CubicSpline
import warnings #for alarming some inappropriate inputs

class LQRSolver:
    #default numerical method is Euler.
    #Euler and Runge-Kutta are supported in this code.
    def __init__(self, H, M, sigma, C, D, R, T = 1, method="euler"):
        
        if not self.is_semi_positive_definite(C):
            raise ValueError("Matrix C must be semi-positive definite.")
        if not self.is_semi_positive_definite(R):
            raise ValueError("Matrix R must be semi-positive definite.")
        if not self.is_positive_definite(D):
            raise ValueError("Matrix D must be positive definite.")
 
        self.H = H
        self.M = M
        self.sigma = sigma
        self.C = C
        self.D = D
        self.R = R
        self.T = T
        self.method = method 
    
    def is_positive_definite(self, matrix):
        
        eigvals, _ = torch.linalg.eig(matrix)
        real_parts = eigvals.real
        return torch.all(real_parts > 0) 

    def is_semi_positive_definite(self, matrix):

        eigvals, _ = torch.linalg.eig(matrix)
        real_parts = eigvals.real
        return torch.all(real_parts >= 0) 

    def solve_riccati_ode(self, time_grids):
        
        if not isinstance(time_grids, torch.Tensor):
            raise TypeError("time_grid should be an batch_size*1-D torch.Tensor")
        else:
            if not (torch.all(np.abs(time_grids[:,-1] - self.T) <= 1e-12) or torch.all(time_grids[:,0]>=0)):
                print()
                raise Exception("Please ensure that the first entry of time_grid >= 0 and the last entry is equal to T.")
            else:
                time_grids_in = time_grids

        time_grids_in = torch.flip(time_grids, dims=[1])
        S = self.R.clone()
        repl = torch.ones(time_grids_in.shape)
        rep_exd = repl.unsqueeze(-1).unsqueeze(-1)
        S_exd = S.unsqueeze(0).unsqueeze(0)

        S_repl = rep_exd*S_exd
        dt = time_grids_in[:,1:]-time_grids_in[:,:-1]

        for i in range(dt.shape[1]):
            

            if self.method == "euler":
                S_repl[:,i+1] = self.euler_step(S_repl[:,i], dt[:,i])
            elif self.method == "rk4":
                S_repl[:,i+1] = self.rk4_step(S_repl[:,i], dt[:,i])
            else:
                raise ValueError("Unsupported method")

        return torch.flip(S_repl, dims=[1])

    def euler_step(self, S_in, dt_in):
        
        dS_in = -2 * self.H.T @ S_in + S_in @ self.M @ torch.inverse(self.D) @ self.M.T @ S_in - self.C

        dt_resized = dt_in[:, None, None]

        return S_in + dS_in * dt_resized
    
    def rk4_step(S_in, dt_in):

        def riccati_derivative(S):
            return -2 * self.H.T @ S_in + S_in @ self.M @ torch.inverse(self.D) @ self.M.T @ S_in - self.C
        
        dt_resized = dt_in[:, None, None]

        k1 = riccati_derivative(S_in)
        k2 = riccati_derivative(S_in + 0.5 *  k1 * dt_resized)
        k3 = riccati_derivative(S_in + 0.5 *  k2 * dt_resized)
        k4 = riccati_derivative(S_in + k3 * dt_resized)

        return S_in + (k1 + 2*k2 + 2*k3 + k4)*dt_resized/6

    def value_function(self, t_batch, x_batch):

        # Verify the shapes of the inputs.

        if not (t_batch.dim() == 1 and torch.all((t_batch >= 0) & (t_batch <= 1))):
            raise TypeError("t_batch should be a 1D tensor in which every entry is in [0,1].")
        else:
            if not (x_batch.dim() == 3 and x_batch.size()[0] == len(t_batch) and x_batch.size()[1] == 1 and x_batch.size()[2] == self.H.size()[0]):
                raise TypeError("x_batch should have shape (%d, 1, %d)."%(len(t_batch),self.H.size(2)))
        
        time_grids = torch.stack([torch.linspace(float(t), self.T, 5000, dtype=torch.float32) for t in t_batch])
        
        S_tensor_tensor = self.solve_riccati_ode(time_grids)
        #print(S_tensor_tensor.shape)
        S_t_s = S_tensor_tensor[:, 0, :, :]
        S_t_T = S_tensor_tensor[:, 1:, :, :]
        x_batch_T = x_batch.transpose(1, 2) 
        
        xTSx = x_batch @ S_t_s @ x_batch_T

        dts = time_grids[:, 1:] - time_grids[:, :-1]

        sigma_T = self.sigma.transpose(1,2)
        print(self.sigma.shape)
        print(sigma_T.shape)
        print(S_t_T.transpose(0,1).shape)
        trace_for_int = torch.einsum('bcii->bc', self.sigma @ self.sigma.transpose(1, 2) @ S_t_T.transpose(0,1)).unsqueeze(1).unsqueeze(2)
        trace_for_int = trace_for_int.squeeze() 
        
        integral_part = dts @ trace_for_int

        v_tx = xTSx.squeeze() + torch.diag(integral_part).squeeze()
        
        return v_tx

    def markov_control(self, t_batch, x_batch):
        
        # Verify the shapes of the inputs.

        if not (t_batch.dim() == 1 and torch.all((t_batch >= 0) & (t_batch <= 1))):
            raise TypeError("t_batch should be a 1D tensor in which every entry is in [0,1].")
        else:
            if not (x_batch.dim() == 3 and x_batch.size()[0] == len(t_batch) and x_batch.size()[1] == 1 and x_batch.size()[2] == self.H.size()[0]):
                raise TypeError("x_batch should have shape (%d, 1, %d)."%(len(t_batch),self.H.size(2)))
        
        time_grids = torch.stack([torch.linspace(float(t), self.T, 5000, dtype=torch.float32) for t in t_batch])
        
        S_tensor_tensor = self.solve_riccati_ode(time_grids)

        #print(S_tensor_tensor.shape)
        
        S_t_s = S_tensor_tensor[:, 0, :, :]
        x_batch_T = x_batch.transpose(1, 2) 

        MT = self.M.T
        D_inv = torch.inverse(self.D)
        x = torch.transpose(x_batch,dim0 = 2,dim1 = 1)
        a_tx = - D_inv @ MT @ S_t_s @ x_batch_T
        a_tx = torch.transpose(a_tx,dim0 = 1,dim1 = 2).squeeze()

        return a_tx.unsqueeze(1)

In [71]:
#Initialization
H = torch.tensor([[1.0, 2.0], [-2.0, -3.0]], dtype=torch.float32)
M = torch.tensor([[1.0,0.0], [0.0,1.0]], dtype=torch.float32)
sigma = torch.tensor([[[0.5249],[0.4072]]], dtype=torch.float32) 
C = torch.tensor([[2.0, 0.0], [0.0, 1.0]], dtype=torch.float32)  # Positive semi-definite
D = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)  # Positive definite
R = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)  # Positive semi-definite
T = 5.0

In [72]:
solver = LQRSolver(H, M, sigma, C, D, R, T)

In [73]:
t_batch = torch.linspace(0.1,0.2,20)
x_batch = torch.rand([t_batch.shape[0],1,2])

In [74]:
solver.value_function(t_batch, x_batch)

torch.Size([1, 2, 1])
torch.Size([1, 1, 2])
torch.Size([4999, 20, 2, 2])


tensor([2.2372, 2.2663, 3.1043, 2.4302, 2.5439, 2.5151, 2.2395, 2.2572, 2.2584,
        2.8201, 2.3908, 2.2295, 2.2535, 2.1996, 2.4656, 3.3392, 2.8854, 3.8431,
        2.1922, 2.2647])

In [75]:
solver.markov_control(t_batch, x_batch)

tensor([[[ 0.1613, -0.0059]],

        [[ 0.0618, -0.0605]],

        [[-0.8084, -0.5288]],

        [[-0.4662, -0.2710]],

        [[-0.4357, -0.3056]],

        [[-0.4087, -0.2905]],

        [[-0.0757, -0.0568]],

        [[-0.1009, -0.0892]],

        [[ 0.0137, -0.0754]],

        [[-0.8757, -0.4902]],

        [[-0.1292, -0.1891]],

        [[ 0.0726, -0.0423]],

        [[ 0.0415, -0.0778]],

        [[ 0.2530, -0.0138]],

        [[-0.3342, -0.2642]],

        [[-0.8750, -0.5932]],

        [[-0.7343, -0.4731]],

        [[-1.0047, -0.7030]],

        [[ 0.1363,  0.0199]],

        [[-0.2487, -0.1536]]])