In [1]:
#Main libraries
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding
import airsim as airsim
from airsim import MultirotorClient
import numpy as np
import os
from tqdm import tqdm
import csv
import datetime
from torch.utils.tensorboard import SummaryWriter
#General
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from typing import Dict, Any, Tuple
import logging
from logging import handlers
from copy import deepcopy
from itertools import count
import tqdm
import pickle
from pathlib import Path
import time
#PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
import torch.nn.utils as torch_utils
import asyncio
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

client = airsim.MultirotorClient()
client.confirmConnection()
print("Connected to AirSim!")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
print(device)

Connected!
Client Ver:1 (Min Req: 1), Server Ver:1 (Min Req: 1)

Connected to AirSim!
cuda


In [24]:
class MultiDroneCosysAirSimEnv():
    def __init__(self, num_drones=5, training_stage=0):
        self.client = airsim.MultirotorClient()
        self.client.confirmConnection()
        
        self.num_drones = num_drones
        self.drone_names = [f'Drone{i}' for i in range(num_drones)]
        
        # Action and observation spaces remain the same for each drone
        self.action_space = spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-5.0, high=5.0, shape=(21,), dtype=np.float32)
        
        self.episode = 0
        self.wind_factor = 0.0
        self.success_count = 0
        self.success_threshold = 10
        self.stage_threshold = 25
        
        
        # Track state for each drone
        self.drone_states = {
            drone_name: {
                'current_step': 0,
                'success_count': 0,
                'consecutive_success': 0,
                'current_target': None,
                'prev_distance': None,
                'reseting': False,
                'resetTick': 0,
                'resetingTime': 0,
                'currentTotalReward': 0,
            } for drone_name in self.drone_names
        }
        self.training_stage = training_stage
        self.setup_stage_params()
        self.initialize_drones()
        
    def initialize_drones(self):
        """Initialize all drones with proper controls and initial positions"""
        for drone_name in self.drone_names:
            self.client.enableApiControl(True, drone_name)
            self.client.armDisarm(True, drone_name)
            
        # Generate initial target
        self.current_target = self._generate_new_target()
        for drone_name in self.drone_names:
            self.drone_states[drone_name]['current_target'] = self.current_target

    def setup_stage_params(self):
        self.params = {
            0: {  # Hovering stage
                'wind': 0.0,
                'area': [5.0, 5.0, 5.0],
                'max_steps': 200,
                'target_type': 'hover',
                'height_range': [-2.0, -1.5],  # Desired hover height
                'target_radius': 0.5
            },
            1: {  # Near vertical movement
                'wind': 0.0,
                'area': [5.0, 5.0, 8.0],
                'max_steps': 250,
                'target_type': 'vertical',
                'height_range': [-4.0, -1.0],
                'target_radius': 0.3
            },
            2: {  # Close-range horizontal movement
                'wind': 0.0,
                'area': [10.0, 10.0, 5.0],
                'max_steps': 300,
                'target_type': 'horizontal_near',
                'target_radius': 0.3,
                'max_target_dist': 5.0
            },
            3: {  # Medium-range movement
                'wind': 0.0,
                'area': [15.0, 15.0, 8.0],
                'max_steps': 400,
                'target_type': 'free_near',
                'target_radius': 0.3,
                'max_target_dist': 10.0
            },
            4: {  # Long-range movement
                'wind': 0.0,
                'area': [20.0, 20.0, 10.0],
                'max_steps': 500,
                'target_type': 'free',
                'target_radius': 0.3,
                'max_target_dist': 15.0
            },
            5: {  # Extended range movement
                'wind': 0.0,
                'area': [30.0, 30.0, 15.0],  # Expanded area
                'max_steps': 600,
                'target_type': 'free',
                'target_radius': 0.3,
                'max_target_dist': 25.0  # Increased maximum distance
            },
            6: {  # Recovery training
                'wind': 0.0,
                'area': [20.0, 20.0, 10.0],
                'max_steps': 500,
                'target_type': 'free',
                'target_radius': 0.3,
                'max_target_dist': 15.0,
                'recovery_interval': 50  # Apply random action every 50 steps
            },
            7: {  # Moderate wind
                'wind': 0.5,
                'area': [30.0, 30.0, 15.0],
                'max_steps': 600,
                'target_type': 'free',
                'target_radius': 0.5,
                'max_target_dist': 25.0
            },
            8: {  # Strong wind
                'wind': 1.0,
                'area': [30.0, 30.0, 15.0],
                'max_steps': 600,
                'target_type': 'free',
                'target_radius': 0.5,  # Slightly larger radius due to strong wind
                'max_target_dist': 25.0
            }
        }
        self.current_params = self.params[self.training_stage]

    def _generate_new_target(self):
        """Internal method to generate a new target based on current parameters"""
        params = self.current_params
        if params['target_type'] == 'hover':
            return np.array([0.0, 0.0, -1.75])
            
        elif params['target_type'] == 'vertical':
            height = np.random.uniform(*params['height_range'])
            return np.array([0.0, 0.0, height])
            
        elif params['target_type'] == 'horizontal_near':
            angle = np.random.uniform(0, 2*np.pi)
            dist = np.random.uniform(2.0, params['max_target_dist'])
            return np.array([
                dist * np.cos(angle),
                dist * np.sin(angle),
                -2.0
            ])
            
        elif params['target_type'] in ['free_near', 'free']:
            while True:
                point = np.random.uniform(
                    low=[-params['max_target_dist'], -params['max_target_dist'], -4.0],
                    high=[params['max_target_dist'], params['max_target_dist'], -1.0],
                    size=(3,)
                )
                if np.linalg.norm(point - self._get_current_position()) > 2.0:
                    return point

    def _get_target(self):
        """Return the current target"""
        return self.current_target

    def _get_current_position(self, drone_name: str):
        state = self.client.getMultirotorState(vehicle_name=drone_name)
        pos = state.kinematics_estimated.position
        return np.array([pos.x_val, pos.y_val, pos.z_val], dtype=np.float32)

    def _get_current_velocity(self, drone_name: str):
        state = self.client.getMultirotorState(vehicle_name=drone_name)
        vel = state.kinematics_estimated.linear_velocity
        return np.array([vel.x_val, vel.y_val, vel.z_val], dtype=np.float32)
    
    def _get_imu_data(self, drone_name: str):
        imu_data = self.client.getImuData(vehicle_name=drone_name)
        ang_vel = imu_data.angular_velocity
        lin_acc = imu_data.linear_acceleration
        return np.array([
            ang_vel.x_val, ang_vel.y_val, ang_vel.z_val,
            lin_acc.x_val, lin_acc.y_val, lin_acc.z_val
        ], dtype=np.float32)

    def _get_observation(self, drone_name: str):
        position = self._get_current_position(drone_name)
        velocity = self._get_current_velocity(drone_name)
        imu_data = self._get_imu_data(drone_name)
        target = self.drone_states[drone_name]['current_target']
        distance = np.linalg.norm(position - target)
        
        area_bounds = np.array(self.current_params['area'])
        normalized_pos = np.clip(position / area_bounds, -1.0, 1.0)
        velocity_scale = np.array([10.0, 10.0, 5.0])
        normalized_vel = np.clip(velocity / velocity_scale, -1.0, 1.0)
        
        # Normalize IMU data
        angular_vel = np.clip(imu_data[:3] / 5.0, -1.0, 1.0)
        linear_acc = np.clip(imu_data[3:6] / 9.81, -3.0, 3.0)  # Allow for higher G-forces
        
        # Normalize target position
        normalized_target = np.clip(target / area_bounds, -1.0, 1.0)
        
        # Normalize distance relative to maximum possible distance in current area
        max_possible_distance = np.linalg.norm(area_bounds)
        normalized_distance = np.clip(distance / max_possible_distance, 0.0, 1.0)
        
        # Height error normalization
        height_error = (position[2] - target[2]) / area_bounds[2]
        normalized_height_error = np.clip(height_error, -1.0, 1.0)

        relative_pos = (target - position) / area_bounds
        target_direction = (target - position) / (np.linalg.norm(target - position) + 1e-6)
        relative_vel = np.dot(velocity, target_direction)
        
        observation = np.concatenate([
            normalized_pos,        # [0:3]
            normalized_vel,        # [3:6]
            angular_vel,          # [6:9]
            linear_acc,          # [9:12]
            normalized_target,  # [12:15]
            relative_pos,        # [15:18]
            [normalized_distance],# [18]
            [normalized_height_error],      # [19]
            [relative_vel]       # [20]
        ])
        return observation.astype(np.float32)

    def reset_step(self, drone_name):
        #Handle the stepwise reset process for a single drone
        drone_state = self.drone_states[drone_name]
        
        if drone_state['reseting'] == True:
            current_time = time.perf_counter()
            
            try:
                # Initial reset
                if drone_state['resetTick'] == 0 and drone_state['resetingTime'] == 0:
                    drone_state['resetingTime'] = current_time
                    self.client.client.call_async(
                        "resetVehicle",
                        drone_name,
                        airsim.Pose(
                            airsim.Vector3r(0, 0, 0),  # Reset to origin
                            airsim.Quaternionr(0.0, 0.0, 0.0, 1.0)
                        )
                    )
                    
                    drone_state['resetTick'] = 1
                    
                elif drone_state['resetTick'] == 2 and current_time - drone_state['resetingTime'] > 3.0:
                    # Then arm
                    self.client.armDisarm(True, drone_name)
                    time.sleep(0.1)  # Wait for arming to complete
                    drone_state['resetTick'] = 3
                    
                # Initiate takeoff - State 3
                elif drone_state['resetTick'] == 3 and current_time - drone_state['resetingTime'] > 4.0:
                    # Use hover instead of takeoff for more stability
                    self.client.hoverAsync(vehicle_name=drone_name)
                    drone_state['resetTick'] = 4
                    
                elif drone_state['resetTick'] == 4 and time.perf_counter() - drone_state['resetingTime'] > 6.0:
                    drone_state['reseting'] = False
                    drone_state['resetTick'] = 0
                    drone_state['current_step'] = 0
                    drone_state['currentTotalReward'] = 0
                    drone_state['resetingTime'] = 0
                    drone_state['consecutive_success'] = 0
                    self.episode += 1
                    print(self.episode)
            except Exception as e:
                print(f"Error during reset for {drone_name}: {str(e)}")
                # If there's an error, try to safely disable everything
                try:
                    self.client.armDisarm(False, drone_name)
                    self.client.enableApiControl(False, drone_name)
                except:
                    pass
                # Reset the reset process
                drone_state['resetTick'] = 0
                drone_state['resetingTime'] = 0

    def reset(self, seed=None):
        #Force reset all drones regardless of state
        if seed is not None:
            self.seed(seed)
        
        observations = {}
        infos = {}
        self.client.reset()
        # Reset all drones
        for drone_name in self.drone_names:
            self.client.enableApiControl(True, drone_name)
            self.client.armDisarm(True, drone_name)
            self.client.takeoffAsync(drone_name)
            # Reset state
            self.drone_states[drone_name].update({
                'current_step': 0,
                'success_count': 0,
                'consecutive_success': 0,
                'current_target': None,
                'prev_distance': None,
                'reseting': False,
                'resetTick': 0,
                'resetingTime': 0,
                'currentTotalReward': 0
            })
            self.current_target = self._generate_new_target()
            self.drone_states[drone_name]['current_target'] = self.current_target
            
            observations[drone_name] = self._get_observation(drone_name)
            infos[drone_name] = {}
            
        return observations, infos
    
    def step(self, actions: Dict[str, np.ndarray]):
            # actions should be a dict mapping drone_name to action array
            observations = {}
            rewards = {}
            terminateds = {}
            truncateds = {}
            infos = {}

            # Filter out resetting drones
            active_drones = {name: action for name, action in actions.items() 
                            if not self.drone_states[name]['reseting']}
            
            if active_drones:
                # Handle each drone's movement individually since there's no batch roll/pitch command
                for drone_name, action in active_drones.items():
                    self.drone_states[drone_name]['current_step'] += 1
                    action = np.clip(action, -1.0, 1.0)
                    
                    # Process controls for each drone
                    max_vx, max_vy, max_vz = 3.0, 3.0, 3.0
                    max_yaw_deg = 45.0
                    vx = float(action[0]) * max_vx
                    vy = float(action[1]) * max_vy
                    vz = float(action[2]) * max_vz
                    yaw_rate_deg = float(action[3]) * max_yaw_deg
                    
                    yaw_mode = airsim.YawMode(is_rate=True, yaw_or_rate=yaw_rate_deg)

                    # Execute movement for each drone
                    self.client.moveByVelocityBodyFrameAsync(
                        vx=vx,
                        vy=vy,
                        vz=vz,
                        duration=0.2,  # small step
                        drivetrain=airsim.DrivetrainType.MaxDegreeOfFreedom,
                        yaw_mode=yaw_mode,
                        vehicle_name = drone_name
                    )

            # Handle resetting drones
            for drone_name in self.drone_names:
                if self.drone_states[drone_name]['reseting']:
                    self.reset_step(drone_name)

            # Process observations and rewards for all drones
            for drone_name in actions.keys():
                if self.drone_states[drone_name]['reseting']:
                    # For resetting drones, provide zero observations and rewards
                    observations[drone_name] = np.zeros(self.observation_space.shape, dtype=np.float32)
                    rewards[drone_name] = 0.0
                    terminateds[drone_name] = True
                    truncateds[drone_name] = False
                else:
                    # Normal processing for active drones
                    observations[drone_name] = self._get_observation(drone_name)
                    reward, terminated = self._calculate_reward(drone_name)
                    rewards[drone_name] = reward
                    terminateds[drone_name] = terminated
                    truncateds[drone_name] = self.drone_states[drone_name]['current_step'] >= self.current_params['max_steps']
                    
                    if terminated or truncateds[drone_name]:
                        self.drone_states[drone_name]['reseting'] = True

                infos[drone_name] = {
                    'success_count': self.drone_states[drone_name]['success_count'],
                    'stage': self.training_stage,
                    'resetting': self.drone_states[drone_name]['reseting']
                }

            return observations, rewards, terminateds, truncateds, infos

    def _calculate_reward(self, drone_name):
        pos = self._get_current_position(drone_name)
        vel = self._get_current_velocity(drone_name)
        angular_vel = self._get_imu_data(drone_name)[:3]
        target = self.drone_states[drone_name]['current_target']
        distance = np.linalg.norm(pos - target)
        reward_scale = 0.1
        reward = 0.0

        # Immediate failure conditions
        if self.client.simGetCollisionInfo(vehicle_name=drone_name).has_collided:
            reward -= 100 * reward_scale
            return reward, True

        if any(abs(p) > a for p, a in zip(pos, self.current_params['area'])):
            reward -= 50 * reward_scale
            return reward, True

        # Main distance reward - smoother gradient
        distance_reward = -distance + 1

        if self.training_stage == 0:  # Hovering
            # Pure hovering - focus on stability
            vertical_distance = abs(pos[2] - target[2])
            reward += (-vertical_distance + 1.0) * reward_scale
            reward += (distance_reward * 0.5) * reward_scale

        elif self.training_stage == 1:  # Vertical movement
            # Reward vertical progress toward target
            vertical_distance = abs(pos[2] - target[2])
            reward += (2.0 * np.exp(-vertical_distance)) * reward_scale
            
            # Penalize horizontal drift more strongly
            horizontal_drift = np.linalg.norm(pos[:2] - target[:2])
            reward -= (0.5 * horizontal_drift) * reward_scale
            
            reward += (distance_reward * 0.5) * reward_scale

        elif self.training_stage == 2:  # Close-range horizontal
            # Reward horizontal progress
            prev_distance = self.drone_states[drone_name]['prev_distance']
            if prev_distance is not None:
                progress = prev_distance - distance
                reward += (3.0 * progress) * reward_scale  # Stronger reward for deliberate movement
            self.drone_states[drone_name]['prev_distance'] = distance
            
            reward += distance_reward * reward_scale

        elif self.training_stage >= 3:  # Medium-range movement
            prev_distance = self.drone_states[drone_name]['prev_distance']
            if prev_distance is not None:
                progress = prev_distance - distance
                reward += (2.0 * progress) * reward_scale
            self.drone_states[drone_name]['prev_distance'] = distance
            reward += distance_reward * reward_scale

        # Success condition
        success_radius = self.current_params['target_radius']
        if distance < success_radius:
            bonus = 0
            bonus += 25.0 * reward_scale
            self.drone_states[drone_name]['consecutive_success'] += 1
            
            if self.drone_states[drone_name]['consecutive_success'] % self.success_threshold == 0:
                self.current_target = self._generate_new_target()
                # Update target for all drones since they share the same target
                for name in self.drone_names:
                    self.drone_states[name]['current_target'] = self.current_target
                bonus += 50.0 * reward_scale
                self.drone_states[drone_name]['success_count'] += 1
            
            if self.drone_states[drone_name]['success_count'] >= self.stage_threshold:
                self.training_stage = min(8, self.training_stage + 1)
                print(f"\n=== Advanced to stage {self.training_stage} ===")
                # Reset success counts for all drones when advancing stage
                for name in self.drone_names:
                    self.drone_states[name]['success_count'] = 0
                self.setup_stage_params()
                
            reward += bonus
            return reward, True

        return reward, False
    def close(self):
        self.client.reset()
        self.client.enableApiControl(False)
        self.client.armDisarm(False)

