In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
from torch.distributions import MultivariateNormal
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

###############################################################################
# 1) SoftLQR Class
#    - Solves the Riccati ODE for the soft LQR problem.
#    - Provides a method to compute the optimal value function v*(t, x).
###############################################################################
class SoftLQR:
    def __init__(self, H, M, C, D, R, sigma, time_grid):
        """
        H, M, C, D, R, sigma: cost and dynamics matrices for LQR.
        time_grid: a 1D torch.Tensor specifying times from 0 to T (the horizon).
        """
        # Convert the PyTorch tensors to numpy arrays for use in scipy routines.
        self.H = H.numpy()
        self.M = M.numpy()
        self.C = C.numpy()
        self.D = D.numpy()
        self.R = R.numpy()
        self.sigma = sigma.numpy()

        # Store the time grid as a numpy array as well.
        self.time_grid = time_grid.numpy()

        # Will store solutions of the Riccati ODE.
        self.solution_finite = None
        self.solution_infinite = None
        self.solution_history = []
        self.stability_results = []

        # These are the relaxation parameter (tau) and scale parameter (gamma).
        self.tau = None
        self.gamma = None

    def solve_ricatti_finite(self, T_max, tau=0.5, gamma=1.0):
        """
        Solve the finite-horizon Riccati ODE backward from T_max to 0:
          S'(t) + Hᵀ S(t) + S(t) H + C - S(t) M Σ Mᵀ S(t) = 0,
        with terminal condition S(T_max) = R.
        Stores the solution in self.solution_history for each time,
        and self.solution_finite for the final matrix at time 0.
        """
        self.tau = tau
        self.gamma = gamma
        self.solution_history.clear()

        # Right-hand side of the Riccati ODE.
        def riccati_rhs(t, S_flat):
            S = S_flat.reshape(self.R.shape)
            gamma_adjust = tau / (2 * gamma**2)

            # Add a small diagonal to avoid singular matrix issues.
            inverse_term = np.linalg.inv(self.D + gamma_adjust * np.eye(self.D.shape[0])
                                         + 1e-8 * np.eye(self.D.shape[0]))

            # - (Hᵀ S + S H + C - S M Σ Mᵀ S).
            S_dot = -(self.H.T @ S + S @ self.H + self.C
                      - S.T @ self.M @ inverse_term @ self.M.T @ S)
            return S_dot.ravel()

        # Flatten the terminal condition S(T_max) = R.
        S_initial = self.R.ravel()

        # Solve backward from T_max to 0 using 'BDF' method.
        sol = solve_ivp(
            riccati_rhs, [T_max, 0], S_initial, method='BDF',
            t_eval=np.linspace(T_max, 0, 500), max_step=10.0,
            atol=1e-4, rtol=1e-3
        )
        if not sol.success:
            raise RuntimeError(f"Riccati ODE Solver failed: {sol.message}")

        # Store the solution at each time point in solution_history.
        for i in range(sol.y.shape[1]):
            self.solution_history.append(sol.y[:, i].reshape(self.R.shape).astype(np.float32))

        # The final matrix at time 0 is sol.y[:, -1].
        self.solution_finite = torch.from_numpy(sol.y[:, -1].reshape(self.R.shape).astype(np.float32))

    def solve_ricatti_infinite(self, tau=0.5, gamma=1.0):
        """
        Solve the infinite-horizon algebraic Riccati equation using solve_continuous_are.
        Stores the result in self.solution_infinite.
        """
        from scipy.linalg import solve_continuous_are, LinAlgError
        A = self.H
        B = self.M
        Q = self.C
        gamma_adjust = tau / (2 * gamma**2)

        try:
            S_ss = solve_continuous_are(A, B, Q, self.D + gamma_adjust * np.eye(self.D.shape[0]))
            self.solution_infinite = torch.from_numpy(S_ss.astype(np.float32))
        except LinAlgError:
            raise RuntimeError("Singular matrix encountered during CARE solution.")

    def get_S_at_time(self, t):
        """
        Returns the Riccati solution matrix S(t) at the closest point in the stored solution_history.
        """
        if not self.solution_history:
            raise ValueError("Solutions have not been computed. Please run solve_ricatti_finite first.")
        closest_time_index = np.argmin(np.abs(self.time_grid - t))
        return torch.tensor(self.solution_history[closest_time_index], dtype=torch.float32)

    def calculate_trace_integral_simpson(self, start_index):
        """
        Compute the integral of trace(σ σᵀ S(r)) from time_grid[start_index] to the end,
        using Simpson's rule. 
        """
        trace_integral = 0.0
        n = len(self.solution_history)
        if (n - start_index) < 3:
            return 0.0
        dt = self.time_grid[1] - self.time_grid[0]

        # We ensure an odd number of points for Simpson's rule.
        end_index = n if (n - start_index) % 2 == 1 else n - 1

        for i in range(start_index, end_index - 1, 2):
            S_i = self.solution_history[i]
            S_ip1 = self.solution_history[i+1]
            S_ip2 = self.solution_history[i+2]

            trace_i = np.trace(self.sigma @ self.sigma.T @ S_i)
            trace_ip1 = np.trace(self.sigma @ self.sigma.T @ S_ip1)
            trace_ip2 = np.trace(self.sigma @ self.sigma.T @ S_ip2)

            trace_integral += (trace_i + 4*trace_ip1 + trace_ip2) * dt / 3
        return trace_integral

    def calculate_control_problem_value(self, t, x):
        """
        Computes the soft-LQR optimal value function at time t, state x:
          v*(t,x) = xᵀ S(t) x + ∫ₜ^T trace(σ σᵀ S(r)) dr + (T - t)*kappa,
        where kappa = -τ ln( (τ^(m/2)) / (γ^m) * sqrt(det( (D + τ/(2γ²)I)⁻¹ )) ).
        """
        S_t = self.get_S_at_time(t)  # S(t) from the solution_history
        x_np = x.numpy()

        # xᵀ S(t) x
        quadratic_term = x_np @ (S_t.numpy() @ x_np)

        # We compute the constant kappa from tau, gamma, and D.
        m = S_t.shape[0]
        gamma_adjust = self.tau / (2 * self.gamma**2)
        inverse_sigma = np.linalg.inv(self.D + gamma_adjust * np.eye(m))
        det_sigma = np.linalg.det(inverse_sigma)

        # kappa
        C_D_tau_gamma = -self.tau * np.log(
            (self.tau**(m/2)) / (self.gamma**m) * math.sqrt(det_sigma)
        )

        # The integral of trace(σσᵀS(r)) from t to T.
        start_index = np.argmin(np.abs(self.time_grid - t))
        trace_integral = self.calculate_trace_integral_simpson(start_index)

        horizon = self.time_grid[-1] - t
        return torch.tensor(quadratic_term + trace_integral + horizon*C_D_tau_gamma, dtype=torch.float32)


