In [None]:
#报错的第五题

import numpy as np
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp, cumulative_trapezoid as cumtrapz

# ==========================================
# 1. Soft_LQR 类（Exercise 2 实现的代码）——改进版
# ==========================================
class Soft_LQR:
    def __init__(self, H, M, C, D, R, sigma, T, N, tau, gamma):
        """
        初始化 soft LQR 类：
          - 计算 S(t)（通过 Riccati ODE 求解）
          - 预计算积分 I(t) = ∫[t,T] tr(σσ^T S(r)) dr
        """
        self.H = H
        self.M = M
        self.C = C
        self.D = D
        self.R = R
        self.sigma = sigma
        self.T = T
        self.N = N
        self.time_grid = torch.linspace(0, T, N + 1)
        self.tau = tau
        self.gamma = gamma
        self.S_values = self.solve_riccati_ode()
        self.integral_values = self.precompute_integrals()

    def riccati_ode(self, t, S_flat):
        S = torch.tensor(S_flat, dtype=torch.float32).reshape(2, 2)
        D_term = self.D + self.tau / (2 * (self.gamma ** 2)) * torch.eye(2)
        S_dot = S.T @ self.M @ torch.linalg.inv(D_term) @ self.M.T @ S - self.H.T @ S - S @ self.H - self.C
        return S_dot.flatten()

    def solve_riccati_ode(self):
        S_T = self.R.flatten()
        indices = torch.arange(self.time_grid.size(0) - 1, -1, -1)
        time_grid_re = torch.index_select(self.time_grid, 0, indices)
        sol = solve_ivp(self.riccati_ode, [self.T, 0], S_T, t_eval=time_grid_re, atol=1e-10, rtol=1e-10)
        S_matrices = sol.y.T[::-1].reshape(-1, 2, 2)
        return dict(zip(tuple(self.time_grid.tolist()), S_matrices))

    def precompute_integrals(self):
        times = self.time_grid.numpy()
        f_vals = []
        for t in times:
            S_t = self.S_values[t]
            f_val = torch.trace(self.sigma @ self.sigma.T @ torch.tensor(S_t, dtype=torch.float32)).item()
            f_vals.append(f_val)
        f_vals = np.array(f_vals)
        dt = times[1] - times[0]
        cum_int = cumtrapz(f_vals[::-1], dx=dt, initial=0)[::-1]
        integral_dict = {t: I for t, I in zip(times, cum_int)}
        return integral_dict

    def get_nearest_S(self, t):
        nearest_t = self.time_grid[torch.argmin(torch.abs(self.time_grid - t))]
        return self.S_values[nearest_t.tolist()]

    def value_function(self, t, x):
        S_t = self.get_nearest_S(t)
        S_t = torch.tensor(S_t, dtype=torch.float32)
        val = x.T @ S_t @ x
        times = self.time_grid.numpy()
        idx = np.argmin(np.abs(times - t))
        t_nearest = times[idx]
        val = val + self.integral_values[t_nearest]
        var_matrix = self.D + self.tau / (2 * (self.gamma ** 2)) * torch.eye(2)
        inv_matrix = torch.linalg.inv(var_matrix)
        det_matrix = torch.det(inv_matrix)
        C_const = - self.tau * torch.log((self.tau / self.gamma ** 2) * torch.sqrt(det_matrix))
        val = val + (self.T - t) * C_const
        return val

    def optimal_control(self, t, x):
        S_t = self.get_nearest_S(t)
        S_t = torch.tensor(S_t, dtype=torch.float32)
        inv_term = self.D + self.tau / (2 * (self.gamma ** 2)) * torch.eye(2)
        mean_control = -torch.linalg.inv(inv_term) @ self.M.T @ S_t @ x
        cov_control = self.tau * inv_term
        control_dist = MultivariateNormal(mean_control, cov_control)
        return control_dist

# ==========================================
# 2. Actor 策略：参数化为线性映射 θ (2x2 矩阵)
# ==========================================
class Actor(nn.Module):
    def __init__(self, state_dim=2, action_dim=2):
        super(Actor, self).__init__()
        # 参数 θ 初始化为随机 2x2 矩阵
        self.theta = nn.Parameter(torch.randn(action_dim, state_dim, dtype=torch.float32))
    
    def forward(self, x):
        # 输入 x 为形状 (2,) 向量
        mu = (self.theta @ x).squeeze()  # 输出形状 (2,)
        return mu

    def get_distribution(self, x, cov):
        mu = self.forward(x)
        return MultivariateNormal(mu, cov)

