In [4]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
import torch.multiprocessing as mp
import pygame
import matplotlib.pyplot as plt
import time

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Constants (unchanged)
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]
NOISE_MAGNITUDE = 0.5

# Colors (unchanged)
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)

class SpatialAttention(nn.Module):
    """Attention mechanism for spatial relationships."""
    def __init__(self, input_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        weights = self.attention(x)
        return x * weights

class DistanceAttention(nn.Module):
    """Special attention for distance-based features."""
    def __init__(self, input_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.LayerNorm(32),
            nn.Linear(32, input_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        weights = self.attention(x)
        return x * weights

class OptimizedCustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        
        # Position network (dot and target positions)
        self.position_encoder = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Dropout(0.1)
        )
        
        # Direction network (human input and perfect direction)
        self.direction_encoder = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Dropout(0.1)
        )
        
        # Distance network
        self.distance_encoder = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Dropout(0.1)
        )
        
        # Attention mechanisms
        self.spatial_attention = SpatialAttention(128)
        self.distance_attention = DistanceAttention(64)
        
        # Final processing with residual connections
        self.final_network = nn.Sequential(
            nn.Linear(128 + 128 + 64, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.1),
            nn.Linear(512, features_dim),
            nn.ReLU(),
            nn.LayerNorm(features_dim)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)
    
    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Split observation into components
        dot_pos = observations[:, :2]         # [batch_size, 2]
        human_input = observations[:, 2:4]    # [batch_size, 2]
        target_pos = observations[:, 4:6]     # [batch_size, 2]
        perfect_dir = observations[:, 6:8]    # [batch_size, 2]
        distances = observations[:, 8:9]      # [batch_size, 1]
        
        
        # Combine positions
        positions = torch.cat([dot_pos, target_pos], dim=1)    # [batch_size, 4]
        
        # Combine directions
        directions = torch.cat([human_input, perfect_dir], dim=1)  # [batch_size, 4]
        
        # Process each component with correct dimensions
        pos_features = self.position_encoder(positions)      # input: [batch_size, 4]
        dir_features = self.direction_encoder(directions)    # input: [batch_size, 4]
        dist_features = self.distance_encoder(distances)
        
        # Combine all features
        combined = torch.cat([pos_features, dir_features, dist_features], dim=1)
        
        return self.final_network(combined)

