In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import pygame
import math
import time

# Constants from original environment
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

# Initialize Pygame
pygame.init()
FONT_SIZE = 24
font = pygame.font.Font(None, FONT_SIZE)

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        if render_mode == "human":
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("PPO Training Visualization")
        
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -MAX_SPEED, -MAX_SPEED, 
                         -1, -1, 
                         0, 0,
                         -1, -1]),
            high=np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], MAX_SPEED, MAX_SPEED,
                          1, 1,
                          FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                          1, 1]),
            dtype=np.float32
        )
        
        self.action_space = spaces.Box(
            low=np.array([0]),
            high=np.array([1]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.dot_vel = None
        self.target_pos = None
        self.step_count = 0
        self.max_steps = 500
        self.episode_reward = 0
        self.episode_count = 0
        self.last_render_time = time.time()
        
    def draw_arrow(self, start_pos, direction, color, length=60):
        if direction[0] == 0 and direction[1] == 0:
            return
            
        end_x = start_pos[0] + direction[0] * length
        end_y = start_pos[1] + direction[1] * length
        
        pygame.draw.line(self.screen, color, 
                        (int(start_pos[0]), int(start_pos[1])), 
                        (int(end_x), int(end_y)), 2)
        
        # Arrow head
        angle = math.atan2(direction[1], direction[0])
        arrow_size = 10
        arrow1_x = end_x - arrow_size * math.cos(angle + math.pi/6)
        arrow1_y = end_y - arrow_size * math.sin(angle + math.pi/6)
        arrow2_x = end_x - arrow_size * math.cos(angle - math.pi/6)
        arrow2_y = end_y - arrow_size * math.sin(angle - math.pi/6)
        
        pygame.draw.line(self.screen, color, 
                        (int(end_x), int(end_y)), 
                        (int(arrow1_x), int(arrow1_y)), 2)
        pygame.draw.line(self.screen, color, 
                        (int(end_x), int(end_y)), 
                        (int(arrow2_x), int(arrow2_y)), 2)

    def render(self, human_input, perfect_dir, combined_dir, gamma):
        if self.render_mode != "human":
            return
            
        # Limit rendering to 60 FPS
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        # Draw target
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        # Draw dot
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        # Draw arrows
        if np.any(human_input):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), human_input, BLUE)
        if np.any(perfect_dir):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), perfect_dir, GREEN)
        if np.any(combined_dir):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), combined_dir, RED)
        
        # Draw info text
        texts = [
            f"Episode: {self.episode_count}",
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {self.episode_reward:.1f}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

    def reset(self, seed=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.dot_vel = np.array([0.0, 0.0], dtype=np.float32)
        self.target_pos = np.array(self._generate_target(), dtype=np.float32)
        self.step_count = 0
        self.episode_reward = 0
        self.episode_count += 1
        
        return self._get_obs(), {}
    
    def _generate_target(self):
        x = np.random.uniform(100, FULL_VIEW_SIZE[0]-100)
        y = np.random.uniform(100, FULL_VIEW_SIZE[1]-100)
        return [x, y]
    
    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        
        noise = np.random.normal(0, 0.3, size=2)
        human_input = perfect_dir + noise
        human_input = human_input / np.linalg.norm(human_input) if np.linalg.norm(human_input) > 0 else np.array([0, 0])
        
        return np.concatenate([
            self.dot_pos,
            self.dot_vel,
            human_input,
            self.target_pos,
            perfect_dir
        ])
    
    def step(self, action):
        self.step_count += 1
        gamma = float(action[0])
        
        obs = self._get_obs()
        human_input = obs[4:6]
        perfect_dir = obs[8:10]
        
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        self.dot_vel = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + self.dot_vel
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        progress_reward = -dist_to_target / np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        gamma_penalty = -0.1 * (abs(gamma - 0.5) ** 2)
        
        if hasattr(self, 'last_gamma'):
            smoothness_penalty = -0.1 * abs(gamma - self.last_gamma)
        else:
            smoothness_penalty = 0
        self.last_gamma = gamma
        
        reward = progress_reward + gamma_penalty + smoothness_penalty
        self.episode_reward += reward
        
        done = False
        if dist_to_target < GOAL_DETECTION_RADIUS:
            done = True
            reward += 10.0
        elif self.step_count >= self.max_steps:
            done = True
            reward -= 5.0
            
        # Render the current state
        self.render(human_input, perfect_dir, combined_dir, gamma)
            
        return self._get_obs(), reward, done, False, {}

def train_ppo_with_viz():
    # Create environment with visualization
    env = DynamicArbitrationEnv(render_mode="human")
    env = DummyVecEnv([lambda: env])
    
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1
    )
    
    try:
        model.learn(total_timesteps=1_000_000)
        model.save("dynamic_arbitration_ppo")
    finally:
        pygame.quit()

if __name__ == "__main__":
    train_ppo_with_viz()

  from pandas.core import (


Using cuda device




-----------------------------
| time/              |      |
|    fps             | 674  |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 530          |
|    iterations           | 2            |
|    time_elapsed         | 7            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0044803466 |
|    clip_fraction        | 0.0414       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.42        |
|    explained_variance   | 0.0966       |
|    learning_rate        | 0.0003       |
|    loss                 | 1.58         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00328     |
|    std                  | 0.996        |
|    value_loss           | 7.15         |
----------------

KeyboardInterrupt: 