In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import pygame
import math
import time
from collections import deque
import json
import argparse

# Constants
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

# Curriculum levels
CURRICULUM_LEVELS = {
    0: {"radius": 200, "noise": 0.1, "obstacles": 0},  # Easy: Close targets
    1: {"radius": 400, "noise": 0.2, "obstacles": 0},  # Medium
    2: {"radius": None, "noise": 0.3, "obstacles": 0}, # Hard
    3: {"radius": None, "noise": 0.3, "obstacles": 3}  # Expert
}

# Initialize Pygame
pygame.init()
FONT_SIZE = 24
font = pygame.font.Font(None, FONT_SIZE)

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        assert input_dim % num_heads == 0, "input_dim must be divisible by num_heads"
        self.head_dim = input_dim // num_heads
        
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.proj = nn.Linear(input_dim, input_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # Linear projections and reshape
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # Transpose for attention computation
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, -1)
        
        return self.proj(context)



class CustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        
        n_input = int(np.prod(observation_space.shape))
        
        # Simple feed-forward network
        self.network = nn.Sequential(
            nn.Linear(n_input, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU()
        )
        
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.network(observations)

class CurriculumCallback(BaseCallback):
    def __init__(self, env, verbose=0):
        super().__init__(verbose)
        self.env = env
        self.success_rate = deque(maxlen=100)
        self.current_level = 0
        
    def _on_step(self):
        if self.env.episode_done:
            self.success_rate.append(1 if self.env.episode_success else 0)
            
            if len(self.success_rate) == self.success_rate.maxlen:
                success_rate = sum(self.success_rate) / len(self.success_rate)
                if success_rate > 0.8 and self.current_level < len(CURRICULUM_LEVELS) - 1:
                    self.current_level += 1
                    self.env.set_curriculum_level(self.current_level)
                    print(f"Advancing to curriculum level {self.current_level}")
        return True

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None, record_demonstrations=False):
        super().__init__()
        self.render_mode = render_mode
        self.record_demonstrations = record_demonstrations
        self.demonstrations = []
        self.current_episode_states = []
        self.curriculum_level = 0
        self.episode_success = False
        self.episode_done = False
        
        if render_mode == "human":
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("PPO Training Visualization")
        
        # Modified observation space to remove velocity
        self.observation_space = spaces.Box(
            low=np.array([0, 0,           # dot position
                         -1, -1,          # human input
                         0, 0,            # target position
                         -1, -1]),        # perfect direction
            high=np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                          1, 1,
                          FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                          1, 1]),
            dtype=np.float32
        )
        
        self.action_space = spaces.Box(
            low=np.array([0]),
            high=np.array([1]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.target_pos = None
        self.step_count = 0
        self.max_steps = 500
        self.episode_reward = 0
        self.episode_count = 0
        self.last_render_time = time.time()

    def draw_arrow(self, start_pos, direction, color, length=60):
        if direction[0] == 0 and direction[1] == 0:
            return
            
        end_x = start_pos[0] + direction[0] * length
        end_y = start_pos[1] + direction[1] * length
        
        pygame.draw.line(self.screen, color, 
                        (int(start_pos[0]), int(start_pos[1])), 
                        (int(end_x), int(end_y)), 2)
        
        angle = math.atan2(direction[1], direction[0])
        arrow_size = 10
        arrow1_x = end_x - arrow_size * math.cos(angle + math.pi/6)
        arrow1_y = end_y - arrow_size * math.sin(angle + math.pi/6)
        arrow2_x = end_x - arrow_size * math.cos(angle - math.pi/6)
        arrow2_y = end_y - arrow_size * math.sin(angle - math.pi/6)
        
        pygame.draw.line(self.screen, color, 
                        (int(end_x), int(end_y)), 
                        (int(arrow1_x), int(arrow1_y)), 2)
        pygame.draw.line(self.screen, color, 
                        (int(end_x), int(end_y)), 
                        (int(arrow2_x), int(arrow2_y)), 2)

    def render(self, human_input, perfect_dir, combined_dir, gamma):
        if self.render_mode != "human":
            return
            
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        # Draw target
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        # Draw dot
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        # Draw arrows
        if np.any(human_input):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), human_input, BLUE)
        if np.any(perfect_dir):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), perfect_dir, GREEN)
        if np.any(combined_dir):
            self.draw_arrow((self.dot_pos[0], self.dot_pos[1]), combined_dir, RED)
        
        # Draw info text
        texts = [
            f"Episode: {self.episode_count}",
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {self.episode_reward:.1f}",
            f"Curriculum Level: {self.curriculum_level}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

    def set_curriculum_level(self, level):
        self.curriculum_level = level

    def _generate_target(self):
        curr_params = CURRICULUM_LEVELS[self.curriculum_level]
        if curr_params["radius"] is None:
            x = np.random.uniform(100, FULL_VIEW_SIZE[0]-100)
            y = np.random.uniform(100, FULL_VIEW_SIZE[1]-100)
        else:
            angle = np.random.uniform(0, 2*np.pi)
            r = np.random.uniform(0, curr_params["radius"])
            x = START_POS[0] + r * np.cos(angle)
            y = START_POS[1] + r * np.sin(angle)
        return [np.clip(x, 100, FULL_VIEW_SIZE[0]-100),
                np.clip(y, 100, FULL_VIEW_SIZE[1]-100)]

    def reset(self, seed=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = np.array(self._generate_target(), dtype=np.float32)
        self.step_count = 0
        self.episode_reward = 0
        self.episode_count += 1
        
        return self._get_obs(), {}
    
    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        
        curr_noise = CURRICULUM_LEVELS[self.curriculum_level]["noise"]
        noise = np.random.normal(0, curr_noise, size=2)
        human_input = perfect_dir + noise
        human_input = human_input / np.linalg.norm(human_input) if np.linalg.norm(human_input) > 0 else np.array([0, 0])
        
        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir
        ])

    def step(self, action):
        self.step_count += 1
        gamma = float(action[0])
        
        obs = self._get_obs()
        human_input = obs[2:4]  # Updated indices since we removed velocity
        perfect_dir = obs[6:8]  # Updated indices
        
        if self.record_demonstrations:
            self.current_episode_states.append({
                'state': obs.tolist(),
                'action': gamma,
                'human_input': human_input.tolist(),
                'perfect_dir': perfect_dir.tolist()
            })
        
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        # Move directly using combined direction (no velocity tracking)
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        progress_reward = -dist_to_target / np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        gamma_penalty = -0.1 * (abs(gamma - 0.5) ** 2)
        
        if hasattr(self, 'last_gamma'):
            smoothness_penalty = -0.1 * abs(gamma - self.last_gamma)
        else:
            smoothness_penalty = 0
        self.last_gamma = gamma
        
        entropy_bonus = -0.01 * (gamma * np.log(gamma + 1e-10) + (1-gamma) * np.log(1-gamma + 1e-10))
        
        reward = progress_reward + gamma_penalty + smoothness_penalty + entropy_bonus
        self.episode_reward += reward
        
        done = False
        self.episode_success = False
        if dist_to_target < GOAL_DETECTION_RADIUS:
            done = True
            self.episode_success = True
            reward += 10.0
        elif self.step_count >= self.max_steps:
            done = True
            reward -= 5.0
            
        self.episode_done = done
        
        if done and self.record_demonstrations:
            self.demonstrations.append({
                'states': self.current_episode_states,
                'success': self.episode_success,
                'total_reward': self.episode_reward
            })
            self.current_episode_states = []
        
        self.render(human_input, perfect_dir, combined_dir, gamma)
        return self._get_obs(), reward, done, False, {}

    def save_demonstrations(self, filename):
        if self.demonstrations:
            with open(filename, 'w') as f:
                json.dump(self.demonstrations, f)
            print(f"Saved {len(self.demonstrations)} demonstrations to {filename}")