In [3]:
class ReplayBuffer:
    def __init__(self, max_size, obs_dim, act_dim, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.device = device
        # Use float32 for all arrays to match network precision
        self.states = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.actions = np.zeros((max_size, act_dim), dtype=np.float32)
        self.rewards = np.zeros(max_size, dtype=np.float32)  # Flattened array
        self.next_states = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.dones = np.zeros(max_size, dtype=np.float32)    # Flattened array

    def add(self, state, action, reward, next_state, done):
        # Ensure incoming data is on CPU and in numpy format
        if torch.is_tensor(state):
            state = state.detach().cpu().numpy()
        if torch.is_tensor(action):
            action = action.detach().cpu().numpy()
        if torch.is_tensor(reward):
            reward = reward.detach().cpu().numpy()
        if torch.is_tensor(next_state):
            next_state = next_state.detach().cpu().numpy()
        if torch.is_tensor(done):
            done = done.detach().cpu().numpy()
            
        np.copyto(self.states[self.ptr], state)
        np.copyto(self.actions[self.ptr], action)
        self.rewards[self.ptr] = reward
        np.copyto(self.next_states[self.ptr], next_state)
        self.dones[self.ptr] = done
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        
        return (
            torch.FloatTensor(self.states[ind]).to(self.device),
            torch.FloatTensor(self.actions[ind]).to(self.device),
            torch.FloatTensor(self.rewards[ind]).to(self.device),
            torch.FloatTensor(self.next_states[ind]).to(self.device),
            torch.FloatTensor(self.dones[ind]).to(self.device)
        )
    
    def save(self, path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        # Save the buffer state and metadata
        save_dict = {
            'max_size': self.max_size,
            'ptr': self.ptr,
            'size': self.size,
            'states': self.states,
            'actions': self.actions,
            'rewards': self.rewards,
            'next_states': self.next_states,
            'dones': self.dones,
            'device': self.device
        }
        
        try:
            with open(path, 'wb') as f:
                pickle.dump(save_dict, f)
        except Exception as e:
            raise RuntimeError(f"Failed to save buffer to {path}: {e}")

    @staticmethod
    def load(path, device=None):
        try:
            with open(path, 'rb') as f:
                save_dict = pickle.load(f)
            
            # Create new buffer with saved dimensions
            obs_dim = save_dict['states'].shape[1]
            act_dim = save_dict['actions'].shape[1]
            buffer = ReplayBuffer(
                max_size=save_dict['max_size'],
                obs_dim=obs_dim,
                act_dim=act_dim,
                device=device or save_dict['device']
            )
            
            # Restore buffer state
            buffer.ptr = save_dict['ptr']
            buffer.size = save_dict['size']
            buffer.states = save_dict['states']
            buffer.actions = save_dict['actions']
            buffer.rewards = save_dict['rewards']
            buffer.next_states = save_dict['next_states']
            buffer.dones = save_dict['dones']
            
            return buffer
        except Exception as e:
            raise RuntimeError(f"Failed to load buffer from {path}: {e}")
        
    def __len__(self):
        return self.size
"""
class ReplayBuffer:
    def __init__(self, max_size, obs_dim, act_dim, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.device = device

        # Initialize buffers with correct shapes
        self.states = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.actions = np.zeros((max_size, act_dim), dtype=np.float32)
        self.rewards = np.zeros((max_size, 1), dtype=np.float32)  # Changed shape to (max_size, 1)
        self.next_states = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.dones = np.zeros((max_size, 1), dtype=np.float32)    # Changed shape to (max_size, 1)

    def add(self, state, action, reward, next_state, done):
        # Convert inputs to numpy arrays and ensure correct shapes
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        reward = np.array(reward, dtype=np.float32).reshape(1)
        next_state = np.array(next_state, dtype=np.float32).flatten()
        done = np.array(done, dtype=np.float32).reshape(1)

        # Store transition
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = next_state
        self.dones[self.ptr] = done

        # Update pointer and size
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        
        return (
            torch.FloatTensor(self.states[ind]).to(self.device),
            torch.FloatTensor(self.actions[ind]).to(self.device),
            torch.FloatTensor(self.rewards[ind]).to(self.device),
            torch.FloatTensor(self.next_states[ind]).to(self.device),
            torch.FloatTensor(self.dones[ind]).to(self.device)
        )
    def save(self, path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        # Save the buffer state and metadata
        save_dict = {
            'max_size': self.max_size,
            'ptr': self.ptr,
            'size': self.size,
            'states': self.states,
            'actions': self.actions,
            'rewards': self.rewards,
            'next_states': self.next_states,
            'dones': self.dones,
            'device': self.device
        }
        
        try:
            with open(path, 'wb') as f:
                pickle.dump(save_dict, f)
        except Exception as e:
            raise RuntimeError(f"Failed to save buffer to {path}: {e}")

    @staticmethod
    def load(path, device=None):
        try:
            with open(path, 'rb') as f:
                save_dict = pickle.load(f)
            
            # Create new buffer with saved dimensions
            obs_dim = save_dict['states'].shape[1]
            act_dim = save_dict['actions'].shape[1]
            buffer = ReplayBuffer(
                max_size=save_dict['max_size'],
                obs_dim=obs_dim,
                act_dim=act_dim,
                device=device or save_dict['device']
            )
            
            # Restore buffer state
            buffer.ptr = save_dict['ptr']
            buffer.size = save_dict['size']
            buffer.states = save_dict['states']
            buffer.actions = save_dict['actions']
            buffer.rewards = save_dict['rewards']
            buffer.next_states = save_dict['next_states']
            buffer.dones = save_dict['dones']
            
            return buffer
        except Exception as e:
            raise RuntimeError(f"Failed to load buffer from {path}: {e}")
        
    def __len__(self):
        return self.size
"""
class TD3Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, max_action, hidden_dim=512):
        super().__init__()
        self.feature_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        
        self.action_net = nn.Sequential(
            nn.Linear(hidden_dim, act_dim),
            nn.Tanh()
        )
        
        self.max_action = max_action
        self.__init_args__ = (obs_dim, act_dim, max_action, hidden_dim)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)

    def forward(self, state):
        features = self.feature_net(state)
        return self.max_action * self.action_net(features)