###############################################################################
# 2) LQREnvironmentWithPolicy
#    - Simulates the continuous-time LQR dynamics with Euler-Maruyama steps.
#    - Each step returns the new state and the immediate cost.
###############################################################################
class LQREnvironmentWithPolicy:
    def __init__(self, H, M, C, D, R, sigma, gamma, initial_distribution, T, dt):
        """
        H, M, C, D, R, sigma: same LQR matrices.
        gamma: scale parameter for the reference distribution (not a learning rate).
        initial_distribution: used to sample the initial state.
        T, dt: final time and time step for the environment simulation.
        """
        self.H = H
        self.M = M
        self.C = C
        self.D = D
        self.R = R
        self.sigma = sigma
        self.gamma = gamma
        self.initial_distribution = initial_distribution
        self.T = T
        self.dt = dt
        self.device = torch.device('cpu')
        self.current_state = None

        # N = number of steps in one episode
        self.N = int(T / dt)
        self.action_dim = M.size(1)

    def sample_initial_state(self):
        """
        Sample the initial state X0 from the provided distribution.
        """
        self.current_state = self.initial_distribution.to(self.device)
        return self.current_state

    def step(self, action):
        """
        One Euler-Maruyama step:
          X_{n+1} = X_n + (H X_n + M a_n) dt + sigma * sqrt(dt) * noise.
        Returns (new_state, cost).
        """
        action = action.view(-1, 1)
        noise = torch.randn((2,1), dtype=torch.float32, device=self.device) * math.sqrt(self.dt)

        current_state_col = self.current_state.unsqueeze(1)
        new_state_col = current_state_col + (self.H @ current_state_col + self.M @ action)*self.dt + self.sigma @ noise
        new_state = new_state_col.squeeze()

        # cost = (xᵀ C x + aᵀ D a)*dt
        state_cost = (current_state_col.T @ self.C @ current_state_col).item()
        action_cost = (action.T @ self.D @ action).item()
        cost = (state_cost + action_cost) * self.dt

        self.current_state = new_state
        return self.current_state, cost

    def observe_terminal_cost(self):
        """
        Return the terminal cost g_T = xᵀ R x at final time T.
        """
        return (self.current_state.T @ self.R @ self.current_state).item()

    def f(self, action, t, x):
        """
        The immediate 'physical' cost function f(x,a) = xᵀ C x + aᵀ D a.
        (Used in log-prob computations for the "fixed" policy.)
        """
        xCx = (x.T @ self.C @ x).item()
        aDa = (action.T @ self.D @ action).item()
        return xCx + aDa

    def gaussian_quadratic_integral(self):
        """
        Compute the normalizing constant for the distribution exp(-f(x,a)).
        This is used by the 'fixed_policy_log_prob' for debugging or reference.
        """
        try:
            epsilon = 1e-8
            adjusted_matrix = torch.eye(self.action_dim)/(2*self.gamma**2) - self.D
            adjusted_matrix += torch.eye(self.action_dim)*epsilon
            precision_matrix = torch.inverse(adjusted_matrix)
            integral_value = torch.sqrt((2*np.pi)**self.action_dim * torch.linalg.det(precision_matrix)).item()
            return integral_value
        except torch.linalg.LinAlgError as e:
            print("Matrix inversion failed:", e)
            return float('inf')

    def fixed_policy_log_prob(self, action, t, state):
        """
        This is the log of π(a|x) ∝ exp( -f(x,a) ), with some reference normalizing constant.
        It's not used in the actor training, but for demonstration if you want a fixed policy.
        """
        f_atx = self.f(action, t, state)
        integral_value = self.gaussian_quadratic_integral()
        log_denominator = np.log(integral_value)
        # log_prob = - f(x,a) - log_denominator
        return -f_atx - log_denominator


