In [2]:
import numpy as np
from scipy.integrate import solve_ivp
from scipy.integrate import quad
import torch
# from torchdiffeq import odeint
# import matplotlib.pyplot as plt

In [5]:
# torch
class LQR:
    def __init__(self, H, M, C, D, R, sigma, T, time_grid):
        """
        初始化 LQR 类

        Parameters:
        H, M, C, D, R: 线性二次调节器的矩阵
        sigma: 噪声项
        T: 终止时间
        time_grid: 时间网格 (numpy array)
        """
        self.H = H
        self.M = M
        self.C = C
        self.D = D
        self.R = R
        self.sigma = sigma
        self.T = T
        self.time_grid = time_grid
        self.S_values = self.solve_riccati_ode()

    def riccati_ode(self, t, S_flat):
        """Riccati ODE 求解函数，转换为向量形式"""
        S = torch.tensor(S_flat, dtype=torch.float32).reshape(2,2) # 2x2 矩阵
        S_dot = S @ self.M @ torch.linalg.inv(self.D) @ self.M.T @ S - self.H.T @ S - S @ self.H - self.C
        return S_dot.flatten()

    def solve_riccati_ode(self):
        """使用 solve_ivp 求解 Riccati ODE"""
        S_T = self.R.flatten()  # 终止条件 S(T) = R
        indices = torch.arange(self.time_grid.size(0) - 1, -1, -1)  # 生成倒序索引
        time_grid_re = torch.index_select(self.time_grid, 0, indices)
        sol = solve_ivp(self.riccati_ode, [self.T, 0], S_T, t_eval=time_grid_re, atol=1e-10, rtol=1e-10)  # 逆向求解
        S_matrices = sol.y.T[::-1].reshape(-1, 2, 2)  # 转换回矩阵格式
        return dict(zip(tuple(self.time_grid.tolist()), S_matrices))

    def get_nearest_S(self, t):
        """找到最近的 S(t)"""
        nearest_t = self.time_grid[torch.argmin(torch.abs(self.time_grid - t))]
        return self.S_values[nearest_t.tolist()]
    
    def value_function(self, t, x):
        """计算新的 v(t, x) = x^T S(t) x + ∫[t,T] tr(σσ^T S(r)) dr"""
        # 第一部分：x^T S(t) x
        # print(self.S_values)
        S_t = self.get_nearest_S(t)
        S_t = torch.tensor(S_t, dtype=torch.float32)
        value = x.T @ S_t @ x
        # print(value)
        
        # 第二部分：积分项 ∫[t,T] tr(σσ^T S(r)) dr
        def integrand(r):
            S_r = self.get_nearest_S(r)
            return torch.trace(self.sigma @ self.sigma.T @ S_r)
        
        integral, _ = quad(integrand, t, self.T)  # 使用数值积分计算积分项
        value += integral
        return value
    
    def optimal_control(self, t, x):
        """计算最优控制 a(t, x) = -D^(-1) M^T S(t) x"""
        S_t = self.get_nearest_S(t)
        S_t = torch.tensor(S_t, dtype=torch.float32)
        return -torch.linalg.inv(self.D) @ self.M.T @ S_t @ x
    
# 示例参数
H = torch.tensor([[1.0, 1.0], [0.0, 1.0]]) * 0.5
M = torch.tensor([[1.0, 1.0], [0.0, 1.0]])
sigma = torch.eye(2) * 0.5
C = torch.tensor([[1.0, 0.1], [0.1, 1.0]]) * 1.0
D = torch.tensor([[1.0, 0.1], [0.1, 1.0]]) * 0.1
R = torch.tensor([[1.0, 0.3], [0.3, 1.0]]) * 10.0
T = 1.0
time_grid = torch.linspace(0, T, 100)

# print(R)
# 实例化 LQR 类
lqr = LQR(H, M, C, D, R, sigma, T, time_grid)
# 计算示例
t_test = 0
x_test = torch.tensor([1.0, 1.0])

print("Value function v(t, x):", lqr.value_function(t_test, x_test))
print("Optimal control a(t, x):", lqr.optimal_control(t_test, x_test))

Value function v(t, x): tensor(0.8846)
Optimal control a(t, x): tensor([-1.4239, -5.0703])


  return torch.trace(self.sigma @ self.sigma.T @ S_r)
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  integral, _ = quad(integrand, t, self.T)  # 使用数值积分计算积分项
