In [5]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
import torch.multiprocessing as mp

import matplotlib.pyplot as plt
import time

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

###############################################################################
# Constants & Hyperparameters
###############################################################################
FULL_VIEW_SIZE = (1200, 800)
MAX_SPEED = 3
DOT_RADIUS = 30
TARGET_RADIUS = 10
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS
START_POS = [FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2]

NOISE_MAGNITUDE = 0.5  
NUM_GOALS = 3  

###############################################################################
# Custom Features Extractor
###############################################################################
class OptimizedCustomFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super().__init__(observation_space, features_dim)
        n_input = int(np.prod(observation_space.shape))
        
        self.network = nn.Sequential(
            nn.Linear(n_input, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, features_dim),
            nn.ReLU(),
            nn.BatchNorm1d(features_dim)
        )
        
        # Orthogonal init
        for m in self.network.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0)

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.network(observations)

###############################################################################
# Metrics Callback
###############################################################################
class MetricsCallback(BaseCallback):
    """
    Tracks:
      - Episode total rewards
      - Average gamma per episode
      - Actor/critic losses
    Saves these metrics as plots + a text summary.
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_gammas = []
        self.current_episode_gammas = []
        self.total_reward = 0
        
        self.actor_losses = []
        self.critic_losses = []
        self.action_low = None
        self.action_high = None

    def _on_step(self):
        # Record action bounds once
        if self.action_low is None:
            self.action_low = self.model.action_space.low[0]
            self.action_high = self.model.action_space.high[0]

        # Current gamma
        actions = self.locals['actions']
        if torch.is_tensor(actions):
            gamma = actions[0].item()
        else:
            gamma = float(actions[0])
        gamma = np.clip(gamma, self.action_low, self.action_high)
        self.current_episode_gammas.append(gamma)

        # Current reward
        rewards = self.locals['rewards']
        if torch.is_tensor(rewards):
            reward = rewards[0].item()
        else:
            reward = float(rewards[0])
        self.total_reward += reward

        # Actor/Critic losses from SB3 logger (if available)
        if (hasattr(self.model, 'logger') 
            and hasattr(self.model.logger, 'name_to_value')):
            logger_dict = self.model.logger.name_to_value
            if 'train/policy_gradient_loss' in logger_dict:
                self.actor_losses.append(logger_dict['train/policy_gradient_loss'])
            if 'train/value_loss' in logger_dict:
                self.critic_losses.append(logger_dict['train/value_loss'])

        # Episode end?
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.total_reward)
            if len(self.current_episode_gammas) > 0:
                avg_g = np.mean(self.current_episode_gammas)
            else:
                avg_g = 0.0
            self.episode_gammas.append(avg_g)
            
            # Reset
            self.current_episode_gammas = []
            self.total_reward = 0
        
        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        
        # 1) Episode reward
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_rewards, label="Episode Reward")
        plt.title('Average Reward per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_reward.png'))
        plt.close()
        
        # 2) Average gamma
        plt.figure(figsize=(8, 5))
        plt.plot(self.episode_gammas, label="Average Gamma")
        plt.title('Average Gamma per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Gamma')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'average_gamma.png'))
        plt.close()
        
        # 3) Actor & critic loss
        if self.actor_losses and self.critic_losses:
            plt.figure(figsize=(8, 5))
            plt.plot(self.actor_losses, label="Actor Loss")
            plt.plot(self.critic_losses, label="Critic Loss")
            plt.title('Actor & Critic Loss')
            plt.xlabel('Training Update')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.legend()
            plt.savefig(os.path.join(save_dir, 'actor_critic_loss.png'))
            plt.close()

        # Summary
        with open(os.path.join(save_dir, 'training_summary.txt'), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if len(self.episode_rewards) > 0:
                f.write(f"Overall Average Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Overall Average Gamma: {np.mean(self.episode_gammas):.3f}\n")
                f.write(f"Best Episode Reward: {max(self.episode_rewards):.3f}\n")
                f.write(f"Worst Episode Reward: {min(self.episode_rewards):.3f}\n")

###############################################################################
# Environment with Exponential Ramp & Override Penalty
###############################################################################
class DynamicArbitrationEnv(gym.Env):
    """
    Environment that:
      - Has multiple goals,
      - Chooses gamma in [0.15..0.7],
      - Gains more reward for picking higher gamma only when close to the goal,
      - Penalizes big changes in gamma, misalignment with "human" input, etc.
    """
    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0] ** 2 + FULL_VIEW_SIZE[1] ** 2)
        
        # Multi-goal setup
        self.NUM_GOALS = NUM_GOALS
        self.targets = []
        self.current_target_idx = 0

        # For continuity penalty on gamma
        self.prev_gamma = None

        # Observations: [dot_x, dot_y, human_x, human_y,
        #                target_x, target_y, perfect_x, perfect_y,
        #                normalized_dist]
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -1, -1, 0, 0, -1, -1, 0], dtype=np.float32),
            high=np.array([
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
                1, 1,
                1
            ], dtype=np.float32),
            dtype=np.float32
        )

        # ---------------------------------------------------------------------
        # Action: gamma in [0.15..0.7]
        # ---------------------------------------------------------------------
        self.action_space = spaces.Box(
            low=np.array([0.15], dtype=np.float32),
            high=np.array([0.7], dtype=np.float32),
            dtype=np.float32
        )

        self.dot_pos = None
        self.step_count = 0
        self.max_steps = 300

    def _generate_target(self):
        """Random target within screen boundaries."""
        x = np.random.uniform(100, FULL_VIEW_SIZE[0] - 100)
        y = np.random.uniform(100, FULL_VIEW_SIZE[1] - 100)
        return np.array([x, y], dtype=np.float32)

    def _get_obs(self):
        """
        Synthetic 'human input' = perfect_dir + Gaussian noise, normalized.
        Returns 9D observation.
        """
        to_target = self.target_pos - self.dot_pos
        dist = np.linalg.norm(to_target)
        if dist > 1e-6:
            perfect_dir = to_target / dist
        else:
            perfect_dir = np.array([0.0, 0.0], dtype=np.float32)
        
        normalized_dist = dist / self.max_dist

        # "human" input is perfect_dir + noise
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        human_input = perfect_dir + noise
        norm_h = np.linalg.norm(human_input)
        if norm_h > 1e-6:
            human_input /= norm_h

        # Combine
        return np.concatenate([
            self.dot_pos,        # (x, y)
            human_input,         # (hx, hy)
            self.target_pos,     # (tx, ty)
            perfect_dir,         # (px, py)
            [normalized_dist]
        ]).astype(np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.step_count = 0
        self.prev_gamma = None

        # Dot starts in the center
        self.dot_pos = np.array(START_POS, dtype=np.float32)

        # Generate multiple random targets
        self.targets = [self._generate_target() for _ in range(self.NUM_GOALS)]
        self.current_target_idx = 0
        self.target_pos = self.targets[self.current_target_idx]

        return self._get_obs(), {}

    def step(self, action):
        self.step_count += 1

        # Clip gamma to [0.15..0.7]
        gamma = float(np.clip(action[0], 0.15, 0.7))

        # Occasionally switch targets mid-episode (2% chance)
        if np.random.rand() < 0.02:
            old_idx = self.current_target_idx
            possible_indices = [i for i in range(self.NUM_GOALS) if i != old_idx]
            if possible_indices:
                self.current_target_idx = np.random.choice(possible_indices)
            self.target_pos = self.targets[self.current_target_idx]

        obs = self._get_obs()
        human_dir = obs[2:4]       # (hx, hy)
        perfect_dir = obs[6:8]     # (px, py)
        dist_ratio = obs[8]

        # Weighted combination
        combined_dir = gamma * perfect_dir + (1 - gamma) * human_dir
        norm_comb = np.linalg.norm(combined_dir)
        if norm_comb > 1e-6:
            combined_dir /= norm_comb

        # Move the dot
        move_vec = combined_dir * MAX_SPEED
        new_pos = self.dot_pos + move_vec
        self.dot_pos = np.clip(new_pos, [0, 0], [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]])

        dist_to_target = np.linalg.norm(self.target_pos - self.dot_pos)

        # ---------------------------------------------------------------------
        # Reward Shaping
        # ---------------------------------------------------------------------
        
        # (A) Exponential distance-based target gamma in [0.15..0.7]
        alpha = 3.0
        target_gamma = 0.15 + 0.55 * np.exp(-alpha * dist_ratio)
        
        # (B) penalty for deviation from target_gamma
        gamma_penalty_coef = 20.0
        gamma_penalty = -gamma_penalty_coef * (gamma - target_gamma)**2

        # (C) smoothness penalty
        if self.prev_gamma is None:
            smoothness_penalty = 0.0
        else:
            smoothness_coef = 2.0
            smoothness_penalty = -smoothness_coef * (gamma - self.prev_gamma)**2
        self.prev_gamma = gamma

        # (D) small alignment reward if final movement aligns with human_dir
        alignment_reward = 0.05 * np.dot(combined_dir, human_dir)

        # (E) override penalty if gamma is high but we’re ignoring the user
        #     "misalignment" = 1 - dot( combined_dir, human_dir )
        #     bigger gamma => bigger penalty if not aligned
        dot_ch = np.dot(combined_dir, human_dir)
        dot_ch = max(-1.0, min(1.0, dot_ch))  # clamp to [-1..1]
        misalignment = 1.0 - dot_ch
        override_penalty_coef = 5.0
        override_penalty = -override_penalty_coef * gamma * misalignment

        # (F) goal reward
        goal_reward = 0.0
        if dist_to_target < GOAL_DETECTION_RADIUS:
            goal_reward = 10.0
            # Optionally switch to a new target
            old_idx = self.current_target_idx
            possible_indices = [i for i in range(self.NUM_GOALS) if i != old_idx]
            if possible_indices:
                self.current_target_idx = np.random.choice(possible_indices)
            self.target_pos = self.targets[self.current_target_idx]

        # Combine final reward
        reward = (
            gamma_penalty
            + smoothness_penalty
            + alignment_reward
            + override_penalty
            + goal_reward
        )

        # Terminate or truncate
        terminated = False
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True

        return self._get_obs(), reward, terminated, truncated, {}

###############################################################################
# VecEnv Helper
###############################################################################
def make_env(rank, seed=0):
    def _init():
        env = DynamicArbitrationEnv(render_mode=None)
        env.reset(seed=seed + rank)
        return env
    set_random_seed(seed)
    return _init

###############################################################################
# Training Routine
###############################################################################
def train():
    model_save_dir = "trained_models"
    metrics_dir = "training_metrics"
    os.makedirs(model_save_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Number of parallel environments
    n_envs = 8
    env = DummyVecEnv([make_env(i) for i in range(n_envs)])
    
    policy_kwargs = dict(
        features_extractor_class=OptimizedCustomFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=dict(
            pi=[256, 128],
            vf=[256, 128]
        )
    )
    
    # Build the PPO model
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-5,
        n_steps=256,
        batch_size=2048,
        ent_coef=0.005,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        normalize_advantage=True,
        policy_kwargs=policy_kwargs,
        device=device,
        verbose=1
    )
    
    metrics_callback = MetricsCallback()
    
    try:
        print("Starting training...")
        model.learn(total_timesteps=5_000_000, callback=metrics_callback)
        # ^ you may want to increase or decrease total_timesteps 
        #   e.g. 120_000_000, depending on your compute resources
        
        # Save final model
        model.save(os.path.join(model_save_dir, "dynamic_arbitration_ppo_exp_override"))
        print("Model saved as dynamic_arbitration_ppo_exp_override.zip")
        
        # Save metrics
        metrics_callback.save_metrics(save_dir=metrics_dir)
        print("Metrics saved to training_metrics directory")
        
    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving current state...")
        model.save(os.path.join(model_save_dir, "dynamic_arbitration_ppo_exp_override_interrupted"))
        metrics_callback.save_metrics(save_dir="training_metrics_interrupted")
        print("Intermediate state saved.")
    
    finally:
        env.close()

###############################################################################
# Main
###############################################################################
if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    train()


Using device: cuda
Using cuda device
Starting training...


  gamma = float(actions[0])


-----------------------------
| time/              |      |
|    fps             | 3271 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 3304          |
|    iterations           | 2             |
|    time_elapsed         | 1             |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 6.7497866e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.42         |
|    explained_variance   | 0.00237       |
|    learning_rate        | 3e-05         |
|    loss                 | 587           |
|    n_updates            | 4             |
|    policy_gradient_loss | -0.00089      |
|    std                  | 1             |
|    value_loss           | 1.18e+03      