In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import os
import time
import imageio  # For creating videos from frames
import matplotlib.cm as cm
from matplotlib.colors import Normalize


In [None]:
# ----------------------------------------------------
# Simple PID Controller 
# ----------------------------------------------------
class PID:
    def __init__(self, Kp, Ki, Kd, setpoint=0.0):
        self.Kp = Kp
        self.Ki = Ki
        self.Kd = Kd
        self.setpoint = setpoint
        self.integral = 0.0
        self.prev_error = 0.0
        self.windup_limit = 5.0  # Anti-windup limit

    def reset(self):
        self.integral = 0.0
        self.prev_error = 0.0

    def update(self, measured, dt):
        error = self.setpoint - measured
        self.integral += error * dt
        if abs(self.integral) > self.windup_limit:
            self.integral = np.sign(self.integral) * self.windup_limit
        derivative = (error - self.prev_error) / dt if dt > 0 else 0.0
        output = self.Kp * error + self.Ki * self.integral + self.Kd * derivative
        self.prev_error = error
        return output

In [None]:
# ----------------------------------------------------
# Quaternion Utilities (Includes inverse and angle calculation)
# ----------------------------------------------------
def quat_multiply(q, r):
    w1, x1, y1, z1 = q
    w2, x2, y2, z2 = r
    return np.array([
        w1*w2 - x1*x2 - y1*y2 - z1*z2,
        w1*x2 + x1*w2 + y1*z2 - z1*y2,
        w1*y2 - x1*z2 + y1*w2 + z1*x2,
        w1*z2 + x1*y2 - y1*x2 + z1*w2
    ])

def quat_normalize(q):
    norm = np.linalg.norm(q)
    return q if norm < 1e-10 else q / norm

def quat_conjugate(q):
    w, x, y, z = q
    return np.array([w, -x, -y, -z])

def quat_to_rot_matrix(q):
    q = quat_normalize(q)
    w, x, y, z = q
    R = np.array([
        [1-2*(y**2+z**2),   2*(x*y - z*w),   2*(x*z + y*w)],
        [2*(x*y + z*w),   1-2*(x**2+z**2),   2*(y*z - x*w)],
        [2*(x*z - y*w),   2*(y*z + x*w),   1-2*(x**2+y**2)]
    ])
    return R

def quat_from_axis_angle(axis, angle):
    axis = np.asarray(axis)
    axis = axis / np.linalg.norm(axis)
    half_angle = angle / 2.0
    w = np.cos(half_angle)
    x, y, z = axis * np.sin(half_angle)
    return quat_normalize(np.array([w, x, y, z]))

def quat_diff_angle(q1, q2):
    """Calculate the angle between two quaternions."""
    q_diff = quat_multiply(quat_conjugate(q1), q2)
    # Ensure w is within [-1, 1] due to potential numerical errors
    w = np.clip(q_diff[0], -1.0, 1.0)
    angle = 2.0 * np.arccos(w)
    return angle

def quat_distance(q, q_ref):
    # Uses angle difference for a more intuitive metric
    angle = quat_diff_angle(q, q_ref)
    # Normalize angle to [0, 1] where 0 is identical, 1 is 180 degrees apart
    return angle / np.pi
    # Original metric:
    # inner = np.dot(q, q_ref)
    # inner = max(min(inner, 1.0), -1.0)
    # return 1 - inner**2 # This penalizes 180 deg rotations the same as 0 deg

def euler_to_quat(roll, pitch, yaw):
    cy = np.cos(yaw * 0.5)
    sy = np.sin(yaw * 0.5)
    cp = np.cos(pitch * 0.5)
    sp = np.sin(pitch * 0.5)
    cr = np.cos(roll * 0.5)
    sr = np.sin(roll * 0.5)

    w = cr * cp * cy + sr * sp * sy
    x = sr * cp * cy - cr * sp * sy
    y = cr * sp * cy + sr * cp * sy
    z = cr * cp * sy - sr * sp * cy
    return quat_normalize(np.array([w, x, y, z]))