class TD3Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=512):
        super().__init__()
        self.state_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
        )
        
        self.action_net = nn.Sequential(
            nn.Linear(act_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
        )
        
        self.q_net = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.__init_args__ = (obs_dim, act_dim, hidden_dim)
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.constant_(m.bias, 0)

    def forward(self, state, action):
        state_features = self.state_net(state)
        action_features = self.action_net(action)
        features = torch.cat([state_features, action_features], dim=-1)
        return self.q_net(features)
    """
class TD3Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, max_action, hidden_dim=256):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, act_dim),
            nn.Tanh()
        )
        
        self.max_action = max_action
        self.__init_args__ = (obs_dim, act_dim, max_action, hidden_dim)
        
        # Use orthogonal initialization
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.orthogonal_(layer.weight, gain=np.sqrt(2))
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, state):
        return self.max_action * self.net(state)

class TD3Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim=256):
        super().__init__()
        
        # Q1 architecture
        self.q1_net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Q2 architecture
        self.q2_net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.__init_args__ = (obs_dim, act_dim, hidden_dim)
        self._init_weights()

    def _init_weights(self):
        for net in [self.q1_net, self.q2_net]:
            for m in net.modules():
                if isinstance(m, nn.Linear):
                    nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                    nn.init.constant_(m.bias, 0)

    def forward(self, state, action):
        # Ensure inputs are properly shaped
        if state.dim() == 1:
            state = state.unsqueeze(0)
        if action.dim() == 1:
            action = action.unsqueeze(0)
            
        # Concatenate state and action
        sa = torch.cat([state, action], dim=1)
        
        # Get Q-values
        q1 = self.q1_net(sa)
        q2 = self.q2_net(sa)
        
        return q1, q2

    def Q1(self, state, action):
        # Ensure inputs are properly shaped
        if state.dim() == 1:
            state = state.unsqueeze(0)
        if action.dim() == 1:
            action = action.unsqueeze(0)
            
        # Concatenate state and action
        sa = torch.cat([state, action], dim=1)
        
        return self.q1_net(sa)
"""
class TD3Trainer:
    def __init__(self, actor, critic1, critic2, actor_optimizer, critic_optimizer1, 
                 critic_optimizer2, max_action, device, gamma=0.99, tau=0.001):
        
        self.actor = actor
        self.critic1 = critic1
        self.critic2 = critic2
        self.actor_target = type(actor)(*actor.__init_args__).to(device)
        self.critic_target1 = type(critic1)(*critic1.__init_args__).to(device)
        self.critic_target2 = type(critic2)(*critic2.__init_args__).to(device)
        
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer1 = critic_optimizer1
        self.critic_optimizer2 = critic_optimizer2
        
        self.max_action = max_action
        self.device = device
        self.gamma = gamma
        self.tau = tau
        
        self.target_entropy = -float(actor.__init_args__[1])
        
        # Initialize episode tracking
        self.current_episode_rewards = []
        self.episode_returns = []
        
    def _hard_update_targets(self):
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target1.load_state_dict(self.critic1.state_dict())
        self.critic_target2.load_state_dict(self.critic2.state_dict())
    
    def _soft_update(self, target, source):
        with torch.no_grad():
            for target_param, param in zip(target.parameters(), source.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1.0 - self.tau) * target_param.data
                )

    def train_step(self, replay_buffer, batch_size, noise_std, noise_clip, policy_delay, total_it, noise):
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
        reward = reward.view(-1)
        done = done.view(-1)

        with torch.no_grad():
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            
            target_Q1 = self.critic_target1(next_state, next_action)
            target_Q2 = self.critic_target2(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            
            if len(reward.shape) == 1:
                reward = reward.unsqueeze(-1)
            if len(done.shape) == 1:
                done = done.unsqueeze(-1)
            target_Q = reward + (1 - done) * self.gamma * target_Q
        
        # Critic 1 update
        current_Q1 = self.critic1(state, action)
        critic_mse_loss1 = F.mse_loss(current_Q1, target_Q)
        critic_l2_reg1 = 0.00001 * sum(torch.sum(param ** 2) for param in self.critic1.parameters())
        critic_loss1 = critic_mse_loss1 + critic_l2_reg1
        
        self.critic_optimizer1.zero_grad()
        critic_loss1.backward(retain_graph=True) 
        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)
        self.critic_optimizer1.step()
        
        # Critic 2 update
        current_Q2 = self.critic2(state, action)
        critic_mse_loss2 = F.mse_loss(current_Q2, target_Q)
        critic_l2_reg2 = 0 #0.00001 * sum(torch.sum(param ** 2) for param in self.critic2.parameters())
        critic_loss2 = critic_mse_loss2 + critic_l2_reg2
        
        self.critic_optimizer2.zero_grad()
        critic_loss2.backward()  
        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
        self.critic_optimizer2.step()

        actor_loss = None
        if total_it % policy_delay == 0:
            actor_actions = self.actor(state)
            actor_loss = -self.critic1(state, actor_actions).mean()
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
            self.actor_optimizer.step()
            
            self._soft_update(self.actor_target, self.actor)
            self._soft_update(self.critic_target1, self.critic1)
            self._soft_update(self.critic_target2, self.critic2)
        
        return {
            'critic_loss1': critic_loss1.item(),
            'critic_loss2': critic_loss2.item(),
            'actor_loss': actor_loss.item() if actor_loss is not None else None
        }
    """
class TD3Trainer:
    def __init__(self, actor, critic1, critic2, actor_optimizer, critic_optimizer1, 
                 critic_optimizer2, max_action, device, gamma=0.99, tau=0.005):
        
        self.actor = actor
        self.critic1 = critic1
        self.critic2 = critic2
        self.actor_target = type(actor)(*actor.__init_args__).to(device)
        self.critic_target1 = type(critic1)(*critic1.__init_args__).to(device)
        self.critic_target2 = type(critic2)(*critic2.__init_args__).to(device)
        
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer1 = critic_optimizer1
        self.critic_optimizer2 = critic_optimizer2
        
        self.max_action = max_action
        self.device = device
        self.gamma = gamma
        self.tau = tau
        
        self._hard_update_targets()

    def _hard_update_targets(self):
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target1.load_state_dict(self.critic1.state_dict())
        self.critic_target2.load_state_dict(self.critic2.state_dict())
    
    def _soft_update(self, target, source):
        with torch.no_grad():
            for target_param, param in zip(target.parameters(), source.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1.0 - self.tau) * target_param.data
                )

    def train_step(self, replay_buffer, batch_size, noise_std, noise_clip, policy_delay, total_it):
        # Sample from replay buffer
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
        
        # Ensure proper dimensions
        reward = reward.view(-1, 1)
        done = done.view(-1, 1)

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(action) * noise_std).clamp(-noise_clip, noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            
            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target1(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (1 - done) * self.gamma * target_Q

        # Get current Q estimates
        current_Q1, _ = self.critic1(state, action)
        current_Q2, _ = self.critic2(state, action)

        # Compute critic loss
        critic_loss1 = F.mse_loss(current_Q1, target_Q)
        critic_loss2 = F.mse_loss(current_Q2, target_Q)

        # Optimize the critics
        self.critic_optimizer1.zero_grad()
        critic_loss1.backward()
        self.critic_optimizer1.step()

        self.critic_optimizer2.zero_grad()
        critic_loss2.backward()
        self.critic_optimizer2.step()

        actor_loss = None

        # Delayed policy updates
        if total_it % policy_delay == 0:
            # Compute actor loss
            actor_action = self.actor(state)
            Q1, _ = self.critic1(state, actor_action)
            actor_loss = -Q1.mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            self._soft_update(self.actor_target, self.actor)
            self._soft_update(self.critic_target1, self.critic1)
            self._soft_update(self.critic_target2, self.critic2)

        return {
            'critic_loss1': critic_loss1.item(),
            'critic_loss2': critic_loss2.item(),
            'actor_loss': actor_loss.item() if actor_loss is not None else None
        }
    """

