In [None]:
import os
import math
import random
import pygame
import numpy as np
import gymnasium as gym
from pygame.math import Vector2
import torch  # Often needed for policy_kwargs like activation_fn

# Stable Baselines
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback

# ---------------- Constants and Utilities ----------------
DEFAULT_SCREEN_WIDTH = 800    # simulation area width
DEFAULT_SCREEN_HEIGHT = 600
DEFAULT_MAX_STEPS = 2000

# Sensor settings (not used in the simplified observation, but kept for simulation logic)
NUM_SENSORS = 8
SENSOR_STEP_SIZE = 5.0  # stepping distance for pseudo-ray

# Colors
WHITE       = (255, 255, 255)
BLACK       = (0, 0, 0)
GRAY        = (128, 128, 128)
BLUE        = (0, 0, 255)
YELLOW      = (255, 255, 0)
DARK_RED    = (100, 0, 0)
DARK_GREEN  = (0, 100, 0)
CRASH_COLOR = (255, 100, 0)
GREENISH    = (0, 200, 0)

SENSOR_OBS_COLOR  = (50, 50, 50)

EPSILON = 1e-6

import torch
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomLSTMExtractor(BaseFeaturesExtractor):
    """
    Custom feature extractor using an LSTM.
    This network processes the flat observation, applies a fully connected layer,
    and then passes the result through an LSTM layer.
    
    Parameters:
      observation_space: The gym observation space.
      lstm_hidden_size: Number of hidden units in the LSTM (controls its size).
    """
    def __init__(self, observation_space, lstm_hidden_size=128):
        # We set features_dim equal to lstm_hidden_size so that the extracted
        # feature vector has this dimension.
        super(CustomLSTMExtractor, self).__init__(observation_space, features_dim=lstm_hidden_size)
        self.lstm_hidden_size = lstm_hidden_size
        input_dim = observation_space.shape[0]
        
        # A fully connected layer before the LSTM can help to scale the observation.
        self.fc = nn.Linear(input_dim, lstm_hidden_size)
        # LSTM layer: here num_layers=1 and batch_first=True.
        self.lstm = nn.LSTM(lstm_hidden_size, lstm_hidden_size, batch_first=True)

    def forward(self, observations):
        # observations has shape: (batch_size, obs_dim)
        x = F.relu(self.fc(observations))
        # Add a time dimension (sequence length = 1) to use LSTM. New shape: (batch, 1, lstm_hidden_size)
        x = x.unsqueeze(1)
        # Process through LSTM; we discard the cell state.
        x, (h_n, _) = self.lstm(x)
        # h_n has shape: (num_layers, batch, lstm_hidden_size). Use the last layer’s hidden state.
        features = h_n[-1]  # shape: (batch, lstm_hidden_size)
        return features



def limit_vector(vector: Vector2, max_val: float) -> Vector2:
    """Clamp the magnitude of 'vector' to 'max_val'."""
    mag_sq = vector.magnitude_squared()
    if mag_sq > max_val * max_val:
        mag = math.sqrt(mag_sq) if mag_sq > EPSILON else EPSILON
        vector = vector * (max_val / mag)
    return vector

# ---------------- Entity Classes ----------------
class Obstacle:
    def __init__(self, x, y, radius, screen_width, screen_height, max_speed=0.8):
        self.position = Vector2(x, y)
        self.radius = radius  # fixed size from parameter
        self.screen_width = screen_width
        self.screen_height = screen_height
        angle = random.uniform(0, 2 * math.pi)
        self.velocity = Vector2(math.cos(angle) * max_speed, math.sin(angle) * max_speed)
        self.draw_radius = max(1, self.radius - 2)

    def update(self):
        self.position += self.velocity
        if not (self.radius <= self.position.x <= self.screen_width - self.radius):
            self.velocity.x *= -1
            self.position.x = np.clip(self.position.x, self.radius, self.screen_width - self.radius)
        if not (self.radius <= self.position.y <= self.screen_height - self.radius):
            self.velocity.y *= -1
            self.position.y = np.clip(self.position.y, self.radius, self.screen_height - self.radius)

    def draw(self, surface):
        pygame.draw.circle(surface, GRAY,
                           (int(self.position.x), int(self.position.y)),
                           int(self.draw_radius))

class FastMovingTarget:
    """
    A target that can optionally be made static and that responds to both drones
    and obstacles using potential fields.
    """
    def __init__(self, x, y, radius, screen_width, screen_height,
                 max_speed=2.5, pf_params=None, avoidance_enabled=True, static_target=False):
        self.position = Vector2(x, y)
        self.initial_position = Vector2(x, y)  # Store initial position for observation
        self.radius = radius
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.velocity = Vector2(max_speed, 0)
        if self.velocity.length() < 0.1:
            self.velocity = Vector2(max_speed / 2, 0)
        self.max_speed = max_speed
        self.pf = pf_params or {}
        self.avoidance_enabled = avoidance_enabled
        self.static_target = static_target
        self.is_active = True  # Track if target is active

    def _compute_potential_field_for_target(self, drones, obstacles):
        net_force = Vector2(0, 0)
        if not self.static_target:
            rep_strength = self.pf.get("drone_repulsion_strength", 100)
            drone_influence_radius = self.pf.get("drone_influence_radius", 50)
            max_force = self.pf.get("max_drone_repulsion_force", 15.0)
            influence_radius_sq = drone_influence_radius ** 2
            for d in drones:
                vec_from_drone = self.position - d.position
                dist_sq = vec_from_drone.magnitude_squared()
                if EPSILON < dist_sq < influence_radius_sq:
                    dist = math.sqrt(dist_sq)
                    rep_magnitude = rep_strength * (1.0/dist - 1.0/drone_influence_radius) / dist
                    force = vec_from_drone.normalize() * rep_magnitude
                    net_force += limit_vector(force, max_force)
        rep_strength_obs = self.pf.get("obs_repulsion_strength_for_target", 3000)
        obs_influence_radius = self.pf.get("obs_influence_radius_for_target", 120)
        max_force_obs = self.pf.get("max_obs_repulsion_force_for_target", 25.0)
        influence_radius_sq_obs = obs_influence_radius ** 2
        for obs in obstacles:
            vec_from_obs = self.position - obs.position
            dist_sq = vec_from_obs.magnitude_squared()
            if EPSILON < dist_sq < influence_radius_sq_obs:
                dist = math.sqrt(dist_sq)
                rep_magnitude = rep_strength_obs * (1.0/dist - 1.0/obs_influence_radius) / dist
                force = vec_from_obs.normalize() * rep_magnitude
                net_force += limit_vector(force, max_force_obs)
        return net_force

    def _compute_obstacle_repulsion_force(self, obstacles):
        net_force = Vector2(0, 0)
        rep_strength_obs = self.pf.get("obs_repulsion_strength_for_target", 3000)
        obs_influence_radius = self.pf.get("obs_influence_radius_for_target", 120)
        max_force_obs = self.pf.get("max_obs_repulsion_force_for_target", 25.0)
        influence_radius_sq_obs = obs_influence_radius ** 2
        for obs in obstacles:
            vec_from_obs = self.position - obs.position
            dist_sq = vec_from_obs.magnitude_squared()
            if EPSILON < dist_sq < influence_radius_sq_obs:
                dist = math.sqrt(dist_sq)
                rep_magnitude = rep_strength_obs * (1.0/dist - 1.0/obs_influence_radius) / dist
                force = vec_from_obs.normalize() * rep_magnitude
                net_force += limit_vector(force, max_force_obs)
        return net_force

    def update(self, drones, obstacles):
        if not self.is_active or self.static_target:
            return
        if self.avoidance_enabled:
            force = self._compute_potential_field_for_target(drones, obstacles)
            self.velocity += force
            self.velocity = limit_vector(self.velocity, self.max_speed)
            self.position += self.velocity
        else:
            force = self._compute_obstacle_repulsion_force(obstacles)
            self.velocity += force
            self.velocity = limit_vector(self.velocity, self.max_speed)
            self.position += self.velocity
        target_wall_margin = max(self.radius, self.pf.get("target_wall_margin", 500))
        if not (target_wall_margin <= self.position.x <= self.screen_width - target_wall_margin):
            self.velocity.x *= -1
            self.position.x = np.clip(self.position.x, target_wall_margin, self.screen_width - target_wall_margin)
        if not (target_wall_margin <= self.position.y <= self.screen_height - target_wall_margin):
            self.velocity.y *= -1
            self.position.y = np.clip(self.position.y, target_wall_margin, self.screen_height - target_wall_margin)

    def draw(self, surface):
        if self.is_active:
            pygame.draw.circle(surface, GREENISH,
                               (int(self.position.x), int(self.position.y)),
                               self.radius)