In [None]:
# ----------------------------------------------------
# 3D Quadrotor Dynamics with Damping
# ----------------------------------------------------
class QuadrotorDynamics3D:
    def __init__(self, mass=1.21, arm_length=0.15):
        self.m = mass
        # More realistic Inertia Tensor based on X-configuration drones
        Ixx = 0.00706 # kg*m^2
        Iyy = 0.00706 # kg*m^2
        Izz = 0.0136  # kg*m^2
        self.J = np.diag([Ixx, Iyy, Izz])
        self.J_inv = np.linalg.inv(self.J) # Precompute inverse

        self.g = np.array([0, 0, -9.81])
        self.Tmin = 0.3   # Minimum total thrust [N]
        self.Tmax = 19.0  # Maximum total thrust [N] - Increased for agility
        self.tau_omega = 0.1 # Time constant for angular velocity control response

        # Damping coefficients
        self.linear_damping = 0.1   # Linear velocity damping (Air resistance)
        self.angular_damping = 0.15 # Angular velocity damping (Rotational friction)

        # Control limits (Max rates)
        self.max_omega_rate = 5.0 # rad/s per axis
        self.max_thrust_rate = 50.0 # N/s (Prevents sudden thrust jumps)

        # Visualization parameters (used by external visualizer)
        self.arm_length = arm_length
        self.prop_radius = 0.05

    def state_derivative(self, state, control):
        p = state[0:3]      # Position
        q = state[3:7]      # Quaternion [w, x, y, z]
        v = state[7:10]     # Linear Velocity
        omega = state[10:13]  # Angular Velocity

        F_t = control[0]     # Total Thrust command
        omega_d = control[1:4] # Desired Angular Velocity command

        # --- Translational Dynamics ---
        p_dot = v

        # Rotation matrix (Body to World)
        R_mat = quat_to_rot_matrix(q)

        # Thrust in world frame (assuming thrust is along body z-axis)
        thrust_body = np.array([0, 0, F_t])
        thrust_world = R_mat @ thrust_body

        # Damping force (acts opposite to velocity)
        damping_force = -self.linear_damping * v

        # Net force
        F_net = thrust_world + self.m * self.g + damping_force

        # Acceleration
        v_dot = F_net / self.m

        # --- Rotational Dynamics ---
        # Quaternion derivative
        omega_quat = np.concatenate(([0.0], omega))
        q_dot = 0.5 * quat_multiply(q, omega_quat)

        # Simplified Angular Velocity Control:
        # Models the response towards the desired angular velocity omega_d
        omega_dot = (omega_d - omega) / self.tau_omega

        # Add angular damping torque (acts opposite to angular velocity)
        # Note: This is a simplified damping. A more rigorous model uses cross products.
        damping_torque = -self.angular_damping * omega
        omega_dot += damping_torque # Simplified addition, assumes diagonal J scaling absorbed

        # --- Combine Derivatives ---
        deriv = np.zeros(13)
        deriv[0:3] = p_dot
        deriv[3:7] = q_dot
        deriv[7:10] = v_dot
        deriv[10:13] = omega_dot
        return deriv

    def rk4_integrate(self, state, control, dt):
        # Check for NaN/Inf states before starting integration
        if np.any(np.isnan(state)) or np.any(np.isinf(state)):
            print("Warning: NaN or Inf state detected before RK4 step. Returning current state.")
            # Attempt to salvage quaternion if possible
            safe_state = state.copy()
            if np.any(np.isnan(safe_state[3:7])) or np.any(np.isinf(safe_state[3:7])) or np.linalg.norm(safe_state[3:7]) < 1e-6:
                 safe_state[3:7] = np.array([1.0, 0.0, 0.0, 0.0]) # Reset to default orientation
            else:
                 safe_state[3:7] = quat_normalize(safe_state[3:7])
            return safe_state

        try:
            k1 = self.state_derivative(state, control)
            k2 = self.state_derivative(state + 0.5*dt*k1, control)
            k3 = self.state_derivative(state + 0.5*dt*k2, control)
            k4 = self.state_derivative(state + dt*k3, control)

            # Check for NaN/Inf in intermediate steps
            if (np.any(np.isnan(k1)) or np.any(np.isnan(k2)) or
                np.any(np.isnan(k3)) or np.any(np.isnan(k4)) or
                np.any(np.isinf(k1)) or np.any(np.isinf(k2)) or
                np.any(np.isinf(k3)) or np.any(np.isinf(k4))):
                print("Warning: NaN or Inf detected during RK4 computation. Using Euler step.")
                next_state = state + dt * k1 # Fallback to Euler step if RK4 fails
            else:
                next_state = state + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)

            # --- State Post-processing & Safety ---
            # Normalize quaternion
            next_state[3:7] = quat_normalize(next_state[3:7])

            # Safety limits on velocities
            vel_norm = np.linalg.norm(next_state[7:10])
            max_vel = 30.0
            if vel_norm > max_vel:
                # print(f"Warning: Clipping velocity from {vel_norm:.2f} to {max_vel}")
                next_state[7:10] = next_state[7:10] * (max_vel / vel_norm)

            # Safety limits on angular velocities
            omega_norm = np.linalg.norm(next_state[10:13])
            max_omega = 8.0 # rad/s
            if omega_norm > max_omega:
                # print(f"Warning: Clipping angular velocity from {omega_norm:.2f} to {max_omega}")
                next_state[10:13] = next_state[10:13] * (max_omega / omega_norm)

            # Check final state
            if np.any(np.isnan(next_state)) or np.any(np.isinf(next_state)):
                 print("ERROR: NaN or Inf state detected AFTER RK4 step. Returning previous state.")
                 return state # Return previous valid state

            return next_state

        except Exception as e:
            print(f"ERROR in RK4 integration: {e}")
            print(f"State: {state}")
            print(f"Control: {control}")
            # Attempt to return previous valid state
            return state

    def dynamics(self, state, control, dt):
        """Applies control limits and integrates the dynamics."""
        # Clip controls based on limits
        F_t = np.clip(control[0], self.Tmin, self.Tmax)
        omega_d = np.clip(control[1:4], -self.max_omega_rate, self.max_omega_rate)

        # Optional: Limit rate of change of thrust (smoother control)
        # prev_thrust = state.get('last_thrust', self.m * 9.81) # Requires state to store last thrust
        # max_thrust_change = self.max_thrust_rate * dt
        # F_t = np.clip(F_t, prev_thrust - max_thrust_change, prev_thrust + max_thrust_change)
        # state['last_thrust'] = F_t # Update stored thrust

        control_clipped = np.concatenate(([F_t], omega_d))
        next_state = self.rk4_integrate(state, control_clipped, dt)
        return next_state, control_clipped