# ==========================================
# 3. Critic 模型：近似价值函数 v(t,x;η)
# ==========================================
class Critic(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=256):
        # 输入为 [t, x0, x1]
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, t, x):
        # t 为标量，x 为形状 (2,) 向量
        # 拼接成 (3,) 向量
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        inp = torch.cat([t_tensor, x], dim=0)
        return self.net(inp).squeeze()  # 输出标量

# ==========================================
# 4. 利用当前 actor 策略生成一条轨迹（Euler 离散化）
# ==========================================
def simulate_actor_episode(actor, H, M, sigma, dt, T, x0, cov):
    # 我们假设 x0 为形状 (2,) 向量
    traj = [x0.detach().clone()]  # 记录状态
    log_probs = []
    costs = []
    # 记录每步的时间
    times = [0.0]
    
    x = x0
    num_steps = int(T / dt)
    for n in range(num_steps):
        t = n * dt
        # 生成动作，a 形状为 (2,)
        dist = actor.get_distribution(x, cov)
        a = dist.rsample()
        logp = dist.log_prob(a)
        log_probs.append(logp)
        
        dt_tensor = torch.tensor(dt, dtype=torch.float32)
        with torch.no_grad():
            # 计算即时 cost: cost = x^T C x + a^T D a
            cost = (x @ C_tensor @ x) + (a @ D_tensor @ a)
            costs.append(cost.item())
            noise = torch.randn(2)
            x = x + (H @ x + M @ a) * dt + sigma @ (noise * torch.sqrt(dt_tensor))
        traj.append(x.detach().clone())
        times.append((n+1)*dt)
    return traj, log_probs, costs, times