In [22]:
def compute_gradient_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def compute_parameter_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def evaluate_policy(actor, env, num_episodes=10, device='cuda'):
    """Evaluate the policy for multiple drones without exploration noise"""
    eval_rewards = []
    eval_success = 0
    
    for _ in range(num_episodes):
        states, _ = env.reset()
        episode_reward = 0
        dones = {drone_name: False for drone_name in env.drone_names}
        
        while not all(dones.values()):
            actions = {}
            # Get actions for all active drones
            for drone_name, state in states.items():
                if not dones[drone_name]:
                    with torch.no_grad():
                        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                        actions[drone_name] = actor(state_tensor).cpu().numpy().flatten()
            
            # Step environment with all drone actions
            next_states, rewards, terminateds, truncateds, infos = env.step(actions)
            
            # Update states and track rewards
            episode_reward += np.mean(list(rewards.values()))  # Average reward across drones
            for drone_name in env.drone_names:
                if not dones[drone_name]:
                    dones[drone_name] = terminateds[drone_name] or truncateds[drone_name]
                    if infos[drone_name].get('success', False):
                        eval_success += 1
                        break  # Count only one success per episode
            
            states = next_states
        
        eval_rewards.append(episode_reward)
    
    return {
        'mean_reward': np.mean(eval_rewards),
        'std_reward': np.std(eval_rewards),
        'success_rate': eval_success / num_episodes,
        'rewards': eval_rewards
    }