In [None]:
# ----------------------------------------------------
# MPPI Controller (Adjusted Costs and Logic)
# ----------------------------------------------------
class MPPIController3D:
    def __init__(self, quad_model, num_samples=200, horizon=20, dt=0.05, lambda_=0.05, name="MPPI", traj_type="Circle"):
        self.model = quad_model
        self.name = name
        self.num_samples = num_samples
        self.horizon = horizon # Planning horizon steps
        self.dt = dt         # Timestep used for simulation WITHIN MPPI rollouts
        self.lambda_ = lambda_ # Temperature parameter for weighting
        self.control_dim = 4 # [Total Thrust, omega_x_d, omega_y_d, omega_z_d]
        self.state_dim = 13  # [x,y,z, w,qx,qy,qz, vx,vy,vz, om_x,om_y,om_z]
        self.traj_type = traj_type

        # Initial nominal control sequence (hover)
        self.nominal_u = np.tile(np.array([quad_model.m * 9.81, 0, 0, 0]), (self.horizon, 1))

        # Noise standard deviation for exploration - trajectory specific
        if traj_type == "Circle":
            # Circle needs more aggressive exploration for tight turns
            self.noise_std = np.array([2.8, 0.7, 0.7, 0.5])
        elif traj_type == "Figure8":
            # Figure8 needs more refined control for complex transitions
            self.noise_std = np.array([2.5, 0.6, 0.6, 0.4])
        else: # TiltedCircle
            # Standard noise parameters
            self.noise_std = np.array([2.5, 0.6, 0.6, 0.4])

        # --- Cost Function Weights ---
        # State tracking costs (make position dominant, velocity secondary)
        if traj_type == "Circle":
            # For Circle trajectory - focus on smooth circular motion with enhanced position tracking
            self.c_ref_p = 250.0     # Position tracking (High priority for circular path)
            self.c_ref_q = 10.0      # Orientation tracking (Moderate) 
            self.c_ref_v = 25.0      # Velocity tracking (Moderate)
            self.c_ref_omega = 4.0   # Angular velocity tracking (Lower)
            self.R = np.diag([0.003, 0.04, 0.04, 0.08]) # Control magnitude cost
            self.R_delta = np.diag([0.008, 0.08, 0.08, 0.15]) # Control smoothness cost
            self.c_terminal_weight = 2.0 # Higher terminal weight for loop closure
            
            # Progressive reward structure for circle
            self.reward_thresholds = [0.2, 0.5, 1.0, 1.5, 2.0]  # Tighter thresholds
            self.reward_values = [100.0, 50.0, 25.0, 10.0, 5.0] # Higher rewards
            self.heading_reward_weight = 15.0  # Stronger heading alignment
            
        elif traj_type == "Figure8":
            # For Figure8 trajectory - focus on smooth transitions with proper velocity tracking
            self.c_ref_p = 200.0     # Position tracking (High, but allowing some flexibility)
            self.c_ref_q = 8.0       # Orientation tracking (Reduced for smoother turns)
            self.c_ref_v = 30.0      # Velocity tracking (Higher - important for figure 8)
            self.c_ref_omega = 3.0   # Angular velocity tracking (Lower)
            self.R = np.diag([0.002, 0.03, 0.03, 0.06]) # Control magnitude cost (reduced)
            self.R_delta = np.diag([0.006, 0.06, 0.06, 0.12]) # Control smoothness cost
            self.c_terminal_weight = 1.8 # Moderate terminal weight
            
            # Progressive reward structure for figure 8
            self.reward_thresholds = [0.2, 0.5, 1.0, 1.5, 2.0]
            self.reward_values = [80.0, 40.0, 20.0, 10.0, 5.0]
            self.heading_reward_weight = 20.0  # Higher heading reward for figure 8
            
        else:  # TiltedCircle
            # For TiltedCircle - balance altitude control with position tracking
            self.c_ref_p = 150.0     # Position tracking (High)
            self.c_ref_q = 15.0      # Orientation tracking (Moderate)
            self.c_ref_v = 30.0      # Velocity tracking (Moderate)
            self.c_ref_omega = 5.0   # Angular velocity tracking (Slightly higher for stability)
            self.R = np.diag([0.005, 0.05, 0.05, 0.1]) # Control magnitude cost
            self.R_delta = np.diag([0.01, 0.1, 0.1, 0.2]) # Control smoothness cost
            self.c_terminal_weight = 1.5 # Standard terminal weight
            
            # Progressive reward structure for tilted circle
            self.reward_thresholds = [0.1, 0.3, 0.5, 1.0, 1.5]
            self.reward_values = [50.0, 25.0, 15.0, 8.0, 4.0]
            self.heading_reward_weight = 10.0  # Standard heading reward

        # Control effort costs
        # Penalize deviation from hover thrust, and any angular rates
        hover_thrust = self.model.m * 9.81
        self.u_ref = np.array([hover_thrust, 0, 0, 0]) # Reference control (hover)

        # Enable progressive rewards for better loop closure
        self.use_progressive_reward = True

        # Initialize PID controller for altitude correction with tuned parameters
        self.pid_alt = PID(Kp=0.8, Ki=0.2, Kd=0.1)  # Tuned PID gains for altitude control

        # Distance threshold for rewards
        self.distance_threshold = 1.0  # Used for basic threshold-based rewards

        # Visualization data storage
        self.rollouts = None
        self.rollout_costs = None
        self.best_rollout_idx = -1
        self.last_best_cost = float('inf')

    def reset(self):
        """Resets the nominal sequence and internal states."""
        self.nominal_u = np.tile(np.array([self.model.m * 9.81, 0, 0, 0]), (self.horizon, 1))
        self.pid_alt.reset()  # Reset PID controller
        self.last_best_cost = float('inf')
        self.rollouts = None
        self.rollout_costs = None
        self.best_rollout_idx = -1

    def compute_stage_cost(self, state, u, ref_state, is_terminal=False, time_weight=1.0):
        """Computes the cost for a single stage (state and control)."""
        # --- Tracking Cost ---
        pos_error_vec = state[0:3] - ref_state[0:3]
        pos_error_sq = np.sum(pos_error_vec**2)
        pos_error_norm = np.sqrt(pos_error_sq)
        
        orient_error = quat_distance(state[3:7], ref_state[3:7])  # Uses angle diff / pi
        vel_error_sq = np.sum((state[7:10] - ref_state[7:10])**2)
        omega_error_sq = np.sum((state[10:13] - ref_state[10:13])**2)

        tracking_cost = (self.c_ref_p * pos_error_sq +
                         self.c_ref_q * orient_error +  # q_err is now angle/pi in [0,1]
                         self.c_ref_v * vel_error_sq +
                         self.c_ref_omega * omega_error_sq)

        # --- Control Cost ---
        u_diff = u - self.u_ref  # Difference from hover control
        control_mag_cost = u_diff.T @ self.R @ u_diff

        # Combine costs
        stage_cost = tracking_cost + control_mag_cost

        # Apply optional time weighting (prioritize near-term costs)
        stage_cost *= time_weight

        # Apply terminal weight multiplier
        if is_terminal:
            stage_cost *= self.c_terminal_weight

        # Progressive reward (Subtract from cost)
        if self.use_progressive_reward:
            for threshold, reward in zip(self.reward_thresholds, self.reward_values):
                if pos_error_norm < threshold:
                    stage_cost -= reward
                    break
        else:
            # Legacy threshold-based reward
            if pos_error_norm < self.distance_threshold:
                stage_cost -= 50.0  # Simple threshold-based reward

        # Heading alignment reward (Subtract from cost)
        v_norm = np.linalg.norm(state[7:10])
        ref_v_norm = np.linalg.norm(ref_state[7:10])
        if v_norm > 0.1 and ref_v_norm > 0.1:
            vel_dot = np.dot(state[7:10], ref_state[7:10])
            cos_angle = np.clip(vel_dot / (v_norm * ref_v_norm), -1.0, 1.0)
            # Reward alignment (cos_angle=1 -> max reward, cos_angle=-1 -> min reward)
            alignment_reward = self.heading_reward_weight * (cos_angle + 1.0) / 2.0
            stage_cost -= alignment_reward

        # Ensure cost is non-negative
        return max(stage_cost, 0.1)  # Use small positive minimum for numerical stability

    def compute_trajectory_cost(self, states, controls, ref_traj):
        """Computes the total cost for a rollout trajectory."""
        total_cost = 0.0
        horizon = len(controls) # Should match self.horizon

        # Add stage costs (with time weighting)
        for j in range(horizon):
            is_terminal = (j == horizon - 1)
            # Apply exponential time weighting to prioritize near-term tracking
            time_weight = np.exp(-0.05 * j) if self.traj_type in ["Circle", "Figure8"] else 1.0
            stage_cost = self.compute_stage_cost(states[j+1], controls[j], ref_traj[j], 
                                               is_terminal=is_terminal, time_weight=time_weight)
            total_cost += stage_cost

        # Add control smoothness cost (delta_u cost)
        for j in range(horizon - 1):
            delta_u = controls[j+1] - controls[j]
            smoothness_cost = delta_u.T @ self.R_delta @ delta_u
            total_cost += smoothness_cost

        return total_cost

    def get_action(self, current_state, ref_traj, visualize=False):
        """
        Computes the optimal control action using MPPI.

        Args:
            current_state (np.ndarray): Current state of the quadrotor.
            ref_traj (np.ndarray): Reference trajectory for the horizon. Shape (H, state_dim).
            visualize (bool): Whether to store rollouts for visualization.

        Returns:
            np.ndarray: Optimal control action for the current step. Shape (control_dim,).
        """
        K = self.num_samples
        N = self.horizon

        # Ensure reference trajectory has sufficient length
        if ref_traj.shape[0] < N:
            last_ref = ref_traj[-1]
            padding = np.tile(last_ref, (N - ref_traj.shape[0], 1))
            ref_traj = np.vstack([ref_traj, padding])
        ref_states = ref_traj[:N] # Reference states for the horizon

        # Generate random perturbations (noise)
        noise = np.random.normal(loc=0.0, scale=self.noise_std, size=(K, N, self.control_dim))

        # Storage for simulated rollouts and their costs
        rollout_costs = np.zeros(K)
        if visualize:
            self.rollouts = np.zeros((K, N + 1, self.state_dim))
            self.rollouts[:, 0, :] = current_state # Start all rollouts from current state
            self.rollout_costs = np.zeros(K) # Store costs for viz

        # Simulate K rollouts in parallel (vectorization)
        best_cost_in_batch = float('inf')
        best_idx_in_batch = -1

        for k in range(K):
            state_k = current_state.copy()
            controls_k = np.zeros((N, self.control_dim))
            states_k = np.zeros((N + 1, self.state_dim))
            states_k[0] = state_k
            valid_rollout = True

            # Simulate one rollout
            for j in range(N):
                # Calculate perturbed control
                perturbed_u = self.nominal_u[j] + noise[k, j]

                # Apply dynamics to get next state
                state_k, clipped_u = self.model.dynamics(state_k, perturbed_u, self.dt)
                controls_k[j] = clipped_u # Store the actually applied control
                states_k[j+1] = state_k

                # Store for visualization
                if visualize:
                    self.rollouts[k, j+1, :] = state_k

                # Check for invalid state (e.g., NaN from integration failure)
                if np.any(np.isnan(state_k)):
                    valid_rollout = False
                    break # Stop simulating this rollout

            # Calculate cost for the completed rollout
            if valid_rollout:
                cost_k = self.compute_trajectory_cost(states_k, controls_k, ref_states)
            else:
                cost_k = 1e9 # Assign very high cost to invalid rollouts

            rollout_costs[k] = cost_k

            # Track best rollout within this batch
            if cost_k < best_cost_in_batch:
                best_cost_in_batch = cost_k
                best_idx_in_batch = k

            if visualize:
                self.rollout_costs[k] = cost_k

        # Store best rollout info for visualization
        self.best_rollout_idx = best_idx_in_batch
        self.last_best_cost = best_cost_in_batch if best_idx_in_batch >= 0 else self.last_best_cost

        # --- Compute Weights and Update Nominal Control ---
        if np.all(rollout_costs > 1e8): # Check if all rollouts failed
            print("Warning: All MPPI rollouts failed or had excessive cost. Using previous nominal.")
        else:
            # Subtract minimum cost for numerical stability (shift costs to be >= 0)
            min_cost = np.min(rollout_costs)
            costs_shifted = rollout_costs - min_cost
            
            # Get max cost for scaling (better numerical stability)
            max_cost = np.max(costs_shifted) + 1e-6

            # Compute weights using softmax (lower cost -> higher weight)
            exp_costs = np.exp(-costs_shifted / (self.lambda_ * max_cost + 1e-8))
            weights = exp_costs / (np.sum(exp_costs) + 1e-10) # Normalize weights

            # Update nominal control sequence by weighted average of perturbations
            weighted_noise = np.sum(weights[:, None, None] * noise, axis=0)
            self.nominal_u += weighted_noise

        # --- Get Control Action ---
        # Get the first control action from the updated nominal sequence
        u_optimal = self.nominal_u[0].copy()

        # Apply PID altitude correction to thrust command
        alt_error = current_state[2] - ref_states[0][2]  # Z difference
        pid_corr = self.pid_alt.update(-alt_error, self.dt)  # Error is target - current
        u_optimal[0] += pid_corr

        # Clip final command to ensure it's within physical limits
        u_optimal[0] = np.clip(u_optimal[0], self.model.Tmin, self.model.Tmax)
        u_optimal[1:4] = np.clip(u_optimal[1:4], -self.model.max_omega_rate, self.model.max_omega_rate)

        # Shift nominal control sequence for the next time step
        self.nominal_u = np.roll(self.nominal_u, -1, axis=0)
        self.nominal_u[-1] = self.nominal_u[-2]  # Repeat last control for next horizon

        return u_optimal


