Sure, here's an example class in Python that can be used to solve an LQR problem using the Riccati ODE approach

In [1]:
import numpy as np
import torch
from scipy.integrate import solve_ivp

In [7]:
class LQR_Riccati:
    def __init__(self, A, B, Q, R, T):
        self.A = A
        self.B = B
        self.Q = Q
        self.R = R
        self.T = T
        
    def solve(self, t_span=None, t_eval=None):
        def riccati_ode(t, P_vec):
            P = P_vec.reshape((self.A.shape[0], self.A.shape[0]))
            dP_dt = -self.Q + P @ self.B @ np.linalg.inv(self.R) @ self.B.T @ P - P @ self.A - self.A.T @ P
            return dP_dt.reshape((self.A.shape[0] ** 2,))
        
        P0 = self.Q
        if t_span is None:
            t_span = (0, self.T)
        if t_eval is None:
            t_eval = np.linspace(t_span[0], t_span[1], num=101)
        P_vec = solve_ivp(riccati_ode, t_span, P0.reshape((self.A.shape[0] ** 2,)), t_eval=t_eval, method='RK45', rtol=1e-6, atol=1e-8).y
        P = P_vec.reshape((self.A.shape[0], self.A.shape[0], -1))
        K = np.linalg.inv(self.R) @ self.B.T @ P
        
        return K
    
    def value_function(self, x):
        x = torch.tensor(x, dtype=torch.float32)
        K = torch.tensor(self.solve(), dtype=torch.float32)
        A_cl = self.A - self.B @ K
        Q_cl = self.Q - self.B @ K @ self.R @ K.T @ self.B.T
        v = torch.sum(x.unsqueeze(-1) @ Q_cl @ x.unsqueeze(-2), dim=(-1, -2))
        return v