class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0, save_freq=1000):
        super().__init__(verbose)
        self.save_freq = save_freq
        self.episode_rewards = []
        self.episode_gammas = []
        self.current_episode_gammas = []
        self.total_reward = 0
        self.gamma_dist_pairs = []
        
        # Additional metrics
        self.actor_losses = []
        self.critic_losses = []
        self.entropy_values = []
        self.value_estimates = []
        
        self.action_low = None
        self.action_high = None

    def _on_step(self):
        if self.action_low is None:
            self.action_low = self.model.action_space.low[0]
            self.action_high = self.model.action_space.high[0]
        
        # Track gamma and distance
        actions = self.locals.get('actions')
        if actions is None:
            return True
            
        if torch.is_tensor(actions):
            gamma = actions[0].item()
        else:
            gamma = float(actions[0])
        gamma = np.clip(gamma, self.action_low, self.action_high)
        
        # Get observations if available
        if 'new_obs' in self.locals:  # Changed from 'observations' to 'new_obs'
            obs = self.locals['new_obs'][0]
            normalized_dist = float(obs[8])
            self.gamma_dist_pairs.append((normalized_dist, gamma))
        
        self.current_episode_gammas.append(gamma)
        
        # Track reward
        rewards = self.locals.get('rewards')
        if rewards is not None:
            if torch.is_tensor(rewards):
                reward = rewards[0].item()
            else:
                reward = float(rewards[0])
            self.total_reward += reward
        
        # Track additional metrics if available
        if hasattr(self.model, 'logger'):
            logger = self.model.logger.name_to_value
            if 'train/entropy_loss' in logger:
                self.entropy_values.append(logger['train/entropy_loss'])
            if 'train/value_loss' in logger:
                self.value_estimates.append(logger['train/value_loss'])
        
        # Handle episode completion
        if self.locals.get('dones', [False])[0]:
            self.episode_rewards.append(self.total_reward)
            if self.current_episode_gammas:
                self.episode_gammas.append(np.mean(self.current_episode_gammas))
            
            # Reset episode trackers
            self.current_episode_gammas = []
            self.total_reward = 0
            
            # Save intermediate results
            if len(self.episode_rewards) % self.save_freq == 0:
                self.save_metrics(f"metrics_checkpoint_{len(self.episode_rewards)}")
        
        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        
        # Save raw data
        np.save(os.path.join(save_dir, 'rewards.npy'), np.array(self.episode_rewards))
        np.save(os.path.join(save_dir, 'gammas.npy'), np.array(self.episode_gammas))
        np.save(os.path.join(save_dir, 'gamma_dist_pairs.npy'), np.array(self.gamma_dist_pairs))
        
        # Plot metrics
        self._plot_rewards(save_dir)
        self._plot_gammas(save_dir)
        self._plot_gamma_dist_relationship(save_dir)
        self._plot_additional_metrics(save_dir)
        
        # Save summary statistics
        self._save_summary(save_dir)

    def _plot_rewards(self, save_dir):
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_rewards)
        plt.title('Episode Rewards')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'rewards.png'))
        plt.close()

    def _plot_gammas(self, save_dir):
        plt.figure(figsize=(10, 6))
        plt.plot(self.episode_gammas)
        plt.title('Average Gamma per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Average Gamma')
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, 'gammas.png'))
        plt.close()

    def _plot_gamma_dist_relationship(self, save_dir):
        if self.gamma_dist_pairs:
            distances, gammas = zip(*self.gamma_dist_pairs)
            plt.figure(figsize=(10, 6))
            plt.hexbin(distances, gammas, gridsize=30, cmap='viridis')
            plt.colorbar(label='Count')
            plt.title('Gamma vs Distance Distribution')
            plt.xlabel('Normalized Distance')
            plt.ylabel('Gamma')
            plt.savefig(os.path.join(save_dir, 'gamma_dist_distribution.png'))
            plt.close()

    def _plot_additional_metrics(self, save_dir):
        if self.entropy_values:
            plt.figure(figsize=(10, 6))
            plt.plot(self.entropy_values, label='Entropy')
            plt.plot(self.value_estimates, label='Value Loss')
            plt.title('Training Metrics')
            plt.xlabel('Update')
            plt.ylabel('Value')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, 'training_metrics.png'))
            plt.close()

    def _save_summary(self, save_dir):
        with open(os.path.join(save_dir, 'summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if self.episode_rewards:
                f.write(f"Average Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Best Reward: {np.max(self.episode_rewards):.3f}\n")
                f.write(f"Average Gamma: {np.mean(self.episode_gammas):.3f}\n")
                f.write(f"Last 100 Episodes Avg Reward: {np.mean(self.episode_rewards[-100:]):.3f}\n")
                f.write(f"Last 100 Episodes Avg Gamma: {np.mean(self.episode_gammas[-100:]):.3f}\n")

class DynamicArbitrationEnv(gym.Env):
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)
        
        self.goal_radius_norm = GOAL_DETECTION_RADIUS / self.max_dist
        
        # Target change parameters
        self.change_target_interval = 50
        self.steps_since_target_change = 0
        self.target_pos = None
        
        # Enhanced distance thresholds for smoother transitions
        self.VERY_CLOSE_THRESHOLD = self.goal_radius_norm * 2.0  # Within 2x goal radius
        self.CLOSE_THRESHOLD = self.goal_radius_norm * 4.0       # Within 4x goal radius
        self.MID_THRESHOLD = self.goal_radius_norm * 8.0         # Within 8x goal radius
        
        # Observation space (unchanged)
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1, 0]),
            high=np.array([
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                1
            ]),
            dtype=np.float32
        )
        
        # Action space (unchanged)
        self.action_space = spaces.Box(
            low=np.array([0.0]),
            high=np.array([0.4]),
            dtype=np.float32
        )
        
        self.dot_pos = None
        self.step_count = 0
        self.max_steps = 300

    def _generate_target(self):
        """Generate a new target position, away from current position"""
        while True:
            x = np.random.uniform(100, FULL_VIEW_SIZE[0] - 100)
            y = np.random.uniform(100, FULL_VIEW_SIZE[1] - 100)
            new_pos = np.array([x, y], dtype=np.float32)
            
            if self.target_pos is not None:
                dist_to_old = np.linalg.norm(new_pos - self.target_pos)
                if dist_to_old < FULL_VIEW_SIZE[0] / 4:
                    continue
            
            if self.dot_pos is not None:
                dist_to_dot = np.linalg.norm(new_pos - self.dot_pos)
                if dist_to_dot < GOAL_DETECTION_RADIUS * 2:
                    continue
            
            return new_pos

    def _get_obs(self):
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0])
        normalized_dist = dist / self.max_dist

        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        human_input = perfect_dir + noise
        if np.linalg.norm(human_input) > 0:
            human_input = human_input / np.linalg.norm(human_input)

        return np.concatenate([
            self.dot_pos,
            human_input,
            self.target_pos,
            perfect_dir,
            [normalized_dist]
        ])

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.dot_pos = np.array(START_POS, dtype=np.float32)
        self.target_pos = self._generate_target()
        self.step_count = 0
        self.steps_since_target_change = 0
        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1
        self.steps_since_target_change += 1
        
        # Get gamma from action
        gamma = float(np.clip(action[0], 0.0, 0.4))
        
        # Get current observation
        obs = self._get_obs()
        human_input = obs[2:4]
        perfect_dir = obs[6:8]
        normalized_dist = obs[8]
        
        # Movement logic
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_input
        if np.linalg.norm(combined_dir) > 0:
            combined_dir = combined_dir / np.linalg.norm(combined_dir)
        
        move_speed = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_speed
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])
        
        # Distance calculations
        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)
        goal_radius_norm = GOAL_DETECTION_RADIUS / self.max_dist
        
        # Define distance thresholds relative to goal radius
        VERY_CLOSE_THRESHOLD = goal_radius_norm * 2.0
        CLOSE_THRESHOLD = goal_radius_norm * 4.0
        MID_THRESHOLD = goal_radius_norm * 8.0
        
        # Gamma targeting with smoother transitions
        if normalized_dist < VERY_CLOSE_THRESHOLD:  # Very close - smoothly approach 0.4
            progress_ratio = 1.0 - (normalized_dist / VERY_CLOSE_THRESHOLD)
            target_gamma = 0.2 + (0.2 * progress_ratio)
            gamma_reward = -25.0 * ((gamma - target_gamma) ** 2) * (1 - progress_ratio)
            
        elif normalized_dist < CLOSE_THRESHOLD:  # Transition zone
            progress_ratio = 1.0 - ((normalized_dist - VERY_CLOSE_THRESHOLD) / 
                                  (CLOSE_THRESHOLD - VERY_CLOSE_THRESHOLD))
            target_gamma = 0.2 + (0.1 * progress_ratio)
            gamma_reward = -15.0 * ((gamma - target_gamma) ** 2)
            
        else:  # Far from target
            target_gamma = 0.2
            if gamma < target_gamma:
                gamma_reward = -5.0 * (gamma - target_gamma) ** 2
            else:
                gamma_reward = -20.0 * (gamma - target_gamma) ** 2
        
        # Small alignment reward
        human_alignment = np.dot(combined_dir, human_input)
        alignment_reward = 0.1 * human_alignment
        
        # Goal checking and target changing
        terminated = False
        truncated = False
        goal_reward = 0.0
        
        # Check if reached current target
        if dist_to_target < GOAL_DETECTION_RADIUS:
            if normalized_dist < VERY_CLOSE_THRESHOLD:
                if gamma > 0.35:
                    goal_reward = 50.0  # Big bonus for high gamma at goal
                elif gamma > 0.3:
                    goal_reward = 20.0  # Medium bonus for moderately high gamma
                else:
                    goal_reward = 5.0   # Small reward otherwise
            else:
                goal_reward = 5.0
            
            self.target_pos = self._generate_target()
            self.steps_since_target_change = 0
            
        elif self.steps_since_target_change >= self.change_target_interval:
            self.target_pos = self._generate_target()
            self.steps_since_target_change = 0
            if gamma <= 0.25:  # Reward low gamma during target changes
                goal_reward = 0.5
        
        if self.step_count >= self.max_steps:
            truncated = True
        
        # Combined reward
        reward = (5.0 * gamma_reward +     # Main gamma control
                 0.1 * alignment_reward +   # Minor alignment influence
                 goal_reward)              # Goal achievement
        
        if self.render_mode == "human":
            self.render(human_input, perfect_dir, combined_dir, gamma, reward)
        
        return self._get_obs(), reward, terminated, truncated, {}

    def render(self, human_input, perfect_dir, combined_dir, gamma, reward):
        if not hasattr(self, 'screen'):
            pygame.init()
            self.screen = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Dynamic Arbitration Training")
            self.font = pygame.font.Font(None, 24)
            self.last_render_time = time.time()
            
        current_time = time.time()
        if current_time - self.last_render_time < 1/60:
            return
        self.last_render_time = current_time
        
        self.screen.fill(WHITE)
        
        # Draw target
        pygame.draw.circle(self.screen, YELLOW, 
                         (int(self.target_pos[0]), int(self.target_pos[1])), 
                         TARGET_RADIUS)
        
        # Draw dot
        pygame.draw.circle(self.screen, BLACK, 
                         (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                         DOT_RADIUS, 2)
        
        arrow_length = 50
        
        # Draw direction arrows
        if not np.all(perfect_dir == 0):
            end_pos = (int(self.dot_pos[0] + perfect_dir[0] * arrow_length),
                      int(self.dot_pos[1] + perfect_dir[1] * arrow_length))
            pygame.draw.line(self.screen, GREEN, 
                           (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                           end_pos, 2)
            
        if not np.all(human_input == 0):
            end_pos = (int(self.dot_pos[0] + human_input[0] * arrow_length),
                      int(self.dot_pos[1] + human_input[1] * arrow_length))
            pygame.draw.line(self.screen, BLUE, 
                           (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                           end_pos, 2)
            
        if not np.all(combined_dir == 0):
            end_pos = (int(self.dot_pos[0] + combined_dir[0] * arrow_length),
                      int(self.dot_pos[1] + combined_dir[1] * arrow_length))
            pygame.draw.line(self.screen, RED, 
                           (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                           end_pos, 2)
        
        # Draw text information
        texts = [
            f"Step: {self.step_count}",
            f"Gamma: {gamma:.2f}",
            f"Reward: {reward:.2f}",
            f"Steps until target change: {self.change_target_interval - self.steps_since_target_change}"
        ]
        
        for i, text in enumerate(texts):
            text_surface = self.font.render(text, True, BLACK)
            self.screen.blit(text_surface, (10, 10 + i*25))
        
        pygame.display.flip()

def make_env(rank, seed=0):
    def _init():
        env = DynamicArbitrationEnv(render_mode=None)
        env.reset(seed=seed + rank)
        return env
    set_random_seed(seed)
    return _init

def train():
    n_envs = 64
    env = DummyVecEnv([make_env(i) for i in range(n_envs)])
    
    policy_kwargs = dict(
        features_extractor_class=OptimizedCustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=256),
        net_arch=dict(
            pi=[256, 256, 128],
            vf=[256, 256, 128]
        )
    )
    
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=2e-5,
        n_steps=512,
        batch_size=2048,
        ent_coef=0.01,
        n_epochs=6,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        normalize_advantage=True,
        policy_kwargs=policy_kwargs,
        device=device,
        verbose=1
    )
    
    metrics_callback = MetricsCallback()
    
    try:
        print("Starting training...")
        model.learn(
            total_timesteps=500_000,
            callback=metrics_callback
        )
        
        model.save("dynamic_arbitration_ppo_attention")
        print("Model saved as dynamic_arbitration_ppo_attention.zip")
        
        metrics_callback.save_metrics()
        print("Metrics saved to training_metrics directory")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current state...")
        model.save("dynamic_arbitration_ppo_interrupted_attention")
        metrics_callback.save_metrics(save_dir="training_metrics_interrupted_attention")
        print("Intermediate state saved")
    
    finally:
        env.close()

if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    train()

Using device: cuda
Using cuda device
Starting training...


  gamma = float(actions[0].item() if torch.is_tensor(actions) else actions[0])


KeyError: 'observations'