In [None]:
# ----------------------------------------------------
# Reference Trajectory Generation (with Yaw Alignment)
# ----------------------------------------------------
def generate_reference_trajectory_3d(traj_type, duration, dt, scale=3.0, loops=1, constant_alt=3.0):
    """
    Generate a reference trajectory with basic yaw alignment.

    Args:
        traj_type (str): 'Circle', 'Figure8', 'TiltedCircle'.
        duration (float): Time duration for ONE loop.
        dt (float): Time step.
        scale (float): Size parameter (e.g., radius).
        loops (int): Number of times to repeat the trajectory.
        constant_alt (float): Base altitude.

    Returns:
        np.ndarray: Reference trajectory (N_total, 13)
                    [x, y, z, qw, qx, qy, qz, vx, vy, vz, wx, wy, wz]
    """
    num_steps_loop = int(duration / dt)
    total_steps = num_steps_loop * loops
    time_loop = np.linspace(0, duration, num_steps_loop, endpoint=False) # Time for one loop
    time_total = np.linspace(0, duration * loops, total_steps, endpoint=False) # Total time

    x, y, z = np.zeros(total_steps), np.zeros(total_steps), np.zeros(total_steps)
    vx, vy, vz = np.zeros(total_steps), np.zeros(total_steps), np.zeros(total_steps)
    ax, ay, az = np.zeros(total_steps), np.zeros(total_steps), np.zeros(total_steps) # Acceleration

    omega = 2 * np.pi / duration # Angular frequency of the base pattern

    for i in range(loops):
        start_idx = i * num_steps_loop
        end_idx = (i + 1) * num_steps_loop
        t = time_loop # Use time relative to the start of the loop

        if traj_type == 'Circle':
            x[start_idx:end_idx] = scale * np.cos(omega * t)
            y[start_idx:end_idx] = scale * np.sin(omega * t)
            z[start_idx:end_idx] = constant_alt

            vx[start_idx:end_idx] = -scale * omega * np.sin(omega * t)
            vy[start_idx:end_idx] = scale * omega * np.cos(omega * t)
            vz[start_idx:end_idx] = 0

            ax[start_idx:end_idx] = -scale * omega**2 * np.cos(omega * t)
            ay[start_idx:end_idx] = -scale * omega**2 * np.sin(omega * t)
            az[start_idx:end_idx] = 0

        elif traj_type == 'Figure8':
            x[start_idx:end_idx] = scale * np.sin(omega * t)
            y[start_idx:end_idx] = scale * 0.5 * np.sin(2 * omega * t)
            z[start_idx:end_idx] = constant_alt

            vx[start_idx:end_idx] = scale * omega * np.cos(omega * t)
            vy[start_idx:end_idx] = scale * omega * np.cos(2 * omega * t) # Note: corrected factor 0.5 * 2
            vz[start_idx:end_idx] = 0

            ax[start_idx:end_idx] = -scale * omega**2 * np.sin(omega * t)
            ay[start_idx:end_idx] = -2 * scale * omega**2 * np.sin(2 * omega * t)
            az[start_idx:end_idx] = 0

        elif traj_type == 'TiltedCircle':
            alt_variation = scale * 0.3 # Amplitude of altitude change
            x[start_idx:end_idx] = scale * np.cos(omega * t)
            y[start_idx:end_idx] = scale * np.sin(omega * t)
            z[start_idx:end_idx] = constant_alt + alt_variation * np.sin(omega * t) # Z varies sinusoidally

            vx[start_idx:end_idx] = -scale * omega * np.sin(omega * t)
            vy[start_idx:end_idx] = scale * omega * np.cos(omega * t)
            vz[start_idx:end_idx] = alt_variation * omega * np.cos(omega * t)

            ax[start_idx:end_idx] = -scale * omega**2 * np.cos(omega * t)
            ay[start_idx:end_idx] = -scale * omega**2 * np.sin(omega * t)
            az[start_idx:end_idx] = -alt_variation * omega**2 * np.sin(omega * t)

        else:
            raise ValueError(f"Unknown trajectory type: {traj_type}")

    # --- Calculate Desired Orientation (Yaw Alignment) and Angular Velocity ---
    q_des = np.zeros((total_steps, 4))
    omega_des = np.zeros((total_steps, 3))

    for i in range(total_steps):
        # 1. Calculate desired yaw angle to face the direction of velocity
        vel_xy = np.array([vx[i], vy[i]])
        vel_xy_norm = np.linalg.norm(vel_xy)

        if vel_xy_norm > 0.1: # Only set yaw if moving horizontally
            yaw = np.arctan2(vy[i], vx[i])
        elif i > 0:
             # If velocity is near zero, keep previous yaw
             prev_q = q_des[i-1]
             # Extract yaw from previous quaternion (simplification, assumes small roll/pitch)
             yaw = np.arctan2(2*(prev_q[0]*prev_q[3] + prev_q[1]*prev_q[2]), 1 - 2*(prev_q[2]**2 + prev_q[3]**2))
        else:
             yaw = 0.0 # Default yaw if starting from rest

        # Assume zero roll and pitch for the reference (controller should handle banking)
        # This is a simplification. Calculating optimal roll/pitch requires considering acceleration.
        roll = 0.0
        pitch = 0.0
        q_des[i] = euler_to_quat(roll, pitch, yaw)

    # 2. Calculate desired angular velocity (using finite difference on quaternions)
    q_des[0] = euler_to_quat(0, 0, np.arctan2(vy[0], vx[0]) if np.linalg.norm([vx[0], vy[0]]) > 0.1 else 0) # Initialize first quaternion
    for i in range(total_steps):
        q_current = q_des[i]
        # Get next quaternion, handle boundary case
        q_next = q_des[i + 1] if i < total_steps - 1 else q_des[i]

        # Calculate quaternion derivative approximation: q_dot ~ (q_next - q_current) / dt
        # Need to handle the quaternion difference carefully (shortest path)
        q_diff = quat_multiply(quat_conjugate(q_current), q_next)

        # Ensure w is in [-1, 1] for acos
        w_diff = np.clip(q_diff[0], -1.0, 1.0)

        if abs(w_diff) < 1.0 - 1e-7: # Avoid division by zero if quaternions are identical
             angle = 2.0 * np.arccos(w_diff)
             axis = q_diff[1:] / np.sin(angle / 2.0)
             # Angular velocity vector: axis * angle_rate
             omega_des[i] = axis * (angle / dt)
        else:
             # Quaternions are very close or identical, angular velocity is zero
             omega_des[i] = np.zeros(3)

        # Ensure the first quaternion is properly normalized (needed for finite diff)
        if i == 0: q_des[0] = quat_normalize(q_des[0])

    # Assemble complete reference trajectory
    ref = np.zeros((total_steps, 13))
    ref[:, 0] = x
    ref[:, 1] = y
    ref[:, 2] = z
    ref[:, 3:7] = q_des
    ref[:, 7] = vx
    ref[:, 8] = vy
    ref[:, 9] = vz
    ref[:, 10:13] = omega_des

    return ref


