In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
from torch.distributions import MultivariateNormal
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

###############################################################################
# 1) SoftLQR Class
#    - Solves the Riccati ODE for the soft LQR problem.
#    - Provides a method to compute the optimal value function v*(t, x) and the optimal control.
###############################################################################
class SoftLQR:
    def __init__(self, H, M, C, D, R, sigma, time_grid):
        # Convert PyTorch tensors to numpy arrays.
        self.H = H.numpy()
        self.M = M.numpy()
        self.C = C.numpy()
        self.D = D.numpy()
        self.R = R.numpy()
        self.sigma = sigma.numpy()
        self.time_grid = time_grid.numpy()  # time grid covering [0, T]
        self.solution_finite = None
        self.solution_infinite = None
        self.solution_history = []  # Stores S(t) at each time point.
        self.stability_results = []
        self.tau = None
        self.gamma = None

    def solve_ricatti_finite(self, T_max, tau=0.5, gamma=1.0):
        """
        Solves the Riccati ODE backward from T_max to 0 with S(T_max)=R.
        Stores the solution history.
        """
        self.tau = tau
        self.gamma = gamma
        self.solution_history.clear()

        def riccati_rhs(t, S_flat):
            S = S_flat.reshape(self.R.shape)
            gamma_adjust = tau / (2 * gamma**2)
            inverse_term = np.linalg.inv(self.D + gamma_adjust * np.eye(self.D.shape[0]) + 1e-8 * np.eye(self.D.shape[0]))
            S_dot = -(self.H.T @ S + S @ self.H + self.C - S.T @ self.M @ inverse_term @ self.M.T @ S)
            return S_dot.ravel()

        S_initial = self.R.ravel()
        sol = solve_ivp(riccati_rhs, [T_max, 0], S_initial, method='BDF',
                        t_eval=np.linspace(T_max, 0, 500),
                        max_step=10.0, atol=1e-4, rtol=1e-3)
        if not sol.success:
            raise RuntimeError(f"Riccati ODE Solver failed: {sol.message}")

        for i in range(sol.y.shape[1]):
            self.solution_history.append(sol.y[:, i].reshape(self.R.shape).astype(np.float32))
        self.solution_finite = torch.from_numpy(sol.y[:, -1].reshape(self.R.shape).astype(np.float32))

    def solve_ricatti_infinite(self, tau=0.5, gamma=1.0):
        from scipy.linalg import solve_continuous_are, LinAlgError
        A = self.H; B = self.M; Q = self.C
        gamma_adjust = tau / (2 * gamma**2)
        try:
            S_ss = solve_continuous_are(A, B, Q, self.D + gamma_adjust * np.eye(self.D.shape[0]))
            self.solution_infinite = torch.from_numpy(S_ss.astype(np.float32))
        except LinAlgError:
            raise RuntimeError("Singular matrix encountered during CARE solution.")

    def get_S_at_time(self, t):
        """
        Returns S(t) (as a torch tensor) at the closest time point in solution_history.
        """
        if not self.solution_history:
            raise ValueError("Run solve_ricatti_finite first.")
        closest_time_index = np.argmin(np.abs(self.time_grid - t))
        return torch.tensor(self.solution_history[closest_time_index], dtype=torch.float32)

    def calculate_trace_integral_simpson(self, start_index):
        """
        Approximates ∫_{t}^{T} tr(σσᵀS(r))dr using Simpson's rule.
        """
        trace_integral = 0.0
        n = len(self.solution_history)
        if (n - start_index) < 3:
            return 0.0
        dt = self.time_grid[1] - self.time_grid[0]
        end_index = n if (n - start_index) % 2 == 1 else n - 1
        for i in range(start_index, end_index - 1, 2):
            S_i = self.solution_history[i]
            S_ip1 = self.solution_history[i+1]
            S_ip2 = self.solution_history[i+2]
            trace_i = np.trace(self.sigma @ self.sigma.T @ S_i)
            trace_ip1 = np.trace(self.sigma @ self.sigma.T @ S_ip1)
            trace_ip2 = np.trace(self.sigma @ self.sigma.T @ S_ip2)
            trace_integral += (trace_i + 4*trace_ip1 + trace_ip2) * dt / 3
        return trace_integral

    def calculate_control_problem_value(self, t, x):
        """
        Computes the optimal value:
          v*(t,x) = xᵀS(t)x + ∫ₜ^T tr(σσᵀS(r))dr + (T-t)*kappa,
        where kappa = -τ ln( (τ^(m/2))/(γ^m)*sqrt(det((D+τ/(2γ²)I)⁻¹) ).
        """
        S_t = self.get_S_at_time(t)
        x_np = x.numpy()
        quadratic_term = x_np @ (S_t.numpy() @ x_np)
        m = S_t.shape[0]
        gamma_adjust = self.tau / (2 * self.gamma**2)
        inverse_sigma = np.linalg.inv(self.D + gamma_adjust * np.eye(m))
        det_sigma = np.linalg.det(inverse_sigma)
        kappa = -self.tau * np.log((self.tau**(m/2))/(self.gamma**m) * math.sqrt(det_sigma))
        start_index = np.argmin(np.abs(self.time_grid - t))
        trace_integral = self.calculate_trace_integral_simpson(start_index)
        horizon = self.time_grid[-1] - t
        return torch.tensor(quadratic_term + trace_integral + horizon * kappa, dtype=torch.float32)

    def optimal_control(self, t, x, tau=0.5, gamma=1.0):
        """
        Computes the optimal control distribution parameters from the Riccati solution:
         μ* = -Σ_opt Mᵀ S(t)x, with Σ_opt = (D + (τ/(2γ²)) I)⁻¹,
         and covariance = τ * (Σ_opt)⁻¹.
        """
        S_t = self.get_S_at_time(t).numpy()
        m = self.D.shape[0]
        Sigma = np.linalg.inv(self.D + (tau/(2*gamma**2))*np.eye(m))
        mean_opt = -Sigma @ self.M.T @ S_t @ x.numpy()
        cov_opt = tau * np.linalg.inv(Sigma)
        return mean_opt, cov_opt


