In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
import torch.multiprocessing as mp

import pygame
import matplotlib.pyplot as plt
import time

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Constants
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]
NOISE_MAGNITUDE = 0.5

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_gammas = []
        self.current_episode_gammas = []
        self.total_reward = 0
        self.actor_losses = []
        self.critic_losses = []
        self.action_low = None
        self.action_high = None

    def _on_step(self):
        if self.action_low is None:
            self.action_low = self.model.action_space.low[0]
            self.action_high = self.model.action_space.high[0]

        gamma = float(self.locals['actions'][0])
        gamma = np.clip(gamma, self.action_low, self.action_high)
        self.current_episode_gammas.append(gamma)

        reward = float(self.locals['rewards'][0])
        self.total_reward += reward

        if hasattr(self.model, 'logger'):
            if 'train/policy_gradient_loss' in self.model.logger.name_to_value:
                self.actor_losses.append(
                    self.model.logger.name_to_value['train/policy_gradient_loss']
                )
            if 'train/value_loss' in self.model.logger.name_to_value:
                self.critic_losses.append(
                    self.model.logger.name_to_value['train/value_loss']
                )
        
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.total_reward)
            self.episode_gammas.append(np.mean(self.current_episode_gammas))
            self.current_episode_gammas = []
            self.total_reward = 0
        
        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_rewards, label="Episode Reward")
        plt.title('Average Reward per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_reward.png'))
        plt.close()
        
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_gammas, label="Average Gamma")
        plt.title('Average Gamma per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Gamma')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_gamma.png'))
        plt.close()
        
        if self.actor_losses and self.critic_losses:
            plt.figure(figsize=(8, 5))
            plt.plot(self.actor_losses, label="Actor Loss")
            plt.plot(self.critic_losses, label="Critic Loss")
            plt.title('Actor & Critic Loss')
            plt.xlabel('Training Update')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.legend()
            plt.savefig(os.path.join(save_dir, 'actor_critic_loss.png'))
            plt.close()

        with open(os.path.join(save_dir, 'training_summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if len(self.episode_rewards) > 0:
                f.write(f"Overall Average Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Overall Average Gamma: {np.mean(self.episode_gammas):.3f}\n")
                f.write(f"Best Episode Reward: {max(self.episode_rewards):.3f}\n")
                f.write(f"Worst Episode Reward: {min(self.episode_rewards):.3f}\n")

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        if render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Training Visualization")
            self.font = pygame.font.Font(None, 24)
        
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1, 0]),
            high=np.array([
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1, 1
            ], dtype=np.float32),
            dtype=np.float32
        )
        
        self.action_space = spaces.Box(
            low=np.array([0.0]),
            high=np.array([0.4]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.target_pos = None
        self.step_count = 0
        self.max_steps = 300
        self.last_render_time = time.time()
        self.steps_since_target_change = 0
        self.change_target_interval = 50

    def _generate_target(self):
        x = np.random.uniform(100, FULL_VIEW_SIZE[0] - 100)
        y = np.random.uniform(100, FULL_VIEW_SIZE[1] - 100)
        return np.array([x, y], dtype=np.float32)

    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        normalized_dist = dist / self.max_dist
        
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        human_input = perfect_dir + noise
        if np.linalg.norm(human_input) > 0:
            human_input = human_input / np.linalg.norm(human_input)

        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir,
            [normalized_dist]
        ])

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = self._generate_target()
        self.step_count = 0
        self.steps_since_target_change = 0
        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        self.steps_since_target_change += 1
        
        gamma = float(np.clip(action[0], 0.0, 0.4))
        
        obs = self._get_obs()
        human_input = obs[2:4]
        perfect_dir = obs[6:8]
        
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        
        progress_reward = -dist_to_target / max_dist
        
        distance_fraction = dist_to_target / max_dist
        if distance_fraction <= 0.1:
            ideal_gamma = 0.4
        else:
            ideal_gamma = 0.1
            
        gamma_penalty = -(gamma - ideal_gamma)**2
        human_alignment = np.dot(combined_dir, human_input)
        human_follow_reward = 0.5 * human_alignment * distance_fraction
        
        reward = 0.05 * progress_reward + 1.0 * gamma_penalty + human_follow_reward
        
        if self.steps_since_target_change >= self.change_target_interval:
            self.target_pos = self._generate_target()
            self.steps_since_target_change = 0
            reward += 0.5
        
        terminated = False
        truncated = False
        
        if dist_to_target < GOAL_DETECTION_RADIUS:
            reward += 50.0
            terminated = True
        elif self.step_count >= self.max_steps:
            truncated = True
        
        if self.render_mode == "human":
            self.render(human_input, perfect_dir, combined_dir, gamma, reward)
        
        return self._get_obs(), reward, terminated, truncated, {}

    def render(self, human_input, perfect_dir, combined_dir, gamma, reward):
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        arrow_length = 50
        
        end_pos = (int(self.dot_pos[0] + perfect_dir[0] * arrow_length),
                   int(self.dot_pos[1] + perfect_dir[1] * arrow_length))
        pygame.draw.line(self.screen, GREEN, 
                       (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                       end_pos, 2)
                        
        end_pos = (int(self.dot_pos[0] + human_input[0] * arrow_length),
                   int(self.dot_pos[1] + human_input[1] * arrow_length))
        pygame.draw.line(self.screen, BLUE, 
                       (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                       end_pos, 2)
                        
        end_pos = (int(self.dot_pos[0] + combined_dir[0] * arrow_length),
                   int(self.dot_pos[1] + combined_dir[1] * arrow_length))
        pygame.draw.line(self.screen, RED, 
                       (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                       end_pos, 2)
        
        texts = [
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {reward:.2f}",
            f"Steps until target change: {self.change_target_interval - self.steps_since_target_change}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = self.font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

def make_env(rank, seed=0):
    def _init():
        env = DynamicArbitrationEnv(render_mode=None)
        env.reset(seed=seed + rank)
        return env
    set_random_seed(seed)
    return _init

def train():
    n_envs = 8
    env = DummyVecEnv([make_env(i) for i in range(n_envs)])
    
    policy_kwargs = dict(
        net_arch=dict(pi=[256, 128], vf=[256, 128])
    )
    
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=2.5e-4,
        n_steps=256,
        batch_size=128,
        ent_coef=0.001,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        normalize_advantage=True,
        policy_kwargs=policy_kwargs,
        device=device,
        verbose=1
    )
    
    metrics_callback = MetricsCallback()
    
    try:
        print("Starting training...")
        model.learn(
            total_timesteps=2000000,
            callback=metrics_callback
        )
        
        print("Training complete! Saving model and metrics...")
        model.save("ppo_dynamic_arbitration_simple")
        metrics_callback.save_metrics()
        print("Model and metrics saved!")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current model and metrics...")
        model.save("ppo_dynamic_arbitration_simple_interrupted")
        metrics_callback.save_metrics(save_dir="training_metrics_interrupted")
    
    finally:
        env.close()
        pygame.quit()

if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    train()

Using device: cuda
Using cuda device
Starting training...


  gamma = float(self.locals['actions'][0])


-----------------------------
| time/              |      |
|    fps             | 3933 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 3197        |
|    iterations           | 2           |
|    time_elapsed         | 1           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.004762459 |
|    clip_fraction        | 0.0116      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.42       |
|    explained_variance   | -0.00156    |
|    learning_rate        | 0.00025     |
|    loss                 | 7.28        |
|    n_updates            | 4           |
|    policy_gradient_loss | -0.000762   |
|    std                  | 0.995       |
|    value_loss           | 34.9        |
----------------------------------