# ==========================================
# 5. Actor–Critic 算法主训练函数
# ==========================================
def train_actor_critic(num_episodes=500, dt=0.01, T=1.0, actor_lr=1e-4, critic_lr=1e-1):
    # 固定环境参数（与 main() 中一致）
    H_np = np.array([[0.5, 0.5],
                     [0.0, 0.5]], dtype=np.float32)
    M_np = np.array([[1.0, 1.0],
                     [0.0, 1.0]], dtype=np.float32)
    sigma_np = np.eye(2, dtype=np.float32) * 0.5
    C_np = np.array([[1.0, 0.1],
                     [0.1, 1.0]], dtype=np.float32)
    D_np = np.array([[1.0, 0.1],
                     [0.1, 1.0]], dtype=np.float32) * 0.1
    R_terminal_np = np.array([[1.0, 0.3],
                              [0.3, 1.0]], dtype=np.float32) * 10.0
    tau = 0.1
    gamma_param = 10.0
    N_steps = int(T / dt)
    
    # 转换为 torch 张量
    H_tensor = torch.tensor(H_np)
    M_tensor = torch.tensor(M_np)
    sigma_tensor = torch.tensor(sigma_np)
    global C_tensor, D_tensor, R_tensor
    C_tensor = torch.tensor(C_np)
    D_tensor = torch.tensor(D_np)
    R_tensor = torch.tensor(R_terminal_np)
    
    # 初始化 Soft_LQR
    soft_lqr = Soft_LQR(H_tensor, M_tensor, C_tensor, D_tensor, R_tensor, sigma_tensor, T, N_steps,
                        tau, gamma_param)
    
    # 固定策略协方差
    var_matrix = D_tensor + tau / (2 * (gamma_param ** 2)) * torch.eye(2)
    cov_opt = tau * var_matrix
    
    # 初始化 Actor、Critic 及其优化器
    actor = Actor(state_dim=2, action_dim=2)
    critic = Critic(input_dim=3, hidden_dim=256)
    optimizer_actor = torch.optim.AdamW(actor.parameters(), lr=actor_lr, weight_decay=1e-4)
    optimizer_critic = torch.optim.AdamW(critic.parameters(), lr=critic_lr)
    
    # 初始状态生成函数：返回形状 (2,) 的向量
    def get_init_state():
        return torch.tensor([np.random.uniform(-2, 2), np.random.uniform(-2, 2)], dtype=torch.float32)
    
    error_list = []
    actor_loss_list = []
    critic_loss_list = []
    
    for ep in range(num_episodes):
        init_state = get_init_state()
        traj, log_probs, costs, times = simulate_actor_episode(actor, H_tensor, M_tensor, sigma_tensor, dt, T, init_state, cov_opt)
        num_steps_sim = len(traj)
        
        # --- 计算 Critic 预测 ---
        v_hat = []
        for n in range(num_steps_sim):
            t_n = times[n]
            x_n = traj[n]
            v_hat.append(critic(t_n, x_n))
        # v_hat 为长度 num_steps_sim 的列表，每个元素为标量 tensor
        
        # --- 计算 Monte-Carlo 回报目标 G_n ---
        # 设终端成本为 g(x_T) = x_T^T R x_T
        terminal_cost = traj[-1] @ R_tensor @ traj[-1]
        G_target = []  # 长度为 num_steps_sim
        # 从末尾反向累加回报
        G = terminal_cost.item()
        G_target.insert(0, G)
        for k in range(num_steps_sim-2, -1, -1):
            # 使用 cost[k] 和对应的 log_prob（注意：critic训练时把 log_prob 作为常数处理）
            G = (costs[k] + tau * log_probs[k].detach().item()) * dt + G
            G_target.insert(0, G)
        # 将 G_target 转成 tensor
        G_target = torch.tensor(G_target, dtype=torch.float32)
        
        # --- Critic Loss: 均方误差 ---
        critic_loss = 0
        for n in range(num_steps_sim):
            critic_loss = critic_loss + (v_hat[n] - G_target[n])**2
        critic_loss = critic_loss / num_steps_sim
        
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()
        
        # --- Actor Loss ---
        # 根据算法 3：使用 TD 差分近似
        actor_loss = 0
        for n in range(num_steps_sim - 1):
            delta_v = v_hat[n+1] - v_hat[n]
            # 注意：log_prob 本身是可微的
            actor_loss_term = log_probs[n] * (delta_v + (costs[n] + tau * log_probs[n]) * dt)
            actor_loss = actor_loss - actor_loss_term  # 因为我们希望上升梯度方向
        actor_loss = actor_loss / (num_steps_sim - 1)
        
        optimizer_actor.zero_grad()
        actor_loss.backward(retain_graph=True)
        optimizer_actor.step()
        
        # 记录误差：这里用当前 actor 与理论最优均值（来自 Soft_LQR 的最优控制）比较
        S0 = torch.tensor(soft_lqr.get_nearest_S(0), dtype=torch.float32)
        mu_star = -torch.linalg.inv(var_matrix) @ M_tensor.T @ S0 @ init_state.unsqueeze(1)
        mu_actor = actor(init_state).unsqueeze(1)
        error = torch.norm(mu_actor - mu_star).item()
        error_list.append(error)
        actor_loss_list.append(actor_loss.item())
        critic_loss_list.append(critic_loss.item())
        
        if (ep + 1) % 50 == 0:
            print(f"Episode {ep+1}/{num_episodes}, Actor Loss: {actor_loss.item():.6f}, Critic Loss: {critic_loss.item():.6f}, Error: {error:.6f}")
    
    # 绘制误差及损失曲线
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(error_list)
    plt.xlabel("Episode")
    plt.ylabel("Actor Mean Error")
    plt.title("Actor Mean Error vs Episodes")
    plt.grid(True)
    
    plt.subplot(1,2,2)
    plt.plot(actor_loss_list, label="Actor Loss")
    plt.plot(critic_loss_list, label="Critic Loss")
    plt.xlabel("Episode")
    plt.ylabel("Loss")
    plt.title("Loss vs Episodes")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return actor, critic

# ==========================================
# 6. 运行 Actor–Critic 算法
# ==========================================
if __name__ == "__main__":
    trained_actor, trained_critic = train_actor_critic(num_episodes=500, dt=0.01, T=1.0, actor_lr=1e-4, critic_lr=1e-1)