###############################################################################
# 2) LQREnvironmentWithPolicy
#    - Simulates the continuous-time LQR dynamics with Euler-Maruyama steps.
#    - Each step returns the new state and the immediate cost.
###############################################################################
class LQREnvironmentWithPolicy:
    def __init__(self, H, M, C, D, R, sigma, gamma, initial_distribution, T, dt):
        """
        H, M, C, D, R, sigma: same LQR matrices.
        gamma: scale parameter for the reference distribution (not a learning rate).
        initial_distribution: used to sample the initial state.
        T, dt: final time and time step for the environment simulation.
        """
        self.H = H
        self.M = M
        self.C = C
        self.D = D
        self.R = R
        self.sigma = sigma
        self.gamma = gamma
        self.initial_distribution = initial_distribution
        self.T = T
        self.dt = dt
        self.device = torch.device('cpu')
        self.current_state = None

        # N = number of steps in one episode
        self.N = int(T / dt)
        self.action_dim = M.size(1)

    def sample_initial_state(self):
        """
        Sample the initial state X0 from the provided distribution.
        """
        self.current_state = self.initial_distribution.to(self.device)
        return self.current_state

    def step(self, action):
        """
        One Euler-Maruyama step:
          X_{n+1} = X_n + (H X_n + M a_n) dt + sigma * sqrt(dt) * noise.
        Returns (new_state, cost).
        """
        action = action.view(-1, 1)
        noise = torch.randn((2,1), dtype=torch.float32, device=self.device) * math.sqrt(self.dt)

        current_state_col = self.current_state.unsqueeze(1)
        new_state_col = current_state_col + (self.H @ current_state_col + self.M @ action)*self.dt + self.sigma @ noise
        new_state = new_state_col.squeeze()

        # cost = (xᵀ C x + aᵀ D a)*dt
        state_cost = (current_state_col.T @ self.C @ current_state_col).item()
        action_cost = (action.T @ self.D @ action).item()
        cost = (state_cost + action_cost) * self.dt

        self.current_state = new_state
        return self.current_state, cost

    def observe_terminal_cost(self):
        """
        Return the terminal cost g_T = xᵀ R x at final time T.
        """
        return (self.current_state.T @ self.R @ self.current_state).item()

    def f(self, action, t, x):
        """
        The immediate 'physical' cost function f(x,a) = xᵀ C x + aᵀ D a.
        (Used in log-prob computations for the "fixed" policy.)
        """
        xCx = (x.T @ self.C @ x).item()
        aDa = (action.T @ self.D @ action).item()
        return xCx + aDa

    def gaussian_quadratic_integral(self):
        """
        Compute the normalizing constant for the distribution exp(-f(x,a)).
        This is used by the 'fixed_policy_log_prob' for debugging or reference.
        """
        try:
            epsilon = 1e-8
            adjusted_matrix = torch.eye(self.action_dim)/(2*self.gamma**2) - self.D
            adjusted_matrix += torch.eye(self.action_dim)*epsilon
            precision_matrix = torch.inverse(adjusted_matrix)
            integral_value = torch.sqrt((2*np.pi)**self.action_dim * torch.linalg.det(precision_matrix)).item()
            return integral_value
        except torch.linalg.LinAlgError as e:
            print("Matrix inversion failed:", e)
            return float('inf')

    def fixed_policy_log_prob(self, action, t, state):
        """
        This is the log of π(a|x) ∝ exp( -f(x,a) ), with some reference normalizing constant.
        It's not used in the actor training, but for demonstration if you want a fixed policy.
        """
        f_atx = self.f(action, t, state)
        integral_value = self.gaussian_quadratic_integral()
        log_denominator = np.log(integral_value)
        # log_prob = - f(x,a) - log_denominator
        return -f_atx - log_denominator
    