In [None]:
# ----------------------------------------------------
# UAV Visualization Functions
# ----------------------------------------------------
def draw_quadrotor_realistic(ax, position, quaternion, scale=1.0, alpha=1.0, arm_length=0.3, color='silver'):
    """Draw a realistic quadrotor model with proper orientation."""
    R = quat_to_rot_matrix(quaternion)
    body_radius = 0.1 * scale
    arm_length_scaled = arm_length * scale
    propeller_radius = 0.08 * scale

    # Basic quadrotor frame points in body coordinates
    # Arms along X and Y axes
    arms_body = np.array([
        [arm_length_scaled, 0, 0], [-arm_length_scaled, 0, 0], # X arms
        [0, arm_length_scaled, 0], [0, -arm_length_scaled, 0], # Y arms
        [0, 0, 0] # Center point (for drawing lines)
    ]) * 0.9 # Slightly shorter arms visually

    # Propeller base points (at the end of the arms)
    prop_bases_body = np.array([
        [arm_length_scaled, 0, 0.02*scale], # Front
        [-arm_length_scaled, 0, 0.02*scale], # Back
        [0, arm_length_scaled, 0.02*scale], # Right
        [0, -arm_length_scaled, 0.02*scale]  # Left
    ])

    # Transform points to world frame
    arms_world = position + arms_body @ R.T
    prop_bases_world = position + prop_bases_body @ R.T

    # Draw Body (simple sphere for center)
    u = np.linspace(0, 2 * np.pi, 8)
    v = np.linspace(0, np.pi, 8)
    x_s = position[0] + body_radius * np.outer(np.cos(u), np.sin(v))
    y_s = position[1] + body_radius * np.outer(np.sin(u), np.sin(v))
    z_s = position[2] + body_radius * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x_s, y_s, z_s, color=color, alpha=alpha*0.6, linewidth=0, antialiased=False, shade=True)

    # Draw Arms
    center_world = arms_world[-1]
    arm_color = 'dimgray'
    ax.plot([arms_world[0,0], arms_world[1,0]], [arms_world[0,1], arms_world[1,1]], [arms_world[0,2], arms_world[1,2]], color=arm_color, linewidth=3*scale, alpha=alpha)
    ax.plot([arms_world[2,0], arms_world[3,0]], [arms_world[2,1], arms_world[3,1]], [arms_world[2,2], arms_world[3,2]], color=arm_color, linewidth=3*scale, alpha=alpha)

    # Draw Propellers (simple circles, front one red)
    prop_colors = ['red', 'black', 'black', 'black'] # Front, Back, Right, Left
    theta = np.linspace(0, 2*np.pi, 10)
    prop_circle_body = np.zeros((len(theta), 3))
    prop_circle_body[:, 0] = propeller_radius * np.cos(theta)
    prop_circle_body[:, 1] = propeller_radius * np.sin(theta)
    # Z is flat in body frame

    for i in range(4):
        prop_circle_world = prop_bases_world[i] + prop_circle_body @ R.T
        ax.plot(prop_circle_world[:,0], prop_circle_world[:,1], prop_circle_world[:,2], color=prop_colors[i], linewidth=2*scale, alpha=alpha)

    # Draw orientation axes (Body Frame: X-red, Y-green, Z-blue) attached to drone
    axis_length = arm_length * 0.8
    axes_body = np.array([[axis_length, 0, 0], [0, axis_length, 0], [0, 0, axis_length]])
    axes_world = position + axes_body @ R.T
    axis_colors = ['red', 'green', 'blue']
    for i in range(3):
        ax.plot([position[0], axes_world[i, 0]],
                [position[1], axes_world[i, 1]],
                [position[2], axes_world[i, 2]],
                color=axis_colors[i], linewidth=2, alpha=alpha * 0.9)

    return ax