###############################################################################
# 3) GaussianPolicyNN (Actor)
#    - A neural network that outputs a Gaussian distribution over actions for each (t, x).
###############################################################################
class GaussianPolicyNN(nn.Module):
    def __init__(self, device=torch.device("cpu"), state_dim=2, hidden_dim=256):
        super().__init__()
        self.device = device
        # A small MLP with two hidden layers.
        # Input dimension = state_dim + 1, i.e. [t, x1, x2].
        self.fc1 = nn.Linear(state_dim + 1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        # Output heads: mean_head for the mean (2D), and log_std for the diagonal log-variance (2D).
        self.mean_head = nn.Linear(hidden_dim, state_dim)
        self.log_std = nn.Parameter(torch.zeros(state_dim))

        self.activation = nn.ReLU()

    def forward(self, t, x):
        """
        Forward pass: takes time t and state x, returns mean, log_std for the action distribution.
        """
        t = t.to(torch.float32)
        x = x.to(torch.float32)
        if t.dim() == 0:
            t = t.unsqueeze(0)

        # Concatenate t and x => shape (1,3) after unsqueeze(0).
        input_vec = torch.cat([t, x], dim=0).unsqueeze(0)

        # Pass through the MLP.
        z = self.activation(self.fc1(input_vec))
        z = self.activation(self.fc2(z))

        # mean is shape (2,), log_std is shape (2,)
        mean = self.mean_head(z).squeeze(0)
        return mean, self.log_std

    def get_action_distribution(self, t, x):
        """
        Construct a MultivariateNormal with the mean, cov given by the NN outputs.
        """
        mean, log_std = self.forward(t, x)
        cov = torch.diag(log_std.exp()**2)
        return MultivariateNormal(mean, cov)

    def sample_action(self, t, x):
        """
        Sample an action from the current policy at (t, x).
        """
        dist = self.get_action_distribution(t, x)
        return dist.sample()

    def log_prob(self, t, x, action):
        """
        Return log π_\theta(a | t, x).
        """
        dist = self.get_action_distribution(t, x)
        return dist.log_prob(action)


###############################################################################
# 4) Actor-only training loop
#    - Offline style: collect one full episode, then do one gradient update.
###############################################################################
def offline_actor_algorithm(env, soft_lqr, num_episodes=500, dt=0.005, tau=0.5, actor_lr=1e-3):
    """
    Runs the actor-only policy gradient algorithm using:
      advantage_n = cost_n + tau * log_prob_n * dt + (V_{n+1} - V_n) (+ terminal if last).
    The policy is updated once per episode.

    :param env: LQREnvironmentWithPolicy instance.
    :param soft_lqr: SoftLQR instance with a known baseline v*(t,x).
    :param num_episodes: how many episodes to train.
    :param dt: time step used in the environment (and for log_prob * dt).
    :param tau: the relaxation parameter weighting the entropy term.
    :param actor_lr: learning rate for the policy network.
    """
    device = torch.device("cpu")
    policy_net = GaussianPolicyNN(device=device).to(device)
    optimizer = optim.Adam(policy_net.parameters(), lr=actor_lr)

    losses = []  # track the "actor loss" each episode

    for episode in range(num_episodes):
        # Sample initial state from the environment
        X0 = env.sample_initial_state().to(device)
        X = X0.clone()

        # Lists to store trajectory info
        states, actions, costs, log_probs, times = [], [], [], [], []

        # 1) Roll out one full episode
        for n in range(env.N):
            t_n = n * dt
            states.append(X)
            times.append(t_n)

            # sample action from current policy
            a_n = policy_net.sample_action(t=torch.tensor(t_n, dtype=torch.float32, device=device), x=X)
            actions.append(a_n)

            # step environment
            X_next, cost_n = env.step(a_n)
            costs.append(cost_n)

            # store log_prob for the gradient update
            lp = policy_net.log_prob(t=torch.tensor(t_n, dtype=torch.float32, device=device),
                                     x=states[-1], action=a_n)
            log_probs.append(lp)

            X = X_next.clone()

        # terminal cost g_T
        g_T = env.observe_terminal_cost()

        # 2) Compute the policy gradient objective => "actor_loss"
        actor_loss = torch.tensor(0.0, device=device)

        for n in range(env.N):
            t_n = times[n]
            X_n = states[n]

            # baseline advantage uses V_next - V_n
            if n < env.N - 1:
                t_next = times[n+1]
                X_next = states[n+1]
                V_next = soft_lqr.calculate_control_problem_value(t_next, X_next).to(device)
            else:
                V_next = torch.tensor(0.0, device=device)

            V_n = soft_lqr.calculate_control_problem_value(t_n, X_n).to(device)

            # advantage = cost + tau*log_prob*dt + (V_next - V_n)
            adv = costs[n] + tau * log_probs[n] * dt + (V_next - V_n)

            # if last step, add terminal cost
            if n == env.N - 1:
                adv += g_T

            # Summation of - log_prob * advantage => negative because we want to maximize
            actor_loss = actor_loss - log_probs[n] * adv

        # 3) One gradient update
        optimizer.zero_grad()
        actor_loss.backward()
        optimizer.step()

        losses.append(actor_loss.item())

        # Print progress every 50 episodes
        if episode % 50 == 0:
            print(f"Episode {episode}, Actor Loss: {actor_loss.item()}")

    return policy_net

###############################################################################
# 5) Rollout and evaluation
###############################################################################
def rollout_cost_from_t_x(H, M, C, D, R, sigma, dt, t_start, T_final, x_init, policy_net, steps_noise_seed=None):
    """
    Simulate from time t_start to T_final with Euler-Maruyama steps, using the policy_net
    to choose actions. Accumulate cost (xᵀCx + aᵀDa)*dt, plus terminal cost xᵀ R x.
    Returns the total cost from that sub-trajectory.
    """
    if steps_noise_seed is not None:
        torch.manual_seed(steps_noise_seed)
    n_steps = int((T_final - t_start) / dt)
    x = x_init.clone()
    total_cost = 0.0

    for n in range(n_steps):
        t_n = t_start + n*dt
        a = policy_net.sample_action(t=torch.tensor(t_n, dtype=torch.float32), x=x)

        # cost = (xᵀCx + aᵀDa)*dt
        state_cost = (x.unsqueeze(0) @ C @ x.unsqueeze(1)).item()
        action_cost = (a.unsqueeze(0) @ D @ a.unsqueeze(1)).item()
        total_cost += (state_cost + action_cost) * dt

        # next state
        noise = torch.randn((2,), dtype=torch.float32) * math.sqrt(dt)
        x = x + (H @ x + M @ a)*dt + sigma @ noise

    # terminal cost
    terminal_cost = (x.unsqueeze(0) @ R @ x.unsqueeze(1)).item()
    total_cost += terminal_cost
    return total_cost

###############################################################################
# 6) Compare Learned vs. Optimal Controls
###############################################################################
def optimal_control_from_softlqr(soft_lqr, t, x, tau=0.2, gamma=1.0):
    """
    Computes the optimal control distribution from the SoftLQR solution:
      mean_opt = - Σ Mᵀ S(t) x,
      Σ_opt = (D + (τ/(2γ²))I)⁻¹,
      cov_opt = τ * (Σ_opt)⁻¹.
    """
    S_t = soft_lqr.get_S_at_time(t).numpy()
    m = soft_lqr.D.shape[0]

    # Σ_opt
    Sigma = np.linalg.inv(soft_lqr.D + (tau/(2*gamma**2)) * np.eye(m))

    # mean = - Σ_opt Mᵀ S(t) x
    mean_opt = - Sigma @ soft_lqr.M.T @ S_t @ x.numpy()

    # covariance = τ Σ_opt⁻¹
    cov_opt = tau * np.linalg.inv(Sigma)
    return mean_opt, cov_opt

def compare_controls(policy_net, soft_lqr,
                     times=[0.0, 1/6, 2/6, 0.5],
                     x_min=-3.0, x_max=3.0, grid_size=7,
                     tau=0.5, gamma=1.0):
    """
    Compare the learned policy's mean & covariance with the theoretical optimal control
    at a grid of (t,x). We measure the average Euclidean error for means and average
    Frobenius norm error for covariances, and print them.
    """
    x_vals = np.linspace(x_min, x_max, grid_size)
    mean_errors = []
    cov_errors = []

    for t_val in times:
        for x1 in x_vals:
            for x2 in x_vals:
                x_tensor = torch.tensor([x1, x2], dtype=torch.float32)

                # Learned distribution
                dist_learned = policy_net.get_action_distribution(torch.tensor(t_val, dtype=torch.float32), x_tensor)
                learned_mean = dist_learned.mean.detach().numpy()
                learned_cov = dist_learned.covariance_matrix.detach().numpy()

                # Optimal distribution
                optimal_mean, optimal_cov = optimal_control_from_softlqr(soft_lqr, t_val, x_tensor, tau, gamma)

                # Euclidean error for means
                mean_err = np.linalg.norm(learned_mean - optimal_mean)
                # Frobenius norm for covariance error
                cov_err = np.linalg.norm(learned_cov - optimal_cov)

                mean_errors.append(mean_err)
                cov_errors.append(cov_err)

    print("Max mean error:", np.max(mean_errors))
    print("Average mean error:", np.mean(mean_errors))

    return np.mean(mean_errors), np.mean(cov_errors)

###############################################################################
# 7) Main Script
###############################################################################

# 1) Set up LQR problem data.
H = torch.tensor([[1.0, 1.0],
                    [0.0, 1.0]], dtype=torch.float32) * 0.5
M = torch.tensor([[1.0, 0.0],
                    [0.0, 1.0]], dtype=torch.float32)
C = torch.tensor([[1.0, 0.1],
                    [0.1, 1.0]], dtype=torch.float32)
D = torch.tensor([[1.0, 0.0],
                    [0.0, 1.0]], dtype=torch.float32)
R = torch.tensor([[1.0, 0.3],
                    [0.3, 1.0]], dtype=torch.float32) * 10.0
sigma = torch.eye(2, dtype=torch.float32) * 0.5

# Final time T = 0.5, discretized into 100 points for the Riccati solution.
T = 0.5
N_points = 100
time_grid = torch.linspace(0, T, N_points)

# 2) Build a SoftLQR instance and solve the Riccati ODE.
soft_lqr = SoftLQR(H, M, C, D, R, sigma, time_grid)
soft_lqr.solve_ricatti_finite(T_max=T, tau=0.5, gamma=1.0)

# 3) Build environment with the same T, dt=0.005
initial_distribution = torch.FloatTensor(2).uniform_(-2.0, 2.0)
env_policy = LQREnvironmentWithPolicy(
    H, M, C, D, R, sigma, gamma=1.0,
    initial_distribution=initial_distribution,
    T=T, dt=0.005
)

# 4) Train the actor-only policy for 300 episodes, with tau=0.5 in the advantage
#    and a very small learning rate actor_lr=1e-6 to keep updates stable.
policy_net = offline_actor_algorithm(
    env_policy, soft_lqr,
    num_episodes=300, dt=0.005,
    tau=0.5, actor_lr=1e-6
)

# 5) Compare the learned control distribution with the optimal control distribution.
avg_mean_err, avg_cov_err = compare_controls(
    policy_net, soft_lqr,
    times=[0.0, 1/6, 2/6, 0.5],
    x_min=-3.0, x_max=3.0, grid_size=7,
    tau=0.5, gamma=1.0
)


  return (self.current_state.T @ self.R @ self.current_state).item()


Episode 0, Actor Loss: -13.504015922546387
Episode 50, Actor Loss: -15.347360610961914
Episode 100, Actor Loss: 7.568218231201172
Episode 150, Actor Loss: 4.637603759765625
Episode 200, Actor Loss: -13.90100383758545
Episode 250, Actor Loss: -9.350479125976562
Max mean error: 44.35244825049869
Average mean error: 16.685590350926848
