In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
from scipy.optimize import minimize
from joblib import Parallel, delayed
import os
from drone_dynamics import *


class SEIBR_DroneNavigation_FigEight:
    """Implementation of SE-IBR algorithm for multi-drone gate navigation"""
    
    def __init__(self, num_robots=3, planning_horizon=10, dt=0.2, max_iterations=5, dynamics_dt=0.01, n_gates=5, radius=6.0):
        """Initialize the multi-drone navigation system"""
        # System parameters
        self.num_robots = num_robots
        self.N = planning_horizon  # Planning horizon
        self.dt = dt  # Planning timestep
        self.max_iterations = max_iterations  # Max iterations for SE-IBR
        self.dynamics_dt = dynamics_dt  # Dynamics timestep
        self.dynamics_steps = int(dt / dynamics_dt)  # Number of dynamics steps per planning step
        
        # World parameters
        self.world_size = radius * 2  # Size of world (m)
        self.n_gates = n_gates  # Number of gates
        
        # Robot parameters
        self.robot_radius = 0.3  # robot radius (m)
        self.min_distance = 0.8  # minimum safe distance between robots (m)
        self.max_velocity = np.array([3.0, 3.0, 2.0])  # max velocity (m/s)
        self.gate_passing_tolerance = 0.5  # Distance to consider gate passed (m)
        self.collision_coeff =  0.32
        
        # SE-IBR parameters
        self.sensitivity_alpha = 0.5  # Sensitivity parameter for SE-IBR
        self.alpha_decay = 0.9  # Decay rate for sensitivity parameter
        self.lagrange_multipliers = {}  # Lagrange multipliers for constraints
        self.prev_lagrange_multipliers = {}  # Previous Lagrange multipliers
        
        # Progress measurement
        self.track_width = 2.0  # Width of track for progress calculation
        
        # Initialize drones, controllers, and gates
        self.drones = []
        self.controllers = []

        for i in range(num_robots):
            drone = DroneDynamics(dt=dynamics_dt)
            low_level_controller = LowLevelController(drone)
            velocity_tracker = VelocityTracker(drone, low_level_controller)
            self.drones.append(drone)
            self.controllers.append(velocity_tracker)

        self.initialize_gates()
        
        # Initialize trajectory predictors for warm-starting
        state_dim = 6 + 3 + 6  # drone state (pos+vel) + target + 2 neighbors
        traj_dim = planning_horizon * 3  # 3D positions over planning horizon
        self.predictors = [TrajectoryPredictor(state_dim, traj_dim) for _ in range(num_robots)]
        
        # Initialize other variables
        self.neighbors = [[] for _ in range(num_robots)]  # List of neighbors for each robot
        self.current_gate_indices = np.zeros(num_robots, dtype=int)  # Current gate index for each robot
        
        # Reset to initial configuration
        self.reset()
    
    def initialize_gates(self):
        """Initialize gates based on the new track layout"""
        self.n_gates = 6
        self.gate_width = 2.0
        self.gate_height = 2.0

        # Gate positions for the custom track
        gate_positions = [
            [5.0, -5, 1],      # Gate 1 (right)
            [10.0, 0, 2],       # Gate 2 (back right)
            [5.0, 5, 2],       # Gate 3 (back)
            [-5.0, -5, 0.5],  # Gate 4 (middle)
            [-10.0, 0, 0.5],   # Gate 5 (front left)
            [-5.0, 5, 1.5]    # Gate 6 (left)
        ]
        
        # Gate orientations (yaw angles in radians)
        gate_yaws = [
            np.pi/2,        # Gate 1 facing +y
            0,              # Gate 2 facing +x
            -np.pi/2,       # Gate 3 facing -y
            -np.pi/2,       # Gate 4 facing -y
            0,              # Gate 5 facing +x
            np.pi/2         # Gate 6 facing +y
        ]
        
        self.gates = []
        for idx in range(self.n_gates):
            center = np.array(gate_positions[idx], dtype=np.float32)
            yaw = gate_yaws[idx]
            
            gate = {
                "center": center,
                "yaw": yaw + np.pi/2,  # Maintain the original adjustment
                "width": self.gate_width,
                "height": self.gate_height
            }
            self.gates.append(gate)
        
        self.gate_positions = np.array([gate["center"] for gate in self.gates])
        self.gate_yaws = np.array([gate["yaw"] for gate in self.gates])

        # Store track-specific information for path planning
        self.track_segments = []
        for i in range(self.n_gates):
            next_i = (i + 1) % self.n_gates
            segment = {
                "start_gate": i,
                "end_gate": next_i,
                "start_pos": self.gate_positions[i],
                "end_pos": self.gate_positions[next_i],
                "start_yaw": self.gate_yaws[i],
                "end_yaw": self.gate_yaws[next_i],
                "height_diff": self.gate_positions[next_i][2] - self.gate_positions[i][2]
            }
            self.track_segments.append(segment)

    def reset(self):
        """Reset drone positions with better initial placement for the custom track"""
        # Safety parameters
        safety_distance_gates = 0.5  # Safe distance from gates
        safety_distance_drones = 0.8  # Minimum distance between drones
        max_attempts = 50     # Maximum attempts to find safe positions
        
        # Preferentially initialize drones near the first gate
        first_gate_pos = self.gate_positions[0]  # Gate 1 
        first_gate_yaw = self.gate_yaws[0]
        
        # Calculate gate normal (direction the gate is facing)
        gate_normal = np.array([np.cos(first_gate_yaw - np.pi/2), np.sin(first_gate_yaw - np.pi/2), 0])
        
        # Position starting area a bit before the first gate
        start_center = first_gate_pos - gate_normal * 2.5  # 2.5m before first gate
        
        # Sample positions for all drones around this starting area
        positions = []
        
        for drone_idx in range(self.num_robots):
            attempts = 0
            position_found = False
            
            # Try to find a safe position for this drone
            while attempts < max_attempts and not position_found:
                attempts += 1
                
                # Sample random position around start center
                offset = np.random.uniform(-1.5, 1.5, size=3)
                offset[2] = abs(offset[2]) * 0.3  # Smaller vertical variation
                candidate_position = start_center + offset
                candidate_position[2] = max(0.5, candidate_position[2])  # Ensure minimum height
                candidate_position = candidate_position.astype(np.float32)
                
                # Verify it's safe from gates
                safe_from_gates = True
                for gate_pos in self.gate_positions:
                    gate_pos_array = np.array(gate_pos)
                    distance = np.linalg.norm(candidate_position - gate_pos_array)
                    if distance < safety_distance_gates:
                        safe_from_gates = False
                        break
                
                if not safe_from_gates:
                    continue  # Try another position
                
                # Verify it's safe from other drones
                safe_from_drones = True
                for existing_pos in positions:
                    distance = np.linalg.norm(candidate_position - existing_pos)
                    if distance < safety_distance_drones:
                        safe_from_drones = False
                        break
                
                if not safe_from_drones:
                    continue  # Try another position
                
                # If we get here, the position is safe
                position_found = True
                positions.append(candidate_position)
            
            # If we couldn't find a safe position after max attempts, use a backup method
            if not position_found:
                # Fallback: place drone at a slightly random position near start center
                backup_position = start_center + np.random.uniform(-0.5, 0.5, size=3)
                backup_position[2] = max(0.5, backup_position[2])  # Ensure minimum height
                positions.append(backup_position.astype(np.float32))
                print(f"Warning: Using backup position for drone {drone_idx}")
        
        # Reset drone states and set current gate indices
        for i in range(self.num_robots):
            # Reset drone state
            self.drones[i].reset(positions[i])
            
            # Set current gate index
            self.current_gate_indices[i] = 0
        
        # Update neighbors
        self.find_neighbors()
        
        # Return positions for possible visualization or logging
        return positions
    
    def find_neighbors(self):
        """Find neighbors for each robot based on proximity"""
        positions = self.get_drone_positions()
        
        # Clear current neighbors
        self.neighbors = [[] for _ in range(self.num_robots)]
        
        # Find neighbors within a certain distance
        neighbor_distance = 4.0  # Consider robots within 4 meters as potential neighbors
        
        for i in range(self.num_robots):
            for j in range(self.num_robots):
                if i != j:  # Don't include self
                    dist = np.linalg.norm(positions[i] - positions[j])
                    if dist < neighbor_distance:
                        self.neighbors[i].append(j)
    
    def get_drone_positions(self):
        """Get current positions of all drones"""
        positions = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            positions[i] = self.drones[i].state[0:3]
        return positions
    
    def get_drone_velocities(self):
        """Get current velocities of all drones"""
        velocities = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            velocities[i] = self.drones[i].state[3:6]
        return velocities
    
    def create_gate_reference_frame(self, gate_position, gate_yaw):
        """Create a reference frame for measuring progress through a gate with caching"""
        # Check if we have a cache for reference frames
        if not hasattr(self, 'gate_frame_cache'):
            self.gate_frame_cache = {}
        
        # Check if this frame is already in the cache
        cache_key = (tuple(gate_position), gate_yaw)
        if cache_key in self.gate_frame_cache:
            return self.gate_frame_cache[cache_key]
        
        # Create the reference frame
        # Adjust for custom track's orientation convention
        forward = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        vertical = np.array([0, 0, 1])
        
        frame = {
            'origin': gate_position, 
            'forward': forward, 
            'lateral': lateral, 
            'vertical': vertical
        }
        
        # Store in cache
        self.gate_frame_cache[cache_key] = frame
        
        return frame
    
    def compute_progress_to_gate(self, position, gate_position, gate_yaw):
        """Enhanced progress measurement with stronger gradients for custom track"""
        # Vector from position to gate
        to_gate = gate_position - position
        dist_to_gate = np.linalg.norm(to_gate)
        
        # Gate direction (normal to gate plane)
        # Adjust for the custom track orientation
        gate_dir = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Project position onto gate normal direction
        projection = np.dot(to_gate, gate_dir)
        
        # Calculate lateral and vertical deviation
        gate_lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        lateral_dev = np.abs(np.dot(to_gate, gate_lateral))
        
        # For custom track, use adaptive vertical deviation metrics based on gate height
        target_height = gate_position[2]
        height_diff = abs(position[2] - target_height)
        vertical_dev = height_diff
        
        # Custom track: stronger incentive to match gate height
        height_penalty = 1.0 * (height_diff**2) / (self.gate_height**2)
        
        # Stronger incentive to fly through the center of the gate
        deviation_penalty = 0.5 * (lateral_dev**2) / (self.track_width**2) + height_penalty
        
        # Much stronger reward for being in front of the gate
        front_reward = 3.0 * max(0, -projection)
        
        # Combined progress metric - more extreme values for clearer optimization
        progress = -dist_to_gate + front_reward - deviation_penalty
        
        return progress
    
    def initialize_straight_trajectory(self, robot_idx):
        """Create a trajectory that aggressively aims to pass through the gate with custom track adjustments"""
        trajectory = np.zeros((self.N, 3))
        
        # Start from current position
        current_pos = self.drones[robot_idx].state[0:3]
        current_vel = self.drones[robot_idx].state[3:6]
        trajectory[0] = current_pos
        
        # Target gate
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Gate normal direction (pointing forward from gate)
        # Adjust for the custom track orientation
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Calculate a target point beyond the gate
        target_beyond_gate = gate_pos + gate_normal * 4.0  # Target 4m beyond gate
        
        # Vector from drone to gate
        to_gate = gate_pos - current_pos
        
        # Check if behind gate
        behind_gate = np.dot(to_gate, gate_normal) > 0
        
        # For the custom track, we need to handle height transitions better
        # Find the next gate for better transitions
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_pos = self.gate_positions[next_gate_idx]
        
        # Create trajectory: aim directly through the gate center with height consideration
        gate_center_point = gate_pos.copy()
        
        # Adjust target height slightly based on gate height to improve passing
        if gate_pos[2] > current_pos[2]:
            # If gate is higher, aim slightly lower than center
            gate_center_point[2] -= 0.1
        else:
            # If gate is lower, aim slightly higher than center
            gate_center_point[2] += 0.1
        
        # Create a smooth trajectory through the gate
        for t in range(1, self.N):
            if t < self.N//2:
                # First half: aim precisely at gate center with smooth height transition
                alpha = t / (self.N//2)
                intermediate_point = (1 - alpha) * current_pos + alpha * gate_center_point
                trajectory[t] = intermediate_point
            else:
                # Second half: continue beyond gate with anticipation of next gate
                alpha = (t - self.N//2) / (self.N - self.N//2)
                
                # For last part of trajectory, blend toward next gate's height
                height_blend = max(0, (alpha - 0.5) * 2)  # Start height transition at 75% of second half
                height = (1 - height_blend) * gate_pos[2] + height_blend * next_gate_pos[2]
                
                # Basic position blending
                intermediate_point = (1 - alpha) * gate_center_point + alpha * target_beyond_gate
                intermediate_point[2] = height  # Apply height adjustment
                
                trajectory[t] = intermediate_point
        
        # Apply custom track adjustments for specific tricky segments
        trajectory = self.adjust_for_track_geometry(robot_idx, trajectory)
        
        return trajectory
    
    def adjust_for_track_geometry(self, robot_idx, trajectory):
        """Apply specific adjustments for the custom track geometry"""
        gate_idx = self.current_gate_indices[robot_idx]
        
        # Handle the different gate segments in the custom track
        
        # Gate 1 -> Gate 2: Right to back-right (increase height, slight right turn)
        if gate_idx == 0:
            # Smooth the height transition
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                # Gradually adjust height after passing gate
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                height_adjustment = 0.3 * progress  # Gradually move higher
                trajectory[t, 2] += height_adjustment
        
        # Gate 2 -> Gate 3: Back-right to back (sharp left turn at height)
        elif gate_idx == 1:
            # Widen the turn for the sharp left
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                # Add slight outward adjustment for smoother turn
                outward_adjustment = np.array([0, 0.5, 0]) * progress
                trajectory[t] += outward_adjustment
        
        # Gate 3 -> Gate 4: Back to middle (height drop and diagonal movement)
        elif gate_idx == 2:
            # Gradual height descent
            for t in range(trajectory.shape[0]//3, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//3) / (trajectory.shape[0] - trajectory.shape[0]//3)
                # More aggressive early height adjustment for dropping gates
                height_adjustment = -0.4 * min(1.0, progress * 1.5)
                trajectory[t, 2] += height_adjustment
        
        # Gate 4 -> Gate 5: Middle to front-left (flat movement, turning left)
        elif gate_idx == 3:
            # Maintain height, smooth lateral transition
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                # Just ensure stable height 
                trajectory[t, 2] = max(trajectory[t, 2], 0.5)  # Ensure minimal safe height
        
        # Gate 5 -> Gate 6: Front-left to left (slight increase in height, sharp right turn)
        elif gate_idx == 4:
            # Smoother turn with height increase
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                # Gradual height increase
                height_adjustment = 0.3 * progress
                # Widen the turn slightly
                lateral_adjustment = np.array([0.5, 0, 0]) * progress
                trajectory[t] += lateral_adjustment
                trajectory[t, 2] += height_adjustment
        
        # Gate 6 -> Gate 1: Left to right (diagonal movement across center, slight height decrease)
        elif gate_idx == 5:
            # Crossing center, adjust to prepare for first gate
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                # Small height adjustment to prepare for gate 1
                height_adjustment = -0.2 * progress
                trajectory[t, 2] += height_adjustment
        
        return trajectory
    
    def optimize_trajectory(self, robot_idx, all_trajectories):
        """Improved trajectory optimization for custom track layout"""
        # Initialize with current trajectory
        if all_trajectories is None:
            trajectory = self.initialize_straight_trajectory(robot_idx)
        else:
            trajectory = all_trajectories[robot_idx].copy()
        
        current_pos = self.drones[robot_idx].state[0:3]
        current_vel = self.drones[robot_idx].state[3:6]
        
        # Get current target gate
        gate_idx = self.current_gate_indices[robot_idx]
        gate_position = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Create gate frame for progress measurement
        gate_frame = self.create_gate_reference_frame(gate_position, gate_yaw)
        
        # Find the next gate for better transitions
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_position = self.gate_positions[next_gate_idx]
        next_gate_yaw = self.gate_yaws[next_gate_idx]
        
        # Optimization parameters - adaptive based on proximity to gate
        dist_to_gate = np.linalg.norm(current_pos - gate_position)
        close_to_gate = dist_to_gate < 3.0
        
        # Adaptive optimization parameters
        n_steps = 7 if close_to_gate else 5  # More steps when close to gate
        learning_rate = 0.3 if close_to_gate else 0.2  # Higher learning rate when close
        
        # Start with current trajectory (skip first point which is fixed)
        opt_traj = trajectory[1:].copy()
        
        # Cache neighbor trajectories
        neighbor_trajectories = {}
        for j in self.neighbors[robot_idx]:
            if all_trajectories is None:
                neighbor_trajectories[j] = self.initialize_straight_trajectory(j)
            else:
                neighbor_trajectories[j] = all_trajectories[j]
        
        # Calculate gate target point - target beyond the gate
        # Adjust for the custom track orientation
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Create a blended target that considers the next gate's position
        direct_target = gate_position + gate_normal * 2.0  # Original target
        
        # For the custom track, blend targets more intelligently
        # Look ahead to next gate when appropriate
        target_to_next = next_gate_position - gate_position
        target_to_next_norm = np.linalg.norm(target_to_next)
        
        # Define the blended target
        if target_to_next_norm > 0:
            # Normalize the vector to next gate
            target_to_next = target_to_next / target_to_next_norm
            
            # Blend based on track characteristics
            # More weight on next gate for sharp turns
            dot_product = np.dot(gate_normal, target_to_next)
            turn_sharpness = 1.0 - abs(dot_product)  # 0 for straight path, 1 for 90° turn
            
            # Adjust target blending based on turn sharpness
            blend_factor = 0.4 * turn_sharpness  # More influence of next gate for sharp turns
            blended_target = gate_position + (gate_normal * (1-blend_factor) + target_to_next * blend_factor) * 3.0
            
            # Also blend height
            height_diff = next_gate_position[2] - gate_position[2]
            blended_target[2] = gate_position[2] + height_diff * 0.3  # Start height transition
        else:
            blended_target = direct_target
        

        # Simple gradient-based optimization
        for step in range(n_steps):
            # Calculate gradients for each point in trajectory
            gradients = np.zeros_like(opt_traj)
            
            # Progress gradient based on our progress function
            for t in range(opt_traj.shape[0]):
                # Compute progress at current position
                progress = self.compute_progress_to_gate(opt_traj[t], gate_position, gate_yaw)                
                # Compute numerical gradient by testing small perturbations
                eps = 0.1
                gradient = np.zeros(3)
                
                for dim in range(3):
                    # Create perturbed position
                    perturbed_pos = opt_traj[t].copy()
                    perturbed_pos[dim] += eps
                    
                    # Compute progress at perturbed position
                    perturbed_progress = self.compute_progress_to_gate(perturbed_pos, gate_position, gate_yaw)
                    
                    # Compute gradient
                    gradient[dim] = (perturbed_progress - progress) / eps
                
                # Normalize gradient if it's large
                grad_norm = np.linalg.norm(gradient)
                if grad_norm > 1.0:
                    gradient = gradient / grad_norm
                
                # Apply higher weights for later timesteps
                time_weight = 1.0 + t/opt_traj.shape[0] * 0.5  # Reduced scaling
                # gradients[t] += gradient * time_weight

                track_influence = min(1.0, 0.3 + 0.7 * t / opt_traj.shape[0])  # Increases with time
                
                # Compute track guidance vector
                track_guidance = self.compute_track_guidance(robot_idx, opt_traj[t])

                # Apply stronger track guidance for later gates
                gate_idx = self.current_gate_indices[robot_idx]
                if gate_idx >= 2:  # After second gate, increase track guidance
                    track_influence *= 1.5
                
                # Add to gradients
                gradients[t] += track_guidance * track_influence
                
                # For later timesteps, add influence of next gate if we're likely to pass this one
                if t > opt_traj.shape[0] * 0.7 and dist_to_gate < 3.0:
                    # Add a small gradient component toward the next gate's height
                    height_gradient = np.zeros(3)
                    height_gradient[2] = np.sign(next_gate_position[2] - opt_traj[t, 2]) * 0.1
                    gradients[t] += height_gradient * (t / opt_traj.shape[0])
            
            # Smoothness gradient
            if opt_traj.shape[0] > 1:
                # Penalize large accelerations
                for t in range(1, opt_traj.shape[0]-1):
                    accel = opt_traj[t+1] - 2*opt_traj[t] + opt_traj[t-1]
                    gradients[t] -= 0.15 * accel  # Stronger smoothness penalty
                
                # Match initial velocity with stronger weighting
                init_vel = (opt_traj[0] - current_pos) / self.dt
                vel_diff = init_vel - current_vel
                gradients[0] -= 0.3 * vel_diff  # Increased weight for initial velocity matching
            
            # Enhanced collision avoidance - higher weighting and more proactive
            for j in self.neighbors[robot_idx]:
                other_traj = neighbor_trajectories[j]
                
                # Check if this is a problematic pair with collision history
                collision_history_factor = 1.0
                if hasattr(self, '_collision_counters'):
                    key = (min(robot_idx, j), max(robot_idx, j))
                    collision_count = self._collision_counters.get(key, 0)
                    # Increase avoidance strength for drones that collide frequently
                    collision_history_factor = 1.0 + min(2.0, collision_count * 0.5)
                
                # Also check if this pair is tagged for evasive action
                evasive_pair = False
                if hasattr(self, '_evasive_pairs'):
                    key = (min(robot_idx, j), max(robot_idx, j))
                    evasive_pair = key in self._evasive_pairs
                    if evasive_pair:
                        collision_history_factor *= 1.5  # Even stronger for evasive pairs
                
                for t in range(opt_traj.shape[0]):
                    t_idx = t + 1  # Adjusting index for full trajectory
                    if t_idx < other_traj.shape[0]:
                        # Vector from other to self
                        diff = opt_traj[t] - other_traj[t_idx]
                        dist = np.linalg.norm(diff)
                        
                        # More proactive collision avoidance - increased detection radius
                        avoidance_radius = self.min_distance * 2.0  # Increased radius
                        
                        # Apply stronger avoidance when close
                        if dist < avoidance_radius and dist > 0:
                            # Repulsive direction (away from other drone)
                            repulsion_dir = diff / dist
                            
                            # Stronger, progressive strength that increases more rapidly as drones get closer
                            # Using a quadratic falloff for more aggressive avoidance
                            proximity_factor = (avoidance_radius - dist) / avoidance_radius
                            strength = 1.0 + 3.0 * proximity_factor * proximity_factor
                            
                            # Include sensitivity term from SE-IBR
                            sensitivity = self.compute_sensitivity_minimal(robot_idx, j, gate_idx)
                            
                            # Apply stronger repulsive gradient with sensitivity adjustment
                            collision_gradient = repulsion_dir * strength * sensitivity * collision_history_factor
                            
                            # Apply gradient with higher weight for collision avoidance
                            # Weight decays with time to allow eventual goal-reaching
                            time_decay = 1.0 if t < 3 else (1.0 - 0.1 * (t-3))
                            collision_weight = 1.5 * time_decay  # Base weight with time decay
                            
                            # Even stronger for evasive pairs
                            if evasive_pair:
                                collision_weight *= 1.5
                                
                            gradients[t] += collision_weight * collision_gradient
                            
                            # Also add vertical separation for nearby drones - prefer going over
                            if dist < self.min_distance * 1.2:
                                vertical_bias = np.array([0, 0, 0.3 * collision_history_factor])
                                gradients[t] += vertical_bias
            
            # Apply gradients with learning rate
            opt_traj += learning_rate * gradients
            
            # Apply constraints directly
            # 1. Velocity constraints
            if opt_traj.shape[0] > 0:
                for t in range(opt_traj.shape[0]):
                    # Calculate velocity
                    if t == 0:
                        vel = (opt_traj[t] - current_pos) / self.dt
                    else:
                        vel = (opt_traj[t] - opt_traj[t-1]) / self.dt
                    
                    # Clip velocity with a safety margin
                    for i in range(3):
                        max_vel = self.max_velocity[i] * 0.95  # 5% safety margin
                        if abs(vel[i]) > max_vel:
                            vel[i] = np.sign(vel[i]) * max_vel
                    
                    # Recompute position from velocity
                    if t == 0:
                        opt_traj[t] = current_pos + vel * self.dt
                    else:
                        opt_traj[t] = opt_traj[t-1] + vel * self.dt
        
        # Combine with fixed initial position
        optimized_trajectory = np.vstack([current_pos, opt_traj])
        
        # Apply height constraints to keep drones at reasonable heights
        min_height = 0.5  # Minimum allowed height
        max_height = 5.0  # Maximum allowed height
        
        for t in range(optimized_trajectory.shape[0]):
            optimized_trajectory[t, 2] = np.clip(optimized_trajectory[t, 2], min_height, max_height)
        
        # Final collision avoidance check - ensure minimum separation at key points
        for j in self.neighbors[robot_idx]:
            other_traj = all_trajectories[j]
            
            # Focus on critical points - start, middle, end
            critical_points = [0, optimized_trajectory.shape[0]//2, optimized_trajectory.shape[0]-1]
            
            for t in critical_points:
                if t < other_traj.shape[0]:
                    diff = optimized_trajectory[t] - other_traj[t]
                    dist = np.linalg.norm(diff)
                    
                    # If too close, apply direct separation
                    if dist < self.min_distance * 0.9 and dist > 0:
                        # Calculate separation direction
                        sep_dir = diff / dist
                        
                        # Calculate how much to move
                        move_dist = (self.min_distance - dist) * 0.6  # Move 60% of the way
                        
                        # Apply move (only to non-initial points)
                        if t > 0:
                            optimized_trajectory[t] += sep_dir * move_dist
        
        # Apply custom track geometry adjustments
        optimized_trajectory = self.adjust_for_track_geometry(robot_idx, optimized_trajectory)
        
        return optimized_trajectory
    
    def compute_sensitivity_minimal(self, ego_idx, other_idx, ego_gate_idx):
        """Enhanced sensitivity computation with collision history awareness"""
        # Get current gate indices
        other_gate_idx = self.current_gate_indices[other_idx]
        
        # Check collision history to increase sensitivity for problematic pairs
        collision_history_factor = 1.0
        if hasattr(self, '_collision_counters'):
            key = (min(ego_idx, other_idx), max(ego_idx, other_idx))
            collision_count = self._collision_counters.get(key, 0)
            # Increase sensitivity for drones that collide frequently
            collision_history_factor = 1.0 + min(2.0, collision_count * 0.5)
        
        # Base sensitivity on relative gate progress with collision history
        if ego_gate_idx > other_gate_idx:
            return 1.0 * collision_history_factor  # Ego drone has priority (higher gate index)
        elif ego_gate_idx < other_gate_idx:
            return 0.2 * collision_history_factor  # Other drone has priority
        else:
            # Same gate, compare distances
            ego_dist = np.linalg.norm(self.drones[ego_idx].state[:3] - self.gate_positions[ego_gate_idx])
            other_dist = np.linalg.norm(self.drones[other_idx].state[:3] - self.gate_positions[other_gate_idx])
            
            if ego_dist < other_dist:
                return (0.8 * collision_history_factor)  # Ego drone is closer to gate
            else:
                return (0.4 * collision_history_factor)  # Other drone is closer to gate
    
    def check_gate_passed(self, robot_idx, verbose=False):
        """Enhanced gate passing detection for custom track"""
        # Get current gate information
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Get robot position and velocity
        robot_pos = self.drones[robot_idx].state[0:3]
        robot_vel = self.drones[robot_idx].state[3:6]
        
        # For custom track with height differences, adapt tolerance
        height_difference = abs(gate_pos[2] - robot_pos[2])
        adaptive_tolerance = self.gate_passing_tolerance * (1.0 + height_difference * 0.3)
        
        # Vector from drone to gate
        to_gate = gate_pos - robot_pos
        dist_to_gate = np.linalg.norm(to_gate)
        
        # Gate normal direction (pointing forward from gate)
        # Adjust for the custom track orientation
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Calculate signed distance to gate plane (positive if behind gate, negative if in front)
        signed_dist = np.dot(to_gate, gate_normal)
        
        # Store previous positions and distances
        if not hasattr(self, '_prev_positions'):
            self._prev_positions = {}
            self._prev_signed_dists = {}
        
        if robot_idx not in self._prev_positions:
            self._prev_positions[robot_idx] = robot_pos
            self._prev_signed_dists[robot_idx] = signed_dist
            return False
        
        # Get previous values
        prev_pos = self._prev_positions[robot_idx]
        prev_signed_dist = self._prev_signed_dists[robot_idx]
        
        # Calculate lateral and vertical distance to center line
        gate_lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        lateral_dist = np.abs(np.dot(to_gate, gate_lateral))
        vertical_dist = np.abs(to_gate[2])
        
        # Merged detection conditions optimized for custom track:
        
        # - Check absolute distance (more permissive for varying heights)
        close_enough = dist_to_gate < adaptive_tolerance
        
        # - Check if we crossed the gate plane
        crossed_plane = prev_signed_dist > 0 and signed_dist <= 0
        
        # - More permissive center criterion for custom track with height variations
        is_close_to_center = lateral_dist < 1.8 and vertical_dist < 1.8
        
        # - Check direction (more permissive for custom track)
        direction_check = True
        vel_magnitude = np.linalg.norm(robot_vel)
        if vel_magnitude > 0.1:  # Very low threshold
            vel_dir = robot_vel / vel_magnitude
            alignment = np.dot(vel_dir, gate_normal)
            direction_check = alignment > 0.0  # Most permissive for custom track
        
        # Direct distance override - needed for sharp turns in custom track
        direct_override = close_enough and is_close_to_center and signed_dist < 0
        
        # Update previous values
        self._prev_positions[robot_idx] = robot_pos
        self._prev_signed_dists[robot_idx] = signed_dist
        
        # Gate is passed if EITHER:
        # - We detected a proper crossing (crossed plane + close to center + right direction)
        # - OR we're very close to the gate on the front side with correct positioning
        gate_passed = (crossed_plane and is_close_to_center and direction_check) or direct_override
        
        # Additional debug information
        if (gate_passed or close_enough) and verbose:
            print(f"Gate check: Drone {robot_idx}, Dist={dist_to_gate:.2f}m, " 
                f"Signed={signed_dist:.2f}, Crossed={crossed_plane}, "
                f"Center={is_close_to_center}, Override={direct_override}")
        
        if gate_passed and verbose:
            print(f"GATE PASSED: Drone {robot_idx} passed gate {gate_idx}")
            new_gate_idx = (gate_idx + 1) % self.n_gates
            print(f"Setting new gate index to {new_gate_idx}")
        
        return gate_passed
    
    def plan(self, verbose=False):
        """Optimized planning with performance improvements and custom track handling"""
        start_time = time.time()
        
        # Update neighbors (less frequently for performance)
        if not hasattr(self, '_neighbor_update_counter'):
            self._neighbor_update_counter = 0
        
        self._neighbor_update_counter += 1
        if self._neighbor_update_counter % 3 == 0:
            self.find_neighbors()
        
        # Initialize or reuse trajectories
        if not hasattr(self, 'previous_trajectories') or self.previous_trajectories is None:
            all_trajectories = np.zeros((self.num_robots, self.N, 3))
            for i in range(self.num_robots):
                all_trajectories[i] = self.initialize_straight_trajectory(i)
        else:
            all_trajectories = self.previous_trajectories.copy()
        
        # Adaptive iterations based on proximity to gates
        close_to_gates = False
        positions = self.get_drone_positions()
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            dist = np.linalg.norm(positions[i] - self.gate_positions[gate_idx])
            if dist < 2.0:
                close_to_gates = True
                break
        
        # Use more iterations for custom track to handle complex geometry
        iterations = 4 if close_to_gates else 3  # Increased iterations for custom track
        
        # SE-IBR iterations with performance optimization
        for iteration in range(iterations):
            # Sequential optimization
            for i in range(self.num_robots):
                # Optimize for all drones in custom track (more complex navigation)
                if iteration == 0 or len(self.neighbors[i]) > 0 or close_to_gates:
                    all_trajectories[i] = self.optimize_trajectory(i, all_trajectories)
        
        # Apply custom track adjustments for all drones
        for i in range(self.num_robots):
            # Fix: Pass a single trajectory, not the whole array
            adjusted_trajectory = self.adjust_for_track_geometry(i, all_trajectories[i])
            all_trajectories[i] = adjusted_trajectory
        
        # Store for next warm-start
        self.previous_trajectories = all_trajectories.copy()
        
        end_time = time.time()
        if hasattr(self, '_step_counter'):
            self._step_counter += 1
        else:
            self._step_counter = 0
            
        if self._step_counter % 10 == 0 and verbose:
            print(f"Planning took {end_time - start_time:.4f} seconds")
        
        return {
            'trajectories': all_trajectories,
            'computation_time': end_time - start_time
        }
    
    def execute_step(self, trajectories, sim_time=None, verbose=False):
        """Execute step with fine-tuned parameters for custom track"""
        if verbose:
            print(f"Process {os.getpid()}: Starting execute_step at t={sim_time}")
            print(f"Initial drone positions: {[self.drones[i].state[0:3] for i in range(self.num_robots)]}")
        
        # Store initial positions for validation
        starting_positions = [self.drones[i].state[0:3].copy() for i in range(self.num_robots)]
        
        # Compute velocity commands
        velocity_commands = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            if trajectories.shape[1] > 1:
                velocity_commands[i] = (trajectories[i, 1] - trajectories[i, 0]) / self.dt
        
        # Apply more adaptive gate targeting
        positions = self.get_drone_positions()
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            gate_pos = self.gate_positions[gate_idx]
            gate_yaw = self.gate_yaws[gate_idx]
            
            # Gate normal direction (adjusted for custom track)
            gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
            
            # Vector from drone to gate
            to_gate = gate_pos - positions[i]
            dist_to_gate = np.linalg.norm(to_gate)
            
            # Check if we're behind the gate (dot product with normal is positive)
            behind_gate = np.dot(to_gate, gate_normal) > 0
            
            # Calculate vertical component separately
            height_diff = gate_pos[2] - positions[i][2]
            
            # Fine-tuned boosting parameters for custom track:
            if behind_gate:
                # Adaptive boost parameters based on gate geometry
                boost_start_dist = 5.0  # Start boosting earlier
                boost_factor_base = 1.5  # Base boost factor
                
                if dist_to_gate < boost_start_dist:
                    vel_norm = np.linalg.norm(velocity_commands[i])
                    
                    if vel_norm > 0.1:
                        vel_dir = velocity_commands[i] / vel_norm
                        
                        # More permissive alignment requirement for custom track
                        required_alignment = min(0.4, 0.2 + 0.3 * (1.0 - dist_to_gate/boost_start_dist))
                        alignment = np.dot(vel_dir, gate_normal)
                        
                        if alignment > required_alignment:
                            # Adaptive boost factor - stronger when closer
                            progress = 1.0 - dist_to_gate/boost_start_dist
                            boost_factor = boost_factor_base + 0.8 * progress * (1.0 - progress * 0.5)
                            
                            # Adjust direction with stronger height component for custom track
                            direction_weight = min(0.5, 0.2 + 0.3 * progress)
                            
                            # For custom track: add height adjustment to direction
                            adjusted_dir = vel_dir.copy()
                            if abs(height_diff) > 0.3:
                                # Add stronger vertical component when height difference is significant
                                height_factor = min(0.3, abs(height_diff) * 0.1) * np.sign(height_diff)
                                adjusted_dir[2] += height_factor
                                # Renormalize
                                adjusted_dir = adjusted_dir / np.linalg.norm(adjusted_dir)
                            
                            # Blend with gate normal direction
                            adjusted_dir = (1.0 - direction_weight) * adjusted_dir + direction_weight * gate_normal
                            adjusted_dir = adjusted_dir / np.linalg.norm(adjusted_dir)
                            
                            # Apply boosted velocity
                            velocity_commands[i] = adjusted_dir * vel_norm * boost_factor
                            
                            # Ensure we don't exceed maximum velocity with a buffer
                            for dim in range(3):
                                max_vel = self.max_velocity[dim] * 0.95  # 5% buffer
                                velocity_commands[i][dim] = np.clip(velocity_commands[i][dim], 
                                                            -max_vel, max_vel)
        
        # Execute drone dynamics with enhanced collision checking
        all_states = [[self.drones[i].state.copy()] for i in range(self.num_robots)]
        
        # Track positions for collision checking
        timestep_positions = []
        timestep_velocities = []
        
        # Run dynamics steps
        for step in range(self.dynamics_steps):
            step_positions = []
            step_velocities = []
            
            # Update all drones
            for i in range(self.num_robots):
                # More cautious velocity control when drones are close
                adjusted_velocity = velocity_commands[i].copy()
                
                # Check proximity to other drones
                for j in range(self.num_robots):
                    if i != j:
                        dist = np.linalg.norm(self.drones[i].state[0:3] - self.drones[j].state[0:3])
                        # If drones are close, reduce velocity
                        if dist < self.min_distance * 1.2:
                            vel_norm = np.linalg.norm(adjusted_velocity)
                            if vel_norm > 0.5:
                                # Scale velocity down when close to avoid collisions
                                reduction = max(0.5, dist / (self.min_distance * 1.5)) 
                                adjusted_velocity = adjusted_velocity * reduction
                                
                control = self.controllers[i].compute_control(adjusted_velocity)
                self.drones[i].step(control)
                
                step_positions.append(self.drones[i].state[0:3])
                step_velocities.append(self.drones[i].state[3:6])
                
                # Check for gate passing with custom track detection
                if self.check_gate_passed(i, verbose=verbose):
                    if verbose:
                        print(f"Drone {i} passed gate {self.current_gate_indices[i]}")
                    
                    # Update to next gate
                    self.current_gate_indices[i] = (self.current_gate_indices[i] + 1) % self.n_gates
                    
                    # Immediately update velocity command to target new gate
                    if step < self.dynamics_steps - 1:
                        new_gate_idx = self.current_gate_indices[i]
                        new_gate_pos = self.gate_positions[new_gate_idx]
                        new_gate_yaw = self.gate_yaws[new_gate_idx]
                        
                        # Adjusted for custom track orientation
                        gate_normal = np.array([np.cos(new_gate_yaw - np.pi/2), np.sin(new_gate_yaw - np.pi/2), 0])
                        
                        # Direct velocity toward new gate with height consideration
                        to_new_gate = new_gate_pos - self.drones[i].state[0:3]
                        dist_to_new_gate = np.linalg.norm(to_new_gate)
                        
                        if dist_to_new_gate > 0:
                            # Calculate direction with additional height emphasis for custom track
                            xy_dist = np.linalg.norm(to_new_gate[:2])
                            height_diff = to_new_gate[2]
                            
                            # Stronger vertical component for significant height differences
                            height_emphasis = 1.0
                            if abs(height_diff) > 1.0:
                                height_emphasis = 1.3  # Increase emphasis for large height changes
                            
                            # Create adjusted direction vector
                            new_dir = to_new_gate.copy()
                            new_dir[2] *= height_emphasis  # Emphasize height component
                            new_dir = new_dir / np.linalg.norm(new_dir)  # Normalize
                            
                            # Set a strong initial velocity toward the new gate
                            vel_magnitude = np.linalg.norm(velocity_commands[i])
                            velocity_commands[i] = new_dir * max(vel_magnitude, 2.0)
                            
                            # Ensure we don't exceed maximum velocity
                            for dim in range(3):
                                velocity_commands[i][dim] = np.clip(velocity_commands[i][dim], 
                                                            -self.max_velocity[dim],
                                                            self.max_velocity[dim])
            
            timestep_positions.append(np.array(step_positions))
            timestep_velocities.append(np.array(step_velocities))
        
        # Check for collisions
        collisions = self.check_for_collisions()
        
        # Apply collision recovery
        if collisions:
            if verbose:
                print(f"WARNING: {len(collisions)} collisions detected.")
            # Record collisions and apply recovery
            for i, j, _ in collisions:
                # Record the collision in a counter
                if not hasattr(self, '_collision_counters'):
                    self._collision_counters = {}
                
                key1 = (min(i, j), max(i, j))
                self._collision_counters[key1] = self._collision_counters.get(key1, 0) + 1
                
                # Apply evasive action for repeated collisions
                if self._collision_counters.get(key1, 0) > 2:
                    if verbose:
                        print(f"Repeated collisions between drones {i} and {j} - applying evasive action")
                    # Apply stronger separation force in the next planning step
                    if not hasattr(self, '_evasive_pairs'):
                        self._evasive_pairs = set()
                    self._evasive_pairs.add(key1)
        
        # Save final states
        for i in range(self.num_robots):
            all_states[i].append(self.drones[i].state.copy())
            distance_moved = np.linalg.norm(self.drones[i].state[0:3] - starting_positions[i])
            max_possible_movement = self.max_velocity.max() * self.dt
        
            if distance_moved > max_possible_movement and verbose:
                print(f"ERROR: Drone {i} moved {distance_moved:.2f}m in one step")
                print(f"From: {starting_positions[i]}")
                print(f"To: {self.drones[i].state[0:3]}")
                print(f"Max theoretical distance: {max_possible_movement:.2f}m")
        
        return all_states, collisions
    
    def check_for_collisions(self):
        """Fast collision detection using vectorized operations"""
        positions = self.get_drone_positions()
        collisions = []
        
        # Only check drone pairs that were close in the previous step
        if not hasattr(self, '_close_pairs'):
            self._close_pairs = []
            # Check all pairs initially
            for i in range(self.num_robots):
                for j in range(i+1, self.num_robots):
                    self._close_pairs.append((i, j))
        
        new_close_pairs = []
        
        # Check only previously close pairs, and add new close pairs for next time
        for i, j in self._close_pairs:
            dist = np.linalg.norm(positions[i] - positions[j])
            
            if dist < self.min_distance * self.collision_coeff:
                collisions.append((i, j, dist))
            
            # Keep tracking if still relatively close
            if dist < self.min_distance * 3.0:
                new_close_pairs.append((i, j))
        
        # Also check neighbors to ensure we don't miss any collisions
        for i in range(self.num_robots):
            for j in self.neighbors[i]:
                if i < j and (i, j) not in new_close_pairs:
                    dist = np.linalg.norm(positions[i] - positions[j])
                    
                    if dist < self.min_distance * self.collision_coeff:
                        collisions.append((i, j, dist))
                    
                    if dist < self.min_distance * 3.0:
                        new_close_pairs.append((i, j))
        
        self._close_pairs = new_close_pairs
        return collisions
    
    def compute_track_guidance(self, robot_idx, position):
        """Compute a guidance vector to keep drones on the track"""
        gate_idx = self.current_gate_indices[robot_idx]
        next_gate_idx = (gate_idx + 1) % self.n_gates
        
        # Get current and next gate positions
        current_gate = self.gate_positions[gate_idx]
        next_gate = self.gate_positions[next_gate_idx]
        
        # Vector to current gate
        to_current = current_gate - position
        dist_to_current = np.linalg.norm(to_current)
        
        # Vector to next gate
        to_next = next_gate - position
        dist_to_next = np.linalg.norm(to_next)
        
        # Calculate ideal track path (vector from current to next gate)
        track_vector = next_gate - current_gate
        track_length = np.linalg.norm(track_vector)
        
        if track_length > 0:
            track_direction = track_vector / track_length
        else:
            # Fallback if gates are at same position
            return to_current / max(0.1, dist_to_current)
        
        # Project drone position onto track line
        t = np.clip(np.dot(position - current_gate, track_direction) / track_length, 0, 1)
        closest_point = current_gate + t * track_vector
        
        # Vector from drone to closest point on track
        to_track = closest_point - position
        dist_to_track = np.linalg.norm(to_track)
        
        # Blend guidance between:
        # 1. Direction to current gate (when far from it)
        # 2. Direction along track (when between gates)
        # 3. Direction to next gate (when close to current gate)
        
        # Weighting factors
        current_weight = max(0, 1.0 - t * 2)  # Decreases as we progress along track
        track_weight = 1.0 - current_weight - max(0, (t - 0.5) * 2)  # Highest in middle
        next_weight = max(0, (t - 0.5) * 2)  # Increases as we approach next gate
        
        # If we're too far from track, increase the weight to get back on track
        if dist_to_track > 2.0:
            track_correction = to_track / dist_to_track
            track_weight = min(1.0, track_weight + 0.3 + dist_to_track * 0.1)
            current_weight *= (1 - track_weight)
            next_weight *= (1 - track_weight)
        else:
            track_correction = to_track / max(0.1, dist_to_track)
        
        # Combine the vectors with weights
        guidance = np.zeros(3)
        if dist_to_current > 0:
            guidance += current_weight * (to_current / dist_to_current)
        if dist_to_track > 0:
            guidance += track_weight * track_correction
        if dist_to_next > 0:
            guidance += next_weight * (to_next / dist_to_next)
        
        # Normalize the guidance vector
        guidance_norm = np.linalg.norm(guidance)
        if guidance_norm > 0:
            guidance = guidance / guidance_norm
        
        return guidance

    def create_video_from_frames(self, frame_directory, output_file="simulation_video.mp4", fps=10):
        """
        Create a video from a directory of frame images.
        
        Parameters:
        -----------
        frame_directory : str
            Directory containing the frame images (PNG files)
        output_file : str
            Path to save the output video
        fps : int
            Frames per second for the output video
        
        Returns:
        --------
        str
            Path to the created video file or None if failed
        """
        try:
            import cv2
            import os
            import glob
            
            # Find all frame files
            frame_files = sorted(glob.glob(os.path.join(frame_directory, "step_*.png")))
            
            if not frame_files:
                print(f"No frame files found in {frame_directory}")
                return None
            
            print(f"Creating video from {len(frame_files)} frames...")
            
            # Read first frame to get dimensions
            first_frame = cv2.imread(frame_files[0])
            height, width, layers = first_frame.shape
            
            # Create video writer
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use mp4v codec
            video = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
            
            # Add each frame to the video
            for frame_file in frame_files:
                video.write(cv2.imread(frame_file))
            
            # Release the video writer
            video.release()
            
            print(f"Video created successfully: {output_file}")
            return output_file
        
        except ImportError:
            print("Error: OpenCV (cv2) is required to create videos.")
            print("Install it with: pip install opencv-python")
            return None
        except Exception as e:
            print(f"Error creating video: {str(e)}")
            return None

    def fix_orbit_trajectories(self, robot_idx, all_trajectories, verbose=False):
        """Enhanced trajectory fixing specifically for custom track gates"""
        # Extract just this robot's trajectory
        trajectory = all_trajectories[robot_idx].copy()
        
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Gate normal direction (adjusted for custom track)
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Next gate information for smoother transitions
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_pos = self.gate_positions[next_gate_idx]
        
        # Target 4m beyond gate for more aggressive passing
        target_beyond_gate = gate_pos + gate_normal * 4.0
        
        # For custom track - adjust gate target height slightly for better passing
        gate_center_point = gate_pos.copy()
        
        # For upward gates, aim slightly below center
        if gate_pos[2] > trajectory[0, 2]:
            gate_center_point[2] -= 0.1
        # For downward gates, aim slightly above center
        elif gate_pos[2] < trajectory[0, 2]:
            gate_center_point[2] += 0.1
        
        # Create direct trajectory with special track handling
        for t in range(1, self.N):
            if t < self.N//2:
                # First half: aim precisely at adjusted gate center point
                alpha = t / (self.N//2)
                trajectory[t] = (1 - alpha) * trajectory[0] + alpha * gate_center_point
            else:
                # Second half: continue beyond gate with height transition to next gate
                alpha = (t - self.N//2) / (self.N - self.N//2)
                base_pos = (1 - alpha) * gate_center_point + alpha * target_beyond_gate
                
                # Start transitioning height toward next gate in latter part
                if alpha > 0.5:
                    height_blend = (alpha - 0.5) * 2  # 0 to 1 in latter half
                    height_adjustment = (next_gate_pos[2] - gate_pos[2]) * height_blend * 0.3
                    base_pos[2] += height_adjustment
                
                trajectory[t] = base_pos
        
        # Apply track-specific adjustments
        adjusted_trajectory = self.adjust_for_track_geometry(robot_idx, trajectory)
        
        # Return the adjusted trajectory (don't modify all_trajectories directly)
        return adjusted_trajectory

    
    def simulate(self, num_steps=50, visualize=True, save_path=None, create_video=False, verbose=False):
        """Simulation loop optimized for custom track"""
        # Initialize progress tracking
        self._simulation_step = 0
        self._gates_passed_total = 0
        
        # Add last positions tracking
        last_positions = self.get_drone_positions().copy()
        
        # Minimal data storage
        all_positions = [self.get_drone_positions()]
        all_velocities = [self.get_drone_velocities()]
        all_gate_indices = [self.current_gate_indices.copy()]
        all_collisions = []
        
        sim_time = 0.0
        total_collision_count = 0
        gates_passed = [0] * self.num_robots
        
        # Precompute straight trajectories for speed
        if hasattr(self, 'initialize_all_straight_trajectories'):
            self.initialize_all_straight_trajectories()
        
        # Visualization interval
        vis_interval = 2
        
        # Print initial state for debugging
        if verbose:
            positions = self.get_drone_positions()
            print("\nINITIAL STATE:")
            for i in range(self.num_robots):
                gate_idx = self.current_gate_indices[i]
                dist = np.linalg.norm(positions[i] - self.gate_positions[gate_idx])
                print(f"Drone {i}: {dist:.2f}m from gate {gate_idx}")
            print()
        
        saved_frames = []
        for step in range(num_steps):
            self._simulation_step = step
            if verbose:
                print(f"\n--- Simulation step {step}/{num_steps} ---")
            
            # Record current gate indices
            current_gates = self.current_gate_indices.copy()
            
            # Plan
            plan_result = self.plan()
            
            # Execute
            _, step_collisions = self.execute_step(plan_result['trajectories'], sim_time)
            
            # Track collisions
            if step_collisions:
                all_collisions.append((step, step_collisions))
                total_collision_count += len(step_collisions)
            
            # Check gates passed and update total
            step_gates_passed = 0
            for i in range(self.num_robots):
                if self.current_gate_indices[i] != current_gates[i]:
                    gates_passed[i] += 1
                    step_gates_passed += 1
                    if verbose:
                        old_gate = current_gates[i]
                        new_gate = self.current_gate_indices[i]
                        print(f"SUCCESS! Drone {i} passed gate {old_gate}, now targeting gate {new_gate}")
            
            self._gates_passed_total += step_gates_passed
            
            # Force progress if stuck
            if step > 30 and self._gates_passed_total == 0 and step % 10 == 0:
                # Force progress after many steps if no gates have been passed
                if verbose:
                    print("WARNING: No gates passed after many steps - forcing progress")
                self.force_gate_progress(verbose=verbose)
            
            # Special handling for custom track - detect specific stuck situations
            if step > 20 and step % 20 == 0:
                positions = self.get_drone_positions()
                # Check if drones haven't moved much in the last 20 steps
                stuck_drones = []
                for i in range(self.num_robots):
                    dist_moved = np.linalg.norm(positions[i] - last_positions[i])
                    if dist_moved < 0.5:  # Very little movement
                        stuck_drones.append(i)
                
                # If any drones are stuck, help them progress
                if stuck_drones and verbose:
                    print(f"WARNING: Drones {stuck_drones} appear to be stuck - applying intervention")
                    for drone_idx in stuck_drones:
                        # Aggressive intervention - move drone forward a bit
                        gate_idx = self.current_gate_indices[drone_idx]
                        gate_pos = self.gate_positions[gate_idx]
                        gate_yaw = self.gate_yaws[gate_idx]
                        
                        # Move in direction of gate
                        to_gate = gate_pos - positions[drone_idx]
                        dist = np.linalg.norm(to_gate)
                        if dist > 0:
                            move_dir = to_gate / dist
                            # Apply small boost to help unstick
                            boost_vel = move_dir * 0.5  # Gentle push
                            # Add to drone velocity
                            self.drones[drone_idx].state[3:6] += boost_vel
                
                # Update last positions
                last_positions = positions.copy()
            
            # Update time
            sim_time += self.dt
            
            # Record state (only every few steps to save memory)
            if step % 2 == 0 or step == num_steps-1:
                all_positions.append(self.get_drone_positions())
                all_velocities.append(self.get_drone_velocities())
                all_gate_indices.append(self.current_gate_indices.copy())
            
            # Visualize
            should_visualize = (
                visualize and 
                (step % vis_interval == 0 or step == num_steps-1 or step_collisions)
            )
            
            if should_visualize:
                fig = self.visualize(plan_result, sim_time=sim_time, collisions=step_collisions)
                if save_path:
                    frame_path = f"{save_path}/step_{step:03d}.png"
                    plt.savefig(frame_path, dpi=80)  # Lower dpi
                    saved_frames.append(frame_path)
                plt.close(fig)
        
        # Calculate velocity statistics
        velocities_array = np.array(all_velocities)
        velocity_stats = self.calculate_mean_velocities(velocities_array)
        
        if verbose:
            print(f"\nSimulation completed:")
            print(f" - Time: {sim_time:.2f} seconds")
            print(f" - Total collisions: {total_collision_count}")
            print(f" - Overall mean speed: {velocity_stats['overall_mean_speed']:.2f} m/s")
            print(f" - Drone mean speeds: {[f'{v:.2f}' for v in velocity_stats['drone_mean_speeds']]}")
            print(f" - Gates passed: {gates_passed}")

        if create_video and save_path and saved_frames:
            video_path = f"{save_path}/simulation_video.mp4"
            self.create_video_from_frames(save_path, video_path)
        
        # Return simulation results
        return {
            'positions': np.array(all_positions),
            'velocities': np.array(all_velocities),
            'sim_time': sim_time,
            'collisions': all_collisions,
            'total_collision_count': total_collision_count,
            'velocity_stats': velocity_stats,
            'gates_passed': gates_passed
        }
    
    def force_gate_progress(self, verbose=False):
        """Development function to force gate progress when drones are stuck"""
        # Force all drones to the next gate
        if verbose:
            print("FORCING ALL DRONES TO NEXT GATE")
        for i in range(self.num_robots):
            old_idx = self.current_gate_indices[i]
            self.current_gate_indices[i] = (self.current_gate_indices[i] + 1) % self.n_gates
            if verbose:
                print(f"  - Drone {i}: Gate {old_idx} -> Gate {self.current_gate_indices[i]}")
        return True
    
    def calculate_mean_velocities(self, velocities):
        """Calculate mean velocity for each drone and overall mean"""
        # velocities shape: [timesteps, num_drones, 3]
        
        # Mean velocity magnitude per drone over time
        drone_mean_speeds = []
        for i in range(self.num_robots):
            drone_velocities = velocities[:, i, :]
            # Calculate velocity magnitude at each timestep
            speeds = np.linalg.norm(drone_velocities, axis=1)
            mean_speed = np.mean(speeds)
            drone_mean_speeds.append(mean_speed)
        
        # Overall mean velocity magnitude
        overall_mean_speed = np.mean(drone_mean_speeds)
        
        return {
            'drone_mean_speeds': drone_mean_speeds,
            'overall_mean_speed': overall_mean_speed
        }
    
    def visualize(self, plan_result=None, show_history=False, sim_time=None, collisions=None):
        """Visualization optimized for custom track layout"""
        # Only visualize when needed - check if a figure will actually be displayed or saved
        if not plt.isinteractive() and not plt.get_fignums():
            # Low-res mode for faster plotting
            fig = plt.figure(figsize=(8, 6), dpi=80)
        else:
            fig = plt.figure(figsize=(12, 10))
        
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot world boundaries with fewer points
        theta = np.linspace(0, 2*np.pi, 30)  # Reduced from 100
        x = self.world_size/2 * np.cos(theta)
        y = self.world_size/2 * np.sin(theta)
        z = np.zeros_like(theta)
        ax.plot(x, y, z, color="k", alpha=0.3)
        
        # For custom track: plot track path to show the intended route
        for i in range(self.n_gates):
            next_i = (i + 1) % self.n_gates
            gate_pos = self.gate_positions[i]
            next_gate_pos = self.gate_positions[next_i]
            
            # Draw a line connecting gate centers
            ax.plot([gate_pos[0], next_gate_pos[0]], 
                    [gate_pos[1], next_gate_pos[1]], 
                    [gate_pos[2], next_gate_pos[2]], 
                    'k--', alpha=0.2, linewidth=1)
        
        # Plot gates with fewer points
        for i, gate in enumerate(self.gates):
            center = gate["center"]
            yaw = gate["yaw"]
            
            # Create a circle with fewer points
            gate_radius = 1.0
            theta = np.linspace(0, 2*np.pi, 15)  # Reduced from 30
            circle_x = gate_radius * np.cos(theta)
            circle_y = gate_radius * np.sin(theta)
            circle_z = np.zeros_like(theta)
            
            # Rotate and translate - adjusted for custom track orientation
            rot_x = np.cos(yaw - np.pi/2) * circle_x - np.sin(yaw - np.pi/2) * circle_y
            rot_y = np.sin(yaw - np.pi/2) * circle_x + np.cos(yaw - np.pi/2) * circle_y
            
            gate_x = center[0] + rot_x
            gate_y = center[1] + rot_y
            gate_z = center[2] + circle_z
            
            # Color active gates differently
            if any(current_idx == i for current_idx in self.current_gate_indices):
                gate_color = 'r'
                alpha = 0.8
                linewidth = 2.5
            else:
                gate_color = 'orange'
                alpha = 0.5
                linewidth = 1.5
            
            ax.plot(gate_x, gate_y, gate_z, color=gate_color, linestyle='-', 
                    linewidth=linewidth, alpha=alpha)
            
            # Label gates with different styles based on activity
            if any(current_idx == i for current_idx in self.current_gate_indices):
                ax.text(center[0], center[1], center[2]+1.0, f"G{i}", 
                        color='red', fontsize=10, ha='center', weight='bold')
            else:
                ax.text(center[0], center[1], center[2]+0.8, f"G{i}", 
                        color='darkred', fontsize=8, ha='center', alpha=0.7)
        
        # Identify drones involved in collisions
        collided_drones = set()
        if collisions:
            for collision in collisions:
                collided_drones.add(collision[0])
                collided_drones.add(collision[1])
        
        # Plot drone positions
        positions = self.get_drone_positions()
        drone_colors = ['b', 'g', 'c', 'm', 'y']
        
        for i in range(self.num_robots):
            color = drone_colors[i % len(drone_colors)]
            current_gate_idx = self.current_gate_indices[i]
            
            marker_size = 100
            if i in collided_drones:
                color = 'red'
                marker_size = 200
            
            ax.scatter(positions[i, 0], positions[i, 1], positions[i, 2], 
                    color=color, s=marker_size, label=f'D{i} (G{current_gate_idx})')
            
            # Only draw safety radius for collided drones
            if i in collided_drones:
                # Use fewer points for the sphere
                u, v = np.mgrid[0:2*np.pi:10j, 0:np.pi:5j]
                x = positions[i, 0] + self.min_distance/2 * np.cos(u) * np.sin(v)
                y = positions[i, 1] + self.min_distance/2 * np.sin(u) * np.sin(v)
                z = positions[i, 2] + self.min_distance/2 * np.cos(v)
                ax.plot_wireframe(x, y, z, color='red', alpha=0.2)
            
            # Draw line to current gate
            current_gate = self.gate_positions[current_gate_idx]
            ax.plot([positions[i, 0], current_gate[0]], 
                    [positions[i, 1], current_gate[1]], 
                    [positions[i, 2], current_gate[2]], 
                    color=color, linestyle=':', alpha=0.5)
        
        # Plot planned trajectories
        if plan_result is not None:
            trajectories = plan_result['trajectories']
            
            # Show trajectory history if requested
            if show_history and 'trajectory_history' in plan_result and len(plan_result['trajectory_history']) > 1:
                history = plan_result['trajectory_history']
                history_to_show = [history[0], history[-1]]
                for iter_idx, iter_trajectories in enumerate(history_to_show):
                    alpha = 0.3 if iter_idx == 0 else 0.7
                    for i in range(self.num_robots):
                        color = drone_colors[i % len(drone_colors)]
                        # Downsample points for faster plotting
                        idx = np.linspace(0, iter_trajectories.shape[1]-1, 5).astype(int)
                        ax.plot(iter_trajectories[i, idx, 0], iter_trajectories[i, idx, 1], 
                                iter_trajectories[i, idx, 2], color=color, alpha=alpha, linestyle='--')
            
            # Plot final trajectories (downsampled)
            for i in range(self.num_robots):
                color = drone_colors[i % len(drone_colors)]
                if i in collided_drones:
                    color = 'red'
                # Downsample points for faster plotting
                idx = np.linspace(0, trajectories.shape[1]-1, 5).astype(int)
                ax.plot(trajectories[i, idx, 0], trajectories[i, idx, 1], trajectories[i, idx, 2], 
                        color=color, linewidth=2)
        
        # Set plot properties
        title = 'Multi-Drone Custom Track Navigation'
        if sim_time is not None:
            title += f' t={sim_time:.1f}s'
        if collisions:
            title += f' - {len(collisions)} COLLISIONS'
                
        ax.set_title(title)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # Set bounds
        max_bounds = np.max(np.abs(self.gate_positions)) * 1.2
        ax.set_xlim(-max_bounds, max_bounds)
        ax.set_ylim(-max_bounds, max_bounds)
        ax.set_zlim(0, max_bounds)
        
        # Use a smaller legend
        ax.legend(fontsize='small', loc='upper right')
        
        return fig


class TrajectoryPredictor:
    """Minimal trajectory predictor without TensorFlow"""
    
    def __init__(self, state_dim, traj_dim):
        self.state_dim = state_dim
        self.traj_dim = traj_dim
        self.output_shape = (-1, 3)
    
    def predict(self, state, target, neighbors):
        """Simple prediction without neural network"""
        # Create a simple trajectory from current position to target
        n_steps = self.traj_dim // 3
        traj = np.zeros((n_steps, 3))
        
        # Current position and velocity
        pos = state[:3]
        vel = state[3:6]
        
        # If neighbors are close, slightly adjust trajectory
        adjust = np.zeros(3)
        if neighbors:
            for n_state in neighbors:
                n_pos = n_state[:3]
                diff = pos - n_pos
                dist = np.linalg.norm(diff)
                if dist < 3.0 and dist > 0:
                    # Add small repulsive component
                    adjust += diff / dist * 0.2
        
        # Simple trajectory with velocity continuation and target approach
        for i in range(n_steps):
            t = i / (n_steps - 1) if n_steps > 1 else 0
            
            # Start with current velocity, then blend toward target
            if i == 0:
                # First step follows current velocity with adjustment
                traj[i] = pos + vel * 0.2 + adjust
            else:
                # Blend between velocity projection and direct target approach
                vel_proj = traj[i-1] + vel * 0.2 * (1 - t)
                target_proj = target * t + pos * (1 - t)
                blend = min(1.0, i/2)  # More direct path after first couple steps
                traj[i] = vel_proj * (1 - blend) + target_proj * blend
        
        return traj


def run_simulation(num_robots=3, num_steps=50, visualize=True, save_path=None):
    """Run a complete simulation with the given parameters"""
    # Create output directory if needed
    if save_path and not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Initialize simulation
    if not hasattr(run_simulation, '_sim') or run_simulation._sim is None:
        print("Initializing simulation...")
        run_simulation._sim = SEIBR_DroneNavigation_FigEight(
            num_robots=num_robots,
            planning_horizon=10,
            dt=0.1,
            max_iterations=5,  # we'll only use 2 in practice
            dynamics_dt=0.01,
        )
    else:
        # Reset existing simulation
        run_simulation._sim.reset()
    
    # Run simulation
    results = run_simulation._sim.simulate(num_steps=num_steps, visualize=visualize, save_path=save_path, create_video=True)
    
    
    # Print summary statistics
    print("\nSimulation Results:")
    print(f"Total simulation time: {results['sim_time']:.2f} seconds")
    print(f"Total collisions: {results['total_collision_count']}")
    print(f"Gates passed by each drone: {results['gates_passed']}")
    print(f"Mean speeds: {[f'{v:.2f} m/s' for v in results['velocity_stats']['drone_mean_speeds']]}")
    print(f"Overall mean speed: {results['velocity_stats']['overall_mean_speed']:.2f} m/s")
    
    return results

def evaluate_algorithm(num_robots=3, num_episodes=10, num_steps_per_episode=50, 
                       visualize=False, save_path=None):
    """
    Evaluate algorithm performance across multiple episodes and calculate statistics
    
    Args:
        num_robots: Number of drones to simulate
        num_episodes: Number of episodes to run
        num_steps_per_episode: Number of steps per episode
        visualize: Whether to visualize the simulation
        save_path: Path to save visualization files
        
    Returns:
        Dictionary containing evaluation metrics
    """
    # Arrays to store metrics across episodes
    rewards = []
    speeds = []
    targets_reached = []
    collision_flags = []
    
    for episode in range(num_episodes):
        # Create episode-specific save path if needed
        episode_save_path = None
        if save_path:
            episode_save_path = os.path.join(save_path, f"episode_{episode+1}")
            if not os.path.exists(episode_save_path):
                os.makedirs(episode_save_path)
        
        print(f"\n=== Running Episode {episode+1}/{num_episodes} ===")
        
        # Run a single simulation episode
        results = run_simulation(num_robots=num_robots, 
                                num_steps=num_steps_per_episode,
                                visualize=visualize, 
                                save_path=episode_save_path)
        
        # Calculate reward for this episode (you may need to implement this)
        # A simple reward could be: gates passed minus collision penalty
        episode_reward = sum(results['gates_passed']) - (1 if results['total_collision_count'] > 0 else 0)
        
        # Calculate average speed for this episode
        episode_speed = results['velocity_stats']['overall_mean_speed']
        
        # Calculate average targets/gates reached per drone
        episode_targets = sum(results['gates_passed']) / num_robots
        
        # Calculate collision flag (binary: 1 if any collision occurred, 0 otherwise)
        episode_collision = 1 if results['total_collision_count'] > 0 else 0
        
        # Store metrics
        rewards.append(episode_reward)
        speeds.append(episode_speed)
        targets_reached.append(episode_targets)
        collision_flags.append(episode_collision)
        
        # Report episode results
        print(f"Episode {episode+1} Results:")
        print(f"  Reward: {episode_reward:.2f}")
        print(f"  Speed: {episode_speed:.2f} m/s")
        print(f"  Average Targets Reached: {episode_targets:.2f}")
        print(f"  Collision Flag: {episode_collision}")
    
    # Calculate statistics across all episodes
    mean_reward = np.mean(rewards)
    mean_speed = np.mean(speeds)
    std_speed = np.std(speeds)
    mean_targets_reached = np.mean(targets_reached)
    std_targets_reached = np.std(targets_reached)
    mean_collisions = np.mean(collision_flags)
    std_collisions = np.std(collision_flags)
    
    # Print final evaluation results
    print(f"\nEvaluation results: Mean Reward={mean_reward:.2f}, Mean Speed={mean_speed:.2f}/±{std_speed:.2f}, "
          f"Mean Targets Reached={mean_targets_reached:.2f}/±{std_targets_reached:.2f}, "
          f"Mean Collisions={mean_collisions:.2f}/±{std_collisions:.2f}")
    

    results_txt_path = os.path.join(save_path, "results.txt")
    with open(results_txt_path, "a") as f:
        f.write(f"mean_targets_reached: {mean_targets_reached:.2f}/±{std_targets_reached:.2f}\n")
        f.write(f"mean_velocity: {mean_speed:.2f}/±{std_speed:.2f}\n")
        f.write(f"mean_collision: {mean_collisions:.2f}/±{std_collisions:.2f}\n")
    print("  Saved results.txt.")

    
    # Return all calculated metrics
    return {
        'rewards': rewards,
        'mean_reward': mean_reward,
        'speeds': speeds,
        'mean_speed': mean_speed,
        'std_speed': std_speed,
        'targets_reached': targets_reached,
        'mean_targets_reached': mean_targets_reached,
        'std_targets_reached': std_targets_reached,
        'collision_flags': collision_flags,
        'mean_collisions': mean_collisions,
        'std_collisions': std_collisions
    }

if __name__ == "__main__":


    results = evaluate_algorithm(num_robots=4, 
                                num_episodes=10, 
                                num_steps_per_episode=500, 
                                visualize=True,
                                save_path="4drones_fig_eight_results")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
from scipy.optimize import minimize
from joblib import Parallel, delayed
import os
from drone_dynamics import *


class SEIBR_DroneNavigation_FigEight:
    """Implementation of SE-IBR algorithm for multi-drone gate navigation"""
    
    def __init__(self, num_robots=3, planning_horizon=10, dt=0.2, max_iterations=5, dynamics_dt=0.01, n_gates=5, radius=6.0):
        """Initialize the multi-drone navigation system"""
        # System parameters
        self.num_robots = num_robots
        self.N = planning_horizon  # Planning horizon
        self.dt = dt  # Planning timestep
        self.max_iterations = max_iterations  # Max iterations for SE-IBR
        self.dynamics_dt = dynamics_dt  # Dynamics timestep
        self.dynamics_steps = int(dt / dynamics_dt)  # Number of dynamics steps per planning step
        
        # World parameters
        self.world_size = radius * 2  # Size of world (m)
        self.n_gates = n_gates  # Number of gates
        
        # Robot parameters
        self.robot_radius = 0.3  # robot radius (m)
        self.min_distance = 0.8  # minimum safe distance between robots (m)
        self.max_velocity = np.array([3.0, 3.0, 2.0])  # max velocity (m/s)
        self.gate_passing_tolerance = 0.5  # Distance to consider gate passed (m)
        self.collision_coeff = 0.32
        
        # Track parameters
        self.track_width = 2.0  # Width of track for progress calculation
        self.max_track_deviation = 12.0  # Maximum allowed deviation from track
        
        # SE-IBR parameters
        self.sensitivity_alpha = 0.5  # Sensitivity parameter for SE-IBR
        self.alpha_decay = 0.9  # Decay rate for sensitivity parameter
        self.lagrange_multipliers = {}  # Lagrange multipliers for constraints
        self.prev_lagrange_multipliers = {}  # Previous Lagrange multipliers
        
        # Initialize drones, controllers, and gates
        self.drones = []
        self.controllers = []

        for i in range(num_robots):
            drone = DroneDynamics(dt=dynamics_dt)
            low_level_controller = LowLevelController(drone)
            velocity_tracker = VelocityTracker(drone, low_level_controller)
            self.drones.append(drone)
            self.controllers.append(velocity_tracker)

        self.initialize_gates()
        
        # Initialize trajectory predictors for warm-starting
        state_dim = 6 + 3 + 6  # drone state (pos+vel) + target + 2 neighbors
        traj_dim = planning_horizon * 3  # 3D positions over planning horizon
        self.predictors = [TrajectoryPredictor(state_dim, traj_dim) for _ in range(num_robots)]
        
        # Initialize other variables
        self.neighbors = [[] for _ in range(num_robots)]  # List of neighbors for each robot
        self.current_gate_indices = np.zeros(num_robots, dtype=int)  # Current gate index for each robot
        
        # Track-specific data
        self.track_segments = []
        self.initialize_track_segments()
        
        # Reset to initial configuration
        self.reset()
    
    def initialize_gates(self):
        """Initialize gates based on the custom track layout"""
        self.n_gates = 6
        self.gate_width = 2.0
        self.gate_height = 2.0
        
        # Gate positions for the custom track
        gate_positions = [
            [5.0, -5, 1],      # Gate 1 (right)
            [10.0, 0, 2],      # Gate 2 (back right)
            [5.0, 5, 2],       # Gate 3 (back)
            [-5.0, -5, 0.5],   # Gate 4 (middle)
            [-10.0, 0, 0.5],   # Gate 5 (front left)
            [-5.0, 5, 1.5]     # Gate 6 (left)
        ]
        
        # Gate orientations (yaw angles in radians)
        gate_yaws = [
            np.pi/2,        # Gate 1 facing +y
            0,              # Gate 2 facing +x
            -np.pi/2,       # Gate 3 facing -y
            -np.pi/2,       # Gate 4 facing -y
            0,              # Gate 5 facing +x
            np.pi/2         # Gate 6 facing +y
        ]
        
        self.gates = []
        for idx in range(self.n_gates):
            center = np.array(gate_positions[idx], dtype=np.float32)
            yaw = gate_yaws[idx]
            
            gate = {
                "center": center,
                "yaw": yaw + np.pi/2,  # Maintain the original adjustment
                "width": self.gate_width,
                "height": self.gate_height
            }
            self.gates.append(gate)
        
        self.gate_positions = np.array([gate["center"] for gate in self.gates])
        self.gate_yaws = np.array([gate["yaw"] for gate in self.gates])
    
    def initialize_track_segments(self):
        """Initialize track segments data for path following"""
        self.track_segments = []
        
        # Create a segment for each pair of consecutive gates
        for i in range(self.n_gates):
            next_i = (i + 1) % self.n_gates
            
            # Calculate segment normal (perpendicular to segment in XY plane)
            segment_vector = self.gate_positions[next_i] - self.gate_positions[i]
            segment_length = np.linalg.norm(segment_vector[:2])  # XY length
            
            if segment_length > 0:
                segment_direction = segment_vector / np.linalg.norm(segment_vector)
                segment_normal = np.array([-segment_direction[1], segment_direction[0], 0])
            else:
                segment_normal = np.array([0, 1, 0])  # Default if gates are vertically aligned
            
            # Calculate start and end normals based on gate orientations
            start_normal = np.array([np.cos(self.gate_yaws[i] - np.pi/2), 
                                    np.sin(self.gate_yaws[i] - np.pi/2), 0])
            end_normal = np.array([np.cos(self.gate_yaws[next_i] - np.pi/2), 
                                   np.sin(self.gate_yaws[next_i] - np.pi/2), 0])
            
            # Store segment data
            segment = {
                "start_gate": i,
                "end_gate": next_i,
                "start_pos": self.gate_positions[i],
                "end_pos": self.gate_positions[next_i],
                "start_yaw": self.gate_yaws[i],
                "end_yaw": self.gate_yaws[next_i],
                "segment_vector": segment_vector,
                "segment_direction": segment_direction if segment_length > 0 else np.zeros(3),
                "segment_normal": segment_normal,
                "start_normal": start_normal,
                "end_normal": end_normal,
                "segment_length": np.linalg.norm(segment_vector),
                "height_diff": self.gate_positions[next_i][2] - self.gate_positions[i][2],
                "is_sharp_turn": False  # Will be calculated below
            }
            
            # Detect sharp turns
            if i > 0:  # Only for gates after the first one
                prev_segment = self.track_segments[i-1] if i > 0 else None
                if prev_segment:
                    # Calculate dot product between segment directions to detect sharp turns
                    if "segment_direction" in prev_segment and np.linalg.norm(prev_segment["segment_direction"]) > 0:
                        dot_product = np.dot(prev_segment["segment_direction"][:2], 
                                           segment["segment_direction"][:2])
                        # If dot product is negative or small, it's a sharp turn
                        segment["is_sharp_turn"] = dot_product < 0.5
            
            self.track_segments.append(segment)
        
        # Second pass to mark last-to-first segment
        last_segment = self.track_segments[-1]
        first_segment = self.track_segments[0]
        
        if np.linalg.norm(last_segment["segment_direction"]) > 0 and np.linalg.norm(first_segment["segment_direction"]) > 0:
            dot_product = np.dot(last_segment["segment_direction"][:2], first_segment["segment_direction"][:2])
            last_segment["is_sharp_turn"] = dot_product < 0.5
    
    def reset(self):
        """Reset drone positions with better initial placement for the custom track"""
        # Safety parameters
        safety_distance_gates = 0.5  # Safe distance from gates
        safety_distance_drones = 0.8  # Minimum distance between drones
        max_attempts = 50     # Maximum attempts to find safe positions
        
        # Preferentially initialize drones near the first gate
        first_gate_pos = self.gate_positions[0]  # Gate 1 
        first_gate_yaw = self.gate_yaws[0]
        
        # Calculate gate normal (direction the gate is facing)
        gate_normal = np.array([np.cos(first_gate_yaw - np.pi/2), np.sin(first_gate_yaw - np.pi/2), 0])
        
        # Position starting area a bit before the first gate
        start_center = first_gate_pos - gate_normal * 2.5  # 2.5m before first gate
        
        # Sample positions for all drones around this starting area
        positions = []
        
        for drone_idx in range(self.num_robots):
            attempts = 0
            position_found = False
            
            # Try to find a safe position for this drone
            while attempts < max_attempts and not position_found:
                attempts += 1
                
                # Sample random position around start center
                offset = np.random.uniform(-1.5, 1.5, size=3)
                offset[2] = abs(offset[2]) * 0.3  # Smaller vertical variation
                candidate_position = start_center + offset
                candidate_position[2] = max(0.5, candidate_position[2])  # Ensure minimum height
                candidate_position = candidate_position.astype(np.float32)
                
                # Verify it's safe from gates
                safe_from_gates = True
                for gate_pos in self.gate_positions:
                    gate_pos_array = np.array(gate_pos)
                    distance = np.linalg.norm(candidate_position - gate_pos_array)
                    if distance < safety_distance_gates:
                        safe_from_gates = False
                        break
                
                if not safe_from_gates:
                    continue  # Try another position
                
                # Verify it's safe from other drones
                safe_from_drones = True
                for existing_pos in positions:
                    distance = np.linalg.norm(candidate_position - existing_pos)
                    if distance < safety_distance_drones:
                        safe_from_drones = False
                        break
                
                if not safe_from_drones:
                    continue  # Try another position
                
                # If we get here, the position is safe
                position_found = True
                positions.append(candidate_position)
            
            # If we couldn't find a safe position after max attempts, use a backup method
            if not position_found:
                # Fallback: place drone at a slightly random position near start center
                backup_position = start_center + np.random.uniform(-0.5, 0.5, size=3)
                backup_position[2] = max(0.5, backup_position[2])  # Ensure minimum height
                positions.append(backup_position.astype(np.float32))
                print(f"Warning: Using backup position for drone {drone_idx}")
        
        # Reset drone states and set current gate indices
        for i in range(self.num_robots):
            # Reset drone state
            self.drones[i].reset(positions[i])
            
            # Set current gate index
            self.current_gate_indices[i] = 0
        
        # Update neighbors
        self.find_neighbors()
        
        # Return positions for possible visualization or logging
        return positions
    
    def find_neighbors(self):
        """Find neighbors for each robot based on proximity"""
        positions = self.get_drone_positions()
        
        # Clear current neighbors
        self.neighbors = [[] for _ in range(self.num_robots)]
        
        # Find neighbors within a certain distance
        neighbor_distance = 4.0  # Consider robots within 4 meters as potential neighbors
        
        for i in range(self.num_robots):
            for j in range(self.num_robots):
                if i != j:  # Don't include self
                    dist = np.linalg.norm(positions[i] - positions[j])
                    if dist < neighbor_distance:
                        self.neighbors[i].append(j)
    
    def get_drone_positions(self):
        """Get current positions of all drones"""
        positions = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            positions[i] = self.drones[i].state[0:3]
        return positions
    
    def get_drone_velocities(self):
        """Get current velocities of all drones"""
        velocities = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            velocities[i] = self.drones[i].state[3:6]
        return velocities
    
    def compute_track_guidance(self, robot_idx, position):
        """Compute a guidance vector to keep drones on the track"""
        gate_idx = self.current_gate_indices[robot_idx]
        next_gate_idx = (gate_idx + 1) % self.n_gates
        
        # Get current segment
        segment = None
        for seg in self.track_segments:
            if seg["start_gate"] == gate_idx and seg["end_gate"] == next_gate_idx:
                segment = seg
                break
        
        if segment is None:
            # Fallback if segment not found
            current_gate = self.gate_positions[gate_idx]
            to_current = current_gate - position
            dist_to_current = np.linalg.norm(to_current)
            return to_current / max(0.1, dist_to_current)
        
        # Get segment data
        start_pos = segment["start_pos"]
        end_pos = segment["end_pos"]
        segment_direction = segment["segment_direction"]
        segment_length = segment["segment_length"]
        
        # Vector to current gate
        to_current = start_pos - position
        dist_to_current = np.linalg.norm(to_current)
        
        # Vector to next gate
        to_next = end_pos - position
        dist_to_next = np.linalg.norm(to_next)
        
        # Project drone position onto track line
        if segment_length > 0:
            t = np.clip(np.dot(position - start_pos, segment_direction) / segment_length, 0, 1)
            closest_point = start_pos + t * segment_length * segment_direction
        else:
            t = 0
            closest_point = start_pos
        
        # Vector from drone to closest point on track
        to_track = closest_point - position
        dist_to_track = np.linalg.norm(to_track)
        
        # Blend guidance between:
        # 1. Direction to current gate (when far from it)
        # 2. Direction along track (when between gates)
        # 3. Direction to next gate (when close to current gate)
        
        # Weighting factors
        current_weight = max(0, 1.0 - t * 2)  # Decreases as we progress along track
        track_weight = 1.0 - current_weight - max(0, (t - 0.5) * 2)  # Highest in middle
        next_weight = max(0, (t - 0.5) * 2)  # Increases as we approach next gate
        
        # If we're too far from track, increase the weight to get back on track
        if dist_to_track > 2.0:
            track_correction = to_track / dist_to_track
            track_weight = min(1.0, track_weight + 0.3 + dist_to_track * 0.1)
            current_weight *= (1 - track_weight)
            next_weight *= (1 - track_weight)
        else:
            track_correction = to_track / max(0.1, dist_to_track)
        
        # Special handling for sharp turns
        if segment["is_sharp_turn"]:
            # Increase influence of next gate for sharper guidance around turns
            next_weight = max(next_weight, t * 0.5)
            
            # Add outward component for smoother turning
            if t > 0.3 and t < 0.8:
                # Add an outward bias to widen the turn
                outward_bias = segment["segment_normal"] * 0.5
                track_correction = (track_correction + outward_bias)
                track_correction = track_correction / np.linalg.norm(track_correction)
        
        # Height guidance - ensure we're approaching the correct height
        height_diff = end_pos[2] - position[2]
        height_guidance = np.array([0, 0, np.sign(height_diff) * min(0.3, abs(height_diff) * 0.2)])
        
        # Combine the vectors with weights
        guidance = np.zeros(3)
        if dist_to_current > 0:
            guidance += current_weight * (to_current / dist_to_current)
        if dist_to_track > 0:
            guidance += track_weight * track_correction
        if dist_to_next > 0:
            guidance += next_weight * (to_next / dist_to_next)
        
        # Add height guidance
        guidance += height_guidance * (1.0 - current_weight)  # More height guidance as we progress
        
        # Normalize the guidance vector
        guidance_norm = np.linalg.norm(guidance)
        if guidance_norm > 0:
            guidance = guidance / guidance_norm
        
        return guidance
    
    def create_gate_reference_frame(self, gate_position, gate_yaw):
        """Create a reference frame for measuring progress through a gate with caching"""
        # Check if we have a cache for reference frames
        if not hasattr(self, 'gate_frame_cache'):
            self.gate_frame_cache = {}
        
        # Check if this frame is already in the cache
        cache_key = (tuple(gate_position), gate_yaw)
        if cache_key in self.gate_frame_cache:
            return self.gate_frame_cache[cache_key]
        
        # Create the reference frame
        # Adjust for custom track's orientation convention
        forward = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        vertical = np.array([0, 0, 1])
        
        frame = {
            'origin': gate_position, 
            'forward': forward, 
            'lateral': lateral, 
            'vertical': vertical
        }
        
        # Store in cache
        self.gate_frame_cache[cache_key] = frame
        
        return frame
    
    def compute_progress_to_gate(self, position, gate_position, gate_yaw):
        """Enhanced progress measurement with stronger gradients for custom track"""
        # Vector from position to gate
        to_gate = gate_position - position
        dist_to_gate = np.linalg.norm(to_gate)
        
        # Gate direction (normal to gate plane)
        # Adjust for the custom track orientation
        gate_dir = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Project position onto gate normal direction
        projection = np.dot(to_gate, gate_dir)
        
        # Calculate lateral and vertical deviation
        gate_lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        lateral_dev = np.abs(np.dot(to_gate, gate_lateral))
        
        # For custom track, use adaptive vertical deviation metrics based on gate height
        target_height = gate_position[2]
        height_diff = abs(position[2] - target_height)
        vertical_dev = height_diff
        
        # Custom track: stronger incentive to match gate height
        height_penalty = 1.0 * (height_diff**2) / (self.gate_height**2)
        
        # Stronger incentive to fly through the center of the gate
        deviation_penalty = 0.5 * (lateral_dev**2) / (self.track_width**2) + height_penalty
        
        # Much stronger reward for being in front of the gate
        front_reward = 3.0 * max(0, -projection)
        
        # Combined progress metric - more extreme values for clearer optimization
        progress = -dist_to_gate + front_reward - deviation_penalty
        
        return progress

    
    def initialize_straight_trajectory(self, robot_idx):
        """Create a trajectory that aims to pass through the gates with custom track adjustments"""
        trajectory = np.zeros((self.N, 3))
        
        # Start from current position
        current_pos = self.drones[robot_idx].state[0:3]
        current_vel = self.drones[robot_idx].state[3:6]
        trajectory[0] = current_pos
        
        # Get current gate
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Get next gate for better planning
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_pos = self.gate_positions[next_gate_idx]
        next_gate_yaw = self.gate_yaws[next_gate_idx]
        
        # Special handling for the G1 to G2 transition
        if gate_idx == 1:  # From G1 to G2 - the problematic turn
            # Create a turning waypoint to force the correct turn
            turning_point = (gate_pos + next_gate_pos) / 2.0  # Midpoint
            turning_point[0] -= 2.0  # Shift left (X axis)
            turning_point[2] += 0.5  # Slight upward shift
            
            # Create a three-phase trajectory: gate -> turning point -> next gate
            midpoint = self.N // 3
            late_point = self.N * 2 // 3
            
            # Phase 1: Current position to current gate
            for t in range(1, midpoint):
                alpha = t / midpoint
                trajectory[t] = (1 - alpha) * current_pos + alpha * gate_pos
            
            # Phase 2: Current gate to turning point
            for t in range(midpoint, late_point):
                alpha = (t - midpoint) / (late_point - midpoint)
                trajectory[t] = (1 - alpha) * gate_pos + alpha * turning_point
            
            # Phase 3: Turning point to next gate
            for t in range(late_point, self.N):
                alpha = (t - late_point) / (self.N - late_point)
                trajectory[t] = (1 - alpha) * turning_point + alpha * next_gate_pos
        
        else:
            # Normal initialization for other gates
            midpoint = self.N // 2
            
            # First half: aim directly at current gate
            for t in range(1, midpoint + 1):
                alpha = t / midpoint
                trajectory[t] = (1 - alpha) * current_pos + alpha * gate_pos
            
            # Second half: aim directly at next gate
            for t in range(midpoint + 1, self.N):
                alpha = (t - midpoint) / (self.N - midpoint)
                trajectory[t] = (1 - alpha) * gate_pos + alpha * next_gate_pos
        
        # Apply custom track adjustments for specific tricky segments
        trajectory = self.adjust_for_track_geometry(robot_idx, trajectory)
        
        return trajectory

    def adjust_for_track_geometry(self, robot_idx, trajectory):
        """Apply specific adjustments for the custom track geometry"""
        gate_idx = self.current_gate_indices[robot_idx]
        
        # Get current gate information
        current_gate = self.gate_positions[gate_idx]
        current_yaw = self.gate_yaws[gate_idx]
        
        # Get next gate information
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate = self.gate_positions[next_gate_idx]
        next_yaw = self.gate_yaws[next_gate_idx]
        
        # Calculate track segment vector and properties
        segment_vector = next_gate - current_gate
        segment_length = np.linalg.norm(segment_vector)
        
        if segment_length > 0:
            segment_direction = segment_vector / segment_length
        else:
            segment_direction = np.array([0, 0, 1])  # Default vertical direction
        
        # CRITICAL FIX: Gate 1 -> Gate 2 transition (from back-right to back)
        # This is where drones are going off track in your images
        if gate_idx == 1:  # Gate 1 -> Gate 2
            # Apply a much more aggressive correction for this problematic segment
            # We need to force a hard left turn here
            
            # Calculate the midpoint of trajectory for splitting into phases
            mid_idx = trajectory.shape[0] // 3
            late_idx = trajectory.shape[0] * 2 // 3
            
            # Phase 1: Initial approach with more direct path to G2
            for t in range(1, mid_idx):
                progress = t / mid_idx
                # Create a direct path to G2
                direct_point = (1 - progress) * current_gate + progress * next_gate
                
                # Apply very strong blending toward direct path
                blend_factor = 0.8  # Heavy weight toward direct path
                original_point = trajectory[t].copy()
                trajectory[t] = (1 - blend_factor) * original_point + blend_factor * direct_point
            
            # Phase 2: Mid-trajectory with a forced turn path
            for t in range(mid_idx, late_idx):
                # Create a curved turning path
                progress = (t - mid_idx) / (late_idx - mid_idx)
                
                # Calculate an intermediate turning point 
                # This is a key point: Create a turning waypoint to force the correct turn
                # The turning point is above and to the left of a direct line
                turning_point = (current_gate + next_gate) / 2.0  # Midpoint
                turning_point[0] -= 2.0  # Shift left (X axis)
                turning_point[2] += 0.5  # Slight upward shift for smoother path
                
                # Blend between current gate, turning point, and next gate
                if progress < 0.5:
                    # First half: blend from current gate to turning point
                    sub_progress = progress * 2.0
                    point = (1 - sub_progress) * current_gate + sub_progress * turning_point
                else:
                    # Second half: blend from turning point to next gate
                    sub_progress = (progress - 0.5) * 2.0
                    point = (1 - sub_progress) * turning_point + sub_progress * next_gate
                
                # Apply very strong blending to force the turn
                blend_factor = 0.9  # Very heavy weight to force the turn
                original_point = trajectory[t].copy()
                trajectory[t] = (1 - blend_factor) * original_point + blend_factor * point
            
            # Phase 3: Final approach to Gate 2
            for t in range(late_idx, trajectory.shape[0]):
                progress = (t - late_idx) / (trajectory.shape[0] - late_idx)
                direct_point = (1 - progress) * turning_point + progress * next_gate
                
                # Very direct approach to next gate
                blend_factor = 0.9  # Very heavy weight to direct approach
                original_point = trajectory[t].copy()
                trajectory[t] = (1 - blend_factor) * original_point + blend_factor * direct_point
        
        # Other gate transitions can use the existing code with minor adjustments
        elif gate_idx == 0:  # Gate 0 -> Gate 1
            # Smooth the height transition
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                # Gradually adjust height after passing gate
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                height_adjustment = 0.3 * progress  # Gradually move higher
                trajectory[t, 2] += height_adjustment
                
                # Ensure we're heading in the right direction for the next gate
                if t > trajectory.shape[0] * 0.7:
                    # For last 30% of trajectory, blend toward next gate direction
                    late_progress = (t - 0.7 * trajectory.shape[0]) / (0.3 * trajectory.shape[0])
                    direction_point = trajectory[t] + segment_direction * late_progress
                    trajectory[t] = (1 - late_progress * 0.3) * trajectory[t] + (late_progress * 0.3) * direction_point
        
        # Gate 2 -> Gate 3: Back to middle (height drop and diagonal movement)
        elif gate_idx == 2:
            # More aggressive height adjustment starting earlier
            height_diff = next_gate[2] - current_gate[2]
            
            for t in range(1, trajectory.shape[0]):
                # Start height adjustment immediately
                progress = t / trajectory.shape[0]
                height_adjustment = height_diff * min(1.0, progress * 1.5)  # Faster descent
                
                # Directly modify trajectory to follow a more direct path
                direct_point = (1 - progress) * current_gate + progress * next_gate
                
                # Keep some of the original path but blend with direct path
                blend_factor = min(0.8, 0.4 + progress * 0.5)  # Increases with progress
                original_point = trajectory[t].copy()
                
                # Blend while preserving the adjusted height
                blended_point = (1 - blend_factor) * original_point + blend_factor * direct_point
                blended_point[2] = original_point[2] + height_adjustment
                
                trajectory[t] = blended_point
        
        # Gate 3 -> Gate 4: Middle to front-left (flat movement, turning left)
        elif gate_idx == 3:
            # Maintain height, smooth lateral transition
            for t in range(trajectory.shape[0]//2, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//2) / (trajectory.shape[0] - trajectory.shape[0]//2)
                
                # Ensure stable height and direct path
                trajectory[t, 2] = max(trajectory[t, 2], 0.5)  # Ensure minimal safe height
                
                # Create a direct path with slight adjustment for better approach to gate 5
                direct_point = (1 - progress) * current_gate + progress * next_gate
                
                # Blend with original trajectory
                blend_factor = min(0.7, 0.3 + progress * 0.5)
                original_point = trajectory[t].copy()
                blended_point = (1 - blend_factor) * original_point + blend_factor * direct_point
                
                # Preserve height
                blended_point[2] = trajectory[t, 2]
                trajectory[t] = blended_point
        
        # Gate 4 -> Gate 5: Front-left to left (slight increase in height, sharp right turn)
        elif gate_idx == 4:
            # This is another potentially difficult turn
            for t in range(trajectory.shape[0]//3, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//3) / (trajectory.shape[0] - trajectory.shape[0]//3)
                
                # Calculate a curved path point
                direct_point = (1 - progress) * current_gate + progress * next_gate
                
                # Add outward adjustment for this sharp turn
                curve_point = direct_point.copy()
                curve_point[0] += 2.0 * np.sin(progress * np.pi)  # Creates a curved path
                
                # Strong blending toward the curved path
                blend_factor = 0.7
                original_point = trajectory[t].copy()
                blended_point = (1 - blend_factor) * original_point + blend_factor * curve_point
                
                # Gradual height increase
                height_adjustment = (next_gate[2] - current_gate[2]) * progress
                blended_point[2] = original_point[2] + height_adjustment
                
                trajectory[t] = blended_point
        
        # Gate 5 -> Gate 0: Left to right (diagonal movement across center, slight height decrease)
        elif gate_idx == 5:
            # This completes the circuit - ensure we approach gate 0 correctly
            for t in range(trajectory.shape[0]//3, trajectory.shape[0]):
                progress = (t - trajectory.shape[0]//3) / (trajectory.shape[0] - trajectory.shape[0]//3)
                
                # Strong direct guidance to first gate
                direct_point = (1 - progress) * current_gate + progress * next_gate
                
                # Keep the trajectory closer to the direct line
                blend_factor = min(0.8, 0.5 + progress * 0.5)
                original_point = trajectory[t].copy()
                blended_point = (1 - blend_factor) * original_point + blend_factor * direct_point
                
                # Smooth height transition
                height_diff = next_gate[2] - current_gate[2]
                height_adjustment = height_diff * progress
                blended_point[2] = original_point[2] + height_adjustment
                
                trajectory[t] = blended_point
        
        return trajectory

    def optimize_trajectory(self, robot_idx, all_trajectories):
        """Trajectory optimization with enhanced track following"""
        # Initialize with current trajectory
        if all_trajectories is None:
            trajectory = self.initialize_straight_trajectory(robot_idx)
        else:
            trajectory = all_trajectories[robot_idx].copy()
        
        current_pos = self.drones[robot_idx].state[0:3]
        current_vel = self.drones[robot_idx].state[3:6]
        
        # Get current target gate
        gate_idx = self.current_gate_indices[robot_idx]
        gate_position = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Next gate info for better planning
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_position = self.gate_positions[next_gate_idx]
        next_gate_yaw = self.gate_yaws[next_gate_idx]
        
        # Create gate frame for progress measurement
        gate_frame = self.create_gate_reference_frame(gate_position, gate_yaw)
        
        # Optimization parameters - adaptive based on proximity to gate
        dist_to_gate = np.linalg.norm(current_pos - gate_position)
        close_to_gate = dist_to_gate < 3.0
        
        # For later gates, increase optimization effort
        later_gate = gate_idx >= 2
        
        # Adaptive optimization parameters
        n_steps = 8 if close_to_gate or later_gate else 5  # More steps when close to gate or for later gates
        learning_rate = 0.3 if close_to_gate else 0.2  # Higher learning rate when close
        
        # Start with current trajectory (skip first point which is fixed)
        opt_traj = trajectory[1:].copy()
        
        # Cache neighbor trajectories
        neighbor_trajectories = {}
        for j in self.neighbors[robot_idx]:
            if all_trajectories is None:
                neighbor_trajectories[j] = self.initialize_straight_trajectory(j)
            else:
                neighbor_trajectories[j] = all_trajectories[j]
        
        # Calculate gate target point - target beyond the gate
        # Adjust for the custom track orientation
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Simple gradient-based optimization
        for step in range(n_steps):
            # Calculate gradients for each point in trajectory
            gradients = np.zeros_like(opt_traj)
            
            # Progress gradient based on our progress function
            for t in range(opt_traj.shape[0]):
                # Compute progress at current position
                progress = self.compute_progress_to_gate(opt_traj[t], gate_position, gate_yaw)
                
                # Compute numerical gradient by testing small perturbations
                eps = 0.1
                gradient = np.zeros(3)
                
                for dim in range(3):
                    # Create perturbed position
                    perturbed_pos = opt_traj[t].copy()
                    perturbed_pos[dim] += eps
                    
                    # Compute progress at perturbed position
                    perturbed_progress = self.compute_progress_to_gate(perturbed_pos, gate_position, gate_yaw)
                    
                    # Compute gradient
                    gradient[dim] = (perturbed_progress - progress) / eps
                
                # Normalize gradient if it's large
                grad_norm = np.linalg.norm(gradient)
                if grad_norm > 1.0:
                    gradient = gradient / grad_norm
                
                # Apply higher weights for later timesteps
                time_weight = 1.0 + t/opt_traj.shape[0] * 0.5
                gradients[t] += gradient * time_weight
                
                # Add track guidance gradient for keeping drones on track
                # This becomes more important for later timesteps and later gates
                track_influence = min(1.0, 0.3 + 0.7 * t / opt_traj.shape[0])
                
                # Increase track influence for later gates
                if gate_idx >= 2:
                    track_influence *= 1.5
                
                # Compute and apply track guidance
                track_guidance = self.compute_track_guidance(robot_idx, opt_traj[t])
                gradients[t] += track_guidance * track_influence
                
                # For later timesteps, add influence of next gate if we're likely to pass this one
                if t > opt_traj.shape[0] * 0.7 and dist_to_gate < 3.0:
                    # Add a small gradient component toward the next gate's height
                    height_gradient = np.zeros(3)
                    height_gradient[2] = np.sign(next_gate_position[2] - opt_traj[t, 2]) * 0.1
                    gradients[t] += height_gradient * (t / opt_traj.shape[0])
            
            # Smoothness gradient
            if opt_traj.shape[0] > 1:
                # Penalize large accelerations
                for t in range(1, opt_traj.shape[0]-1):
                    accel = opt_traj[t+1] - 2*opt_traj[t] + opt_traj[t-1]
                    gradients[t] -= 0.15 * accel  # Stronger smoothness penalty
                
                # Match initial velocity with stronger weighting
                init_vel = (opt_traj[0] - current_pos) / self.dt
                vel_diff = init_vel - current_vel
                gradients[0] -= 0.3 * vel_diff  # Increased weight for initial velocity matching
            
            # Enhanced collision avoidance - higher weighting and more proactive
            for j in self.neighbors[robot_idx]:
                other_traj = neighbor_trajectories[j]
                
                # Check if this is a problematic pair with collision history
                collision_history_factor = 1.0
                if hasattr(self, '_collision_counters'):
                    key = (min(robot_idx, j), max(robot_idx, j))
                    collision_count = self._collision_counters.get(key, 0)
                    # Increase avoidance strength for drones that collide frequently
                    collision_history_factor = 1.0 + min(2.0, collision_count * 0.5)
                
                # Also check if this pair is tagged for evasive action
                evasive_pair = False
                if hasattr(self, '_evasive_pairs'):
                    key = (min(robot_idx, j), max(robot_idx, j))
                    evasive_pair = key in self._evasive_pairs
                    if evasive_pair:
                        collision_history_factor *= 1.5  # Even stronger for evasive pairs
                
                for t in range(opt_traj.shape[0]):
                    t_idx = min(t + 1, other_traj.shape[0] - 1)  # Ensure valid index
                    # Vector from other to self
                    diff = opt_traj[t] - other_traj[t_idx]
                    dist = np.linalg.norm(diff)
                    
                    # More proactive collision avoidance - increased detection radius
                    avoidance_radius = self.min_distance * 2.0  # Increased radius
                    
                    # Apply stronger avoidance when close
                    if dist < avoidance_radius and dist > 0:
                        # Repulsive direction (away from other drone)
                        repulsion_dir = diff / dist
                        
                        # Stronger, progressive strength that increases more rapidly as drones get closer
                        # Using a quadratic falloff for more aggressive avoidance
                        proximity_factor = (avoidance_radius - dist) / avoidance_radius
                        strength = 1.0 + 3.0 * proximity_factor * proximity_factor
                        
                        # Include sensitivity term from SE-IBR
                        sensitivity = self.compute_sensitivity_minimal(robot_idx, j, gate_idx)
                        
                        # Apply stronger repulsive gradient with sensitivity adjustment
                        collision_gradient = repulsion_dir * strength * sensitivity * collision_history_factor
                        
                        # Apply gradient with higher weight for collision avoidance
                        # Weight decays with time to allow eventual goal-reaching
                        time_decay = 1.0 if t < 3 else (1.0 - 0.1 * (t-3))
                        collision_weight = 1.5 * time_decay  # Base weight with time decay
                        
                        # Even stronger for evasive pairs
                        if evasive_pair:
                            collision_weight *= 1.5
                            
                        gradients[t] += collision_weight * collision_gradient
                        
                        # Also add vertical separation for nearby drones - prefer going over
                        if dist < self.min_distance * 1.2:
                            vertical_bias = np.array([0, 0, 0.3 * collision_history_factor])
                            gradients[t] += vertical_bias
            
            # Apply gradients with learning rate
            opt_traj += learning_rate * gradients
            
            # Apply constraints directly
            # 1. Velocity constraints
            if opt_traj.shape[0] > 0:
                for t in range(opt_traj.shape[0]):
                    # Calculate velocity
                    if t == 0:
                        vel = (opt_traj[t] - current_pos) / self.dt
                    else:
                        vel = (opt_traj[t] - opt_traj[t-1]) / self.dt
                    
                    # Clip velocity with a safety margin
                    for i in range(3):
                        max_vel = self.max_velocity[i] * 0.95  # 5% safety margin
                        if abs(vel[i]) > max_vel:
                            vel[i] = np.sign(vel[i]) * max_vel
                    
                    # Recompute position from velocity
                    if t == 0:
                        opt_traj[t] = current_pos + vel * self.dt
                    else:
                        opt_traj[t] = opt_traj[t-1] + vel * self.dt
        
        # Combine with fixed initial position
        optimized_trajectory = np.vstack([current_pos, opt_traj])
        
        # Apply height constraints to keep drones at reasonable heights
        min_height = 0.5  # Minimum allowed height
        max_height = 5.0  # Maximum allowed height
        
        for t in range(optimized_trajectory.shape[0]):
            optimized_trajectory[t, 2] = np.clip(optimized_trajectory[t, 2], min_height, max_height)
        
        # Final collision avoidance check - ensure minimum separation at key points
        for j in self.neighbors[robot_idx]:
            other_traj = neighbor_trajectories[j]
            
            # Focus on critical points - start, middle, end
            critical_points = [0, optimized_trajectory.shape[0]//2, optimized_trajectory.shape[0]-1]
            
            for t in critical_points:
                if t < other_traj.shape[0]:
                    diff = optimized_trajectory[t] - other_traj[t]
                    dist = np.linalg.norm(diff)
                    
                    # If too close, apply direct separation
                    if dist < self.min_distance * 0.9 and dist > 0:
                        # Calculate separation direction
                        sep_dir = diff / dist
                        
                        # Calculate how much to move
                        move_dist = (self.min_distance - dist) * 0.6  # Move 60% of the way
                        
                        # Apply move (only to non-initial points)
                        if t > 0:
                            optimized_trajectory[t] += sep_dir * move_dist
        
        # Apply custom track geometry adjustments
        optimized_trajectory = self.adjust_for_track_geometry(robot_idx, optimized_trajectory)
        
        return optimized_trajectory
    
    def compute_sensitivity_minimal(self, ego_idx, other_idx, ego_gate_idx):
        """Enhanced sensitivity computation with collision history awareness"""
        # Get current gate indices
        other_gate_idx = self.current_gate_indices[other_idx]
        
        # Check collision history to increase sensitivity for problematic pairs
        collision_history_factor = 1.0
        if hasattr(self, '_collision_counters'):
            key = (min(ego_idx, other_idx), max(ego_idx, other_idx))
            collision_count = self._collision_counters.get(key, 0)
            # Increase sensitivity for drones that collide frequently
            collision_history_factor = 1.0 + min(2.0, collision_count * 0.5)
        
        # Base sensitivity on relative gate progress with collision history
        if ego_gate_idx > other_gate_idx:
            return 1.0 * collision_history_factor  # Ego drone has priority (higher gate index)
        elif ego_gate_idx < other_gate_idx:
            return 0.2 * collision_history_factor  # Other drone has priority
        else:
            # Same gate, compare distances
            ego_dist = np.linalg.norm(self.drones[ego_idx].state[:3] - self.gate_positions[ego_gate_idx])
            other_dist = np.linalg.norm(self.drones[other_idx].state[:3] - self.gate_positions[other_gate_idx])
            
            if ego_dist < other_dist:
                return (0.8 * collision_history_factor)  # Ego drone is closer to gate
            else:
                return (0.4 * collision_history_factor)  # Other drone is closer to gate
    
    # def check_gate_passed(self, robot_idx, verbose=False):
    #     """Enhanced gate passing detection for custom track"""
    #     # Get current gate information
    #     gate_idx = self.current_gate_indices[robot_idx]
    #     gate_pos = self.gate_positions[gate_idx]
    #     gate_yaw = self.gate_yaws[gate_idx]
        
    #     # Get robot position and velocity
    #     robot_pos = self.drones[robot_idx].state[0:3]
    #     robot_vel = self.drones[robot_idx].state[3:6]
        
    #     # For custom track with height differences, adapt tolerance
    #     height_difference = abs(gate_pos[2] - robot_pos[2])
    #     adaptive_tolerance = self.gate_passing_tolerance * (1.0 + height_difference * 0.3)
        
    #     # Vector from drone to gate
    #     to_gate = gate_pos - robot_pos
    #     dist_to_gate = np.linalg.norm(to_gate)
        
    #     # Gate normal direction (pointing forward from gate)
    #     # Adjust for the custom track orientation
    #     gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
    #     # Calculate signed distance to gate plane (positive if behind gate, negative if in front)
    #     signed_dist = np.dot(to_gate, gate_normal)
        
    #     # Store previous positions and distances
    #     if not hasattr(self, '_prev_positions'):
    #         self._prev_positions = {}
    #         self._prev_signed_dists = {}
        
    #     if robot_idx not in self._prev_positions:
    #         self._prev_positions[robot_idx] = robot_pos
    #         self._prev_signed_dists[robot_idx] = signed_dist
    #         return False
        
    #     # Get previous values
    #     prev_pos = self._prev_positions[robot_idx]
    #     prev_signed_dist = self._prev_signed_dists[robot_idx]
        
    #     # Calculate lateral and vertical distance to center line
    #     gate_lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
    #     lateral_dist = np.abs(np.dot(to_gate, gate_lateral))
    #     vertical_dist = np.abs(to_gate[2])
        
    #     # More permissive gate passing for later gates
    #     if gate_idx >= 2:  # After second gate
    #         # Wider acceptance for lateral deviation
    #         is_close_to_center = lateral_dist < self.gate_width * 0.9 and vertical_dist < self.gate_height * 0.9
            
    #         # More permissive distance check
    #         adaptive_tolerance *= 1.5
            
    #         # More permissive direction check
    #         direction_check = True  # Accept any direction for later gates
    #     else:
    #         # Standard checks for early gates
    #         is_close_to_center = lateral_dist < 1.8 and vertical_dist < 1.8
            
    #         # Check direction (for early gates only)
    #         direction_check = True
    #         vel_magnitude = np.linalg.norm(robot_vel)
    #         if vel_magnitude > 0.1:
    #             vel_dir = robot_vel / vel_magnitude
    #             alignment = np.dot(vel_dir, gate_normal)
    #             direction_check = alignment > 0.0  # Most permissive for custom track
        
    #     # Check absolute distance
    #     close_enough = dist_to_gate < adaptive_tolerance
        
    #     # Check if we crossed the gate plane
    #     crossed_plane = prev_signed_dist > 0 and signed_dist <= 0
        
    #     # Direct distance override
    #     direct_override = close_enough and is_close_to_center and signed_dist < 0
        
    #     # Update previous values
    #     self._prev_positions[robot_idx] = robot_pos
    #     self._prev_signed_dists[robot_idx] = signed_dist
        
    #     # Gate is passed if EITHER:
    #     # - We detected a proper crossing (crossed plane + close to center + right direction)
    #     # - OR we're very close to the gate on the front side with correct positioning
    #     gate_passed = (crossed_plane and is_close_to_center and direction_check) or direct_override
        
    #     # Additional fallback for stuck drones
    #     if not gate_passed and gate_idx >= 2 and dist_to_gate < self.gate_width:
    #         # Very permissive check for later gates when drones are close but stuck
    #         gate_passed = close_enough and signed_dist < self.gate_width * 0.5
    #         if gate_passed and verbose:
    #             print(f"Using permissive gate passing for drone {robot_idx} at gate {gate_idx}")
        
    #     # Additional debug information
    #     if (gate_passed or close_enough) and verbose:
    #         print(f"Gate check: Drone {robot_idx}, Dist={dist_to_gate:.2f}m, " 
    #             f"Signed={signed_dist:.2f}, Crossed={crossed_plane}, "
    #             f"Center={is_close_to_center}, Override={direct_override}")
        
    #     if gate_passed and verbose:
    #         print(f"GATE PASSED: Drone {robot_idx} passed gate {gate_idx}")
    #         new_gate_idx = (gate_idx + 1) % self.n_gates
    #         print(f"Setting new gate index to {new_gate_idx}")
        
    #     return gate_passed
    
    def check_gate_passed(self, robot_idx, verbose=False):
        """Enhanced gate passing detection for custom track"""
        # Get current gate information
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Get robot position and velocity
        robot_pos = self.drones[robot_idx].state[0:3]
        robot_vel = self.drones[robot_idx].state[3:6]
        
        # For custom track with height differences, adapt tolerance
        height_difference = abs(gate_pos[2] - robot_pos[2])
        adaptive_tolerance = self.gate_passing_tolerance * (1.0 + height_difference * 0.3)
        
        # Vector from drone to gate
        to_gate = gate_pos - robot_pos
        dist_to_gate = np.linalg.norm(to_gate)
        
        # Gate normal direction (pointing forward from gate)
        # Adjust for the custom track orientation
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Calculate signed distance to gate plane (positive if behind gate, negative if in front)
        signed_dist = np.dot(to_gate, gate_normal)
        
        # Store previous positions and distances
        if not hasattr(self, '_prev_positions'):
            self._prev_positions = {}
            self._prev_signed_dists = {}
        
        if robot_idx not in self._prev_positions:
            self._prev_positions[robot_idx] = robot_pos
            self._prev_signed_dists[robot_idx] = signed_dist
            return False
        
        # Get previous values
        prev_pos = self._prev_positions[robot_idx]
        prev_signed_dist = self._prev_signed_dists[robot_idx]
        
        # Calculate lateral and vertical distance to center line
        gate_lateral = np.array([-np.sin(gate_yaw - np.pi/2), np.cos(gate_yaw - np.pi/2), 0])
        lateral_dist = np.abs(np.dot(to_gate, gate_lateral))
        vertical_dist = np.abs(to_gate[2])
        
        # Special case for Gate 1 to Gate 2 transition - more permissive
        if gate_idx == 1:  # Gate 1
            # Wider acceptance for lateral deviation
            is_close_to_center = lateral_dist < self.gate_width * 1.2 and vertical_dist < self.gate_height * 1.2
            
            # More permissive distance check
            adaptive_tolerance *= 1.8
            
            # No direction check for this problematic gate
            direction_check = True
            
            # Check absolute distance
            close_enough = dist_to_gate < adaptive_tolerance
            
            # Check if we crossed the gate plane
            crossed_plane = prev_signed_dist > 0 and signed_dist <= 0
            
            # Direct distance override - even more permissive
            direct_override = close_enough and is_close_to_center and signed_dist < self.gate_width * 0.5
            
            # Update previous values
            self._prev_positions[robot_idx] = robot_pos
            self._prev_signed_dists[robot_idx] = signed_dist
            
            # Gate is passed if EITHER:
            # - We detected a proper crossing
            # - OR we're very close to the gate on the front side
            # - OR we've come very close to the gate in any manner (super permissive)
            gate_passed = (crossed_plane and is_close_to_center) or direct_override or (dist_to_gate < self.gate_width * 0.8)
            
        else:
            # More permissive gate passing for later gates
            if gate_idx >= 2:  # After second gate
                # Wider acceptance for lateral deviation
                is_close_to_center = lateral_dist < self.gate_width * 0.9 and vertical_dist < self.gate_height * 0.9
                
                # More permissive distance check
                adaptive_tolerance *= 1.5
                
                # More permissive direction check
                direction_check = True  # Accept any direction for later gates
            else:
                # Standard checks for early gates
                is_close_to_center = lateral_dist < 1.8 and vertical_dist < 1.8
                
                # Check direction (for early gates only)
                direction_check = True
                vel_magnitude = np.linalg.norm(robot_vel)
                if vel_magnitude > 0.1:
                    vel_dir = robot_vel / vel_magnitude
                    alignment = np.dot(vel_dir, gate_normal)
                    direction_check = alignment > 0.0  # Most permissive for custom track
            
            # Check absolute distance
            close_enough = dist_to_gate < adaptive_tolerance
            
            # Check if we crossed the gate plane
            crossed_plane = prev_signed_dist > 0 and signed_dist <= 0
            
            # Direct distance override
            direct_override = close_enough and is_close_to_center and signed_dist < 0
            
            # Update previous values
            self._prev_positions[robot_idx] = robot_pos
            self._prev_signed_dists[robot_idx] = signed_dist
            
            # Gate is passed if EITHER:
            # - We detected a proper crossing (crossed plane + close to center + right direction)
            # - OR we're very close to the gate on the front side with correct positioning
            gate_passed = (crossed_plane and is_close_to_center and direction_check) or direct_override
            
            # Additional fallback for stuck drones
            if not gate_passed and gate_idx >= 2 and dist_to_gate < self.gate_width:
                # Very permissive check for later gates when drones are close but stuck
                gate_passed = close_enough and signed_dist < self.gate_width * 0.5
                if gate_passed and verbose:
                    print(f"Using permissive gate passing for drone {robot_idx} at gate {gate_idx}")
        
        # Additional debug information
        if (gate_passed or close_enough) and verbose:
            print(f"Gate check: Drone {robot_idx}, Dist={dist_to_gate:.2f}m, " 
                f"Signed={signed_dist:.2f}, Crossed={crossed_plane}, "
                f"Center={is_close_to_center}, Override={direct_override}")
        
        if gate_passed and verbose:
            print(f"GATE PASSED: Drone {robot_idx} passed gate {gate_idx}")
            new_gate_idx = (gate_idx + 1) % self.n_gates
            print(f"Setting new gate index to {new_gate_idx}")
        
        return gate_passed

    def validate_gate_targets(self):
        """Validate and correct gate targets for all drones"""
        positions = self.get_drone_positions()
        any_corrected = False
        
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            current_gate = self.gate_positions[gate_idx]
            
            # Distance to current target gate
            dist_to_gate = np.linalg.norm(positions[i] - current_gate)
            
            # Check if the drone is unreasonably far from its target gate
            if dist_to_gate > 25.0:  # Threshold for being off course
                # Find the closest gate to this drone
                closest_idx = 0
                min_dist = float('inf')
                
                for j in range(self.n_gates):
                    dist = np.linalg.norm(positions[i] - self.gate_positions[j])
                    if dist < min_dist:
                        min_dist = dist
                        closest_idx = j
                
                # If closest gate is different from current target and reasonably close
                if closest_idx != gate_idx and min_dist < 10.0:
                    # Reset gate target to closest gate
                    self.current_gate_indices[i] = closest_idx
                    any_corrected = True
        
        return any_corrected
    
    def check_and_recover_off_track_drones(self):
        """Detect and recover drones that have gone off track"""
        positions = self.get_drone_positions()
        any_recovered = False
        
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            current_gate = self.gate_positions[gate_idx]
            next_gate_idx = (gate_idx + 1) % self.n_gates
            next_gate = self.gate_positions[next_gate_idx]
            
            # Calculate track vector and project drone position
            track_vector = next_gate - current_gate
            track_length = np.linalg.norm(track_vector)
            
            if track_length > 0:
                track_direction = track_vector / track_length
                t = np.clip(np.dot(positions[i] - current_gate, track_direction) / track_length, 0, 1)
                closest_point = current_gate + t * track_vector
                
                # Distance from drone to track
                to_track = closest_point - positions[i]
                dist_to_track = np.linalg.norm(to_track)
                
                # If drone is too far from track, apply recovery
                if dist_to_track > self.max_track_deviation:
                    # Drone is off track - apply correction to velocity
                    recovery_direction = to_track / dist_to_track
                    
                    # Adjust velocity to point back to track
                    correction_strength = min(1.0, (dist_to_track - self.max_track_deviation) / 5.0)
                    current_vel = self.drones[i].state[3:6]
                    vel_magnitude = np.linalg.norm(current_vel)
                    
                    # Blend current velocity with recovery direction
                    if vel_magnitude > 0:
                        new_vel_dir = (1 - correction_strength) * (current_vel / vel_magnitude) + correction_strength * recovery_direction
                        new_vel_dir = new_vel_dir / np.linalg.norm(new_vel_dir)
                        new_vel = new_vel_dir * vel_magnitude
                        
                        # Apply corrected velocity
                        self.drones[i].state[3:6] = new_vel
                    
                    any_recovered = True
                    
                    # Extreme case - teleport if extremely far off track
                    if dist_to_track > self.max_track_deviation * 2:
                        # Teleport back to a point near the track
                        teleport_point = positions[i] + to_track * 0.7  # Move 70% of the way back
                        self.drones[i].state[0:3] = teleport_point
        
        return any_recovered
    
    def plan(self, verbose=False):
        """Optimized planning with performance improvements and custom track handling"""
        start_time = time.time()
        
        # Update neighbors (less frequently for performance)
        if not hasattr(self, '_neighbor_update_counter'):
            self._neighbor_update_counter = 0
        
        self._neighbor_update_counter += 1
        if self._neighbor_update_counter % 3 == 0:
            self.find_neighbors()
        
        # Initialize or reuse trajectories
        if not hasattr(self, 'previous_trajectories') or self.previous_trajectories is None:
            all_trajectories = np.zeros((self.num_robots, self.N, 3))
            for i in range(self.num_robots):
                all_trajectories[i] = self.initialize_straight_trajectory(i)
        else:
            all_trajectories = self.previous_trajectories.copy()
        
        # Adaptive iterations based on proximity to gates
        close_to_gates = False
        positions = self.get_drone_positions()
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            dist = np.linalg.norm(positions[i] - self.gate_positions[gate_idx])
            if dist < 2.0:
                close_to_gates = True
                break
        
        # Use more iterations for custom track to handle complex geometry
        iterations = 4 if close_to_gates else 3  # Increased iterations for custom track
        
        # Increase iterations for later gates which are more challenging
        for i in range(self.num_robots):
            if self.current_gate_indices[i] >= 2:  # After second gate
                iterations = max(iterations, 5)  # More iterations for later gates
                break
        
        # SE-IBR iterations with performance optimization
        for iteration in range(iterations):
            # Sequential optimization
            for i in range(self.num_robots):
                # Optimize for all drones in custom track (more complex navigation)
                if iteration == 0 or len(self.neighbors[i]) > 0 or close_to_gates:
                    all_trajectories[i] = self.optimize_trajectory(i, all_trajectories)
        
        # Apply custom track adjustments for all drones
        for i in range(self.num_robots):
            # Fix: Pass a single trajectory, not the whole array
            adjusted_trajectory = self.adjust_for_track_geometry(i, all_trajectories[i])
            all_trajectories[i] = adjusted_trajectory
        
        # Store for next warm-start
        self.previous_trajectories = all_trajectories.copy()
        
        end_time = time.time()
        if hasattr(self, '_step_counter'):
            self._step_counter += 1
        else:
            self._step_counter = 0
            
        if self._step_counter % 10 == 0 and verbose:
            print(f"Planning took {end_time - start_time:.4f} seconds")
        
        return {
            'trajectories': all_trajectories,
            'computation_time': end_time - start_time
        }
    
    def fix_orbit_trajectories(self, robot_idx, all_trajectories, verbose=False):
        """Enhanced trajectory fixing specifically for custom track gates"""
        # Extract just this robot's trajectory
        trajectory = all_trajectories.copy()
        
        gate_idx = self.current_gate_indices[robot_idx]
        gate_pos = self.gate_positions[gate_idx]
        gate_yaw = self.gate_yaws[gate_idx]
        
        # Gate normal direction (adjusted for custom track)
        gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
        
        # Next gate information for smoother transitions
        next_gate_idx = (gate_idx + 1) % self.n_gates
        next_gate_pos = self.gate_positions[next_gate_idx]
        
        # Target 4m beyond gate for more aggressive passing
        target_beyond_gate = gate_pos + gate_normal * 4.0
        
        # For custom track - adjust gate target height slightly for better passing
        gate_center_point = gate_pos.copy()
        
        # For upward gates, aim slightly below center
        if gate_pos[2] > trajectory[0, 2]:
            gate_center_point[2] -= 0.1
        # For downward gates, aim slightly above center
        elif gate_pos[2] < trajectory[0, 2]:
            gate_center_point[2] += 0.1
        
        # Create direct trajectory with special track handling
        for t in range(1, self.N):
            if t < self.N//2:
                # First half: aim precisely at adjusted gate center point
                alpha = t / (self.N//2)
                trajectory[t] = (1 - alpha) * trajectory[0] + alpha * gate_center_point
            else:
                # Second half: continue beyond gate with height transition to next gate
                alpha = (t - self.N//2) / (self.N - self.N//2)
                base_pos = (1 - alpha) * gate_center_point + alpha * target_beyond_gate
                
                # Start transitioning height toward next gate in latter part
                if alpha > 0.5:
                    height_blend = (alpha - 0.5) * 2  # 0 to 1 in latter half
                    height_adjustment = (next_gate_pos[2] - gate_pos[2]) * height_blend * 0.3
                    base_pos[2] += height_adjustment
                
                trajectory[t] = base_pos
        
        # Apply track-specific adjustments
        adjusted_trajectory = self.adjust_for_track_geometry(robot_idx, trajectory)
        
        # Return the adjusted trajectory
        return adjusted_trajectory
    
    def execute_step(self, trajectories, sim_time=None, verbose=False):
        """Execute step with fine-tuned parameters for custom track"""
        if verbose:
            print(f"Process {os.getpid()}: Starting execute_step at t={sim_time}")
            print(f"Initial drone positions: {[self.drones[i].state[0:3] for i in range(self.num_robots)]}")
        
        # Store initial positions for validation
        starting_positions = [self.drones[i].state[0:3].copy() for i in range(self.num_robots)]
        
        # Compute velocity commands
        velocity_commands = np.zeros((self.num_robots, 3))
        for i in range(self.num_robots):
            if trajectories.shape[1] > 1:
                velocity_commands[i] = (trajectories[i, 1] - trajectories[i, 0]) / self.dt
        
        # Apply more adaptive gate targeting
        positions = self.get_drone_positions()
        for i in range(self.num_robots):
            gate_idx = self.current_gate_indices[i]
            gate_pos = self.gate_positions[gate_idx]
            gate_yaw = self.gate_yaws[gate_idx]
            
            # Gate normal direction (adjusted for custom track)
            gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
            
            # Vector from drone to gate
            to_gate = gate_pos - positions[i]
            dist_to_gate = np.linalg.norm(to_gate)
            
            # Check if we're behind the gate (dot product with normal is positive)
            behind_gate = np.dot(to_gate, gate_normal) > 0
            
            # Calculate vertical component separately
            height_diff = gate_pos[2] - positions[i][2]
            
            # Fine-tuned boosting parameters for custom track:
            if behind_gate:
                # Adaptive boost parameters based on gate geometry
                boost_start_dist = 5.0  # Start boosting earlier
                boost_factor_base = 1.5  # Base boost factor
                
                if dist_to_gate < boost_start_dist:
                    vel_norm = np.linalg.norm(velocity_commands[i])
                    
                    if vel_norm > 0.1:
                        vel_dir = velocity_commands[i] / vel_norm
                        
                        # More permissive alignment requirement for custom track
                        required_alignment = min(0.4, 0.2 + 0.3 * (1.0 - dist_to_gate/boost_start_dist))
                        alignment = np.dot(vel_dir, gate_normal)
                        
                        if alignment > required_alignment:
                            # Adaptive boost factor - stronger when closer
                            progress = 1.0 - dist_to_gate/boost_start_dist
                            boost_factor = boost_factor_base + 0.8 * progress * (1.0 - progress * 0.5)
                            
                            # Adjust direction with stronger height component for custom track
                            direction_weight = min(0.5, 0.2 + 0.3 * progress)
                            
                            # For custom track: add height adjustment to direction
                            adjusted_dir = vel_dir.copy()
                            if abs(height_diff) > 0.3:
                                # Add stronger vertical component when height difference is significant
                                height_factor = min(0.3, abs(height_diff) * 0.1) * np.sign(height_diff)
                                adjusted_dir[2] += height_factor
                                # Renormalize
                                adjusted_dir = adjusted_dir / np.linalg.norm(adjusted_dir)
                            
                            # Blend with gate normal direction
                            adjusted_dir = (1.0 - direction_weight) * adjusted_dir + direction_weight * gate_normal
                            adjusted_dir = adjusted_dir / np.linalg.norm(adjusted_dir)
                            
                            # Apply boosted velocity
                            velocity_commands[i] = adjusted_dir * vel_norm * boost_factor
                            
                            # Ensure we don't exceed maximum velocity with a buffer
                            for dim in range(3):
                                max_vel = self.max_velocity[dim] * 0.95  # 5% buffer
                                velocity_commands[i][dim] = np.clip(velocity_commands[i][dim], 
                                                            -max_vel, max_vel)
        
        # Execute drone dynamics with enhanced collision checking
        all_states = [[self.drones[i].state.copy()] for i in range(self.num_robots)]
        
        # Track positions for collision checking
        timestep_positions = []
        timestep_velocities = []
        
        # Run dynamics steps
        for step in range(self.dynamics_steps):
            step_positions = []
            step_velocities = []
            
            # Update all drones
            for i in range(self.num_robots):
                # More cautious velocity control when drones are close
                adjusted_velocity = velocity_commands[i].copy()
                
                # Check proximity to other drones
                for j in range(self.num_robots):
                    if i != j:
                        dist = np.linalg.norm(self.drones[i].state[0:3] - self.drones[j].state[0:3])
                        # If drones are close, reduce velocity
                        if dist < self.min_distance * 1.2:
                            vel_norm = np.linalg.norm(adjusted_velocity)
                            if vel_norm > 0.5:
                                # Scale velocity down when close to avoid collisions
                                reduction = max(0.5, dist / (self.min_distance * 1.5)) 
                                adjusted_velocity = adjusted_velocity * reduction
                
                # Apply track guidance for off-track recovery
                position = self.drones[i].state[0:3]
                gate_idx = self.current_gate_indices[i]
                current_gate = self.gate_positions[gate_idx]
                next_gate_idx = (gate_idx + 1) % self.n_gates
                next_gate = self.gate_positions[next_gate_idx]
                
                # Calculate track vector
                track_vector = next_gate - current_gate
                track_length = np.linalg.norm(track_vector)
                
                if track_length > 0:
                    track_direction = track_vector / track_length
                    t = np.clip(np.dot(position - current_gate, track_direction) / track_length, 0, 1)
                    closest_point = current_gate + t * track_vector
                    
                    # Distance from drone to track
                    to_track = closest_point - position
                    dist_to_track = np.linalg.norm(to_track)
                    
                    # Apply mild correction if moving away from track
                    if dist_to_track > 4.0:
                        recovery_direction = to_track / dist_to_track
                        vel_norm = np.linalg.norm(adjusted_velocity)
                        
                        if vel_norm > 0.1:
                            current_dir = adjusted_velocity / vel_norm
                            # Check if moving away from track
                            if np.dot(current_dir, recovery_direction) < 0:
                                # Blend in recovery direction
                                correction_strength = min(0.3, dist_to_track / 20.0)
                                new_dir = current_dir + recovery_direction * correction_strength
                                new_dir = new_dir / np.linalg.norm(new_dir)
                                adjusted_velocity = new_dir * vel_norm
                
                # Apply the adjusted control
                control = self.controllers[i].compute_control(adjusted_velocity)
                self.drones[i].step(control)
                
                step_positions.append(self.drones[i].state[0:3])
                step_velocities.append(self.drones[i].state[3:6])
                
                # Check for gate passing with custom track detection
                if self.check_gate_passed(i, verbose=verbose):
                    if verbose:
                        print(f"Drone {i} passed gate {self.current_gate_indices[i]}")
                    
                    # Update to next gate
                    old_gate_idx = self.current_gate_indices[i]
                    self.current_gate_indices[i] = (self.current_gate_indices[i] + 1) % self.n_gates
                    new_gate_idx = self.current_gate_indices[i]
                    
                    # Immediately update velocity command to target new gate
                    if step < self.dynamics_steps - 1:
                        new_gate_pos = self.gate_positions[new_gate_idx]
                        new_gate_yaw = self.gate_yaws[new_gate_idx]
                        
                        # Adjusted for custom track orientation
                        gate_normal = np.array([np.cos(new_gate_yaw - np.pi/2), np.sin(new_gate_yaw - np.pi/2), 0])
                        
                        # Direct velocity toward new gate with height consideration
                        to_new_gate = new_gate_pos - self.drones[i].state[0:3]
                        dist_to_new_gate = np.linalg.norm(to_new_gate)
                        
                        if dist_to_new_gate > 0:
                            # Calculate direction with additional height emphasis for custom track
                            xy_dist = np.linalg.norm(to_new_gate[:2])
                            height_diff = to_new_gate[2]
                            
                            # Stronger vertical component for significant height differences
                            height_emphasis = 1.0
                            if abs(height_diff) > 1.0:
                                height_emphasis = 1.3  # Increase emphasis for large height changes
                            
                            # Create adjusted direction vector
                            new_dir = to_new_gate.copy()
                            new_dir[2] *= height_emphasis  # Emphasize height component
                            new_dir = new_dir / np.linalg.norm(new_dir)  # Normalize
                            
                            # Set a strong initial velocity toward the new gate
                            vel_magnitude = np.linalg.norm(velocity_commands[i])
                            velocity_commands[i] = new_dir * max(vel_magnitude, 2.0)
                            
                            # Ensure we don't exceed maximum velocity
                            for dim in range(3):
                                velocity_commands[i][dim] = np.clip(velocity_commands[i][dim], 
                                                            -self.max_velocity[dim],
                                                            self.max_velocity[dim])
                        
                        if verbose:
                            print(f"Updated velocity command for drone {i} to target gate {new_gate_idx}")
            
            timestep_positions.append(np.array(step_positions))
            timestep_velocities.append(np.array(step_velocities))
        
        # Check for collisions
        collisions = self.check_for_collisions()
        
        # Apply collision recovery
        if collisions:
            if verbose:
                print(f"WARNING: {len(collisions)} collisions detected.")
            # Record collisions and apply recovery
            for i, j, _ in collisions:
                # Record the collision in a counter
                if not hasattr(self, '_collision_counters'):
                    self._collision_counters = {}
                
                key1 = (min(i, j), max(i, j))
                self._collision_counters[key1] = self._collision_counters.get(key1, 0) + 1
                
                # Apply evasive action for repeated collisions
                if self._collision_counters.get(key1, 0) > 2:
                    if verbose:
                        print(f"Repeated collisions between drones {i} and {j} - applying evasive action")
                    # Apply stronger separation force in the next planning step
                    if not hasattr(self, '_evasive_pairs'):
                        self._evasive_pairs = set()
                    self._evasive_pairs.add(key1)
        
        # Save final states
        for i in range(self.num_robots):
            all_states[i].append(self.drones[i].state.copy())
            distance_moved = np.linalg.norm(self.drones[i].state[0:3] - starting_positions[i])
            max_possible_movement = self.max_velocity.max() * self.dt
        
            if distance_moved > max_possible_movement and verbose:
                print(f"ERROR: Drone {i} moved {distance_moved:.2f}m in one step")
                print(f"From: {starting_positions[i]}")
                print(f"To: {self.drones[i].state[0:3]}")
                print(f"Max theoretical distance: {max_possible_movement:.2f}m")
        
        return all_states, collisions
    
    def check_for_collisions(self):
        """Fast collision detection using vectorized operations"""
        positions = self.get_drone_positions()
        collisions = []
        
        # Only check drone pairs that were close in the previous step
        if not hasattr(self, '_close_pairs'):
            self._close_pairs = []
            # Check all pairs initially
            for i in range(self.num_robots):
                for j in range(i+1, self.num_robots):
                    self._close_pairs.append((i, j))
        
        new_close_pairs = []
        
        # Check only previously close pairs, and add new close pairs for next time
        for i, j in self._close_pairs:
            dist = np.linalg.norm(positions[i] - positions[j])
            
            if dist < self.min_distance * self.collision_coeff:
                collisions.append((i, j, dist))
            
            # Keep tracking if still relatively close
            if dist < self.min_distance * 3.0:
                new_close_pairs.append((i, j))
        
        # Also check neighbors to ensure we don't miss any collisions
        for i in range(self.num_robots):
            for j in self.neighbors[i]:
                if i < j and (i, j) not in new_close_pairs:
                    dist = np.linalg.norm(positions[i] - positions[j])
                    
                    if dist < self.min_distance * self.collision_coeff:
                        collisions.append((i, j, dist))
                    
                    if dist < self.min_distance * 3.0:
                        new_close_pairs.append((i, j))
        
        self._close_pairs = new_close_pairs
        return collisions
    
    def simulate(self, num_steps=50, visualize=True, save_path=None, create_video=False, verbose=False):
        """Simulation loop optimized for custom track"""
        # Initialize progress tracking
        self._simulation_step = 0
        self._gates_passed_total = 0
        
        # Add last positions tracking
        last_positions = self.get_drone_positions().copy()
        
        # Minimal data storage
        all_positions = [self.get_drone_positions()]
        all_velocities = [self.get_drone_velocities()]
        all_gate_indices = [self.current_gate_indices.copy()]
        all_collisions = []
        
        sim_time = 0.0
        total_collision_count = 0
        gates_passed = [0] * self.num_robots
        
        # Visualization interval
        vis_interval = 2
        
        # Print initial state for debugging
        if verbose:
            positions = self.get_drone_positions()
            print("\nINITIAL STATE:")
            for i in range(self.num_robots):
                gate_idx = self.current_gate_indices[i]
                dist = np.linalg.norm(positions[i] - self.gate_positions[gate_idx])
                print(f"Drone {i}: {dist:.2f}m from gate {gate_idx}")
            print()
        
        saved_frames = []
        for step in range(num_steps):
            self._simulation_step = step
            if verbose:
                print(f"\n--- Simulation step {step}/{num_steps} ---")
            
            # Record current gate indices
            current_gates = self.current_gate_indices.copy()
            
            # Validate gate targets periodically
            if step > 20 and step % 5 == 0:
                if self.validate_gate_targets() and verbose:
                    print("Corrected gate targets for off-track drones")
            
            # Plan
            plan_result = self.plan(verbose=(verbose and step % 10 == 0))
            
            # Execute
            _, step_collisions = self.execute_step(plan_result['trajectories'], sim_time, verbose=verbose and step % 10 == 0)
            
            # Check for and recover off-track drones
            if step > 10:
                if self.check_and_recover_off_track_drones() and verbose:
                    print("Applied recovery to off-track drones")
            
            # Track collisions
            if step_collisions:
                all_collisions.append((step, step_collisions))
                total_collision_count += len(step_collisions)
            
            # Check gates passed and update total
            step_gates_passed = 0
            for i in range(self.num_robots):
                if self.current_gate_indices[i] != current_gates[i]:
                    gates_passed[i] += 1
                    step_gates_passed += 1
                    if verbose:
                        old_gate = current_gates[i]
                        new_gate = self.current_gate_indices[i]
                        print(f"SUCCESS! Drone {i} passed gate {old_gate}, now targeting gate {new_gate}")
            
            self._gates_passed_total += step_gates_passed
            
            # Force progress if stuck
            if step > 30 and self._gates_passed_total == 0 and step % 10 == 0:
                # Force progress after many steps if no gates have been passed
                if verbose:
                    print("WARNING: No gates passed after many steps - forcing progress")
                self.force_gate_progress(verbose=verbose)
            
            # Check if any drone is far from both its current and next gate
            if step % 10 == 0 and step > 0:
                positions = self.get_drone_positions()
                
                for i in range(self.num_robots):
                    gate_idx = self.current_gate_indices[i]
                    current_gate = self.gate_positions[gate_idx]
                    next_gate_idx = (gate_idx + 1) % self.n_gates
                    next_gate = self.gate_positions[next_gate_idx]
                    
                    dist_to_current = np.linalg.norm(positions[i] - current_gate)
                    dist_to_next = np.linalg.norm(positions[i] - next_gate)
                    
                    # If drone is far from both gates, it's likely off track
                    if dist_to_current > 15.0 and dist_to_next > 15.0:
                        if verbose:
                            print(f"WARNING: Drone {i} appears to be off track (distances: {dist_to_current:.1f}, {dist_to_next:.1f})")
                        
                        # Extreme intervention - teleport back to a point between gates
                        if dist_to_current > 20.0 and dist_to_next > 20.0:
                            if verbose:
                                print(f"EMERGENCY: Teleporting drone {i} back to track")
                            
                            # Calculate a position on the track
                            track_point = (current_gate + next_gate) / 2
                            
                            # Move drone back to track with reduced velocity
                            self.drones[i].state[0:3] = track_point
                            self.drones[i].state[3:6] *= 0.1  # Reduce velocity
            
            # Special handling for custom track - detect specific stuck situations
            if step > 20 and step % 20 == 0:
                positions = self.get_drone_positions()
                # Check if drones haven't moved much in the last 20 steps
                stuck_drones = []
                for i in range(self.num_robots):
                    dist_moved = np.linalg.norm(positions[i] - last_positions[i])
                    if dist_moved < 0.5:  # Very little movement
                        stuck_drones.append(i)
                
                # If any drones are stuck, help them progress
                if stuck_drones and verbose:
                    print(f"WARNING: Drones {stuck_drones} appear to be stuck - applying intervention")
                    for drone_idx in stuck_drones:
                        # Aggressive intervention - move drone forward a bit
                        gate_idx = self.current_gate_indices[drone_idx]
                        gate_pos = self.gate_positions[gate_idx]
                        gate_yaw = self.gate_yaws[gate_idx]
                        
                        # Move in direction of gate
                        to_gate = gate_pos - positions[drone_idx]
                        dist = np.linalg.norm(to_gate)
                        if dist > 0:
                            move_dir = to_gate / dist
                            # Apply small boost to help unstick
                            boost_vel = move_dir * 0.5  # Gentle push
                            # Add to drone velocity
                            self.drones[drone_idx].state[3:6] += boost_vel
                
                # Update last positions
                last_positions = positions.copy()
            
            # Update time
            sim_time += self.dt
            
            # Record state (only every few steps to save memory)
            if step % 2 == 0 or step == num_steps-1:
                all_positions.append(self.get_drone_positions())
                all_velocities.append(self.get_drone_velocities())
                all_gate_indices.append(self.current_gate_indices.copy())
            
            # Visualize
            should_visualize = (
                visualize and 
                (step % vis_interval == 0 or step == num_steps-1 or step_collisions)
            )
            
            if should_visualize:
                fig = self.visualize(plan_result, sim_time=sim_time, collisions=step_collisions)
                if save_path:
                    frame_path = f"{save_path}/step_{step:03d}.png"
                    plt.savefig(frame_path, dpi=80)  # Lower dpi
                    saved_frames.append(frame_path)
                plt.close(fig)
            
            # Debug: Print current distances to gates every 10 steps
            if step % 10 == 0 and verbose:
                positions = self.get_drone_positions()
                print("\nCURRENT STATE:")
                for i in range(self.num_robots):
                    gate_idx = self.current_gate_indices[i]
                    dist = np.linalg.norm(positions[i] - self.gate_positions[gate_idx])
                    gate_yaw = self.gate_yaws[gate_idx]
                    gate_normal = np.array([np.cos(gate_yaw - np.pi/2), np.sin(gate_yaw - np.pi/2), 0])
                    to_gate = self.gate_positions[gate_idx] - positions[i]
                    signed_dist = np.dot(to_gate, gate_normal)
                    behind_gate = signed_dist > 0
                    status = "behind gate" if behind_gate else "in front of gate"
                    print(f"Drone {i}: {dist:.2f}m from gate {gate_idx} ({status}, signed dist: {signed_dist:.2f})")
                print()
        
        # Calculate velocity statistics
        velocities_array = np.array(all_velocities)
        velocity_stats = self.calculate_mean_velocities(velocities_array)
        
        if verbose:
            print(f"\nSimulation completed:")
            print(f" - Time: {sim_time:.2f} seconds")
            print(f" - Total collisions: {total_collision_count}")
            print(f" - Overall mean speed: {velocity_stats['overall_mean_speed']:.2f} m/s")
            print(f" - Drone mean speeds: {[f'{v:.2f}' for v in velocity_stats['drone_mean_speeds']]}")
            print(f" - Gates passed: {gates_passed}")

        if create_video and save_path and saved_frames:
            video_path = f"{save_path}/simulation_video.mp4"
            self.create_video_from_frames(save_path, video_path)
        
        # Return simulation results
        return {
            'positions': np.array(all_positions),
            'velocities': np.array(all_velocities),
            'sim_time': sim_time,
            'collisions': all_collisions,
            'total_collision_count': total_collision_count,
            'velocity_stats': velocity_stats,
            'gates_passed': gates_passed
        }
    
    def force_gate_progress(self, verbose=False):
        """Development function to force gate progress when drones are stuck"""
        # Force all drones to the next gate
        if verbose:
            print("FORCING ALL DRONES TO NEXT GATE")
        for i in range(self.num_robots):
            old_idx = self.current_gate_indices[i]
            self.current_gate_indices[i] = (self.current_gate_indices[i] + 1) % self.n_gates
            if verbose:
                print(f"  - Drone {i}: Gate {old_idx} -> Gate {self.current_gate_indices[i]}")
        return True
    
    def calculate_mean_velocities(self, velocities):
        """Calculate mean velocity for each drone and overall mean"""
        # velocities shape: [timesteps, num_drones, 3]
        
        # Mean velocity magnitude per drone over time
        drone_mean_speeds = []
        for i in range(self.num_robots):
            drone_velocities = velocities[:, i, :]
            # Calculate velocity magnitude at each timestep
            speeds = np.linalg.norm(drone_velocities, axis=1)
            mean_speed = np.mean(speeds)
            drone_mean_speeds.append(mean_speed)
        
        # Overall mean velocity magnitude
        overall_mean_speed = np.mean(drone_mean_speeds)
        
        return {
            'drone_mean_speeds': drone_mean_speeds,
            'overall_mean_speed': overall_mean_speed
        }
    
    def create_video_from_frames(self, frame_directory, output_file="simulation_video.mp4", fps=10):
        """Create a video from a directory of frame images"""
        try:
            import cv2
            import os
            import glob
            
            # Find all frame files
            frame_files = sorted(glob.glob(os.path.join(frame_directory, "step_*.png")))
            
            if not frame_files:
                print(f"No frame files found in {frame_directory}")
                return None
            
            print(f"Creating video from {len(frame_files)} frames...")
            
            # Read first frame to get dimensions
            first_frame = cv2.imread(frame_files[0])
            height, width, layers = first_frame.shape
            
            # Create video writer
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use mp4v codec
            video = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
            
            # Add each frame to the video
            for frame_file in frame_files:
                video.write(cv2.imread(frame_file))
            
            # Release the video writer
            video.release()
            
            print(f"Video created successfully: {output_file}")
            return output_file
        
        except ImportError:
            print("Error: OpenCV (cv2) is required to create videos.")
            print("Install it with: pip install opencv-python")
            return None
        except Exception as e:
            print(f"Error creating video: {str(e)}")
            return None
    
    def visualize(self, plan_result=None, show_history=False, sim_time=None, collisions=None):
        """Visualization optimized for custom track layout"""
        # Only visualize when needed - check if a figure will actually be displayed or saved
        if not plt.isinteractive() and not plt.get_fignums():
            # Low-res mode for faster plotting
            fig = plt.figure(figsize=(8, 6), dpi=80)
        else:
            fig = plt.figure(figsize=(12, 10))
        
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot world boundaries with fewer points
        theta = np.linspace(0, 2*np.pi, 30)  # Reduced from 100
        x = self.world_size/2 * np.cos(theta)
        y = self.world_size/2 * np.sin(theta)
        z = np.zeros_like(theta)
        ax.plot(x, y, z, color="k", alpha=0.3)
        
        # For custom track: plot track path to show the intended route
        for i in range(self.n_gates):
            next_i = (i + 1) % self.n_gates
            gate_pos = self.gate_positions[i]
            next_gate_pos = self.gate_positions[next_i]
            
            # Draw a line connecting gate centers with enhanced styling
            ax.plot([gate_pos[0], next_gate_pos[0]], 
                    [gate_pos[1], next_gate_pos[1]], 
                    [gate_pos[2], next_gate_pos[2]], 
                    'k--', alpha=0.3, linewidth=1.5)
            
            # Indicate track direction with small arrows
            mid_point = (gate_pos + next_gate_pos) / 2
            direction = next_gate_pos - gate_pos
            if np.linalg.norm(direction) > 0:
                direction = direction / np.linalg.norm(direction) * 0.5  # Scale arrow
                ax.quiver(mid_point[0], mid_point[1], mid_point[2], 
                         direction[0], direction[1], direction[2],
                         color='gray', alpha=0.5, arrow_length_ratio=0.3)
        
        # Plot gates with custom styling for the track
        for i, gate in enumerate(self.gates):
            center = gate["center"]
            yaw = gate["yaw"]
            width = gate["width"]
            height = gate["height"]
            
            # Create a circle with fewer points
            gate_radius = 1.0
            theta = np.linspace(0, 2*np.pi, 15)  # Reduced from 30
            circle_x = gate_radius * np.cos(theta)
            circle_y = gate_radius * np.sin(theta)
            circle_z = np.zeros_like(theta)
            
            # Rotate and translate - adjusted for custom track orientation
            rot_x = np.cos(yaw - np.pi/2) * circle_x - np.sin(yaw - np.pi/2) * circle_y
            rot_y = np.sin(yaw - np.pi/2) * circle_x + np.cos(yaw - np.pi/2) * circle_y
            
            gate_x = center[0] + rot_x
            gate_y = center[1] + rot_y
            gate_z = center[2] + circle_z
            
            # Color active gates differently
            if any(current_idx == i for current_idx in self.current_gate_indices):
                gate_color = 'red'
                alpha = 0.8
                linewidth = 2.5
            else:
                gate_color = 'orange'
                alpha = 0.5
                linewidth = 1.5
            
            # Plot gate circle
            ax.plot(gate_x, gate_y, gate_z, color=gate_color, linestyle='-', 
                    linewidth=linewidth, alpha=alpha)
            
            # Add gate normal vector (direction vector)
            gate_normal = np.array([np.cos(yaw - np.pi/2), np.sin(yaw - np.pi/2), 0])
            ax.quiver(center[0], center[1], center[2], 
                     gate_normal[0], gate_normal[1], gate_normal[2],
                     color=gate_color, alpha=alpha, length=1.0, arrow_length_ratio=0.2)
            
            # Label gates with different styles based on activity
            if any(current_idx == i for current_idx in self.current_gate_indices):
                ax.text(center[0], center[1], center[2]+1.0, f"G{i}", 
                        color='red', fontsize=10, ha='center', weight='bold')
            else:
                ax.text(center[0], center[1], center[2]+0.8, f"G{i}", 
                        color='darkred', fontsize=8, ha='center', alpha=0.7)
        
        # Identify drones involved in collisions
        collided_drones = set()
        if collisions:
            for collision in collisions:
                collided_drones.add(collision[0])
                collided_drones.add(collision[1])
        
        # Plot drone positions
        positions = self.get_drone_positions()
        drone_colors = ['b', 'g', 'c', 'm', 'y']
        
        for i in range(self.num_robots):
            color = drone_colors[i % len(drone_colors)]
            current_gate_idx = self.current_gate_indices[i]
            
            marker_size = 100
            if i in collided_drones:
                color = 'red'
                marker_size = 200
            
            ax.scatter(positions[i, 0], positions[i, 1], positions[i, 2], 
                    color=color, s=marker_size, label=f'D{i} (G{current_gate_idx})')
            
            # Only draw safety radius for collided drones
            if i in collided_drones:
                # Use fewer points for the sphere
                u, v = np.mgrid[0:2*np.pi:10j, 0:np.pi:5j]
                x = positions[i, 0] + self.min_distance/2 * np.cos(u) * np.sin(v)
                y = positions[i, 1] + self.min_distance/2 * np.sin(u) * np.sin(v)
                z = positions[i, 2] + self.min_distance/2 * np.cos(v)
                ax.plot_wireframe(x, y, z, color='red', alpha=0.2)
            
            # Draw line to current gate
            current_gate = self.gate_positions[current_gate_idx]
            ax.plot([positions[i, 0], current_gate[0]], 
                    [positions[i, 1], current_gate[1]], 
                    [positions[i, 2], current_gate[2]], 
                    color=color, linestyle=':', alpha=0.5)
            
            # Also show velocity vector
            velocities = self.get_drone_velocities()
            vel = velocities[i]
            vel_norm = np.linalg.norm(vel)
            if vel_norm > 0.1:
                # Scale velocity vector for visualization
                vel_scale = 0.5
                ax.quiver(positions[i, 0], positions[i, 1], positions[i, 2],
                         vel[0], vel[1], vel[2],
                         color=color, alpha=0.7, length=vel_scale, arrow_length_ratio=0.2)
        
        # Plot planned trajectories
        if plan_result is not None:
            trajectories = plan_result['trajectories']
            
            # Show trajectory history if requested
            if show_history and 'trajectory_history' in plan_result and len(plan_result['trajectory_history']) > 1:
                history = plan_result['trajectory_history']
                history_to_show = [history[0], history[-1]]
                for iter_idx, iter_trajectories in enumerate(history_to_show):
                    alpha = 0.3 if iter_idx == 0 else 0.7
                    for i in range(self.num_robots):
                        color = drone_colors[i % len(drone_colors)]
                        # Downsample points for faster plotting
                        idx = np.linspace(0, iter_trajectories.shape[1]-1, 5).astype(int)
                        ax.plot(iter_trajectories[i, idx, 0], iter_trajectories[i, idx, 1], 
                                iter_trajectories[i, idx, 2], color=color, alpha=alpha, linestyle='--')
            
            # Plot final trajectories (downsampled)
            for i in range(self.num_robots):
                color = drone_colors[i % len(drone_colors)]
                if i in collided_drones:
                    color = 'red'
                # Downsample points for faster plotting
                idx = np.linspace(0, trajectories.shape[1]-1, 5).astype(int)
                ax.plot(trajectories[i, idx, 0], trajectories[i, idx, 1], trajectories[i, idx, 2], 
                        color=color, linewidth=2)
        
        # Set plot properties
        title = 'Multi-Drone Custom Track Navigation'
        if sim_time is not None:
            title += f' t={sim_time:.1f}s'
        if collisions:
            title += f' - {len(collisions)} COLLISIONS'
                
        ax.set_title(title)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # Set bounds based on track size
        max_x = np.max(np.abs([gate["center"][0] for gate in self.gates])) * 1.3
        max_y = np.max(np.abs([gate["center"][1] for gate in self.gates])) * 1.3
        max_z = np.max([gate["center"][2] for gate in self.gates]) * 1.5
        
        ax.set_xlim(-max_x, max_x)
        ax.set_ylim(-max_y, max_y)
        ax.set_zlim(0, max(5.0, max_z))
        
        # Use a smaller legend
        ax.legend(fontsize='small', loc='upper right')
        
        return fig


# Keep the TrajectoryPredictor class and supporting functions
class TrajectoryPredictor:
    """Minimal trajectory predictor without TensorFlow"""
    
    def __init__(self, state_dim, traj_dim):
        self.state_dim = state_dim
        self.traj_dim = traj_dim
        self.output_shape = (-1, 3)
    
    def predict(self, state, target, neighbors):
        """Simple prediction without neural network"""
        # Create a simple trajectory from current position to target
        n_steps = self.traj_dim // 3
        traj = np.zeros((n_steps, 3))
        
        # Current position and velocity
        pos = state[:3]
        vel = state[3:6]
        
        # If neighbors are close, slightly adjust trajectory
        adjust = np.zeros(3)
        if neighbors:
            for n_state in neighbors:
                n_pos = n_state[:3]
                diff = pos - n_pos
                dist = np.linalg.norm(diff)
                if dist < 3.0 and dist > 0:
                    # Add small repulsive component
                    adjust += diff / dist * 0.2
        
        # Simple trajectory with velocity continuation and target approach
        for i in range(n_steps):
            t = i / (n_steps - 1) if n_steps > 1 else 0
            
            # Start with current velocity, then blend toward target
            if i == 0:
                # First step follows current velocity with adjustment
                traj[i] = pos + vel * 0.2 + adjust
            else:
                # Blend between velocity projection and direct target approach
                vel_proj = traj[i-1] + vel * 0.2 * (1 - t)
                target_proj = target * t + pos * (1 - t)
                blend = min(1.0, i/2)  # More direct path after first couple steps
                traj[i] = vel_proj * (1 - blend) + target_proj * blend
        
        return traj


def run_simulation(num_robots=3, num_gates=6, num_steps=50, visualize=True, save_path=None):
    """Run a complete simulation with the given parameters"""
    # Create output directory if needed
    if save_path and not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Initialize simulation
    if not hasattr(run_simulation, '_sim') or run_simulation._sim is None:
        print("Initializing simulation...")
        run_simulation._sim = SEIBR_DroneNavigation_FigEight(
            num_robots=num_robots,
            planning_horizon=10,
            dt=0.1,
            max_iterations=5,
            dynamics_dt=0.01,
            n_gates=num_gates,
            radius=6.0
        )
    else:
        # Reset existing simulation
        run_simulation._sim.reset()
    
    # Run simulation
    results = run_simulation._sim.simulate(num_steps=num_steps, visualize=visualize, save_path=save_path, create_video=True)
    
    # Print summary statistics
    print("\nSimulation Results:")
    print(f"Total simulation time: {results['sim_time']:.2f} seconds")
    print(f"Total collisions: {results['total_collision_count']}")
    print(f"Gates passed by each drone: {results['gates_passed']}")
    print(f"Mean speeds: {[f'{v:.2f} m/s' for v in results['velocity_stats']['drone_mean_speeds']]}")
    print(f"Overall mean speed: {results['velocity_stats']['overall_mean_speed']:.2f} m/s")
    
    return results

def evaluate_algorithm(num_robots=3, num_gates=6, num_episodes=10, num_steps_per_episode=50, 
                       visualize=False, save_path=None):
    """
    Evaluate algorithm performance across multiple episodes and calculate statistics
    """
    # Arrays to store metrics across episodes
    rewards = []
    speeds = []
    targets_reached = []
    collision_flags = []
    
    for episode in range(num_episodes):
        # Create episode-specific save path if needed
        episode_save_path = None
        if save_path:
            episode_save_path = os.path.join(save_path, f"episode_{episode+1}")
            if not os.path.exists(episode_save_path):
                os.makedirs(episode_save_path)
        
        print(f"\n=== Running Episode {episode+1}/{num_episodes} ===")
        
        # Run a single simulation episode
        results = run_simulation(num_robots=num_robots, 
                                num_gates=num_gates, 
                                num_steps=num_steps_per_episode,
                                visualize=visualize, 
                                save_path=episode_save_path)
        
        # Calculate reward for this episode
        episode_reward = sum(results['gates_passed']) - (1 if results['total_collision_count'] > 0 else 0)
        
        # Calculate average speed for this episode
        episode_speed = results['velocity_stats']['overall_mean_speed']
        
        # Calculate average targets/gates reached per drone
        episode_targets = sum(results['gates_passed']) / num_robots
        
        # Calculate collision flag (binary: 1 if any collision occurred, 0 otherwise)
        episode_collision = 1 if results['total_collision_count'] > 0 else 0
        
        # Store metrics
        rewards.append(episode_reward)
        speeds.append(episode_speed)
        targets_reached.append(episode_targets)
        collision_flags.append(episode_collision)
        
        # Report episode results
        print(f"Episode {episode+1} Results:")
        print(f"  Reward: {episode_reward:.2f}")
        print(f"  Speed: {episode_speed:.2f} m/s")
        print(f"  Average Targets Reached: {episode_targets:.2f}")
        print(f"  Collision Flag: {episode_collision}")
    
    # Calculate statistics across all episodes
    mean_reward = np.mean(rewards)
    mean_speed = np.mean(speeds)
    std_speed = np.std(speeds)
    mean_targets_reached = np.mean(targets_reached)
    std_targets_reached = np.std(targets_reached)
    mean_collisions = np.mean(collision_flags)
    std_collisions = np.std(collision_flags)
    
    # Print final evaluation results
    print(f"\nEvaluation results: Mean Reward={mean_reward:.2f}, Mean Speed={mean_speed:.2f}/±{std_speed:.2f}, "
          f"Mean Targets Reached={mean_targets_reached:.2f}/±{std_targets_reached:.2f}, "
          f"Mean Collisions={mean_collisions:.2f}/±{std_collisions:.2f}")
    
    if save_path:
        results_txt_path = os.path.join(save_path, "results.txt")
        with open(results_txt_path, "a") as f:
            f.write(f"mean_targets_reached: {mean_targets_reached:.2f}/±{std_targets_reached:.2f}\n")
            f.write(f"mean_velocity: {mean_speed:.2f}/±{std_speed:.2f}\n")
            f.write(f"mean_collision: {mean_collisions:.2f}/±{std_collisions:.2f}\n")
        print(f"  Saved results to {results_txt_path}")
    
    # Return all calculated metrics
    return {
        'rewards': rewards,
        'mean_reward': mean_reward,
        'speeds': speeds,
        'mean_speed': mean_speed,
        'std_speed': std_speed,
        'targets_reached': targets_reached,
        'mean_targets_reached': mean_targets_reached,
        'std_targets_reached': std_targets_reached,
        'collision_flags': collision_flags,
        'mean_collisions': mean_collisions,
        'std_collisions': std_collisions
    }




results = evaluate_algorithm(num_robots=2, 
                                num_episodes=10, 
                                num_steps_per_episode=500, 
                                visualize=True,
                                save_path="2drones_fig_eight_results")