def train_ppo_with_viz(demonstration_file=None):
    # Create environment with visualization
    env = DynamicArbitrationEnv(render_mode="human", record_demonstrations=False)
    env = DummyVecEnv([lambda: env])
    
    # Policy kwargs with custom feature extractor
    policy_kwargs = dict(
        features_extractor_class=CustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=64),
        net_arch=[dict(pi=[64, 64], vf=[64, 64])]  # Separate networks for policy and value function
    )
    
    # Initialize PPO with custom settings
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        policy_kwargs=policy_kwargs,
        device='cpu',  # Force CPU usage
        verbose=1
    )
    
    # Add curriculum learning callback
    curriculum_callback = CurriculumCallback(env.envs[0])
    
    try:
        # Train the model
        model.learn(
            total_timesteps=1_000_000,
            callback=curriculum_callback
        )
        
        # Save the model and demonstrations
        model.save("dynamic_arbitration_ppo")
        env.envs[0].save_demonstrations("training_demonstrations.json")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current model...")
        model.save("dynamic_arbitration_ppo_interrupted")
        env.envs[0].save_demonstrations("training_demonstrations_interrupted.json")
    finally:
        pygame.quit()

def load_and_run_model():
    """Function to load and run a trained model"""
    env = DynamicArbitrationEnv(render_mode="human")
    env = DummyVecEnv([lambda: env])
    
    try:
        model = PPO.load("dynamic_arbitration_ppo")
        
        obs, _ = env.reset()
        while True:
            action, _states = model.predict(obs)
            obs, rewards, dones, _, infos = env.step(action)
            
            if dones:
                obs, _ = env.reset()
                
    except KeyboardInterrupt:
        print("\nVisualization stopped by user")
    finally:
        pygame.quit()