def run_training(env, obs_dim, act_dim, max_action, episodes=10000):
    # Initialize environment and models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize TensorBoard writer with more descriptive name
    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    log_dir = os.path.join('logs', f'TD3_train_{current_time}')
    train_writer = SummaryWriter(log_dir + '/train')
    eval_writer = SummaryWriter(log_dir + '/eval')
    print(f"TensorBoard logs will be saved to: {log_dir}")
    
    # Model initialization
    actor = TD3Actor(obs_dim, act_dim, max_action).to(device)
    critic1 = TD3Critic(obs_dim, act_dim).to(device)
    critic2 = TD3Critic(obs_dim, act_dim).to(device)
    
    actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)
    critic_optimizer1 = optim.Adam(critic1.parameters(), lr=3e-4)
    critic_optimizer2 = optim.Adam(critic2.parameters(), lr=3e-4)

    # Load checkpoint if exists
    total_steps = 0
    current_stage = 0
    best_rewards = {i: float('-inf') for i in range(9)}
    
    try:
        checkpoint = torch.load('models/TD3/Model/mtp_model_main.pth')
        actor.load_state_dict(checkpoint['actor_state_dict'])
        critic1.load_state_dict(checkpoint['critic1_state_dict'])
        critic2.load_state_dict(checkpoint['critic2_state_dict'])
        actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        critic_optimizer1.load_state_dict(checkpoint['critic1_optimizer_state_dict'])
        critic_optimizer2.load_state_dict(checkpoint['critic2_optimizer_state_dict'])
        
        try:
            total_steps = checkpoint['total_steps']
            current_stage = checkpoint['stage']
            best_rewards = checkpoint['best_rewards']
        except KeyError:
            pass
        print(f"Loaded checkpoint successfully. Total steps: {total_steps}")
    except:
        print("No checkpoint found, starting fresh")

    # Initialize or load replay buffer
    try:
        replay_buffer = ReplayBuffer.load('models/TD3/Replay_Buffer/mtp_replay_buffer.pkl')
        print("Loaded buffer successfully")
    except:
        replay_buffer = ReplayBuffer(max_size=1_000_000, obs_dim=obs_dim, act_dim=act_dim)
        print("No buffer found, starting fresh")
    
    trainer = TD3Trainer(
        actor=actor,
        critic1=critic1,
        critic2=critic2,
        actor_optimizer=actor_optimizer,
        critic_optimizer1=critic_optimizer1,
        critic_optimizer2=critic_optimizer2,
        max_action=max_action,
        device=device
    )
    trainer._hard_update_targets()

    # Training hyperparameters
    batch_size = 256
    warmup_steps = 10000
    noise_std = 0.1
    noise_clip = 0.5
    policy_delay = 2
    eval_freq = 1000  # Evaluate every 1000 episodes
    
    # Create directories
    os.makedirs('models/TD3/Replay_Buffer', exist_ok=True)
    os.makedirs('models/TD3/Model', exist_ok=True)
    os.makedirs('logs', exist_ok=True)
    
    def get_exploration_noise(action):
        """Return exploration noise as numpy array with correct shape"""
        noise = np.random.normal(0, noise_std, size=action.shape)
        return np.clip(noise, -noise_clip, noise_clip)
    
    def log_training_step(train_info, step):
        # Log losses
        train_writer.add_scalar('Loss/Critic1', train_info['critic_loss1'], step)
        train_writer.add_scalar('Loss/Critic2', train_info['critic_loss2'], step)
        if train_info['actor_loss'] is not None:
            train_writer.add_scalar('Loss/Actor', train_info['actor_loss'], step)
        
        # Log gradient norms
        train_writer.add_scalar('Gradients/Actor_Norm', compute_gradient_norm(actor), step)
        train_writer.add_scalar('Gradients/Critic1_Norm', compute_gradient_norm(critic1), step)
        train_writer.add_scalar('Gradients/Critic2_Norm', compute_gradient_norm(critic2), step)
        
        # Log parameter norms
        train_writer.add_scalar('Parameters/Actor_Norm', compute_parameter_norm(actor), step)
        train_writer.add_scalar('Parameters/Critic1_Norm', compute_parameter_norm(critic1), step)
        train_writer.add_scalar('Parameters/Critic2_Norm', compute_parameter_norm(critic2), step)
        
        # Log learning rates
        train_writer.add_scalar('LearningRate/Actor', actor_optimizer.param_groups[0]['lr'], step)
        train_writer.add_scalar('LearningRate/Critic1', critic_optimizer1.param_groups[0]['lr'], step)
        train_writer.add_scalar('LearningRate/Critic2', critic_optimizer2.param_groups[0]['lr'], step)
    
    def run_evaluation(episode):
        # Create a separate environment for evaluation
        eval_env = env.__class__()  # Assuming env has a constructor that takes no arguments
        eval_env.training_stage = env.training_stage  # Sync the training stage
        
        # Run evaluation
        eval_results = evaluate_policy(actor, eval_env, num_episodes=10, device=device)
        
        # Log evaluation metrics
        eval_writer.add_scalar('Eval/Mean_Reward', eval_results['mean_reward'], episode)
        eval_writer.add_scalar('Eval/Reward_Std', eval_results['std_reward'], episode)
        eval_writer.add_scalar('Eval/Success_Rate', eval_results['success_rate'], episode)
        
        # Log reward distribution
        eval_writer.add_histogram('Eval/Reward_Distribution', 
                                torch.tensor(eval_results['rewards']), 
                                episode)
        
        eval_env.close()
        return eval_results
    
    
    # Modified reward tracking initialization
    recent_rewards = deque(maxlen=100)  # Track recent episode rewards
    stage_rewards = {i: [] for i in range(9)}  # Track rewards per stage
    best_eval_reward = float('-inf')
    episode_rewards = []
    # Initialize states and tracking

    progress_bar = tqdm.tqdm(range(episodes), desc="Training")
    for episode in progress_bar:
        current_episode_reward = 0  # Track current episode's reward sum
        states, _ = env.reset()
        dones = {drone_name: False for drone_name in env.drone_names}
        episode_rewards_per_drone = {drone_name: 0 for drone_name in env.drone_names}
        episodes_completed = 0
        total_steps = 0
        
        while not dones:
            actions = {}
            # Log current stage
            train_writer.add_scalar('Training/Current_Stage', current_stage, total_steps)
        
            # Select actions for each active drone
            for drone_name in env.drone_names:
                if not dones[drone_name]:
                    state = states[drone_name]
                    if total_steps < warmup_steps:
                        actions[drone_name] = np.random.uniform(-max_action, max_action, size=act_dim)
                        train_writer.add_scalar('Training/Exploration_Type', 0, total_steps)
                    else:
                        with torch.no_grad():
                            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                            action = actor(state_tensor).cpu().numpy().flatten()
                            noise = get_exploration_noise(torch.from_numpy(action))
                            noisy_action = action + noise
                            actions[drone_name] = np.clip(noisy_action, -max_action, max_action)
                        
                        train_writer.add_scalar('Training/Exploration_Type', 1, total_steps)
                        train_writer.add_scalar('Training/Exploration_Noise', noise, total_steps)
            
            # Log action statistics
            all_actions = np.array(list(actions.values()))  # Shape: (num_drones, action_dim)
            train_writer.add_histogram('Actions/Distribution', all_actions, total_steps)
            train_writer.add_scalar('Actions/Mean', np.mean(all_actions), total_steps)
            train_writer.add_scalar('Actions/Std', np.std(all_actions), total_steps)
            
            # Step environment
            next_states, rewards, terminateds, truncateds, info = env.step(actions)
            
            # Store transition
            for drone_name in env.drone_names:
                if not dones[drone_name]:
                    replay_buffer.add(
                        states[drone_name],
                        actions[drone_name],
                        rewards[drone_name],
                        next_states[drone_name],
                        float(terminateds[drone_name] or truncateds[drone_name])
                    )
                    current_episode_reward += rewards[drone_name]
                    # If this drone is done, reset just this drone and count it as an episode
                    if terminateds[drone_name] or truncateds[drone_name]:
                        # Log the completed episode for this drone
                        episode_rewards.append(current_episode_reward)
                        recent_rewards.append(current_episode_reward)
                        stage_rewards[current_stage].append(current_episode_reward)
                        
                        # Reset this drone's episode reward
                        episode_rewards_per_drone[drone_name] = 0
                        episodes_completed += 1
                        dones[drone_name] = True
            
            # Update state and metrics
            states.update(next_states)
            for drone_name in env.drone_names:
                if dones[drone_name]:
                    states[drone_name] = env._get_observation(drone_name)
                    dones[drone_name] = False

            total_steps += 1
            
            # Train agent
            if total_steps > warmup_steps and len(replay_buffer.states) > batch_size:
                train_info = trainer.train_step(
                    replay_buffer=replay_buffer,
                    batch_size=batch_size,
                    noise_std=noise_std,
                    noise_clip=noise_clip,
                    policy_delay=policy_delay,
                    total_it=total_steps,
                    #action=list(actions.values())[0],
                    #reward = list(rewards.values())[0],
                    #noise = current_noise
                )
                
                # Log training metrics
                log_training_step(train_info, total_steps)
            # Run evaluation periodically
        if episodes_completed % eval_freq == 0:
            eval_results = run_evaluation(episodes_completed)
            
            # Save best model based on evaluation
            if eval_results['mean_reward'] > best_eval_reward:
                best_eval_reward = eval_results['mean_reward']
                torch.save({
                    'actor_state_dict': actor.state_dict(),
                    'critic1_state_dict': critic1.state_dict(),
                    'critic2_state_dict': critic2.state_dict(),
                    'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                    'critic1_optimizer_state_dict': critic_optimizer1.state_dict(),
                    'critic2_optimizer_state_dict': critic_optimizer2.state_dict(),
                    'total_steps': total_steps,
                    'episode': episodes_completed,
                    'stage': current_stage,
                    'eval_reward': best_eval_reward,
                    'best_rewards': best_rewards
                }, 'models/TD3/Model/mtp_model_best_eval.pth')
                
        train_writer.add_scalar('Training/Current_Stage', env.training_stage, total_steps)
        train_writer.add_scalar('Training/Episodes_Completed', episodes_completed, total_steps)
        train_writer.add_scalar('Training/Buffer_Size', len(replay_buffer), total_steps)
        
        # Stage-specific metrics
        if len(stage_rewards[current_stage]) >= 100:
            # Convert to numpy array and ensure float type
            stage_reward_array = np.array(stage_rewards[current_stage][-100:], dtype=np.float32)
            current_avg_reward = float(np.mean(stage_reward_array))
            
            train_writer.add_scalar(f'Stage_{current_stage}/Average_100_Episodes', 
                                  current_avg_reward, 
                                  len(stage_rewards[current_stage]))
            
            if current_avg_reward > best_rewards[current_stage]:
                best_rewards[current_stage] = current_avg_reward
                train_writer.add_scalar(f'Stage_{current_stage}/Best_Average_Reward', 
                                      current_avg_reward, 
                                      len(stage_rewards[current_stage]))
        
        # Update progress bar
        progress_bar.set_postfix({
            'stage': env.training_stage,
            'episodes': episodes_completed,
            'buffer': len(replay_buffer)
        })
        
        # Check for training completion
        if current_stage == 8 and env.success_count >= env.stage_threshold:
            print("\n=== Training completed successfully! ===")
            env.close()
            print("Environment closed. Running final evaluation...")
            final_eval_results = run_evaluation(episodes_completed)
            print(f"Final evaluation results: {final_eval_results}")
            train_writer.close()
            eval_writer.close()
            break
        
        # Periodic checkpointing
        if episodes_completed % 2500 == 0:
            checkpoint_path = f'models/TD3/Model/mtp_backup_{episodes_completed}.pth'
            torch.save({
                'actor_state_dict': actor.state_dict(),
                'critic1_state_dict': critic1.state_dict(),
                'critic2_state_dict': critic2.state_dict(),
                'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                'critic1_optimizer_state_dict': critic_optimizer1.state_dict(),
                'critic2_optimizer_state_dict': critic_optimizer2.state_dict(),
                'episode': episodes_completed,
                'total_steps': total_steps,
                }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
    return actor, critic1, critic2, current_episode_reward

In [25]:
env = MultiDroneCosysAirSimEnv()

# Get environment dimensions
obs_dim = env.observation_space.shape[0]  
act_dim = env.action_space.shape[0]     
max_action = float(env.action_space.high[0]) 

# Run training
actor, critic1, critic2, rewards = run_training(
    env=env,
    obs_dim=obs_dim,
    act_dim=act_dim,
    max_action=max_action,
    episodes= 10_000
)

  checkpoint = torch.load('models/TD3/Model/mtp_model_main.pth')


Connected!
Client Ver:1 (Min Req: 1), Server Ver:1 (Min Req: 1)

Using device: cuda
TensorBoard logs will be saved to: logs\TD3_train_20250120-180707
No checkpoint found, starting fresh
No buffer found, starting fresh


Training:   0%|          | 0/10000 [00:00<?, ?it/s]

Connected!
Client Ver:1 (Min Req: 1), Server Ver:1 (Min Req: 1)



Training:   0%|          | 1/10000 [00:04<11:47:45,  4.25s/it, stage=0, episodes=0, buffer=0]

Checkpoint saved at models/TD3/Model/mtp_backup_0.pth
Connected!
Client Ver:1 (Min Req: 1), Server Ver:1 (Min Req: 1)



Training:   0%|          | 2/10000 [00:07<10:19:16,  3.72s/it, stage=0, episodes=0, buffer=0]


Checkpoint saved at models/TD3/Model/mtp_backup_0.pth
Connected!
Client Ver:1 (Min Req: 1), Server Ver:1 (Min Req: 1)



KeyboardInterrupt: 