###############################################################################
# 3) Critic Network (Value Function Approximator)
###############################################################################
class OnlyLinearValueNN(nn.Module):
    def __init__(self, device=torch.device("cpu")):
        """
        A simple critic network that approximates the value function:
          V(t,x) = xᵀ S(t) x + offset(t),
        with S(t) a positive semidefinite matrix output and offset a scalar.
        """
        super().__init__()
        self.device = device
        self.hidden_layer_width = 256

        self.hidden_layer = nn.Linear(1, self.hidden_layer_width, device=device)
        self.activation = nn.ReLU()
        self.matrix_layer = nn.Linear(self.hidden_layer_width, 4, device=device)
        self.offset_layer = nn.Linear(self.hidden_layer_width, 1, device=device)

    def forward(self, t):
        """
        Input: t (shape (1,)) 
        Returns: a 2x2 symmetric matrix S(t) and a scalar offset.
        """
        x = self.activation(self.hidden_layer(t))
        matrix_elements = self.matrix_layer(x)  # shape (1,4)
        matrix = matrix_elements.view(-1, 2, 2)
        S_out = torch.bmm(matrix, matrix.transpose(1,2)) + 1e-3*torch.eye(2).to(self.device)
        offset = self.offset_layer(x)
        return S_out, offset

    def predict_value(self, t, x):
        """
        Returns V(t,x) = xᵀ S(t) x + offset.
        """
        if not isinstance(t, torch.Tensor):
            t = torch.tensor([t], dtype=torch.float32, device=self.device)
        if t.dim() == 0:
            t = t.unsqueeze(0)
        S, offset = self.forward(t)
        S = S.squeeze(0)
        offset = offset.squeeze()
        x_col = x.view(-1,1)
        return (x_col.t() @ S @ x_col).squeeze() + offset


###############################################################################
# 4) Actor Network (Gaussian Policy)
###############################################################################
class GaussianPolicyNN(nn.Module):
    def __init__(self, device=torch.device("cpu"), state_dim=2, hidden_dim=256):
        """
        A neural network that outputs a Gaussian distribution over actions given (t,x).
        """
        super().__init__()
        self.device = device
        self.fc1 = nn.Linear(state_dim+1, hidden_dim, device=device)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim, device=device)
        self.mean_head = nn.Linear(hidden_dim, state_dim, device=device)
        self.log_std = nn.Parameter(torch.zeros(state_dim))
        self.activation = nn.ReLU()

    def forward(self, t, x):
        """
        Input: t (shape (1,)) and x (shape (2,))
        Output: mean (2,) and log_std (2,)
        """
        t = t.to(torch.float32)
        x = x.to(torch.float32)
        if t.dim() == 0:
            t = t.unsqueeze(0)
        inp = torch.cat([t, x], dim=0).unsqueeze(0)  # shape (1,3)
        z = self.activation(self.fc1(inp))
        z = self.activation(self.fc2(z))
        mean = self.mean_head(z).squeeze(0)
        return mean, self.log_std

    def get_action_distribution(self, t, x):
        mean, log_std = self.forward(t, x)
        cov = torch.diag(log_std.exp()**2)
        return MultivariateNormal(mean, cov)

    def sample_action(self, t, x):
        return self.get_action_distribution(t, x).sample()

    def log_prob(self, t, x, action):
        return self.get_action_distribution(t, x).log_prob(action)