def analyze_model_performance():
    """Function to analyze model performance"""
    env = DynamicArbitrationEnv(render_mode=None)
    env = DummyVecEnv([lambda: env])
    model = PPO.load("dynamic_arbitration_ppo")
    
    n_episodes = 100
    rewards = []
    success_rate = 0
    gamma_values = []
    
    for episode in range(n_episodes):
        obs, _ = env.reset()
        episode_reward = 0
        episode_gammas = []
        done = False
        
        while not done:
            action, _states = model.predict(obs)
            obs, reward, done, _, info = env.step(action)
            episode_reward += reward
            episode_gammas.append(action[0])
        
        rewards.append(episode_reward)
        gamma_values.extend(episode_gammas)
        if episode_reward > 0:  # Consider it successful if reward is positive
            success_rate += 1
    
    print(f"Average reward: {np.mean(rewards):.2f}")
    print(f"Success rate: {success_rate/n_episodes*100:.2f}%")
    print(f"Average gamma: {np.mean(gamma_values):.2f}")
    print(f"Gamma std: {np.std(gamma_values):.2f}")

def main(mode='train'):
    """Main function to run the program in different modes"""
    try:
        if mode == 'train':
            train_ppo_with_viz()
        elif mode == 'run':
            load_and_run_model()
        elif mode == 'analyze':
            analyze_model_performance()
    except Exception as e:
        print(f"Error: {e}")
    finally:
        pygame.quit()

if __name__ == "__main__":
    train_ppo_with_viz()

Using cpu device




-----------------------------
| time/              |      |
|    fps             | 1640 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
Advancing to curriculum level 1
Advancing to curriculum level 2
Advancing to curriculum level 3
-----------------------------------------
| time/                   |             |
|    fps                  | 891         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.003775102 |
|    clip_fraction        | 0.0157      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.41       |
|    explained_variance   | -0.0389     |
|    learning_rate        | 0.0003      |
|    loss                 | 2.48        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00211    |
|    std              