In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
import torch.multiprocessing as mp

import pygame
import matplotlib.pyplot as plt
import time

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Constants
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]
NOISE_MAGNITUDE = 0.5

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

class OptimizedCustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super().__init__(observation_space, features_dim)
        
        n_input = int(np.prod(observation_space.shape))
        
        self.network = nn.Sequential(
            nn.Linear(n_input, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, features_dim),
            nn.ReLU(),
            nn.BatchNorm1d(features_dim)
        )
        
        # Initialize weights properly
        for m in self.network.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)
                
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.network(observations)

class MetricsCallback(BaseCallback):
    """
    Modified callback that tracks:
      - Episode total rewards
      - Average gamma per episode
      - Actor loss and critic loss from SB3 logger
    Then plots:
      1. Average reward per episode
      2. Average gamma per episode
      3. Actor & Critic loss
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_gammas = []
        self.current_episode_gammas = []
        self.total_reward = 0
        
        self.actor_losses = []
        self.critic_losses = []
        self.action_low = None  # Add this line
        self.action_high = None  # Add this line

    def _on_step(self):

        if self.action_low is None:
            self.action_low = self.model.action_space.low[0]  # 0.0
            self.action_high = self.model.action_space.high[0]  # 0.4

        if torch.is_tensor(self.locals['actions']):
            gamma = self.locals['actions'][0].item()
        else:
            gamma = float(self.locals['actions'][0])

        gamma = np.clip(gamma, self.action_low, self.action_high)
        self.current_episode_gammas.append(gamma)

        
        # Track rewards
        if torch.is_tensor(self.locals['rewards']):
            reward = self.locals['rewards'][0].item()
        else:
            reward = float(self.locals['rewards'][0])
        self.total_reward += reward
        
        # Grab actor/critic loss from the SB3 logger if available
        if (hasattr(self.model, 'logger') 
            and hasattr(self.model.logger, 'name_to_value')):
            # Actor (policy) loss
            if 'train/policy_gradient_loss' in self.model.logger.name_to_value:
                self.actor_losses.append(
                    self.model.logger.name_to_value['train/policy_gradient_loss']
                )
            # Critic (value) loss
            if 'train/value_loss' in self.model.logger.name_to_value:
                self.critic_losses.append(
                    self.model.logger.name_to_value['train/value_loss']
                )
        
        # Check if the episode ended
        if self.locals['dones'][0]:
            # Store final metrics of this episode
            self.episode_rewards.append(self.total_reward)
            self.episode_gammas.append(np.mean(self.current_episode_gammas))
            
            # Reset for next episode
            self.current_episode_gammas = []
            self.total_reward = 0
        
        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        
        # 1) Plot average reward per episode
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_rewards, label="Episode Reward")
        plt.title('Average Reward per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_reward.png'))
        plt.close()
        
        # 2) Plot average gamma per episode
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_gammas, label="Average Gamma")
        plt.title('Average Gamma per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Gamma')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_gamma.png'))
        plt.close()
        
        # 3) Plot actor & critic loss (if we have them)
        if self.actor_losses and self.critic_losses:
            plt.figure(figsize=(8, 5))
            plt.plot(self.actor_losses, label="Actor Loss")
            plt.plot(self.critic_losses, label="Critic Loss")
            plt.title('Actor & Critic Loss')
            plt.xlabel('Training Update (not per step)')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.legend()
            plt.savefig(os.path.join(save_dir, 'actor_critic_loss.png'))
            plt.close()

        # (Optional) Save summary stats
        with open(os.path.join(save_dir, 'training_summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if len(self.episode_rewards) > 0:
                f.write(f"Overall Average Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Overall Average Gamma: {np.mean(self.episode_gammas):.3f}\n")
                f.write(f"Best Episode Reward: {max(self.episode_rewards):.3f}\n")
                f.write(f"Worst Episode Reward: {min(self.episode_rewards):.3f}\n")

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        
        # Target change parameters
        self.change_target_interval = 50  # Change target every N steps
        self.steps_since_target_change = 0
        self.target_pos = None
        
        # Distance thresholds for gamma targeting
        self.VERY_CLOSE_THRESHOLD = 0.1  # Normalized distance for 0.4 gamma
        self.CLOSE_THRESHOLD = 0.2       # Start increasing gamma
        
        # Same observation and action spaces
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1, 0]),
            high=np.array([
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                1
            ]),
            dtype=np.float32
        )
        
        self.action_space = spaces.Box(
            low=np.array([0.0]),
            high=np.array([0.4]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.step_count = 0
        self.max_steps = 300

    # Keep _generate_target, _get_obs, and reset methods the same
    def _generate_target(self):
        """Generate a new target position, away from current position"""
        while True:
            x = np.random.uniform(100, FULL_VIEW_SIZE[0] - 100)
            y = np.random.uniform(100, FULL_VIEW_SIZE[1] - 100)
            new_pos = np.array([x, y], dtype=np.float32)
            
            if self.target_pos is not None:
                dist_to_old = np.linalg.norm(new_pos - self.target_pos)
                if dist_to_old < FULL_VIEW_SIZE[0] / 4:
                    continue
            
            if self.dot_pos is not None:
                dist_to_dot = np.linalg.norm(new_pos - self.dot_pos)
                if dist_to_dot < GOAL_DETECTION_RADIUS * 2:
                    continue
            
            return new_pos

    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        normalized_dist = dist / self.max_dist

        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        human_input = perfect_dir + noise
        if np.linalg.norm(human_input) > 0:
            human_input = human_input / np.linalg.norm(human_input)

        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir,
            [normalized_dist]
        ])

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = self._generate_target()
        self.step_count = 0
        self.steps_since_target_change = 0
        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        self.steps_since_target_change += 1
        
        # Get gamma from action
        gamma = float(np.clip(action[0], 0.0, 0.4))
        
        # Get current observation
        obs = self._get_obs()
        human_input = obs[2:4]
        perfect_dir = obs[6:8]
        normalized_dist = obs[8]  # This is distance/max_dist
        
        # Movement logic (same as before)
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        # Distance calculations
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        goal_radius_norm = GOAL_DETECTION_RADIUS / self.max_dist
        
        # Define distance thresholds relative to goal radius
        # In the step function:

# Define distance thresholds relative to goal radius
        VERY_CLOSE_THRESHOLD = goal_radius_norm * 2.0  
        CLOSE_THRESHOLD = goal_radius_norm * 4.0      
        MID_THRESHOLD = goal_radius_norm * 8.0        

        # Gamma targeting with smoother transitions
        if normalized_dist < VERY_CLOSE_THRESHOLD:  # Very close - smoothly approach 0.4
            progress_ratio = 1.0 - (normalized_dist / VERY_CLOSE_THRESHOLD)  # 0 to 1
            target_gamma = 0.2 + (0.2 * progress_ratio)  # Smoothly go from 0.2 to 0.4
            # Stronger but still smooth penalty
            gamma_reward = -25.0 * ((gamma - target_gamma) ** 2)
            
        elif normalized_dist < CLOSE_THRESHOLD:  # Transition zone
            progress_ratio = 1.0 - ((normalized_dist - VERY_CLOSE_THRESHOLD) / (CLOSE_THRESHOLD - VERY_CLOSE_THRESHOLD))
            target_gamma = 0.2 + (0.1 * progress_ratio)  # Smoothly go from 0.2 to 0.3
            gamma_reward = -15.0 * ((gamma - target_gamma) ** 2)
            
        else:  # Far from target
            target_gamma = 0.2
            # Less aggressive penalty when gamma is too low
            if gamma < target_gamma:
                gamma_reward = -5.0 * (gamma - target_gamma) ** 2
            # More aggressive penalty when gamma is too high
            else:
                gamma_reward = -20.0 * (gamma - target_gamma) ** 2
        # Small alignment reward
        human_alignment = np.dot(combined_dir, human_input)
        alignment_reward = 0.1 * human_alignment
        
        # Goal and target change logic
        terminated = False
        truncated = False
        goal_reward = 0.0
        
        # Check if reached current target
        if dist_to_target < GOAL_DETECTION_RADIUS:
            if normalized_dist < VERY_CLOSE_THRESHOLD:
                if gamma > 0.35:
                    goal_reward = 50.0  # Big bonus for high gamma at goal
                elif gamma > 0.3:
                    goal_reward = 20.0  # Medium bonus for moderately high gamma
                else:
                    goal_reward = 5.0   # Small reward otherwise
            else:
                goal_reward = 5.0
            
            self.target_pos = self._generate_target()
            self.steps_since_target_change = 0
            
        elif self.steps_since_target_change >= self.change_target_interval:
            self.target_pos = self._generate_target()
            self.steps_since_target_change = 0
            if gamma <= 0.25:
                goal_reward = 0.5
                
        if self.step_count >= self.max_steps:
            truncated = True
        
        # Combined reward
        reward = ( 
                 5.0 * gamma_reward +     # Reduced weight for smoother transitions
                 0.1 * alignment_reward + # Minimal alignment importance
                 goal_reward
                 )
        
        
        if self.render_mode == "human":
            self.render(human_input, perfect_dir, combined_dir, gamma, reward)
        
        return self._get_obs(), reward, terminated, truncated, {}

    def render(self, human_input, perfect_dir, combined_dir, gamma, reward):
        # Basic pygame rendering at ~60 FPS
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        pygame.draw.circle(self.screen, YELLOW, 
                           (int(self.target_pos[0]), int(self.target_pos[1])), 
                           TARGET_RADIUS)
        
        pygame.draw.circle(self.screen, BLACK, 
                           (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                           DOT_RADIUS, 2)
        
        arrow_length = 50
        
        # Perfect direction (green)
        end_pos = (int(self.dot_pos[0] + perfect_dir[0] * arrow_length),
                   int(self.dot_pos[1] + perfect_dir[1] * arrow_length))
        pygame.draw.line(self.screen, GREEN, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         end_pos, 2)
                        
        # Human direction (blue)
        end_pos = (int(self.dot_pos[0] + human_input[0] * arrow_length),
                   int(self.dot_pos[1] + human_input[1] * arrow_length))
        pygame.draw.line(self.screen, BLUE, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         end_pos, 2)
                        
        # Combined direction (red)
        end_pos = (int(self.dot_pos[0] + combined_dir[0] * arrow_length),
                   int(self.dot_pos[1] + combined_dir[1] * arrow_length))
        pygame.draw.line(self.screen, RED, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         end_pos, 2)
        
        texts = [
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {reward:.2f}",
            f"Steps until target change: {self.change_target_interval - self.steps_since_target_change}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = self.font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

def make_env(rank, seed=0):
    def _init():
        env = DynamicArbitrationEnv(render_mode=None)
        env.reset(seed=seed + rank)
        return env
    set_random_seed(seed)
    return _init

def train():
    # Create directories if they don't exist
    model_save_dir = "trained_models"
    metrics_dir = "training_metrics"
    os.makedirs(model_save_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)
    
    n_envs = 64
    env = DummyVecEnv([make_env(i) for i in range(n_envs)])
    
    policy_kwargs = dict(
        features_extractor_class=OptimizedCustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=dict(
            pi=[256, 128],
            vf=[256, 128]
        )
    )
    
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-5,      
        n_steps=256,            
        batch_size=2048,
        ent_coef=0.005,         
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        normalize_advantage=True,
        policy_kwargs=policy_kwargs,
        device=device,
        verbose=1
    )
    
    # Create metrics callback
    metrics_callback = MetricsCallback()
    
    try:
        print("Starting training...")
        model.learn(
            total_timesteps=25_000_000,
            callback=metrics_callback
        )
        
        # Save the model
        model.save("dynamic_arbitration_ppo")
        print("Model saved as dynamic_arbitration_ppo.zip")
        
        # Save the metrics
        metrics_callback.save_metrics()
        print("Metrics saved to training_metrics directory")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current state...")
        model.save("dynamic_arbitration_ppo_interrupted")
        metrics_callback.save_metrics(save_dir="training_metrics_interrupted")
        print("Intermediate state saved")
    
    finally:
        env.close()

if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    train()


Using device: cuda
Using cuda device
Starting training...


  gamma = float(self.locals['actions'][0])


------------------------------
| time/              |       |
|    fps             | 11083 |
|    iterations      | 1     |
|    time_elapsed    | 1     |
|    total_timesteps | 16384 |
------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 10439        |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 9.157782e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.42        |
|    explained_variance   | 0.00206      |
|    learning_rate        | 3e-05        |
|    loss                 | 617          |
|    n_updates            | 4            |
|    policy_gradient_loss | 8.04e-05     |
|    std                  | 0.999        |
|    value_loss           | 1.26e+03     |
---------