###############################################################################
# 5) Actor-Critic Training Algorithm
###############################################################################
def actor_critic_algorithm(env, num_episodes=500, dt=0.005, tau=0.5, actor_lr=1e-3, critic_lr=1e-3):
    """
    Offline actor-critic algorithm:
      - Critic: learns V(t,x) by minimizing MSE between its prediction and the Monte Carlo return.
      - Actor: updated using policy gradient with advantage computed as:
            A_n = f_n + tau*log(pi(a_n|t_n,x_n))*dt + (V(t_{n+1},x_{n+1}) - V(t_n,x_n))
        plus terminal cost.
    """
    device = torch.device("cpu")
    # Initialize networks
    critic_net = OnlyLinearValueNN(device=device).to(device)
    actor_net  = GaussianPolicyNN(device=device).to(device)

    critic_optimizer = optim.Adam(critic_net.parameters(), lr=critic_lr)
    actor_optimizer  = optim.Adam(actor_net.parameters(), lr=actor_lr)

    for episode in range(num_episodes):
        # 1) Roll out an episode
        X0 = env.sample_initial_state().to(device)
        X = X0.clone()
        states, actions, costs, log_probs, times = [], [], [], [], []

        for n in range(env.N):
            t_n = n * dt
            states.append(X)
            times.append(t_n)
            a_n = actor_net.sample_action(torch.tensor(t_n, dtype=torch.float32, device=device), X)
            actions.append(a_n)
            X_next, cost_n = env.step(a_n)
            costs.append(cost_n)
            lp_n = actor_net.log_prob(torch.tensor(t_n, dtype=torch.float32, device=device), states[-1], a_n)
            log_probs.append(lp_n)
            X = X_next.clone()

        g_T = env.observe_terminal_cost()

        # 2) Compute Monte Carlo return for critic update.
        returns = []
        G = g_T
        for cost in reversed(costs):
            G = cost + G  # simple sum over episode (could use discounting if needed)
            returns.insert(0, G)

        returns = torch.tensor(returns, dtype=torch.float32, device=device)

        critic_loss = 0.0
        for n in range(env.N):
            t_n = times[n]
            V_pred = critic_net.predict_value(torch.tensor(t_n, dtype=torch.float32, device=device), states[n])
            critic_loss += (V_pred - returns[n])**2

        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()

        # 3) Actor update using advantage computed with critic.
        actor_loss = 0.0
        for n in range(env.N):
            t_n = times[n]
            if n < env.N - 1:
                V_next = critic_net.predict_value(torch.tensor(times[n+1], dtype=torch.float32, device=device), states[n+1])
            else:
                V_next = torch.tensor(0.0, device=device)
            V_n = critic_net.predict_value(torch.tensor(t_n, dtype=torch.float32, device=device), states[n])
            advantage = costs[n] + tau * log_probs[n] * dt + (V_next - V_n)
            if n == env.N - 1:
                advantage += g_T
            actor_loss = actor_loss - log_probs[n] * advantage

        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        if episode % 50 == 0:
            print(f"Episode {episode}: Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")

    return actor_net, critic_net


###############################################################################
# 6) Comparison Functions
###############################################################################
def optimal_control_from_softlqr(soft_lqr, t, x, tau=0.5, gamma=1.0):
    """
    Computes the optimal control distribution from the Riccati solution:
      μ* = -Σ_opt Mᵀ S(t)x, with Σ_opt = (D + (τ/(2γ²))I)⁻¹,
      and covariance = τ * (Σ_opt)⁻¹.
    """
    S_t = soft_lqr.get_S_at_time(t).numpy()
    m = soft_lqr.D.shape[0]
    Sigma = np.linalg.inv(soft_lqr.D + (tau/(2*gamma**2)) * np.eye(m))
    mean_opt = -Sigma @ soft_lqr.M.T @ S_t @ x.numpy()
    cov_opt = tau * np.linalg.inv(Sigma)
    return mean_opt, cov_opt

def compare_value_functions(critic_net, soft_lqr, times=[0.0, 0.25, 0.5], x_vals=[-2.0, 0.0, 2.0]):
    """
    Compare the critic's learned value V(t,x) with the optimal value v*(t,x) from SoftLQR.
    """
    errors = []
    for t in times:
        for x1 in x_vals:
            for x2 in x_vals:
                x = torch.tensor([x1, x2], dtype=torch.float32)
                V_pred = critic_net.predict_value(torch.tensor([t], dtype=torch.float32), x).item()
                V_opt = soft_lqr.calculate_control_problem_value(t, x).item()
                err = abs(V_pred - V_opt)
                errors.append(err)
    avg_err = np.mean(errors)

    print("Max value function error:", np.max(errors))
    print("Average value function error:", avg_err)
    return avg_err

