# SCDAA Coursework

In [1]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import torch

In [None]:
import torch
import numpy as np
from scipy.integrate import solve_ivp

class LQR:
    def __init__(self, H, M, C, D, R, T):
        self.H = torch.tensor(H, dtype=torch.float32)
        self.M = torch.tensor(M, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.D = torch.tensor(D, dtype=torch.float32)
        self.R = torch.tensor(R, dtype=torch.float32)
        self.T = T
        self.sigma = None  # This should be set based on the problem's context if needed

    def solve_ricatti_ode(self, time_grid):
        H, M, C, D, R = self.H.numpy(), self.M.numpy(), self.C.numpy(), self.D.numpy(), self.R.numpy()
        T = self.T

        def ricatti_ode(t, S_flat):
            S = S_flat.reshape(2, 2)
            dSdt = -2 * H.T @ S + S @ M @ np.linalg.inv(D) @ M.T @ S - C
            return dSdt.flatten()

        sol = solve_ivp(ricatti_ode, [T, time_grid[0]], R.flatten(), t_eval=time_grid[::-1], vectorized=True)
        S_values = sol.y.T.reshape(-1, 2, 2)
        return torch.tensor(S_values[::-1], dtype=torch.float32)  # Reverse the order to match the input time grid

    def control_problem_value(self, t, x):
        S_values = self.solve_ricatti_ode(t)
        v_values = torch.stack([x_.T @ S_values[i] @ x_ for i, x_ in enumerate(x)])
        if self.sigma is not None:
            tr_values = torch.tensor([np.trace(self.sigma @ self.sigma.T @ S_values[i].numpy()) for i in range(len(t))], dtype=torch.float32)
            integral_tr = torch.cumsum(tr_values * (t[1:] - t[:-1]), dim=0)
            v_values += integral_tr
        return v_values

    def markov_control_function(self, t, x):
        S_values = self.solve_ricatti_ode(t)
        a_values = torch.stack([-torch.inverse(self.D) @ self.M.T @ S_values[i] @ x_ for i, x_ in enumerate(x)])
        return a_values.squeeze(dim=1)

# Example of how to initialize and use the class (without actual matrices and T)
# H, M, C, D, R are matrices and T is a scalar
# lqr = LQR(H=[[0, 1], [-1, 0]], M=[[1, 0], [0, 1]], C=[[1, 0], [0, 1]], D=[[1, 0], [0, 1]], R=[[1, 0], [0, 1]], T=1)
# time_grid = torch.linspace(0, 1, steps=100)  # Example time grid
# x = torch.tensor([[[1.0, 2.0]]])  # Example state tensor
# print(lqr.markov_control_function(time_grid, x))
