In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
import pygame
import matplotlib.pyplot as plt
import math
import time
from collections import deque

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Constants
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]
NOISE_MAGNITUDE = 0.5

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

class OptimizedCustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super().__init__(observation_space, features_dim)
        
        n_input = int(np.prod(observation_space.shape))
        
        self.network = nn.Sequential(
            nn.Linear(n_input, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),  # Add some dropout for regularization
            nn.Linear(256, features_dim),
            nn.ReLU(),
            nn.BatchNorm1d(features_dim)
        )
        
        # Initialize weights using orthogonal initialization
        for m in self.network.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)
                
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.network(observations)

class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_gammas = []
        self.training_losses = []
        self.current_episode_gammas = []
        self.total_reward = 0
        
    def _on_step(self):
        # Get info from current step
        if torch.is_tensor(self.locals['actions']):
            gamma = self.locals['actions'][0].item()
        else:
            gamma = float(self.locals['actions'][0])
            
        self.current_episode_gammas.append(gamma)
        
        # Get reward
        if torch.is_tensor(self.locals['rewards']):
            reward = self.locals['rewards'][0].item()
        else:
            reward = float(self.locals['rewards'][0])
        self.total_reward += reward
        
        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.total_reward)
            self.episode_lengths.append(len(self.current_episode_gammas))
            self.episode_gammas.append(np.mean(self.current_episode_gammas))
            
            # Reset episode-specific metrics
            self.current_episode_gammas = []
            self.total_reward = 0
            
            # Store training loss if available
            if hasattr(self.model, 'logger') and 'train/loss' in self.model.logger.name_to_value:
                self.training_losses.append(self.model.logger.name_to_value['train/loss'])
        
        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        
        # Plot episode rewards
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_rewards)
        plt.title('Episode Rewards Over Time')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_rewards.png'))
        plt.close()
        
        # Plot episode lengths
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_lengths)
        plt.title('Episode Lengths Over Time')
        plt.xlabel('Episode')
        plt.ylabel('Steps')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_lengths.png'))
        plt.close()
        
        # Plot average gamma values
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_gammas)
        plt.title('Average Gamma Values Per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Average Gamma')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_gammas.png'))
        plt.close()
        
        # Plot training loss if available
        if self.training_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.training_losses)
            plt.title('Training Loss Over Time')
            plt.xlabel('Update')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, 'training_loss.png'))
            plt.close()

        # Save summary statistics
        with open(os.path.join(save_dir, 'training_summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            f.write(f"Average Reward: {np.mean(self.episode_rewards):.2f}\n")
            f.write(f"Average Episode Length: {np.mean(self.episode_lengths):.2f}\n")
            f.write(f"Average Gamma: {np.mean(self.episode_gammas):.2f}\n")
            if len(self.episode_rewards) >= 100:
                f.write(f"Final 100 Episodes Average Reward: {np.mean(self.episode_rewards[-100:]):.2f}\n")
            f.write(f"Best Episode Reward: {max(self.episode_rewards):.2f}\n")
            f.write(f"Worst Episode Reward: {min(self.episode_rewards):.2f}\n")

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        
        if render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Training Visualization")
            self.font = pygame.font.Font(None, 24)
        
        # Observation space: [dot_pos_x, dot_pos_y, human_input_x, human_input_y, 
        #                    target_pos_x, target_pos_y, perfect_dir_x, perfect_dir_y]
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1]),
            high=np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
                          FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1]),
            dtype=np.float32
        )
        
        # Action space: gamma value between 0 and 1
        self.action_space = spaces.Box(
            low=np.array([0]),
            high=np.array([1]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.target_pos = None
        self.step_count = 0
        self.max_steps = 300
        self.last_render_time = time.time()
        self.steps_since_target_change = 0
        self.change_target_interval = 50  # Change target every N steps
        
        # Add tracking for human input consistency
        self.recent_human_inputs = deque(maxlen=10)  # Track last 10 steps
        self.last_positions = deque(maxlen=5)  # Track last 5 positions
        
    def _generate_target(self):
        """Generate a new target position with minimum distance from current position"""
        while True:
            x = np.random.uniform(100, FULL_VIEW_SIZE[0]-100)
            y = np.random.uniform(100, FULL_VIEW_SIZE[1]-100)
            new_pos = np.array([x, y], dtype=np.float32)
            
            if self.dot_pos is not None:
                min_distance = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2) * 0.3
                if np.linalg.norm(new_pos - self.dot_pos) >= min_distance:
                    return new_pos
            else:
                return new_pos

    def _calculate_human_consistency(self, current_input):
        """Calculate how consistently the human is moving toward a direction."""
        self.recent_human_inputs.append(current_input)
        
        if len(self.recent_human_inputs) < 3:  # Need minimum history
            return 0.0
            
        # Calculate average direction
        recent_inputs = np.array(self.recent_human_inputs)
        avg_direction = np.mean(recent_inputs, axis=0)
        avg_magnitude = np.linalg.norm(avg_direction)
        
        if avg_magnitude < 0.1:  # No consistent movement
            return 0.0
            
        # Calculate consistency as dot product with current direction
        consistencies = []
        for inp in list(self.recent_human_inputs)[-3:]:
            inp_norm = np.linalg.norm(inp)
            if inp_norm > 0.1:  # Only consider significant movements
                consistency = np.dot(inp/inp_norm, avg_direction/avg_magnitude)
                consistencies.append(consistency)
        
        return np.mean(consistencies) if consistencies else 0.0

    def _calculate_progress_rate(self):
        """Calculate the rate of progress toward the target"""
        if len(self.last_positions) < 2:
            return 0.0
        
        # Calculate the change in distance to target over recent positions
        old_pos = self.last_positions[0]
        new_pos = self.last_positions[-1]
        old_dist = np.linalg.norm(self.target_pos - old_pos)
        new_dist = np.linalg.norm(self.target_pos - new_pos)
        
        return (old_dist - new_dist) / old_dist if old_dist > 0 else 0.0

    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        
        # Simulate human input with noise and some bias toward target
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        # Mix perfect direction with noise
        target_bias = 0.3
        human_input = (1 - target_bias) * noise + target_bias * perfect_dir
        # Normalize
        human_input_mag = np.linalg.norm(human_input)
        human_input = human_input / human_input_mag if human_input_mag > 0 else np.array([0, 0])
        
        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir
        ])

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = self._generate_target()
        self.step_count = 0
        self.steps_since_target_change = 0
        self.recent_human_inputs.clear()
        self.last_positions.clear()
        self.last_positions.append(self.dot_pos.copy())
        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        self.steps_since_target_change += 1
        
        gamma = float(np.clip(action[0], 0, 1))
        obs = self._get_obs()
        human_input = obs[2:4]
        perfect_dir = obs[6:8]
        
        # Calculate movement and distances
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        normalized_dist = dist_to_target / np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        
        # Calculate movement
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        # Move dot
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        # Update position history
        self.last_positions.append(self.dot_pos.copy())
        progress_rate = self._calculate_progress_rate()
        
        # Reward Components:
        
        # 1. Base progress reward - increased weight on progress
        progress_reward = -normalized_dist * 0.8 + progress_rate * 3.0
        
        # 2. Distance-based gamma targeting
        # Calculate target gamma based on distance to goal
        proximity_factor = 1.0 - normalized_dist
        base_target_gamma = 0.0  # Start from zero
        
        # Increase target gamma as we get closer
        if proximity_factor > 0.6:  # When we're within 40% of max distance
            # Smoothly increase from 0.3 to 0.5 based on proximity
            base_target_gamma = 0.3 + (proximity_factor - 0.6) * 0.5
        
        # Sharp penalty for high gamma when far from target
        far_penalty = -2.0 * gamma if normalized_dist > 0.8 else 0.0
        
        # Stronger push toward zero right after target change
        if self.steps_since_target_change < 5:
            target_gamma = 0.0
        else:
            target_gamma = base_target_gamma
            
        gamma_reward = -3.0 * (gamma - target_gamma)**2 + far_penalty
        
        # 3. Progress-based gamma bonus
        # Reward higher gamma when making good progress
        if progress_rate > 0:
            gamma_reward += progress_rate * gamma
        
        # 4. Movement alignment reward
        # Reward alignment between movement and target direction
        movement_alignment = np.dot(combined_dir, perfect_dir)
        alignment_reward = movement_alignment * proximity_factor * 2.0
        
        # Combine rewards with adjusted weights
        reward = (
            0.35 * progress_reward +    # Increased weight on progress
            0.35 * gamma_reward +       # Increased weight on gamma targeting
            0.2 * alignment_reward +    # New movement alignment component
            0.1 * (-abs(gamma_change_penalty) if 'gamma_change_penalty' in locals() else 0.0)  # Smoothness
        )
        
        # Additional reward shaping
        if dist_to_target < GOAL_DETECTION_RADIUS:
            terminated = True
            # Base completion bonus
            reward += 15.0
            # Extra reward for controlled approach
            if 0.3 <= gamma <= 0.5:
                reward += 5.0
        else:
            terminated = False
        
        truncated = self.step_count >= self.max_steps
        
        if self.render_mode == "human":
            self.render(human_input, perfect_dir, combined_dir, gamma, reward)
        
        return self._get_obs(), reward, terminated, truncated, {
            'gamma': gamma,
            'target_gamma': target_gamma,
            'progress_rate': progress_rate,
            'normalized_dist': normalized_dist
        }

    def render(self, human_input, perfect_dir, combined_dir, gamma, reward):
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        # Draw target
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        # Draw dot
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        arrow_length = 50
        
        # Draw directions
        for direction, color in [(perfect_dir, GREEN), 
                               (human_input, BLUE), 
                               (combined_dir, RED)]:
            if np.any(direction):
                end_pos = (int(self.dot_pos[0] + direction[0] * arrow_length),
                          int(self.dot_pos[1] + direction[1] * arrow_length))
                pygame.draw.line(self.screen, color, 
                               (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                               end_pos, 2)
        
        # Draw info text
        texts = [
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {reward:.2f}",
            f"Steps until target change: {self.change_target_interval - self.steps_since_target_change}",
            f"Human Consistency: {self._calculate_human_consistency(human_input):.2f}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = self.font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

def make_env(rank, seed=0):
    def _init():
        env = DynamicArbitrationEnv(render_mode=None)
        env.reset(seed=(seed + rank))
        return env
    set_random_seed(seed)
    return _init

def train():
    n_envs = 8  # Number of parallel environments
    env = SubprocVecEnv([make_env(i) for i in range(n_envs)])
    
    policy_kwargs = dict(
        features_extractor_class=OptimizedCustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=dict(
            pi=[256, 128],  # Policy network architecture
            vf=[256, 128]   # Value function network architecture
        )
    )
    
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=2.5e-4,
        n_steps=256,
        batch_size=128,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        normalize_advantage=True,
        policy_kwargs=policy_kwargs,
        device=device,
        verbose=1
    )
    
    # Create metrics callback
    metrics_callback = MetricsCallback()
    
    try:
        print("Starting training...")
        model.learn(
            total_timesteps=1_000_000,
            callback=metrics_callback
        )
        
        print("Training complete! Saving model and metrics...")
        model.save("dynamic_arbitration_ppo_optimized")
        metrics_callback.save_metrics()
        print("Model and metrics saved!")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current model and metrics...")
        model.save("dynamic_arbitration_ppo_optimized_interrupted")
        metrics_callback.save_metrics("training_metrics_interrupted")
    
    finally:
        env.close()
        pygame.quit()

if __name__ == "__main__":
    train()

Using device: cuda
Using cuda device
Starting training...


  gamma = float(self.locals['actions'][0])


-----------------------------
| time/              |      |
|    fps             | 2590 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 2188         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0013810615 |
|    clip_fraction        | 0.00134      |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.41        |
|    explained_variance   | 0.0301       |
|    learning_rate        | 0.00025      |
|    loss                 | 0.841        |
|    n_updates            | 4            |
|    policy_gradient_loss | -0.00122     |
|    std                  | 0.987        |
|    value_loss           | 5.6          |
----------------