def compare_control_distributions(actor_net, soft_lqr, times=[0.0, 0.25, 0.5], x_vals=[-2.0, 0.0, 2.0], tau=0.5, gamma=1.0):
    """
    Compare the learned control distribution from the actor with the optimal control
    computed from SoftLQR.
    """
    mean_errors = []
    cov_errors = []
    for t in times:
        for x1 in x_vals:
            for x2 in x_vals:
                x = torch.tensor([x1, x2], dtype=torch.float32)
                # Learned distribution:
                dist_learned = actor_net.get_action_distribution(torch.tensor(t, dtype=torch.float32), x)
                learned_mean = dist_learned.mean.detach().numpy()
                learned_cov = dist_learned.covariance_matrix.detach().numpy()

                # Optimal distribution:
                optimal_mean, optimal_cov = optimal_control_from_softlqr(soft_lqr, t, x, tau, gamma)
                mean_err = np.linalg.norm(learned_mean - optimal_mean)
                cov_err = np.linalg.norm(learned_cov - optimal_cov)
                mean_errors.append(mean_err)
                cov_errors.append(cov_err)

    print("Max control mean error:", np.max(mean_errors))
    print("Average control mean error:", np.mean(mean_errors))

    return np.mean(mean_errors), np.mean(cov_errors)


###############################################################################
# 7) Main Script: Setup, Train, and Compare
###############################################################################

# LQR problem data
H = torch.tensor([[1.0, 1.0],
                    [0.0, 1.0]], dtype=torch.float32) * 0.5
M = torch.tensor([[1.0, 0.0],
                    [0.0, 1.0]], dtype=torch.float32)
C = torch.tensor([[1.0, 0.1],
                    [0.1, 1.0]], dtype=torch.float32)
D = torch.tensor([[1.0, 0.0],
                    [0.0, 1.0]], dtype=torch.float32)
R = torch.tensor([[1.0, 0.3],
                    [0.3, 1.0]], dtype=torch.float32) * 10.0
sigma = torch.eye(2, dtype=torch.float32) * 0.5

# Set final time and discretization for SoftLQR.
T = 0.5
N_points = 100
time_grid = torch.linspace(0, T, N_points)

# Build SoftLQR and solve the Riccati ODE.
soft_lqr = SoftLQR(H, M, C, D, R, sigma, time_grid)
soft_lqr.solve_ricatti_finite(T_max=T, tau=0.5, gamma=1.0)

# Build the environment.
initial_distribution = torch.FloatTensor(2).uniform_(-2.0, 2.0)
env = LQREnvironmentWithPolicy(H, M, C, D, R, sigma, gamma=1.0,
                                initial_distribution=initial_distribution,
                                T=T, dt=0.005)

# Train the actor-critic algorithm.
actor_net, critic_net = actor_critic_algorithm(env, num_episodes=501, dt=0.005, tau=0.5, actor_lr=1e-6, critic_lr=1e-6)
print("Actor-Critic training completed.")

# Compare learned value function (critic) with the optimal value function.
compare_value_functions(critic_net, soft_lqr, times=[0.0, 1/6, 2/6, 0.5], x_vals=[-2.0, 0.0, 2.0])

# Compare learned control distribution (actor) with the optimal control distribution.
compare_control_distributions(actor_net, soft_lqr, times=[0.0, 1/6, 2/6, 0.5], x_vals=[-2.0, 0.0, 2.0], tau=0.5, gamma=1.0)


Episode 0: Critic Loss: 176385.9531, Actor Loss: 148.0489
Episode 50: Critic Loss: 725275.0000, Actor Loss: 320.8912
Episode 100: Critic Loss: 438044.0000, Actor Loss: 165.0351
Episode 150: Critic Loss: 615037.4375, Actor Loss: 184.6769
Episode 200: Critic Loss: 179211.9531, Actor Loss: 96.0362
Episode 250: Critic Loss: 236634.4688, Actor Loss: 183.7198
Episode 300: Critic Loss: 209090.7969, Actor Loss: 108.9466
Episode 350: Critic Loss: 373778.8125, Actor Loss: 140.7334
Episode 400: Critic Loss: 502283.3438, Actor Loss: 204.0352
Episode 450: Critic Loss: 1042841.4375, Actor Loss: 305.3223
Episode 500: Critic Loss: 658376.8750, Actor Loss: 257.3591
Actor-Critic training completed.
Max value function error: 107.92829656600952
Average value function error: 45.151035868045355
Max control mean error: 29.744786486896096
Average control mean error: 13.540088084515704


(13.540088084515704, 0.5303600119706506)