In [None]:
# ----------------------------------------------------
# Enhanced MPPI Visualization Class (Frame-Based)
# ----------------------------------------------------
class EnhancedMPPIVisualizer:
    def __init__(self, output_dir='mppi_visualization_output', uav_arm_length=0.15, uav_scale=1.0):
        self.output_dir = output_dir
        self.uav_arm_length = uav_arm_length
        self.uav_scale = uav_scale

        # Create output directories
        self.base_dir = os.path.abspath(output_dir)
        self.combined_frames_dir = os.path.join(self.base_dir, 'combined_frames')
        self.animations_dir = os.path.join(self.base_dir, 'animations')
        os.makedirs(self.combined_frames_dir, exist_ok=True)
        os.makedirs(self.animations_dir, exist_ok=True)

        # Visualization settings
        self.dpi = 120 # Lower DPI for faster frame saving
        self.figsize_combined = (16, 8) # Width, Height
        self.frame_interval = 2  # Save every Nth frame (increase for speed, decrease for smoothness)
        self.show_top_n_rollouts = 50 # How many rollouts to show (max)
        self.trajectory_trail_length = 150 # Max points in the trail

        # Data storage for animation frames
        self.combined_frames = [] # Store paths to combined frames

        # Styling
        plt.style.use('seaborn-v0_8-darkgrid') # Use a nice style
        self.background_color = '#f0f0f0' # Light gray background
        self.reference_color = 'orangered'
        self.trajectory_color = 'royalblue'
        self.best_rollout_color = 'limegreen'
        self.rollout_cmap = cm.plasma # Colormap for rollouts (cost based)

        # Figure handle for reuse
        self.combined_fig = None
        self.combined_ax_traj = None
        self.combined_ax_rollout = None

        self.global_traj_bounds = None # To store overall bounds for stable axes

        print(f"Enhanced Visualizer: Output will be saved to {self.base_dir}")

    def setup_combined_figure(self):
        """Set up the combined visualization figure if it doesn't exist."""
        if self.combined_fig is None:
            self.combined_fig = plt.figure(figsize=self.figsize_combined, dpi=self.dpi)
            # self.combined_fig.set_facecolor(self.background_color) # Set figure background

            gs = self.combined_fig.add_gridspec(1, 2, width_ratios=[3, 2], wspace=0.3) # Give more space to trajectory

            self.combined_ax_traj = self.combined_fig.add_subplot(gs[0, 0], projection='3d')
            self.combined_ax_rollout = self.combined_fig.add_subplot(gs[0, 1], projection='3d')

            # Apply background only to axes, not the whole figure pane
            # self.combined_ax_traj.set_facecolor(self.background_color)
            # self.combined_ax_rollout.set_facecolor(self.background_color)

    def _setup_3d_axis(self, ax, title, bounds):
        """Helper to setup common 3D axis properties."""
        ax.clear()
        ax.set_title(title, fontsize=14, pad=15)
        ax.set_xlabel('X [m]', fontsize=10, labelpad=5)
        ax.set_ylabel('Y [m]', fontsize=10, labelpad=5)
        ax.set_zlabel('Z [m]', fontsize=10, labelpad=5)
        ax.tick_params(axis='both', which='major', labelsize=8)
        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

        # Set axis limits for stable animation frames
        if bounds:
            ax.set_xlim(bounds['xlim'])
            ax.set_ylim(bounds['ylim'])
            ax.set_zlim(bounds['zlim'])
            # Set aspect ratio to be equal for a more intuitive view
            # ax.set_aspect('equal', adjustable='box') # Can cause issues if ranges differ wildly
            ax.set_box_aspect([bounds['xlim'][1]-bounds['xlim'][0],
                               bounds['ylim'][1]-bounds['ylim'][0],
                               bounds['zlim'][1]-bounds['zlim'][0]]) # Requires Matplotlib 3.3+


        # Set a consistent view angle
        ax.view_init(elev=25, azim=-110)
        return ax

    def set_global_trajectory_bounds(self, ref_trajectory, margin=2.0):
        """Calculate overall bounds from the reference trajectory."""
        min_coords = np.min(ref_trajectory[:, 0:3], axis=0)
        max_coords = np.max(ref_trajectory[:, 0:3], axis=0)
        self.global_traj_bounds = {
            'xlim': (min_coords[0] - margin, max_coords[0] + margin),
            'ylim': (min_coords[1] - margin, max_coords[1] + margin),
            'zlim': (min(0, min_coords[2]) - margin/2, max_coords[2] + margin) # Ensure ground is visible
        }
        print(f"Global trajectory bounds set: {self.global_traj_bounds}")


    def render_combined_frame(self, t, current_state, reference_state, past_states, full_ref_traj,
                              mppi_rollouts, mppi_costs, best_rollout_idx, last_best_cost, step_num):
        """
        Render and save a combined trajectory and rollout frame.
        """
        if step_num % self.frame_interval != 0:
            return # Skip frame

        if self.global_traj_bounds is None and full_ref_traj is not None:
             self.set_global_trajectory_bounds(full_ref_traj)

        # --- Setup Figure and Axes ---
        self.setup_combined_figure() # Ensures figure exists
        self._setup_3d_axis(self.combined_ax_traj, "UAV Trajectory", self.global_traj_bounds)
        # Rollout bounds can be tighter, centered around current state + horizon extent
        rollout_center = current_state[0:3]
        rollout_range = 5.0 # Approximate range rollouts might cover
        rollout_bounds = {
             'xlim': (rollout_center[0] - rollout_range, rollout_center[0] + rollout_range),
             'ylim': (rollout_center[1] - rollout_range, rollout_center[1] + rollout_range),
             'zlim': (max(0, rollout_center[2] - rollout_range/2), rollout_center[2] + rollout_range)
        }
        best_cost_str = f"{last_best_cost:.2f}" if last_best_cost != float('inf') else "N/A" # Handle initial infinite cost
        self._setup_3d_axis(self.combined_ax_rollout, f"MPPI Rollouts (Best Cost: {best_cost_str})", rollout_bounds)

        # --- Calculate Current Speed ---
        vx, vy, vz = current_state[7:10]  # Velocity components
        speed = np.sqrt(vx**2 + vy**2 + vz**2)  # Total speed in m/s
        speed_kmh = speed * 3.6  # Convert to km/h

        # --- Plot Trajectory Subplot (Left) ---
        ax = self.combined_ax_traj

        # Plot full reference trajectory lightly
        if full_ref_traj is not None:
             ax.plot(full_ref_traj[:, 0], full_ref_traj[:, 1], full_ref_traj[:, 2],
                     color=self.reference_color, linestyle=':', linewidth=1.0, alpha=0.6, label='Full Reference')

        # Plot trajectory trail (limited length)
        if past_states is not None and len(past_states) > 1:
            trail = np.array(past_states[-self.trajectory_trail_length:])
            # Fading trail effect
            trail_alphas = np.linspace(0.2, 1.0, len(trail))
            for i in range(len(trail) - 1):
                 ax.plot(trail[i:i+2, 0], trail[i:i+2, 1], trail[i:i+2, 2],
                         color=self.trajectory_color, linewidth=2, alpha=trail_alphas[i])

        # Draw current UAV state
        draw_quadrotor_realistic(ax, current_state[0:3], current_state[3:7],
                                 scale=self.uav_scale, arm_length=self.uav_arm_length, color='blue')

        # Draw current reference point (target)
        ax.scatter(reference_state[0], reference_state[1], reference_state[2],
                   color='red', s=60, marker='x', label='Current Target', depthshade=False, alpha=0.9)

        # Add speed display in the top right corner
        speed_text = f"Speed: {speed_kmh:.1f} km/h"
        ax.text2D(0.95, 0.95, speed_text, transform=ax.transAxes,
                 horizontalalignment='right', verticalalignment='top',
                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

        ax.legend(fontsize=8, loc='upper left')

        # --- Plot Rollouts Subplot (Right) ---
        ax_rollout = self.combined_ax_rollout

        if mppi_rollouts is not None and mppi_costs is not None and len(mppi_rollouts) > 0:
            num_to_show = min(self.show_top_n_rollouts, mppi_rollouts.shape[0])

            # Normalize costs for coloring (handle potential NaNs/Infs safely)
            valid_costs = mppi_costs[np.isfinite(mppi_costs)]
            if len(valid_costs) > 1:
                min_c, max_c = np.min(valid_costs), np.max(valid_costs)
                if max_c - min_c < 1e-6: # Avoid division by zero if costs are identical
                    norm = Normalize(vmin=min_c - 0.5, vmax=max_c + 0.5)
                else:
                    norm = Normalize(vmin=min_c, vmax=max_c)
                mapper = cm.ScalarMappable(norm=norm, cmap=self.rollout_cmap)
            else: # Fallback if not enough valid costs to normalize
                mapper = None

            # Sort rollouts by cost to draw best ones last (on top)
            sorted_indices = np.argsort(mppi_costs)

            # Plot the rollouts
            for k_idx in sorted_indices[:num_to_show]:
                rollout_path = mppi_rollouts[k_idx, :, 0:3] # Position part of the rollout
                cost = mppi_costs[k_idx]

                if not np.isfinite(cost): continue # Skip invalid cost rollouts

                color = self.best_rollout_color if k_idx == best_rollout_idx else (mapper.to_rgba(cost) if mapper else 'gray')
                alpha = 0.9 if k_idx == best_rollout_idx else 0.25
                linewidth = 2.0 if k_idx == best_rollout_idx else 0.8
                zorder = 10 if k_idx == best_rollout_idx else 1

                ax_rollout.plot(rollout_path[:, 0], rollout_path[:, 1], rollout_path[:, 2],
                                color=color, alpha=alpha, linewidth=linewidth, zorder=zorder)

            # Highlight the best rollout start point
            if best_rollout_idx >= 0 and best_rollout_idx < len(mppi_rollouts):
                 best_start = mppi_rollouts[best_rollout_idx, 1, 0:3] # First step of best rollout
                 # ax_rollout.scatter(best_start[0], best_start[1], best_start[2], color=self.best_rollout_color, s=30, marker='o', alpha=0.8, zorder=11)
                 pass # Drawing the full line is usually clear enough

        # Draw current UAV state at the start of the rollouts
        draw_quadrotor_realistic(ax_rollout, current_state[0:3], current_state[3:7],
                                 scale=self.uav_scale*0.8, arm_length=self.uav_arm_length*0.8, color='black', alpha=0.9) # Slightly smaller here

        # Add colorbar for rollout cost (optional, can clutter)
        # if mapper:
        #     cbar = self.combined_fig.colorbar(mapper, ax=ax_rollout, shrink=0.6, aspect=10, pad=0.1)
        #     cbar.set_label('Rollout Cost', size=10)
        #     cbar.ax.tick_params(labelsize=8)

        # --- Finalize and Save ---
        self.combined_fig.suptitle(f'MPPI Simulation | Time: {t:.2f} s | Step: {step_num}', fontsize=16)
        self.combined_fig.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust layout to prevent title overlap

        frame_filename = os.path.join(self.combined_frames_dir, f'frame_{step_num:05d}.png')
        try:
            self.combined_fig.savefig(frame_filename, dpi=self.dpi, facecolor=self.combined_fig.get_facecolor())
            self.combined_frames.append(frame_filename)
        except Exception as e:
            print(f"Error saving frame {frame_filename}: {e}")

        # We don't close the figure here, reuse it for the next frame to save time.
        # plt.close(self.combined_fig) # Only close if memory becomes an issue

    def close_figure(self):
        """Closes the main figure handle."""
        if self.combined_fig is not None:
            plt.close(self.combined_fig)
            self.combined_fig = None # Reset handles
            self.combined_ax_traj = None
            self.combined_ax_rollout = None

    def create_animation(self, filename='mppi_simulation.mp4', fps=20):
        """Create animation from saved combined frames."""
        if not self.combined_frames:
            print("No frames found to create animation.")
            return

        animation_path = os.path.join(self.animations_dir, filename)
        print(f"Creating animation: {animation_path} at {fps} FPS...")

        try:
            with imageio.get_writer(animation_path, fps=fps) as writer:
                for frame_path in tqdm(self.combined_frames, desc="Building Animation"):
                    try:
                        image = imageio.v2.imread(frame_path)
                        writer.append_data(image)
                    except FileNotFoundError:
                        print(f"Warning: Frame not found {frame_path}, skipping.")
                    except Exception as e:
                        print(f"Error reading frame {frame_path}: {e}, skipping.")
            print(f"Animation saved successfully: {animation_path}")
        except Exception as e:
            print(f"Error creating animation: {e}")


In [None]:
# ----------------------------------------------------
# Main Simulation Loop
# ----------------------------------------------------
if __name__ == "__main__":
    # --- Simulation Parameters ---
    SIM_DT = 0.05           # Simulation timestep [s] (Needs to be small enough for dynamics)
    CONTROL_DT = 0.05       # Controller update rate [s] (Can be >= SIM_DT)
    TOTAL_DURATION = 400.0   # Total simulation time [s]
    TRAJECTORY_TYPE = 'Circle' # 'Circle', 'Figure8', 'TiltedCircle'
    TRAJ_LOOP_DURATION = 5.0 # Time for one loop of the trajectory [s]
    TRAJ_SCALE = 4.0        # Size of the trajectory [m]
    TRAJ_ALTITUDE = 4.0     # Base altitude [m]
    OUTPUT_DIR = "mppi_sim_output_3" # Directory for results

    # Ensure controller dt matches simulation dt if MPPI internal dt is the same
    assert SIM_DT == CONTROL_DT # Simplifies things if they match

    # --- Initialize ---
    quad_model = QuadrotorDynamics3D(mass=1.21, arm_length=0.15)
    controller = MPPIController3D(quad_model,
                                  num_samples=500, # Lower for faster testing, increase for performance
                                  horizon=20,      #  was 15 changed to 20 !! for 8 --Planning horizon steps (e.g., 15 * 0.05s = 0.75s lookahead)
                                  dt=CONTROL_DT,   # MPPI internal simulation dt
                                  lambda_=0.05,    # Exploration/exploitation trade-off
                                  traj_type=TRAJECTORY_TYPE)

    visualizer = EnhancedMPPIVisualizer(output_dir=OUTPUT_DIR,
                                        uav_arm_length=quad_model.arm_length,
                                        uav_scale=1.0)

    # --- Generate Reference Trajectory ---
    num_loops = int(np.ceil(TOTAL_DURATION / TRAJ_LOOP_DURATION))
    ref_trajectory = generate_reference_trajectory_3d(
        traj_type=TRAJECTORY_TYPE,
        duration=TRAJ_LOOP_DURATION,
        dt=SIM_DT, # Generate reference at the simulation resolution
        scale=TRAJ_SCALE,
        loops=num_loops,
        constant_alt=TRAJ_ALTITUDE
    )
    # Extend final state if sim duration is longer than trajectory
    sim_steps = int(TOTAL_DURATION / SIM_DT)
    if len(ref_trajectory) < sim_steps:
         padding = np.tile(ref_trajectory[-1], (sim_steps - len(ref_trajectory), 1))
         ref_trajectory = np.vstack([ref_trajectory, padding])
    ref_trajectory = ref_trajectory[:sim_steps] # Ensure it matches sim length exactly

    # Set global bounds for visualization based on the reference path
    visualizer.set_global_trajectory_bounds(ref_trajectory, margin=TRAJ_SCALE * 0.5)


    # --- Initial State ---
    initial_pos = ref_trajectory[0, 0:3] # Start at the beginning of the trajectory
    initial_quat = ref_trajectory[0, 3:7] # Start with reference orientation
    initial_vel = ref_trajectory[0, 7:10]
    initial_omega = ref_trajectory[0, 10:13]
    # initial_quat = np.array([1.0, 0.0, 0.0, 0.0]) # Or start level
    # initial_vel = np.array([0.0, 0.0, 0.0])
    # initial_omega = np.array([0.0, 0.0, 0.0])

    current_state = np.concatenate([initial_pos, initial_quat, initial_vel, initial_omega])

    # --- Simulation Log ---
    state_history = [current_state]
    control_history = []
    time_history = [0.0]

    # --- Run Simulation ---
    print(f"\nStarting {TRAJECTORY_TYPE} Trajectory Simulation...")
    print(f" - Simulation DT: {SIM_DT}s, Control DT: {CONTROL_DT}s, Total Duration: {TOTAL_DURATION}s")
    print(f" - MPPI Horizon: {controller.horizon} steps ({controller.horizon * controller.dt:.2f}s)")
    print(f" - Saving output to: {visualizer.base_dir}\n")

    start_time = time.time()
    for i in tqdm(range(sim_steps - 1), desc="Simulating"):
        current_time = i * SIM_DT

        # Get reference trajectory slice for the controller horizon
        ref_start_idx = i
        ref_end_idx = min(ref_start_idx + controller.horizon, len(ref_trajectory))
        current_ref_traj = ref_trajectory[ref_start_idx:ref_end_idx]

        # Get optimal control action from MPPI
        # Pass visualize=True to store rollouts needed by the visualizer
        control_action = controller.get_action(current_state, current_ref_traj, visualize=True)

        # Apply control action to the dynamics model
        next_state, applied_control = quad_model.dynamics(current_state, control_action, SIM_DT)

        # --- Log Data ---
        time_history.append(current_time + SIM_DT)
        state_history.append(next_state)
        control_history.append(applied_control)

        # --- Render Visualization Frame ---
        visualizer.render_combined_frame(
            t=current_time + SIM_DT,
            current_state=next_state,
            reference_state=ref_trajectory[i+1], # Target is the next ref state
            past_states=np.array(state_history),
            full_ref_traj=ref_trajectory[:, 0:3], # Pass full path for plotting context
            mppi_rollouts=controller.rollouts,    # Pass data from controller
            mppi_costs=controller.rollout_costs,
            best_rollout_idx=controller.best_rollout_idx,
            last_best_cost=controller.last_best_cost, # <-- ADD THIS LINE
            step_num=i + 1
        )

        # Update state for next iteration
        current_state = next_state

        # Optional: Check for simulation instability
        if np.any(np.isnan(current_state)):
            print(f"\nERROR: Simulation became unstable (NaN state) at step {i+1}. Stopping.")
            break
        # Check altitude (simple crash detection)
        if current_state[2] < -0.5:
            print(f"\nWarning: Drone altitude below -0.5m at step {i+1}. Stopping.")
            # break # Uncomment to stop on crash

    end_time = time.time()
    print(f"\nSimulation finished in {end_time - start_time:.2f} seconds.")

    # --- Create Animation ---
    visualizer.create_animation(filename=f"{TRAJECTORY_TYPE}_sim.mp4", fps=int(1/SIM_DT / visualizer.frame_interval))
    # visualizer.create_animation(filename=f"{TRAJECTORY_TYPE}_sim.gif", fps=int(1/SIM_DT / visualizer.frame_interval)) # GIF option

    # Close the plot figure handle
    visualizer.close_figure()

    # --- Optional: Plot final results (e.g., position tracking) ---
    print("Generating final performance plots...")
    state_history_np = np.array(state_history)
    control_history_np = np.array(control_history)
    time_history_np = np.array(time_history)

    fig_perf, axs = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
    axs[0].plot(time_history_np, state_history_np[:, 0], label='Actual X', color=visualizer.trajectory_color)
    axs[0].plot(time_history_np, ref_trajectory[:len(time_history_np), 0], label='Reference X', color=visualizer.reference_color, linestyle='--')
    axs[0].plot(time_history_np, state_history_np[:, 1], label='Actual Y', color=visualizer.trajectory_color, alpha=0.7)
    axs[0].plot(time_history_np, ref_trajectory[:len(time_history_np), 1], label='Reference Y', color=visualizer.reference_color, linestyle='--', alpha=0.7)
    axs[0].set_ylabel('Position X, Y [m]')
    axs[0].legend()
    axs[0].grid(True)

    axs[1].plot(time_history_np, state_history_np[:, 2], label='Actual Z', color=visualizer.trajectory_color)
    axs[1].plot(time_history_np, ref_trajectory[:len(time_history_np), 2], label='Reference Z', color=visualizer.reference_color, linestyle='--')
    axs[1].set_ylabel('Position Z [m]')
    axs[1].legend()
    axs[1].grid(True)

    # Plot tracking error
    pos_error = np.linalg.norm(state_history_np[:, 0:3] - ref_trajectory[:len(time_history_np), 0:3], axis=1)
    axs[2].plot(time_history_np, pos_error, label='Position Tracking Error', color='red')
    axs[2].set_ylabel('Error [m]')
    axs[2].set_xlabel('Time [s]')
    axs[2].legend()
    axs[2].grid(True)

    fig_perf.suptitle(f'{TRAJECTORY_TYPE} Tracking Performance', fontsize=16)
    fig_perf.tight_layout(rect=[0, 0, 1, 0.96])
    plot_filename = os.path.join(visualizer.animations_dir, f'{TRAJECTORY_TYPE}_performance.png')
    fig_perf.savefig(plot_filename)
    print(f"Performance plot saved to {plot_filename}")
    # plt.show() # Uncomment to display plot

    print("\nVisualization and analysis complete.")