In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback
import pygame
import math
import time
import matplotlib.pyplot as plt
import os

# Constants
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]
NOISE_MAGNITUDE = 0.5  # Match the game's noise level

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

class CustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        
        n_input = int(np.prod(observation_space.shape))
        
        self.network = nn.Sequential(
            nn.Linear(n_input, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU()
        )
        
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.network(observations)

class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_gammas = []
        self.training_losses = []
        self.current_episode_gammas = []
        self.total_reward = 0
        
    def _on_step(self):
        # Get info from current step
        info = self.locals['infos'][0]
        gamma = float(self.locals['actions'][0])
        self.current_episode_gammas.append(gamma)
        self.total_reward += self.locals['rewards'][0]
        
        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.total_reward)
            self.episode_lengths.append(len(self.current_episode_gammas))
            self.episode_gammas.append(np.mean(self.current_episode_gammas))
            
            # Reset episode-specific metrics
            self.current_episode_gammas = []
            self.total_reward = 0
            
            # Store training loss if available
            if hasattr(self.model, 'logger') and 'train/loss' in self.model.logger.name_to_value:
                self.training_losses.append(self.model.logger.name_to_value['train/loss'])
        
        return True

    def plot_metrics(self, save_dir="training_metrics"):
        """Plot and save training metrics."""
        os.makedirs(save_dir, exist_ok=True)
        
        # Plot episode rewards
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_rewards)
        plt.title('Episode Rewards Over Time')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_rewards.png'))
        plt.close()
        
        # Plot episode lengths
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_lengths)
        plt.title('Episode Lengths Over Time')
        plt.xlabel('Episode')
        plt.ylabel('Steps')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_lengths.png'))
        plt.close()
        
        # Plot average gamma values
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_gammas)
        plt.title('Average Gamma Values Per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Average Gamma')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'episode_gammas.png'))
        plt.close()
        
        # Plot training loss if available
        if self.training_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.training_losses)
            plt.title('Training Loss Over Time')
            plt.xlabel('Update')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, 'training_loss.png'))
            plt.close()
        
        # Plot moving averages
        window = 100
        if len(self.episode_rewards) > window:
            plt.figure(figsize=(10, 6))
            rewards_ma = np.convolve(self.episode_rewards, 
                                   np.ones(window)/window, 
                                   mode='valid')
            plt.plot(rewards_ma)
            plt.title(f'Moving Average of Rewards (Window={window})')
            plt.xlabel('Episode')
            plt.ylabel('Average Reward')
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, 'rewards_moving_average.png'))
            plt.close()
        
        # Save summary statistics
        with open(os.path.join(save_dir, 'training_summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            f.write(f"Average Reward: {np.mean(self.episode_rewards):.2f}\n")
            f.write(f"Average Episode Length: {np.mean(self.episode_lengths):.2f}\n")
            f.write(f"Average Gamma: {np.mean(self.episode_gammas):.2f}\n")
            f.write(f"Final 100 Episodes Average Reward: {np.mean(self.episode_rewards[-100:]):.2f}\n")
            f.write(f"Best Episode Reward: {np.max(self.episode_rewards):.2f}\n")
            f.write(f"Worst Episode Reward: {np.min(self.episode_rewards):.2f}\n")

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        
        if render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Training Visualization")
            self.font = pygame.font.Font(None, 24)
        
        # Observation space: [dot_pos_x, dot_pos_y, human_input_x, human_input_y, 
        #                    target_pos_x, target_pos_y, perfect_dir_x, perfect_dir_y]
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1]),
            high=np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
                          FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1]),
            dtype=np.float32
        )
        
        # Action space: gamma value between 0 and 1
        self.action_space = spaces.Box(
            low=np.array([0]),
            high=np.array([1]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.target_pos = None
        self.step_count = 0
        self.max_steps = 300  # Shorter episodes for faster learning
        self.last_render_time = time.time()

    def _generate_target(self):
        # Simple random target generation
        x = np.random.uniform(100, FULL_VIEW_SIZE[0]-100)
        y = np.random.uniform(100, FULL_VIEW_SIZE[1]-100)
        return np.array([x, y], dtype=np.float32)

    def _get_obs(self):
        # Calculate perfect direction to target
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        
        # Add noise to perfect direction to simulate human input
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        human_input = perfect_dir + noise
        human_input = human_input / np.linalg.norm(human_input) if np.linalg.norm(human_input) > 0 else np.array([0, 0])
        
        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir
        ])

    def reset(self, seed=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = self._generate_target()
        self.step_count = 0
        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        gamma = float(action[0])
        
        obs = self._get_obs()
        human_input = obs[2:4]
        perfect_dir = obs[6:8]
        
        # Combine directions
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        # Move
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        # Calculate reward
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        progress_reward = -dist_to_target / np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        
        # Simplified reward structure
        reward = progress_reward
        
        # Check if done
        done = False
        if dist_to_target < GOAL_DETECTION_RADIUS:
            done = True
            reward += 5.0  # Bonus for reaching target
        elif self.step_count >= self.max_steps:
            done = True
        
        # Render if needed
        if self.render_mode == "human":
            self.render(human_input, perfect_dir, combined_dir, gamma, reward)
        
        return self._get_obs(), reward, done, False, {}

    def render(self, human_input, perfect_dir, combined_dir, gamma, reward):
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        # Draw target
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        # Draw dot
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        # Draw info text
        texts = [
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {reward:.2f}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = self.font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

def train():
    # Create and wrap the environment
    env = DynamicArbitrationEnv(render_mode="human")
    env = DummyVecEnv([lambda: env])
    
    # Create metrics callback
    metrics_callback = MetricsCallback()
    
    # Policy settings
    policy_kwargs = dict(
        features_extractor_class=CustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=64),
        net_arch=[dict(pi=[64, 64], vf=[64, 64])]
    )
    
    # Create and train the model
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        policy_kwargs=policy_kwargs,
        device='cpu',
        verbose=1
    )
    
    try:
        print("Starting training...")
        model.learn(
            total_timesteps=500_000,
            callback=metrics_callback
        )
        print("Training complete!")
        
        # Plot and save metrics
        print("Generating training metrics plots...")
        metrics_callback.plot_metrics()
        
        # Save the model
        model.save("dynamic_arbitration_ppo")
        print("Model saved!")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current model and metrics...")
        model.save("dynamic_arbitration_ppo_interrupted")
        metrics_callback.plot_metrics("training_metrics_interrupted")
    finally:
        pygame.quit()

if __name__ == "__main__":
    train()

Using cpu device




Starting training...


  gamma = float(self.locals['actions'][0])


-----------------------------
| time/              |      |
|    fps             | 2097 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1003         |
|    iterations           | 2            |
|    time_elapsed         | 4            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0020793537 |
|    clip_fraction        | 0.00479      |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.42        |
|    explained_variance   | -0.0187      |
|    learning_rate        | 0.0003       |
|    loss                 | 0.356        |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.000614    |
|    std                  | 0.993        |
|    value_loss           | 2.57         |
----------------