class Drone:
    def __init__(self, x, y, drone_size=5, max_speed=3.0, screen_width=800, screen_height=600):
        self.position = Vector2(x, y)
        self.velocity = Vector2(max_speed, 0)
        self.drone_size = drone_size
        self.max_speed = max_speed
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.is_active = True  # Track if drone is active

    def update(self, force: Vector2, damping=0.98):
        if not self.is_active:
            return
        self.velocity += force
        self.velocity *= damping
        self.velocity = limit_vector(self.velocity, self.max_speed)
        self.position += self.velocity

    def draw(self, surface):
        if self.is_active:
            pygame.draw.circle(surface, BLUE,
                           (int(self.position.x), int(self.position.y)),
                           self.drone_size)

# ---------- Helper: Draw Dashed Line ----------
def draw_dashed_line(surface, color, start_pos, end_pos, dash_length=5, space_length=3):
    start = np.array(start_pos)
    end = np.array(end_pos)
    line_vec = end - start
    line_len = np.linalg.norm(line_vec)
    if line_len == 0:
        return
    line_dir = line_vec / line_len
    num_dashes = int(line_len // (dash_length + space_length))
    for i in range(num_dashes + 1):
        seg_start = start + (dash_length + space_length) * i * line_dir
        seg_end = seg_start + dash_length * line_dir
        if np.linalg.norm(seg_end - start) > line_len:
            seg_end = end
        pygame.draw.line(surface, color, seg_start, seg_end, 1)





# ---------------- DroneSwarmEnv ----------------
class DroneSwarmEnv(gym.Env):
    """
    Environment for controlling a swarm of drones via potential fields.
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 100000}

    def __init__(self,
                 render_mode="human",
                 max_steps=DEFAULT_MAX_STEPS,
                 screen_width=DEFAULT_SCREEN_WIDTH,
                 screen_height=DEFAULT_SCREEN_HEIGHT,
                 sensor_range=150,
                 # Drones
                 num_drones=5,
                 drone_size=5,
                 max_drone_speed=3.0,
                 # Obstacles
                 num_obstacles=5,
                 obstacle_max_speed=0.8,
                 obstacle_radius=20,
                 # Targets
                 num_targets=5,
                 fast_target_max_speed=2.5,
                 fast_target_radius=12,
                 # Potential Field Hyperparameters
                 pf_params=None,
                 # Reward configuration
                 reward_config=None,
                 # Reward scaling configuration
                 reward_ranges=None,
                 # PF parameter bounding (defining min..max for BINARY action)
                 controlled_pf_params=None,
                 # Other parameters:
                 enable_target_avoidance=True,
                 drones_required_to_hit_target=2,
                 max_drones_destroyed=None,
                 intrinsic_reward_weight=0.1,
                 num_static_targets=0,
                 steps_required_to_hit_target=2,
                 hit_distance_threshold=None):
        super().__init__()
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.max_steps = max_steps
        self.sensor_range = sensor_range
        self.num_drones_init = num_drones
        self.drone_size = drone_size
        self.max_drone_speed = max_drone_speed
        self.num_obstacles_init = num_obstacles
        self.obstacle_max_speed = obstacle_max_speed
        self.obstacle_radius = obstacle_radius
        self.num_targets_init = num_targets
        self.fast_target_max_speed = fast_target_max_speed
        self.fast_target_radius = fast_target_radius
        self.enable_target_avoidance = enable_target_avoidance
        self.drones_required_to_hit_target = drones_required_to_hit_target
        self.max_drones_destroyed = max_drones_destroyed
        self.intrinsic_reward_weight = intrinsic_reward_weight
        self.num_static_targets = num_static_targets
        self.steps_required_to_hit_target = steps_required_to_hit_target
        self.hit_distance_threshold = hit_distance_threshold if hit_distance_threshold is not None else (self.drone_size + self.fast_target_radius)

        default_pf = {
            "attraction_strength": 0.0005, "max_attraction_force": 0.8,
            "obs_repulsion_strength": 7000, "obs_influence_radius": 100, "max_obs_repulsion_force": 20.0,
            "drone_repulsion_strength": 100, "drone_influence_radius": 50, "max_drone_repulsion_force": 15.0,
            "boundary_repulsion_strength": 6000, "boundary_influence": 40, "max_boundary_force": 18.0,
            "obs_repulsion_strength_for_target": 3000, "obs_influence_radius_for_target": 120, "max_obs_repulsion_force_for_target": 25.0,
            "target_wall_margin": 50
        }
        if pf_params is not None:
            default_pf.update(pf_params)
        self.pf = default_pf

        default_reward_config = {
            "step_penalty": 0.0,
            "hit_reward": 100.0,
            "hitting_reward": 10.0,  # New reward value for acting in a hitting manner
            "crash_penalty": 50.0,
            "all_drones_lost_penalty": 200.0,
            "all_targets_hit_reward": 200.0,
            "timeout_penalty": 50.0,
            "max_drones_destroyed_penalty": -500.0
        }

        if reward_config is not None:
            default_reward_config.update(reward_config)
        self.reward_config = default_reward_config

        default_reward_ranges = {
            "step_penalty": {"raw": (-10.0, 0.0), "scaled": (-0.1, 0.0)},
            "hit_reward": {"raw": (0.0, 300.0), "scaled": (0.0, 10.0)},
            "crash_penalty": {"raw": (-150.0, 0.0), "scaled": (-10.0, 0.0)},
            "all_drones_lost": {"raw": (-500.0, 0.0), "scaled": (-20.0, 0.0)},
            "all_targets_hit": {"raw": (0.0, 300.0), "scaled": (0.0, 50.0)},
            "timeout": {"raw": (-500.0, 0.0), "scaled": (-10, 0.0)},
            "max_drones_destroyed": {"raw": (-500.0, 0.0), "scaled": (-20.0, 0.0)},
            "intrinsic_reward": {"raw": (0.0, 1.0), "scaled": (0.0, 0.1)},

            "hitting_reward": {"raw": (0.0, 20.0), "scaled": (0.0, 0.1)},  # New range for hitting reward

        }

        self.reward_ranges = reward_ranges if reward_ranges is not None else default_reward_ranges

        if controlled_pf_params is None:
            self.controlled_pf_params = {
                "attraction_strength": {"min": 0.0001, "max": 0.01},
                "obs_repulsion_strength": {"min": 500.0, "max": 50000.0},
            }
        else:
            self.controlled_pf_params = controlled_pf_params
        self.controlled_pf_keys = list(self.controlled_pf_params.keys())

        # Action space remains unchanged.
        action_dim = 2  # Only control x and y
        low  = np.zeros(action_dim, dtype=np.float32)
        high = np.ones(action_dim, dtype=np.float32)
        self.action_space = gym.spaces.Box(low=low, high=high, shape=(action_dim,), dtype=np.float32)
        
        # Simplified Observation:
        # - For each drone: 2 values (normalized x and y if active; (-1,-1) if inactive)
        # - 5 obstacles: 5 * 2 = 10 values (normalized positions if detected; (-1,-1) if not)
        # - 3 targets: 3 * 2 = 6 values (normalized positions if detected, (0,0) if hit, (-1,-1) if not detected)

        # Original:
        # obs_dim = (self.num_drones_init * (2 + NUM_SENSORS)) + 6

        # Updated: now each drone contributes (2 positions + NUM_SENSORS sensor readings + 1 hitting flag)
        obs_dim = (self.num_drones_init * (3 + NUM_SENSORS)) + 6
        self.observation_space = gym.spaces.Box(
            low=-1.0, high=1.0,
            shape=(obs_dim,), dtype=np.float32
        )

        self.current_step = 0
        self.render_mode = render_mode
        self.screen = None
        self.clock = None
        self.info_font = None

        self.drones = []
        self.obstacles = []
        self.fast_targets = []
        self.target_pos = Vector2(self.screen_width * 0.5, self.screen_height * 0.5)

        self.total_crashes_episode = 0
        self.total_hits_episode = 0
        self.current_episode_reward = 0.0
        self.last_raw_action = None
        self.last_extrinsic_reward = 0.0
        self.last_intrinsic_reward = 0.0
        self.last_step_reward = 0.0
        self.last_reward_breakdown = {}
        self.visited_cells = {}
        self.target_hit_steps = {}
        self.episode_num = 0

        self.sidebar_width = 350
        self.total_width = self.screen_width + self.sidebar_width
        self.sidebar_scroll_offset = 0
        self.sidebar_hscroll_offset = 0
        self.slider_active = False
        self.hslider_active = False

        if self.render_mode == "human":
            self._initialize_pygame()

    def _initialize_pygame(self):
        if self.screen is None:
            pygame.init()
            pygame.font.init()
            self.screen = pygame.display.set_mode((self.total_width, self.screen_height))
            pygame.display.set_caption("DroneSwarmEnv_lstm")
            self.clock = pygame.time.Clock()
            try:
                self.info_font = pygame.font.SysFont('Arial', 16)
            except Exception:
                self.info_font = pygame.font.SysFont(None, 16)




    def compute_intrinsic_reward(self):
        avg_pos = Vector2(0, 0)
        active_drones = [d for d in self.drones if d.is_active]
        num_drones = len(active_drones)
        if num_drones == 0:
            return 0.0
        for d in active_drones:
            avg_pos += d.position
        avg_pos /= num_drones
        grid_size = 50
        cell = (int(avg_pos.x // grid_size), int(avg_pos.y // grid_size))
        self.visited_cells[cell] = self.visited_cells.get(cell, 0) + 1
        intrinsic_reward = 1.0 / self.visited_cells[cell]
        return intrinsic_reward

    def scale_reward(self, term_name, raw_value):
        config = self.reward_ranges.get(term_name)
        if config is None:
            return raw_value
        raw_min, raw_max = config["raw"]
        scaled_min, scaled_max = config["scaled"]
        if raw_max - raw_min == 0:
            return scaled_min if raw_value <= raw_min else scaled_max
        clipped = np.clip(raw_value, raw_min, raw_max)
        ratio = (clipped - raw_min) / (raw_max - raw_min)
        return scaled_min + ratio * (scaled_max - scaled_min)

    def compute_reward(self, step_hits, step_crashes, all_drones_lost, all_targets_hit, timed_out):
        step_penalty_raw    = self.reward_config.get("step_penalty", 0.0)
        hit_reward_raw      = step_hits * self.reward_config.get("hit_reward", 100.0)
        crash_penalty_raw   = step_crashes * self.reward_config.get("crash_penalty", 50.0)
        all_drones_lost_raw = self.reward_config.get("all_drones_lost_penalty", 200.0) if all_drones_lost else 0.0
        all_targets_hit_raw = self.reward_config.get("all_targets_hit_reward", 200.0) if all_targets_hit else 0.0
        timeout_raw         = self.reward_config.get("timeout_penalty", 50.0) if timed_out else 0.0

        max_drones_penalty_raw = 0.0
        num_drones_active = sum(1 for d in self.drones if d.is_active)
        num_drones_lost = self.num_drones_init - num_drones_active
        if self.max_drones_destroyed is not None and num_drones_lost >= self.max_drones_destroyed:
            max_drones_penalty_raw = self.reward_config.get("max_drones_destroyed_penalty", -500.0)

        # Compute the new "hitting reward" raw value.
        # For each active drone, if it is close enough to any active target, add 1.
        hitting_reward_raw = 0.0
        for d in self.drones:
            if d.is_active:
                for ft in self.fast_targets:
                    if ft.is_active and d.position.distance_to(ft.position) < self.hit_distance_threshold:
                        hitting_reward_raw += 1
                        # No break here, so a single drone can contribute multiple hits.

        # Scale each reward term.
        step_penalty = self.scale_reward("step_penalty", step_penalty_raw)
        hit_reward   = self.scale_reward("hit_reward", hit_reward_raw)
        # Scale the new hitting reward term.
        hitting_reward = self.scale_reward("hitting_reward", hitting_reward_raw)
        crash_penalty = self.scale_reward("crash_penalty", crash_penalty_raw)
        all_drones_term = self.scale_reward("all_drones_lost", all_drones_lost_raw)
        all_targets_term = self.scale_reward("all_targets_hit", all_targets_hit_raw)
        timeout_term = self.scale_reward("timeout", timeout_raw)
        max_drones_term = self.scale_reward("max_drones_destroyed", max_drones_penalty_raw)

        # Include the new hitting reward in the extrinsic reward.
        extrinsic = (step_penalty + hit_reward + hitting_reward + crash_penalty +
                    all_drones_term + all_targets_term + timeout_term + max_drones_term)

        intrinsic_raw = self.compute_intrinsic_reward() * self.intrinsic_reward_weight
        intrinsic = self.scale_reward("intrinsic_reward", intrinsic_raw)

        total_reward = extrinsic + intrinsic

        reward_breakdown = {
            "step_penalty": step_penalty,
            "hit_reward": hit_reward,
            "hitting_reward": hitting_reward,  # New breakdown term.
            "crash_penalty": crash_penalty,
            "all_drones_lost": all_drones_term,
            "max_drones_destroyed": max_drones_term,
            "all_targets_hit": all_targets_term,
            "timeout": timeout_term,
            "extrinsic": extrinsic,
            "intrinsic": intrinsic,
            "total": total_reward
        }
        return reward_breakdown


    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.episode_num += 1
        self.current_step = 0
        self.total_crashes_episode = 0
        self.total_hits_episode = 0
        self.current_episode_reward = 0.0
        self.last_raw_action = None
        self.visited_cells = {}
        self.target_hit_steps = {}

        rng = self.np_random

        self.drones = []
        spawn_edge = rng.choice(["left", "right", "top", "bottom"])
        for i in range(self.num_drones_init):
            if spawn_edge == "left":
                x = self.drone_size
                y = rng.uniform(self.drone_size, self.screen_height - self.drone_size)
            elif spawn_edge == "right":
                x = self.screen_width - self.drone_size
                y = rng.uniform(self.drone_size, self.screen_height - self.drone_size)
            elif spawn_edge == "top":
                x = rng.uniform(self.drone_size, self.screen_width - self.drone_size)
                y = self.drone_size
            else:
                x = rng.uniform(self.drone_size, self.screen_width - self.drone_size)
                y = self.screen_height - self.drone_size
            d = Drone(x, y, drone_size=self.drone_size, max_speed=self.max_drone_speed,
                      screen_width=self.screen_width, screen_height=self.screen_height)
            self.drones.append(d)

        self.obstacles = []
        for _ in range(self.num_obstacles_init):
            ox = rng.uniform(self.screen_width * 0.15, self.screen_width * 0.85)
            oy = rng.uniform(self.screen_height * 0.15, self.screen_height * 0.85)
            obs_obj = Obstacle(ox, oy, self.obstacle_radius, self.screen_width, self.screen_height,
                           max_speed=self.obstacle_max_speed)
            self.obstacles.append(obs_obj)

        self.fast_targets = []
        margin = self.pf.get("target_wall_margin", 50)
        attempts = 0
        needed = self.num_targets_init
        min_dist_from_obs = 15
        min_dist_from_target = self.fast_target_radius * 3

        while len(self.fast_targets) < needed and attempts < needed * 30:
            attempts += 1
            tx = rng.uniform(margin, self.screen_width - margin)
            ty = rng.uniform(margin, self.screen_height - margin)
            new_pos = Vector2(tx, ty)
            valid_spawn = True
            for obs_obj in self.obstacles:
                if new_pos.distance_to(obs_obj.position) < (obs_obj.radius + self.fast_target_radius + min_dist_from_obs):
                    valid_spawn = False
                    break
            if not valid_spawn:
                continue
            for existing_ft in self.fast_targets:
                if new_pos.distance_to(existing_ft.position) < min_dist_from_target:
                    valid_spawn = False
                    break
            if not valid_spawn:
                continue
            for d in self.drones:
                if new_pos.distance_to(d.position) < (self.fast_target_radius + self.drone_size + 50):
                    valid_spawn = False
                    break
            if not valid_spawn:
                continue

            is_static = len(self.fast_targets) < self.num_static_targets
            ft = FastMovingTarget(
                tx, ty, self.fast_target_radius, self.screen_width, self.screen_height,
                max_speed=self.fast_target_max_speed, pf_params=self.pf,
                avoidance_enabled=self.enable_target_avoidance, static_target=is_static)
            ft.is_active = True
            self.fast_targets.append(ft)
            self.target_hit_steps[id(ft)] = 0

        if len(self.fast_targets) < needed:
            print(f"Warning: Could only spawn {len(self.fast_targets)}/{needed} targets. Consider adjusting spawn parameters.")

        self.target_pos = Vector2(self.screen_width * 0.5, self.screen_height * 0.5)

        obs = self._get_obs()
        info = self._get_info()
        if self.render_mode == "human":
            self._render_frame()
        return obs, info

    def step(self, action):
        self.current_step += 1
        self.last_raw_action = action.copy()

        tx = action[0] * self.screen_width
        ty = action[1] * self.screen_height
        self.target_pos = Vector2(tx, ty)

        forces = [self._compute_potential_field_for_drone(d, i) for i, d in enumerate(self.drones) if d.is_active]
        active_drone_indices = [i for i, d in enumerate(self.drones) if d.is_active]
        for i, force in zip(active_drone_indices, forces):
            self.drones[i].update(force)

        for obs_obj in self.obstacles:
            obs_obj.update()

        active_drones_list = [d for d in self.drones if d.is_active]
        for ft in self.fast_targets:
            if ft.is_active:
                ft.update(active_drones_list, self.obstacles)

        step_crashes = 0
        active_drone_indices_after_crash_check = []
        if self.current_step < 100:
            active_drone_indices_after_crash_check = [i for i, d in enumerate(self.drones) if d.is_active]
        else:
            for i, d in enumerate(self.drones):
                if not d.is_active:
                    continue
                crashed = False
                if not (self.drone_size <= d.position.x <= self.screen_width - self.drone_size and
                        self.drone_size <= d.position.y <= self.screen_height - self.drone_size):
                    crashed = True
                if not crashed:
                    for obs_obj in self.obstacles:
                        if d.position.distance_squared_to(obs_obj.position) < (obs_obj.radius + d.drone_size) ** 2:
                            crashed = True
                            break
                if not crashed:
                    for j, other in enumerate(self.drones):
                        if i == j or not other.is_active:
                            continue
                        if d.position.distance_squared_to(other.position) < (d.drone_size + other.drone_size - 1) ** 2:
                            crashed = True
                            break
                if crashed:
                    d.is_active = False
                    step_crashes += 1
                else:
                    active_drone_indices_after_crash_check.append(i)
        self.total_crashes_episode += step_crashes
        num_drones_active = sum(1 for d in self.drones if d.is_active)

        step_hits = 0
        if num_drones_active > 0:
            active_drones_list = [d for d in self.drones if d.is_active]
            hit_target_ids_this_step = set()
            for ft in self.fast_targets:
                if not ft.is_active:
                    continue
                target_id = id(ft)
                drones_near_target = 0
                for d in active_drones_list:
                    if d.position.distance_to(ft.position) < self.hit_distance_threshold:
                        drones_near_target += 1
                if drones_near_target >= self.drones_required_to_hit_target:
                    self.target_hit_steps[target_id] = self.target_hit_steps.get(target_id, 0) + 1
                else:
                    self.target_hit_steps[target_id] = 0
                if self.target_hit_steps[target_id] >= self.steps_required_to_hit_target:
                    hit_target_ids_this_step.add(target_id)
            if hit_target_ids_this_step:
                for ft in self.fast_targets:
                    if id(ft) in hit_target_ids_this_step and ft.is_active:
                        ft.is_active = False
                        step_hits += 1
        self.total_hits_episode += step_hits
        num_targets_active = sum(1 for ft in self.fast_targets if ft.is_active)

        all_drones_lost = (num_drones_active == 0 and self.num_drones_init > 0)
        all_targets_hit = (num_targets_active == 0 and self.num_targets_init > 0)

        drones_destroyed_threshold_reached = False
        if self.max_drones_destroyed is not None:
            num_drones_lost = self.num_drones_init - num_drones_active
            drones_destroyed_threshold_reached = (num_drones_lost >= self.max_drones_destroyed)

        timed_out = (self.current_step >= self.max_steps)
        terminated = all_drones_lost or all_targets_hit or drones_destroyed_threshold_reached
        truncated = (timed_out and not terminated)

        reward_breakdown = self.compute_reward(
            step_hits, step_crashes, all_drones_lost, all_targets_hit, timed_out
        )
        reward = reward_breakdown["total"]
        self.last_reward_breakdown = reward_breakdown
        self.last_extrinsic_reward = reward_breakdown["extrinsic"]
        self.last_intrinsic_reward = reward_breakdown["intrinsic"]
        self.last_step_reward = reward
        self.current_episode_reward += reward

        obs = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        if np.isnan(obs).any() or np.isinf(obs).any():
            print("Warning: NaN or Inf detected in observation!")
            obs = np.nan_to_num(obs, nan=0.0, posinf=1.0, neginf=-1.0)
        if np.isnan(reward) or np.isinf(reward):
            print("Warning: NaN or Inf detected in reward!")
            reward = 0.0

        return obs, reward, terminated, truncated, info

    def _compute_potential_field_for_drone(self, drone, drone_idx):
        pf = self.pf
        net_force = Vector2(0, 0)
        attraction_strength = pf.get("attraction_strength", 0.0005)
        max_attraction_force = pf.get("max_attraction_force", 0.8)
        vec_to_target = self.target_pos - drone.position
        dist_to_target = vec_to_target.length()
        if dist_to_target > EPSILON:
            force = vec_to_target.normalize() * attraction_strength * dist_to_target
            net_force += limit_vector(force, max_attraction_force)
        obs_repulsion_strength = pf.get("obs_repulsion_strength", 7000)
        obs_influence_radius = pf.get("obs_influence_radius", 100)
        max_obs_force = pf.get("max_obs_repulsion_force", 20.0)
        obs_infl_sq = obs_influence_radius ** 2
        for obs_obj in self.obstacles:
            vec_from_obs = drone.position - obs_obj.position
            dist_sq = vec_from_obs.magnitude_squared()
            if EPSILON < dist_sq < obs_infl_sq:
                dist = math.sqrt(dist_sq)
                if dist > EPSILON:
                    rep_mag = obs_repulsion_strength * (1.0/dist - 1.0/obs_influence_radius) / dist
                    force = vec_from_obs.normalize() * rep_mag
                    net_force += limit_vector(force, max_obs_force)
        drone_repulsion_strength = pf.get("drone_repulsion_strength", 100)
        drone_influence_radius = pf.get("drone_influence_radius", 50)
        max_drone_force = pf.get("max_drone_repulsion_force", 15.0)
        drone_infl_sq = drone_influence_radius ** 2
        for i, other in enumerate(self.drones):
            if i == drone_idx or not other.is_active:
                continue
            vec_from_other = drone.position - other.position
            dist_sq = vec_from_other.magnitude_squared()
            if EPSILON < dist_sq < drone_infl_sq:
                dist = math.sqrt(dist_sq)
                if dist > EPSILON:
                    rep_mag = drone_repulsion_strength * (1.0/dist - 1.0/drone_influence_radius) / dist
                    force = vec_from_other.normalize() * rep_mag
                    net_force += limit_vector(force, max_drone_force)
        boundary_repulsion_strength = pf.get("boundary_repulsion_strength", 6000)
        boundary_influence = pf.get("boundary_influence", 40)
        max_boundary_force = pf.get("max_boundary_force", 18.0)
        dist_x_left = drone.position.x
        if dist_x_left < boundary_influence:
            force_mag = boundary_repulsion_strength * (1.0 / max(dist_x_left, 1.0) - 1.0 / boundary_influence)
            if force_mag > 0:
                net_force += limit_vector(Vector2(1, 0) * force_mag, max_boundary_force)
        dist_x_right = self.screen_width - drone.position.x
        if dist_x_right < boundary_influence:
            force_mag = boundary_repulsion_strength * (1.0 / max(dist_x_right, 1.0) - 1.0 / boundary_influence)
            if force_mag > 0:
                net_force += limit_vector(Vector2(-1, 0) * force_mag, max_boundary_force)
        dist_y_top = drone.position.y
        if dist_y_top < boundary_influence:
            force_mag = boundary_repulsion_strength * (1.0 / max(dist_y_top, 1.0) - 1.0 / boundary_influence)
            if force_mag > 0:
                net_force += limit_vector(Vector2(0, 1) * force_mag, max_boundary_force)
        dist_y_bottom = self.screen_height - drone.position.y
        if dist_y_bottom < boundary_influence:
            force_mag = boundary_repulsion_strength * (1.0 / max(dist_y_bottom, 1.0) - 1.0 / boundary_influence)
            if force_mag > 0:
                net_force += limit_vector(Vector2(0, -1) * force_mag, max_boundary_force)
        return net_force


    def _get_sensor_reading(self, drone: Drone, angle: float) -> float:
        """
        Cast a ray from the drone's position at the given angle.
        Return the normalized distance (distance/sensor_range) at which an obstacle or boundary is encountered.
        If nothing is detected within sensor_range, return 1.0.
        """
        direction = Vector2(math.cos(angle), math.sin(angle))
        current_pos = Vector2(drone.position.x, drone.position.y)
        distance = 0.0
        while distance < self.sensor_range:
            current_pos += direction * SENSOR_STEP_SIZE
            distance += SENSOR_STEP_SIZE
            # Check for boundary collision:
            if (current_pos.x < 0 or current_pos.x > self.screen_width or
                current_pos.y < 0 or current_pos.y > self.screen_height):
                return distance / self.sensor_range
            # Check for collision with any obstacle:
            for obs_obj in self.obstacles:
                if current_pos.distance_to(obs_obj.position) <= obs_obj.radius:
                    return distance / self.sensor_range
        return 1.0


    def _get_obs(self):
        obs = []
        # For each drone, include normalized (x,y) position, sensor readings, and a hitting flag.
        for d in self.drones:
            if d.is_active:
                pos_x = d.position.x / self.screen_width
                pos_y = d.position.y / self.screen_height
                obs.extend([pos_x, pos_y])
                sensor_readings = []
                for i in range(NUM_SENSORS):
                    angle = (2 * math.pi * i) / NUM_SENSORS
                    reading = self._get_sensor_reading(d, angle)
                    sensor_readings.append(reading)
                obs.extend(sensor_readings)
                # Add a flag indicating if the drone is hitting a target.
                hitting_count = 0.0
                for ft in self.fast_targets:
                    if ft.is_active and d.position.distance_to(ft.position) < self.hit_distance_threshold:
                        hitting_count += 1.0
                obs.append(hitting_count)
            else:
                # If inactive, fill with defaults (including the hitting flag).
                obs.extend([-1.0, -1.0])
                obs.extend([-1.0] * NUM_SENSORS)
                obs.append(-1.0)
        
        # Targets: unchanged, still include up to 3 targets' positions.
        target_detections = []
        for ft in self.fast_targets:
            if ft.is_active:
                detected = False
                min_distance = float('inf')
                for d in self.drones:
                    if d.is_active:
                        dist = d.position.distance_to(ft.position)
                        if dist <= self.sensor_range:
                            detected = True
                            if dist < min_distance:
                                min_distance = dist
                if detected:
                    norm_x = ft.position.x / self.screen_width
                    norm_y = ft.position.y / self.screen_height
                    target_detections.append((min_distance, (norm_x, norm_y)))
        target_detections.sort(key=lambda tup: tup[0])
        count = 0
        for _, pos in target_detections[:3]:
            obs.extend(list(pos))
            count += 1
        for _ in range(3 - count):
            obs.extend([-1.0, -1.0])
            
        return np.array(obs, dtype=np.float32)



    def _get_info(self):
        num_drones_active = sum(1 for d in self.drones if d.is_active)
        num_targets_active = sum(1 for ft in self.fast_targets if ft.is_active)
        return {
            "drones_active": num_drones_active,
            "targets_active": num_targets_active,
            "total_crashes": self.total_crashes_episode,
            "total_hits": self.total_hits_episode,
            "steps": self.current_step,
        }

    def render(self):
        if self.render_mode == "human":
            if self.screen is None:
                self._initialize_pygame()
            self._render_frame()
        elif self.render_mode == "rgb_array":
            if self.screen is None:
                self._initialize_pygame()
            return self._render_frame()


    def _build_sidebar_lines(self):
        # Basic stats (episode, rewards, etc.)
        num_drones_active = sum(1 for d in self.drones if d.is_active)
        num_targets_active = sum(1 for ft in self.fast_targets if ft.is_active)
        lines = [
            f"Episode: {self.episode_num}",
            f"Ep Reward: {self.current_episode_reward:.2f}",
            f"Crashes: {self.total_crashes_episode}",
            f"Hits: {self.total_hits_episode}",
            (f"Action: ({', '.join(f'{a:.2f}' for a in self.last_raw_action)})"
            if self.last_raw_action is not None else "Action: (None yet)"),
            f"Drones: {num_drones_active}/{self.num_drones_init}",
            f"Targets: {num_targets_active}/{self.num_targets_init}",
            f"Step: {self.current_step}/{self.max_steps}",
            f"Step Reward: {self.last_step_reward:.2f}",
            "--- Reward Breakdown ---"
        ]

        if hasattr(self, 'last_reward_breakdown') and self.last_reward_breakdown:
            for key, val in self.last_reward_breakdown.items():
                lines.append(f"  {key}: {val:.2f}")
        else:
            lines.append("  (No breakdown yet)")

        lines.append("--- Controlled PF Params ---")
        if hasattr(self, 'controlled_pf_keys') and hasattr(self, 'pf'):
            for key in self.controlled_pf_keys:
                lines.append(f"  {key}: {self.pf.get(key, 'N/A'):.4f}")
        else:
            lines.append("  (PF Params not available)")


    
        return lines


    def _render_sidebar(self, lines):
        sidebar_width = self.sidebar_width
        sidebar_height = self.screen_height
        sidebar_surface = pygame.Surface((sidebar_width, sidebar_height))
        sidebar_surface.fill((50, 50, 50))  # Dark gray background

        line_height = self.info_font.get_linesize()
        v_spacing = 5
        # Calculate the maximum number of lines that can be shown in the sidebar
        max_lines = sidebar_height // (line_height + v_spacing)
        
        # Use the stored scroll offset to determine which lines to show.
        start_index = self.sidebar_scroll_offset
        visible_lines = lines[start_index:start_index + max_lines]
        
        y = v_spacing
        for line in visible_lines:
            try:
                text_surf = self.info_font.render(line, True, WHITE)
                sidebar_surface.blit(text_surf, (10, y))
            except Exception as e:
                error_surf = self.info_font.render(f"Render error: {e}", True, (255, 100, 100))
                sidebar_surface.blit(error_surf, (10, y))
            y += line_height + v_spacing

        # If there are more lines than can be shown, draw a simple slider.
        total_lines = len(lines)
        if total_lines > max_lines:
            # The slider height is proportional to the ratio of visible lines to total lines.
            slider_height = max(20, sidebar_height * max_lines // total_lines)
            # Compute slider vertical position based on scroll offset.
            slider_y = (self.sidebar_scroll_offset / total_lines) * sidebar_height
            slider_rect = pygame.Rect(sidebar_width - 15, slider_y, 10, slider_height)
            pygame.draw.rect(sidebar_surface, GRAY, slider_rect)
        
        return sidebar_surface

    def _render_frame(self):
        # Process events for the sidebar scroll
        for event in pygame.event.get():
            if event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 4:  # Scroll up
                    self.sidebar_scroll_offset = max(0, self.sidebar_scroll_offset - 1)
                elif event.button == 5:  # Scroll down
                    self.sidebar_scroll_offset += 1
        
        # Build sidebar lines and render the sidebar as before
        sidebar_lines = self._build_sidebar_lines()
        sidebar_surface = self._render_sidebar(sidebar_lines)
        sim_surface = self._render_simulation()
        
        self.screen.fill(BLACK)
        self.screen.blit(sidebar_surface, (0, 0))
        self.screen.blit(sim_surface, (self.sidebar_width, 0))
        pygame.display.flip()

        if self.clock:
            self.clock.tick(self.metadata["render_fps"])

        if self.render_mode == "rgb_array":
            try:
                return np.transpose(pygame.surfarray.array3d(sim_surface), axes=(1, 0, 2))
            except pygame.error as e:
                print(f"Error getting surface array: {e}")
                return np.zeros((self.screen_height, self.screen_width, 3), dtype=np.uint8)
        return None

    def _render_simulation(self):
            sim_surface = pygame.Surface((self.screen_width, self.screen_height))
            sim_surface.fill(BLACK)

            # Draw obstacles, targets, etc.
            for obs_obj in self.obstacles:
                if hasattr(obs_obj, 'draw') and callable(obs_obj.draw):
                    obs_obj.draw(sim_surface)

            try:
                target_pos_int = (int(self.target_pos.x), int(self.target_pos.y))
                pygame.draw.circle(sim_surface, YELLOW, target_pos_int, 10)
                pygame.draw.circle(sim_surface, WHITE, target_pos_int, 10, 2)
            except Exception:
                pass

            for ft in self.fast_targets:
                if ft.is_active and hasattr(ft, 'draw') and callable(ft.draw):
                    ft.draw(sim_surface)
                    if hasattr(self, 'target_hit_steps') and self.target_hit_steps.get(id(ft), 0) > 0:
                        try:
                            ft_pos_int = (int(ft.position.x), int(ft.position.y))
                            radius_int = int(ft.radius + 3)
                            pygame.draw.circle(sim_surface, DARK_RED, ft_pos_int, radius_int, 2)
                        except Exception:
                            pass
                if hasattr(ft, 'initial_position'):
                    try:
                        init_pos_int = (int(ft.initial_position.x), int(ft.initial_position.y))
                        pygame.draw.circle(sim_surface, DARK_GREEN, init_pos_int, 3, 1)
                    except Exception:
                        pass

            # (Optional) Draw hit-line indicators
            hit_threshold = getattr(self, 'hit_distance_threshold', float('inf'))
            for d in self.drones:
                if d.is_active:
                    for ft in self.fast_targets:
                        if ft.is_active and d.position.distance_to(ft.position) < hit_threshold:
                            start_pos = (int(d.position.x), int(d.position.y))
                            end_pos = (int(ft.position.x), int(ft.position.y))
                            pygame.draw.line(sim_surface, GREENISH, start_pos, end_pos, 2)
            
            # ---------------- Render sensor rays ----------------
            # Create a transparent surface for sensor lines
            sensor_surface = pygame.Surface((self.screen_width, self.screen_height), pygame.SRCALPHA)
            # Use a very faint white color (alpha 30) for subtle sensor lines
            sensor_color = (255, 255, 255, 30)
            for d in self.drones:
                if d.is_active:
                    for i in range(NUM_SENSORS):
                        angle = (2 * math.pi * i) / NUM_SENSORS
                        start_pos = (int(d.position.x), int(d.position.y))
                        # Sensor endpoint using the environment sensor_range
                        end_vector = Vector2(math.cos(angle), math.sin(angle)) * self.sensor_range
                        end_pos = (int(d.position.x + end_vector.x), int(d.position.y + end_vector.y))
                        draw_dashed_line(sensor_surface, sensor_color, start_pos, end_pos, dash_length=5, space_length=3)
            # Blit the sensor layer on top
            sim_surface.blit(sensor_surface, (0, 0))
            # -------------------------------------------------------
            
            # Finally, draw the drones
            for d in self.drones:
                if d.is_active:
                    d.draw(sim_surface)

            return sim_surface

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.font.quit()
            pygame.quit()
            self.screen = None
            self.clock = None
            self.info_font = None

# ---------------- Render Callback for Stable Baselines ----------------
class RenderCallback(BaseCallback):
    def __init__(self, render_freq=100, verbose=0):
        super().__init__(verbose)
        self.render_freq = render_freq

    def _on_step(self) -> bool:
        if self.n_calls % self.render_freq == 0:
            self.training_env.envs[0].render()
        return True

# --------------- Example Tuning Parameters Section ---------------
tuning_reward_ranges = {
    "step_penalty": {"raw": (-10.0, 0.0), "scaled": (-0.01, 0.0)},
    "hit_reward": {"raw": (0.0, 300.0), "scaled": (0.0, 10.0)},
    "crash_penalty": {"raw": (-150.0, 0.0), "scaled": (-10.0, 0.0)},
    "all_drones_lost": {"raw": (-500.0, 0.0), "scaled": (-0.0, 0.0)},
    "all_targets_hit": {"raw": (0.0, 300.0), "scaled": (0.0, 50.0)},
    "timeout": {"raw": (-500.0, 0.0), "scaled": (-10, 0.0)},
    "max_drones_destroyed": {"raw": (-500.0, 0.0), "scaled": (-20.0, 0.0)},
    "intrinsic_reward": {"raw": (0.0, 1.0), "scaled": (0.0, 0.1)},
    "hitting_reward": {"raw": (0.0, 20.0), "scaled": (0.0, 0.1)},  # New range for hitting reward

}

tuning_reward_ranges = {
    "step_penalty": {"raw": (-10.0, 0.0), "scaled": (-0.01, 0.0)},
    "hit_reward": {"raw": (0.0, 300.0), "scaled": (0.0, 10.0)},
    "crash_penalty": {"raw": (-150.0, 0.0), "scaled": (-5.0, 0.0)},
    "all_drones_lost": {"raw": (-500.0, 0.0), "scaled": (-0.0, 0.0)},
    "all_targets_hit": {"raw": (0.0, 300.0), "scaled": (0.0, 50.0)},
    "timeout": {"raw": (-500.0, 0.0), "scaled": (-10, 0.0)},
    "max_drones_destroyed": {"raw": (-500.0, 0.0), "scaled": (-40.0, 0.0)},
    "intrinsic_reward": {"raw": (0.0, 1.0), "scaled": (0.0, 0.1)},
    "hitting_reward": {"raw": (0.0, 20.0), "scaled": (0.0, 0.1)},
}

# ---------------- Example of Custom Init Params ----------------
train_env_params = {
    "render_mode": "",
    "max_steps": 3000,
    "screen_width": 600,
    "screen_height": 600,
    "num_drones": 6,
    "drone_size": 6,
    "max_drone_speed": 7.0,
    "num_obstacles":2,
    "obstacle_max_speed":4,
    "obstacle_radius": 25,
    "num_targets": 10,
    "fast_target_max_speed": 0.0,
    "fast_target_radius": 10,
    "sensor_range": 150,
    "pf_params": {
        "attraction_strength": 0.01,
        "max_attraction_force": 1.0,
        "obs_repulsion_strength": 3000,
        "obs_influence_radius": 150,
        "max_obs_repulsion_force": 10000.0,
        "drone_repulsion_strength": 120,
        "drone_influence_radius": 150,
        "max_drone_repulsion_force": 18.0,
        "boundary_repulsion_strength": 1000,
        "boundary_influence": 10,
        "max_boundary_force": 20.0,
        "obs_repulsion_strength_for_target": 3000,
        "obs_influence_radius_for_target": 120,
        "max_obs_repulsion_force_for_target": 25.0,
        "target_wall_margin": 50
    },
    "reward_config": {
        "step_penalty": -0.01,
        "hit_reward": 150.0,
        "crash_penalty": -80.0,
        "all_drones_lost_penalty": -500.0,
        "all_targets_hit_reward": 300.0,
        "timeout_penalty": -500.0,
        "max_drones_destroyed_penalty": -500.0
    },
    "reward_ranges": tuning_reward_ranges,
    "controlled_pf_params": {},
    "enable_target_avoidance": False,
    "drones_required_to_hit_target": 3,
    "steps_required_to_hit_target": 1,
    "hit_distance_threshold": 50,
    "max_drones_destroyed": 4,
    "intrinsic_reward_weight": 1,
    "num_static_targets": 7
}


pygame 2.6.0 (SDL 2.28.4, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:


import time
env = DroneSwarmEnv(**train_env_params)
num_episodes = 1  # Number of episodes to test
for ep in range(num_episodes):
    obs, info = env.reset()
    done = False
    print(f"Starting episode {ep + 1}...")
    while not done:
        # Use random actions for testing.
        # Action: [target_x, target_y, pf_attraction_control, pf_obs_repulsion_control]
        action = env.action_space.sample()
        obs, reward, done, truncated, info = env.step(action)
        env.render()

    print(f"Episode {ep + 1} finished. Info: {info}")
env.close()

Starting episode 1...
Episode 1 finished. Info: {'drones_active': 2, 'targets_active': 6, 'total_crashes': 4, 'total_hits': 4, 'steps': 744}


In [None]:

env = DroneSwarmEnv(**train_env_params)

log_dir = "./tensorboard_logs/"
os.makedirs(log_dir, exist_ok=True)

checkpoint_callback = CheckpointCallback(save_freq=100000, 
                                         save_path='./checkpoints_lstm/', 
                                         name_prefix='ppo_drone_swarm')

render_callback = RenderCallback(render_freq=1000)
# default_policy_kwargs = dict(
#     activation_fn=torch.nn.Tanh,
#     net_arch=dict(
#         pi=[ 256,256], 
#         vf=[ 256,256])
# )
# Set up policy keyword arguments with the custom extractor.
# Import RecurrentPPO from sb3-contrib
from sb3_contrib import RecurrentPPO

# Define your LSTM policy parameters
policy_kwargs = dict(
    activation_fn=torch.nn.ELU,
    net_arch=[dict(pi=[ 256, 256], vf=[256, 256])],
    lstm_hidden_size=256,  # Set the hidden size of the LSTM layer
)

# Create the model using RecurrentPPO with the MlpLstmPolicy
model = RecurrentPPO(
    policy="MlpLstmPolicy",  # Use a recurrent policy
    env=env,
    learning_rate=1e-4,
    n_steps=2048,    # Ensure this value works with your recurrence settings
    batch_size=128,   # This may need to be divisible by the recurrent sequence length
    n_epochs=4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    tensorboard_log=log_dir,
    policy_kwargs=policy_kwargs,
    verbose=1,
    seed=None,
    device="cpu"
)

# Train the model as before
total_timesteps = 10_000_000
model.learn(total_timesteps=total_timesteps, callback=[checkpoint_callback, render_callback])

model.save("ppo_drone_swarm_final")
print("Training complete. Model saved as 'ppo_drone_swarm_final'.")
env.close()



Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




Logging to ./tensorboard_logs/RecurrentPPO_4
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 474      |
|    ep_rew_mean     | -32      |
| time/              |          |
|    fps             | 1649     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 1024     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 647          |
|    ep_rew_mean          | -15.8        |
| time/                   |              |
|    fps                  | 655          |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 2048         |
| train/                  |              |
|    approx_kl            | 0.0011755745 |
|    clip_fraction        | 0.000244     |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.84        |
|    explaine

In [None]:
import os
import torch
import pygame
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback

# If your DroneSwarmEnv and RenderCallback classes (and train_env_params) are in a separate file,
# import them. For example:
# from drone_swarm_env import DroneSwarmEnv, RenderCallback, train_env_params

# Re-create the environment using the same parameters as before.
env = DroneSwarmEnv(**train_env_params)

# Setup callbacks (adjust the frequency or names if desired)
checkpoint_callback = CheckpointCallback(
    save_freq=2, 
    save_path='./checkpoints_lstm/', 
    name_prefix='ppo_drone_swarm_retrain'
)
render_callback = RenderCallback(render_freq=50)

# Load the previously saved model and attach the environment.
model = PPO.load("checkpoints/ppo_drone_swarm_378_steps.zip", env=env)

# Optionally, adjust additional training parameters here.
additional_timesteps = 5_000_000  # for example, 5 million additional timesteps

# Continue training (retraining) the model.
model.learn(total_timesteps=additional_timesteps, 
            callback=[checkpoint_callback, 
                      render_callback])

# Save the retrained model.
model.save("ppo_drone_swarm_final_retrained_lstm")
print("Retraining complete. Model saved as 'ppo_drone_swarm_final_retrained_lstm'.")

env.close()

In [None]:
import time
import gymnasium as gym
from stable_baselines3 import PPO

# Create an evaluation environment with human rendering.
eval_env_params = train_env_params.copy()
eval_env_params["render_mode"] = "human"
eval_env = DroneSwarmEnv(**eval_env_params)

# Load the saved model and pass the evaluation environment.
model = PPO.load("ppo_drone_swarm_final", env=eval_env)

num_episodes = 5
for ep in range(num_episodes):
    obs, info = eval_env.reset()
    done = False
    truncated = False
    episode_reward = 0
    while not (done or truncated):
        # Predict action from the trained model.
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = eval_env.step(action)
        episode_reward += reward
        eval_env.render()
        time.sleep(0.02)  # Slow down rendering for visualization.
    print(f"Episode {ep+1} reward: {episode_reward}")

eval_env.close()