In [4]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import math
import random
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import entropy
from sklearn.isotonic import IsotonicRegression
import pickle
import gymnasium as gym
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback, CallbackList
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.policies import ActorCriticPolicy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

# For rendering (optional):
try:
    import pygame
except ImportError:
    pygame = None

###############################################################################
# CONSTANTS & UTILS
###############################################################################
FULL_VIEW_SIZE = (1200, 800)
SCALING_FACTOR_X = FULL_VIEW_SIZE[0] / 600.0
SCALING_FACTOR_Y = FULL_VIEW_SIZE[1] / 600.0
SCALING_FACTOR   = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

DOT_RADIUS       = int(15 * SCALING_FACTOR)
TARGET_RADIUS    = int(10 * SCALING_FACTOR)
OBSTACLE_RADIUS  = int(10 * SCALING_FACTOR)
COLLISION_BUFFER = int(5  * SCALING_FACTOR)
MAX_SPEED        = 3 * SCALING_FACTOR
NOISE_MAGNITUDE  = 2.5
RENDER_FPS       = 30

START_POS = np.array([FULL_VIEW_SIZE[0]//2, FULL_VIEW_SIZE[1]//2], dtype=np.float32)

# Colors for visualization
WHITE = (255, 255, 255)
GRAY  = (128, 128, 128)
YELLOW= (255, 255, 0)
BLACK = (0, 0, 0)
RED   = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE  = (0, 0, 255)
PURPLE= (128, 0, 128)

def distance(a, b):
    return math.hypot(a[0] - b[0], a[1] - b[1])

def check_line_collision(start, end, center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    fx = center[0] - start[0]
    fy = center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 < 1e-9:
        return distance(start, center) <= radius
    t = max(0, min(1, (fx*dx + fy*dy) / l2))
    px = start[0] + t*dx
    py = start[1] + t*dy
    return distance((px, py), center) <= radius

def line_collision(pos, new_pos, obstacles):
    for obs in obstacles:
        if check_line_collision(pos, new_pos, obs, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos, obstacles):
    for obs in obstacles:
        if distance(pos, obs) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def potential_field_dir(pos, goal, obstacles):
    """
    Returns a normalized direction from pos to goal,
    plus repulsion from obstacles.
    """
    gx = goal[0] - pos[0]
    gy = goal[1] - pos[1]
    dg = math.hypot(gx, gy)
    if dg < 1e-6:
        return np.zeros(2, dtype=np.float32)
    att = np.array([gx / dg, gy / dg], dtype=np.float32)

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 23.0 * SCALING_FACTOR
    repulsion_gain   = 30000.0

    for obs in obstacles:
        dx = pos[0] - obs[0]
        dy = pos[1] - obs[1]
        dobs = math.hypot(dx, dy)
        if dobs < 1e-9:
            continue
        if dobs < repulsion_radius:
            pushx    = dx / dobs
            pushy    = dy / dobs
            strength = repulsion_gain / (dobs**2)
            repulse_x += pushx * strength
            repulse_y += pushy * strength

    px = att[0] + repulse_x
    py = att[1] + repulse_y
    mg = math.hypot(px, py)
    if mg < 1e-9:
        return np.zeros(2, dtype=np.float32)
    return np.array([px / mg, py / mg], dtype=np.float32)

###############################################################################
# BAYESIAN GOAL INFERENCE MODEL
###############################################################################
class BayesianGoalInference:
    """
    Recursive Bayesian goal inference model that maintains and updates
    a probability distribution over potential goals based on the human inputs.
    """
    def __init__(self, beta=10.0, w_theta=0.7, w_d=0.3, decay_rate=0.85):
        """
        Initialize the Bayesian goal inference model.
        
        Args:
            beta (float): Rationality parameter that determines how closely
                         the human follows the optimal policy
            w_theta (float): Weight for angular deviation in the cost function
            w_d (float): Weight for distance deviation in the cost function
            decay_rate (float): Decay rate for temporal smoothing of probabilities
        """
        self.beta = beta                # Rationality parameter
        self.w_theta = w_theta          # Weight for angular deviation in cost
        self.w_d = w_d                  # Weight for distance deviation in cost
        self.decay_rate = decay_rate    # For temporal smoothing
        self.goals = []                 # List of potential goals
        self.priors = None              # Prior probabilities for goals
        self.goal_probs = None          # Current goal probabilities
        self.calibrator = None          # For confidence calibration
        self.history = []               # History of probability updates
        self.max_hist_len = 30          # Maximum history length to store
    
    def initialize_goals(self, goals):
        """
        Initialize goal distribution with equal priors.
        
        Args:
            goals (list): List of potential goal positions
        """
        self.goals = goals
        n_goals = len(goals)
        # Uniform prior
        self.priors = np.ones(n_goals) / n_goals
        self.goal_probs = self.priors.copy()
        self.history = []
    
    def load_calibrator(self, calibrator_path):
        """
        Load a pre-trained confidence calibrator.
        
        Args:
            calibrator_path (str): Path to the saved calibrator model
        """
        try:
            with open(calibrator_path, 'rb') as f:
                self.calibrator = pickle.load(f)
            print(f"Loaded calibrator from {calibrator_path}")
        except:
            print(f"Could not load calibrator from {calibrator_path}")
            self.calibrator = None
    
    def train_calibrator(self, confidences, accuracies):
        """
        Train a confidence calibrator using isotonic regression.
        
        Args:
            confidences (list): List of predicted confidence values
            accuracies (list): List of binary accuracy values (0 or 1)
        """
        if len(confidences) < 10:
            print("Not enough data to train calibrator")
            return
        
        self.calibrator = IsotonicRegression(out_of_bounds='clip')
        self.calibrator.fit(confidences, accuracies)
        
        # Save the calibrator
        os.makedirs("models", exist_ok=True)
        with open("models/calibrator.pkl", 'wb') as f:
            pickle.dump(self.calibrator, f)
        print("Calibrator trained and saved")
    
    def compute_cost(self, human_input, agent_pos, goal_pos):
        """
        Compute the cost of human input based on deviation from optimal action.
        
        Args:
            human_input (numpy.ndarray): Normalized direction vector of human input
            agent_pos (numpy.ndarray): Current agent position
            goal_pos (numpy.ndarray): Goal position
            
        Returns:
            float: Cost value
        """
        # Compute optimal direction vector to goal
        goal_dir = goal_pos - agent_pos
        dist = np.linalg.norm(goal_dir)
        if dist < 1e-6:
            return 0.0
        
        goal_dir = goal_dir / dist  # Normalize
        
        # Compute angular deviation (radians)
        dot_product = np.clip(np.dot(human_input, goal_dir), -1.0, 1.0)
        theta_dev = abs(np.arccos(dot_product))
        
        # Compute distance (magnitude) deviation
        h_magnitude = np.linalg.norm(human_input)
        
        # Optimal magnitude scales inversely with proximity to target
        # (slow down as approaching the target)
        min_dist = TARGET_RADIUS + DOT_RADIUS
        if dist < min_dist:
            opt_magnitude = 0.0
        else:
            dist_factor = min(1.0, (dist - min_dist) / (200 * SCALING_FACTOR))
            opt_magnitude = dist_factor * 1.0  # Scale to [0,1]
        
        d_dev = abs(1.0 - h_magnitude / max(opt_magnitude, 1e-6))
        
        # Combine angular and distance deviation with weights
        cost = self.w_theta * theta_dev + self.w_d * d_dev
        return cost
    
    def update(self, agent_pos, human_input, obstacles=None):
        """
        Update the goal probabilities based on the observed human input.
        
        Args:
            agent_pos (numpy.ndarray): Current agent position
            human_input (numpy.ndarray): Normalized direction vector of human input
            obstacles (list, optional): List of obstacle positions
            
        Returns:
            numpy.ndarray: Updated goal probabilities
        """
        if len(self.goals) == 0:
            return np.array([])
        
        # Process human input toward each goal
        likelihoods = np.zeros(len(self.goals))
        costs = np.zeros(len(self.goals))
        
        for i, goal in enumerate(self.goals):
            cost = self.compute_cost(human_input, agent_pos, goal)
            costs[i] = cost
            # Noisy rational model: P(action|goal) ∝ exp(-β * cost)
            likelihoods[i] = np.exp(-self.beta * cost)
        
        # Normalize likelihoods to prevent numerical issues
        if np.sum(likelihoods) > 0:
            likelihoods = likelihoods / np.sum(likelihoods)
        else:
            likelihoods = np.ones_like(likelihoods) / len(likelihoods)
        
        # Bayesian update: P(goal|action) ∝ P(action|goal) * P(goal)
        raw_posteriors = likelihoods * self.goal_probs
        
        # Normalize
        if np.sum(raw_posteriors) > 0:
            posteriors = raw_posteriors / np.sum(raw_posteriors)
        else:
            # If all posteriors are zero, revert to prior
            posteriors = self.priors.copy()
        
        # Apply temporal smoothing using exponential moving average
        self.goal_probs = self.decay_rate * self.goal_probs + (1 - self.decay_rate) * posteriors
        
        # Normalize again after smoothing
        if np.sum(self.goal_probs) > 0:
            self.goal_probs = self.goal_probs / np.sum(self.goal_probs)
        
        # Apply calibration if available
        if self.calibrator is not None:
            max_prob = np.max(self.goal_probs)
            calibrated_max = self.calibrator.predict([max_prob])[0]
            # Adjust other probabilities proportionally
            if max_prob > 0:
                scale_factor = calibrated_max / max_prob
                self.goal_probs = self.goal_probs * scale_factor
                remainder = 1.0 - np.sum(self.goal_probs)
                if remainder > 0:
                    # Distribute remainder proportionally to other goals
                    indices = np.arange(len(self.goal_probs))
                    max_idx = np.argmax(self.goal_probs)
                    other_indices = indices[indices != max_idx]
                    if len(other_indices) > 0:
                        other_sum = np.sum(self.goal_probs[other_indices])
                        if other_sum > 0:
                            for idx in other_indices:
                                self.goal_probs[idx] += remainder * (self.goal_probs[idx] / other_sum)
                        else:
                            # If all other probs are zero, distribute uniformly
                            for idx in other_indices:
                                self.goal_probs[idx] += remainder / len(other_indices)
        
        # Keep track of probabilities history
        self.history.append(self.goal_probs.copy())
        if len(self.history) > self.max_hist_len:
            self.history.pop(0)
        
        return self.goal_probs
    
    def get_goal_probs(self):
        """
        Get the current goal probabilities.
        
        Returns:
            numpy.ndarray: Current goal probabilities
        """
        return self.goal_probs
    
    def get_most_likely_goal(self):
        """
        Get the most likely goal and its probability.
        
        Returns:
            tuple: (goal_position, probability)
        """
        if len(self.goals) == 0:
            return None, 0.0
        
        max_idx = np.argmax(self.goal_probs)
        return self.goals[max_idx], self.goal_probs[max_idx]
    
    def get_entropy(self):
        """
        Calculate the entropy of the current goal distribution.
        
        Returns:
            float: Entropy value
        """
        return entropy(self.goal_probs)
    
    def compute_expert_recommendation(self, agent_pos, obstacles):
        """
        Compute the expert's recommended action based on goal probabilities.
        
        Args:
            agent_pos (numpy.ndarray): Current agent position
            obstacles (list): List of obstacle positions
            
        Returns:
            numpy.ndarray: Normalized direction vector for expert recommendation
        """
        if len(self.goals) == 0:
            return np.zeros(2, dtype=np.float32)
        
        # Compute weighted direction based on goal probabilities
        weighted_dir = np.zeros(2, dtype=np.float32)
        
        for i, goal in enumerate(self.goals):
            # Get the expert direction for this goal using potential field
            expert_dir = potential_field_dir(agent_pos, goal, obstacles)
            # Weight by goal probability
            weighted_dir += self.goal_probs[i] * expert_dir
        
        # Normalize
        magnitude = np.linalg.norm(weighted_dir)
        if magnitude > 1e-6:
            weighted_dir = weighted_dir / magnitude
        
        return weighted_dir
    
    def reset(self):
        """Reset the model to initial state with uniform priors."""
        if len(self.goals) > 0:
            self.goal_probs = self.priors.copy()
        self.history = []

###############################################################################
# CUSTOM POLICY: DUAL-HEAD NETWORK ARCHITECTURE
###############################################################################
class DualHeadMlpPolicy(ActorCriticPolicy):
    """
    Custom policy with a dual-head architecture:
    - Shared encoder for feature extraction
    - Separate heads for goal inference and assistance determination
    """
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super(DualHeadMlpPolicy, self).__init__(
            observation_space,
            action_space,
            lr_schedule,
            **kwargs
        )
    
    def _build_mlp_extractor(self, config=None):
        """
        Build the feature extraction network.
        
        Overridden to create a specialized architecture with better feature extraction
        for both goal inference and policy determination.
        """
        # Create a custom MLP extractor with separated streams
        # for goal inference and policy determination
        class DualHeadExtractor(nn.Module):
            def __init__(self, feature_dim, goal_features, policy_features):
                super(DualHeadExtractor, self).__init__()
                
                # Shared encoder
                self.shared = nn.Sequential(
                    nn.Linear(feature_dim, 256),
                    nn.ReLU(),
                    nn.Linear(256, 256),
                    nn.ReLU()
                )
                
                # Goal inference branch - predicts goal probabilities
                self.goal_branch = nn.Sequential(
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, goal_features),
                    nn.Softmax(dim=-1)
                )
                
                # Policy determination branch
                self.pi_branch = nn.Sequential(
                    nn.Linear(256 + goal_features, 256),  # Includes goal probs
                    nn.ReLU(),
                    nn.Linear(256, policy_features),
                    nn.ReLU()
                )
                
                # Value branch
                self.vf_branch = nn.Sequential(
                    nn.Linear(256 + goal_features, 256),  # Includes goal probs
                    nn.ReLU(),
                    nn.Linear(256, policy_features),
                    nn.ReLU()
                )
            
            def forward(self, features):
                shared_features = self.shared(features)
                
                # Goal inference prediction
                goal_probs = self.goal_branch(shared_features)
                
                # Concatenate shared features with predicted goal probabilities
                combined_features = torch.cat([shared_features, goal_probs], dim=-1)
                
                # Policy (pi) and value (vf) features
                pi_features = self.pi_branch(combined_features)
                vf_features = self.vf_branch(combined_features)
                
                return pi_features, vf_features, goal_probs
        
        # Number of potential goals - will be set by the environment
        n_goals = 8  # Default, will be updated at runtime
        
        # Create the dual-head extractor
        self.mlp_extractor = DualHeadExtractor(
            feature_dim=self.features_dim,
            goal_features=n_goals,
            policy_features=256
        )
        
        return self.mlp_extractor
    
    def forward(self, obs, deterministic=False):
        """
        Forward pass of the neural network.
        
        Args:
            obs: Observation tensor
            deterministic: Whether to sample or take the most likely action
            
        Returns:
            Tuple of (actions, values, log_probs, goal_probs)
        """
        # Extract features
        features = self.extract_features(obs)
        
        # Get policy features, value features, and goal probabilities
        latent_pi, latent_vf, goal_probs = self.mlp_extractor(features)
        
        # Get action distribution parameters
        mean_actions = self.action_net(latent_pi)
        log_std = torch.clamp(self.log_std, -20, 2)  # Constrain for numerical stability
        std = torch.exp(log_std)
        
        # Create action distribution
        distribution = Normal(mean_actions, std)
        
        # Sample or take most likely action
        if deterministic:
            actions = torch.tanh(mean_actions)
        else:
            actions = torch.tanh(distribution.rsample())
        
        log_prob = distribution.log_prob(actions)
        
        # Get value estimate
        values = self.value_net(latent_vf)
        
        return actions, values, log_prob, goal_probs

###############################################################################
# METRICS CALLBACK WITH GOAL INFERENCE TRACKING
###############################################################################
class EnhancedMetricsCallback(BaseCallback):
    """
    Enhanced callback that logs training metrics including goal inference performance.
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)
        # Episode metrics
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_mean_gammas = []
        self.episode_std_gammas = []
        self.goal_inference_accuracy = []
        self.goal_entropy = []
        self.path_completion_accuracy = []  # Accuracy at different path completions
        
        # Current episode tracking
        self.total_reward = 0.0
        self.ep_length = 0
        self.current_episode_gammas = []
        self.current_goal_probs = []
        self.current_true_goal_idx = None
        self.trajectory_start = None
        self.trajectory_positions = []
        
        # Collision tracking
        self.n_collisions = 0
        self.n_episodes = 0
        
        # Loss tracking
        self.losses = []
        self.value_losses = []
        self.policy_losses = []
        self.entropy_losses = []
        self.training_steps = []
        self.n_updates = 0
        
        # Goal inference performance at different path completion percentages
        self.completion_thresholds = [0.25, 0.5, 0.75]
        self.completion_accuracies = {t: [] for t in self.completion_thresholds}

    def _on_training_start(self):
        """Called at the start of training."""
        self.episode_rewards.clear()
        self.episode_lengths.clear()
        self.episode_mean_gammas.clear()
        self.episode_std_gammas.clear()
        self.goal_inference_accuracy.clear()
        self.goal_entropy.clear()
        self.path_completion_accuracy.clear()
        
        self.total_reward = 0.0
        self.ep_length = 0
        self.current_episode_gammas.clear()
        self.current_goal_probs.clear()
        self.current_true_goal_idx = None
        self.trajectory_start = None
        self.trajectory_positions = []
        
        self.n_collisions = 0
        self.n_episodes = 0
        
        for t in self.completion_thresholds:
            self.completion_accuracies[t].clear()

    def _on_step(self) -> bool:
        """
        Called at each step of training.
        
        Returns:
            bool: Whether training should continue
        """
        actions = self.locals['actions']
        rewards = self.locals['rewards']
        dones = self.locals['dones']
        infos = self.locals['infos']
        obs = self.locals['new_obs']
        
        # Get the environment instance
        env = self.model.env.envs[0]
        
        # Compute gamma from the action (mapping [-1,1] -> [0,1])
        gamma_val = 0.5 * (actions[0][0] + 1.0)
        self.current_episode_gammas.append(gamma_val)
        
        # Track reward and episode length
        r = float(rewards[0])
        self.total_reward += r
        self.ep_length += 1
        
        # Store current position for trajectory analysis
        if hasattr(env, 'dot_pos'):
            current_pos = env.dot_pos.copy()
            self.trajectory_positions.append(current_pos)
            
            # Record start position if this is first step
            if self.trajectory_start is None:
                self.trajectory_start = current_pos.copy()
        
        # If using the dual-head policy, get goal probabilities
        if hasattr(env, 'bayesian_model') and hasattr(env, 'true_goal_idx'):
            # Get goal probabilities from environment's Bayesian model
            goal_probs = env.bayesian_model.get_goal_probs()
            self.current_goal_probs.append(goal_probs.copy())
            
            # Store true goal index for accuracy calculation
            if self.current_true_goal_idx is None:
                self.current_true_goal_idx = env.true_goal_idx
            
            # Calculate current path completion
            if hasattr(env, 'goal_pos') and len(self.trajectory_positions) > 1:
                goal_pos = env.goal_pos
                total_distance = distance(self.trajectory_start, goal_pos)
                
                if total_distance > 0:
                    current_distance = distance(current_pos, goal_pos)
                    path_completion = 1.0 - (current_distance / total_distance)
                    
                    # For each completion threshold, check if we've just crossed it
                    for threshold in self.completion_thresholds:
                        if path_completion >= threshold:
                            # Check if prediction is correct
                            pred_goal = np.argmax(goal_probs)
                            is_correct = (pred_goal == self.current_true_goal_idx)
                            
                            # Store in appropriate bin if not already captured for this episode
                            threshold_key = f"{int(threshold*100)}%"
                            if not hasattr(self, f"recorded_{threshold_key}"):
                                setattr(self, f"recorded_{threshold_key}", True)
                                self.completion_accuracies[threshold].append(int(is_correct))
        
        if dones[0]:
            # Calculate goal inference accuracy for this episode
            if self.current_goal_probs and self.current_true_goal_idx is not None:
                # Average accuracy across episode
                accuracies = []
                entropies = []
                
                for probs in self.current_goal_probs:
                    pred_goal = np.argmax(probs)
                    accuracies.append(int(pred_goal == self.current_true_goal_idx))
                    entropies.append(entropy(probs))
                
                # Store average accuracy and entropy
                self.goal_inference_accuracy.append(np.mean(accuracies))
                self.goal_entropy.append(np.mean(entropies))
            
            # Store episode statistics
            self.episode_rewards.append(self.total_reward)
            self.episode_lengths.append(self.ep_length)
            
            if self.current_episode_gammas:
                mean_gamma = np.mean(self.current_episode_gammas)
                std_gamma = np.std(self.current_episode_gammas)
                self.episode_mean_gammas.append(mean_gamma)
                self.episode_std_gammas.append(std_gamma)
            
            # Reset episode tracking
            self.total_reward = 0.0
            self.ep_length = 0
            self.current_episode_gammas.clear()
            self.current_goal_probs.clear()
            self.current_true_goal_idx = None
            self.trajectory_start = None
            self.trajectory_positions.clear()
            
            # Reset completion recording flags
            for threshold in self.completion_thresholds:
                threshold_key = f"{int(threshold*100)}%"
                if hasattr(self, f"recorded_{threshold_key}"):
                    delattr(self, f"recorded_{threshold_key}")
            
            # Increment episode and collision counters
            self.n_episodes += 1
            if 'terminal_reason' in infos[0] and infos[0]['terminal_reason'] == 'collision':
                self.n_collisions += 1
        
        return True

    def _on_rollout_end(self):
        """Called at the end of a rollout."""
        self.n_updates += 1
        logs = self.model.logger.name_to_value or {}
        
        # Record loss metrics
        if "train/loss" in logs:
            self.losses.append(logs["train/loss"])
            self.training_steps.append(self.n_updates)
        if "train/value_loss" in logs:
            self.value_losses.append(logs["train/value_loss"])
        if "train/policy_gradient_loss" in logs:
            self.policy_losses.append(logs["train/policy_gradient_loss"])
        if "train/entropy_loss" in logs:
            self.entropy_losses.append(logs["train/entropy_loss"])

    def _moving_average(self, data, window=10):
        """Calculate moving average of data."""
        if len(data) < window:
            return np.array(data)
        return np.convolve(data, np.ones(window)/window, mode='valid')

    def save_metrics(self, save_dir="training_metrics"):
        """
        Save all metrics and plots to the specified directory.
        
        Args:
            save_dir (str): Directory to save metrics
        """
        os.makedirs(save_dir, exist_ok=True)
        
        # 1. Episode Rewards
        if self.episode_rewards:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_rewards, label="Episode Reward", alpha=0.6)
            ma_rewards = self._moving_average(self.episode_rewards, 10)
            if len(ma_rewards):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_rewards)), 
                         ma_rewards, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Reward")
            plt.title("Episode Rewards")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "episode_rewards.png"))
            plt.close()
        
        # 2. Average Gamma per Episode
        if self.episode_mean_gammas:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_mean_gammas, label="Average Gamma", alpha=0.6)
            ma_gamma = self._moving_average(self.episode_mean_gammas, 10)
            if len(ma_gamma):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_gamma)), 
                         ma_gamma, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Gamma (avg)")
            plt.title("Average Gamma per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "average_gamma.png"))
            plt.close()
        
        # 3. Gamma Std per Episode
        if self.episode_std_gammas:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_std_gammas, label="Gamma Std", alpha=0.6)
            ma_gstd = self._moving_average(self.episode_std_gammas, 10)
            if len(ma_gstd):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_gstd)), 
                         ma_gstd, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Gamma Std")
            plt.title("Gamma Std per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "gamma_std.png"))
            plt.close()
        
        # 4. Goal Inference Accuracy
        if self.goal_inference_accuracy:
            plt.figure(figsize=(10, 6))
            plt.plot(self.goal_inference_accuracy, label="Goal Inference Accuracy", alpha=0.6)
            ma_acc = self._moving_average(self.goal_inference_accuracy, 10)
            if len(ma_acc):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_acc)), 
                         ma_acc, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Accuracy")
            plt.title("Goal Inference Accuracy per Episode")
            plt.ylim(0, 1.05)
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "goal_inference_accuracy.png"))
            plt.close()
        
        # 5. Goal Distribution Entropy
        if self.goal_entropy:
            plt.figure(figsize=(10, 6))
            plt.plot(self.goal_entropy, label="Goal Distribution Entropy", alpha=0.6)
            ma_ent = self._moving_average(self.goal_entropy, 10)
            if len(ma_ent):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_ent)), 
                         ma_ent, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Entropy")
            plt.title("Goal Distribution Entropy per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "goal_entropy.png"))
            plt.close()
        
        # 6. Path Completion Accuracy
        accuracy_by_completion = {}
        for threshold in self.completion_thresholds:
            if self.completion_accuracies[threshold]:
                accuracy_by_completion[threshold] = np.mean(self.completion_accuracies[threshold])
        
        if accuracy_by_completion:
            plt.figure(figsize=(10, 6))
            x = list(accuracy_by_completion.keys())
            y = [accuracy_by_completion[t] for t in x]
            plt.bar([str(int(t*100))+"%" for t in x], y)
            plt.xlabel("Path Completion")
            plt.ylabel("Accuracy")
            plt.title("Goal Inference Accuracy vs. Path Completion")
            plt.ylim(0, 1.05)
            plt.grid(True, axis='y')
            plt.savefig(os.path.join(save_dir, "completion_accuracy.png"))
            plt.close()
        
        # 7. Total Model Loss
        if self.losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.training_steps, self.losses, label="Total Model Loss", alpha=0.7)
            if len(self.losses) >= 10:
                ma_loss = self._moving_average(self.losses, 10)
                plt.plot(range(self.training_steps[0] + (10 - 1),
                               self.training_steps[0] + (10 - 1) + len(ma_loss)),
                         ma_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Training Updates")
            plt.ylabel("Loss")
            plt.title("Total Model Loss Over Rollouts")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "total_loss.png"))
            plt.close()
        
        # 8. Critic (Value) Loss
        if self.value_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.value_losses, label="Value Loss", alpha=0.7)
            if len(self.value_losses) >= 10:
                ma_val_loss = self._moving_average(self.value_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_val_loss)), 
                         ma_val_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Value Loss")
            plt.title("Value (Critic) Loss")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "value_loss.png"))
            plt.close()
        
        # 9. Actor (Policy) Loss
        if self.policy_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.policy_losses, label="Policy Loss", alpha=0.7)
            if len(self.policy_losses) >= 10:
                ma_pol_loss = self._moving_average(self.policy_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_pol_loss)), 
                         ma_pol_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Policy Loss")
            plt.title("Policy (Actor) Loss")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "policy_loss.png"))
            plt.close()
        
        # 10. Entropy Loss
        if self.entropy_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.entropy_losses, label="Entropy Loss", alpha=0.7)
            if len(self.entropy_losses) >= 10:
                ma_ent_loss = self._moving_average(self.entropy_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_ent_loss)), 
                         ma_ent_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Entropy Loss")
            plt.title("Entropy Loss (Exploration)")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "entropy_loss.png"))
            plt.close()
        
        # 11. Episode Length
        if self.episode_lengths:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_lengths, label="Episode Length", alpha=0.6)
            ma_length = self._moving_average(self.episode_lengths, 10)
            if len(ma_length):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_length)), 
                         ma_length, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Length (# steps)")
            plt.title("Episode Length")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "episode_length.png"))
            plt.close()
        
        # 12. Save summary text
        with open(os.path.join(save_dir, "summary.txt"), "w") as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if self.episode_rewards:
                avg_reward = np.mean(self.episode_rewards)
                f.write(f"Mean Episode Reward: {avg_reward:.3f}\n")
            f.write(f"Collisions Count: {self.n_collisions}\n")
            if self.episode_mean_gammas:
                mean_gamma_all = np.mean(self.episode_mean_gammas)
                f.write(f"Mean of Average-Gamma: {mean_gamma_all:.3f}\n")
            if self.goal_inference_accuracy:
                f.write(f"Mean Goal Inference Accuracy: {np.mean(self.goal_inference_accuracy):.3f}\n")
            for threshold in sorted(self.completion_thresholds):
                if self.completion_accuracies[threshold]:
                    f.write(f"Goal Accuracy @ {int(threshold*100)}% completion: {np.mean(self.completion_accuracies[threshold]):.3f}\n")

###############################################################################
# CONTEXT-ADAPTIVE SHARED AUTONOMY ENVIRONMENT
###############################################################################
class SharedAutonomyEnv(gym.Env):
    """
    Environment for training context-adaptive shared autonomy.
    
    Features:
    - Multiple potential targets with varying goal ambiguity
    - Dynamic obstacle environments
    - Recursive Bayesian goal inference
    - Context-dependent blending of human and AI control
    """
    metadata = {"render_modes": ["human"], "render_fps": RENDER_FPS}
    
    def __init__(self, visualize=False, use_joint_optimization=True):
        """
        Initialize the shared autonomy environment.
        
        Args:
            visualize (bool): Whether to visualize the environment
            use_joint_optimization (bool): Whether to use joint optimization of
                                          goal inference and assistance determination
        """
        super().__init__()
        self.visualize = visualize
        self.use_joint_optimization = use_joint_optimization
        
        # State observation includes:
        # [dot_x, dot_y, h_dir_x, h_dir_y, goal_x, goal_y, w_dir_x, w_dir_y, 
        #  dist_ratio, obs_dist_ratio, entropy, goal_probs (8)]
        # Space dimensions will be determined when goals are created
        
        # Action space: blending parameter gamma ∈ [-1, 1] (mapped to [0, 1])
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        
        # Initialize environment state
        self.dot_pos = None
        self.goal_pos = None
        self.obstacles = []
        self.goals = []
        self.true_goal_idx = None
        self.step_count = 0
        self.max_steps = 300
        self.episode_reward = 0.0
        self.max_dist = math.hypot(FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1])
        
        # Hyperparameters for human simulation
        self.alpha = 3.0  # Directness of human control
        self.beta = 3.0   # Precision of human control
        
        # Thresholds for context-based reward
        self.goal_threshold = 100.0 * SCALING_FACTOR
        self.obs_threshold = 100.0 * SCALING_FACTOR
        
        # Curriculum learning settings
        self.SCENARIO_SEEDS = [0, 1, 2, 58, 487]
        self.scenario_index = 0
        self.episode_counter = 0
        self.random_seed_probability = 0.3
        
        # Initialize Bayesian model for goal inference
        self.bayesian_model = BayesianGoalInference(beta=10.0, w_theta=0.7, w_d=0.3)
        
        # Set up visualization
        if self.visualize and pygame is not None:
            pygame.init()
            self.window = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Context-Adaptive Shared Autonomy")
            self.clock = pygame.time.Clock()
        else:
            self.window = None
            self.clock = None
    
    def reset(self, seed=None, options=None):
        """
        Reset the environment.
        
        Args:
            seed (int, optional): Random seed
            options (dict, optional): Additional options
            
        Returns:
            tuple: (observation, info)
        """
        # Reset the base environment
        super().reset(seed=seed)
        
        # Update episode counter
        self.episode_counter += 1
        
        # Choose a scenario seed
        use_random = (random.random() < self.random_seed_probability)
        if use_random:
            scenario_seed = random.randint(0, 9999999)
        else:
            scenario_seed = self.SCENARIO_SEEDS[self.scenario_index]
            self.scenario_index = (self.scenario_index + 1) % len(self.SCENARIO_SEEDS)
        
        # Create the environment layout
        self.randomize_env(scenario_seed)
        
        # Reset internal state
        self.step_count = 0
        self.episode_reward = 0.0
        self.dot_pos = START_POS.copy()
        
        # Choose a random goal from the available targets
        if self.goals:
            self.true_goal_idx = random.randint(0, len(self.goals) - 1)
            self.goal_pos = self.goals[self.true_goal_idx].copy()
        else:
            self.goal_pos = np.array([random.uniform(0.2*FULL_VIEW_SIZE[0], 0.8*FULL_VIEW_SIZE[0]),
                                     random.uniform(0.2*FULL_VIEW_SIZE[1], 0.8*FULL_VIEW_SIZE[1])],
                                    dtype=np.float32)
            self.true_goal_idx = 0
        
        # Reset Bayesian goal inference model
        self.bayesian_model.initialize_goals(self.goals)
        self.bayesian_model.reset()
        
        # Set up observation space based on number of goals
        n_goals = len(self.goals)
        low = np.array(
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0, 0] + [0] * n_goals,
            dtype=np.float32
        )
        high = np.array(
            [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
             FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 1, 1, np.log(n_goals)] + [1] * n_goals,
            dtype=np.float32
        )
        self.observation_space = spaces.Box(low=low, high=high, shape=(11 + n_goals,), dtype=np.float32)
        
        return self._get_obs(), {}
    
    def randomize_env(self, scenario_seed):
        """
        Create a randomized environment layout.
        
        Args:
            scenario_seed (int): Random seed for environment generation
        """
        random.seed(scenario_seed)
        np.random.seed(scenario_seed)
        
        margin = 50 * SCALING_FACTOR
        N_GOALS = 8
        N_OBSTACLES = 5
        min_goal_distance = 300 * SCALING_FACTOR
        
        # Generate goals
        new_goals = []
        attempts = 0
        while len(new_goals) < N_GOALS and attempts < 2000:
            x = random.uniform(margin, FULL_VIEW_SIZE[0] - margin)
            y = random.uniform(margin, FULL_VIEW_SIZE[1] - margin)
            candidate = np.array([x, y], dtype=np.float32)
            
            # Ensure goal is far enough from start position
            if distance(candidate, START_POS) >= min_goal_distance:
                new_goals.append(candidate)
            attempts += 1
        
        self.goals = new_goals[:N_GOALS]
        
        # Generate obstacles
        new_obstacles = []
        if len(self.goals) > 1:
            obstacle_goals = random.sample(self.goals, k=min(min(N_GOALS-1, N_OBSTACLES), len(self.goals)-1))
        else:
            obstacle_goals = self.goals.copy()
        
        for goal in obstacle_goals:
            # Place obstacle between start and goal
            t = random.uniform(0.6, 0.8)
            base_point = START_POS + t*(goal - START_POS)
            
            # Create perpendicular offset
            vec = goal - START_POS
            vec_norm = np.linalg.norm(vec)
            if vec_norm < 1e-6:
                perp = np.array([0, 0], dtype=np.float32)
            else:
                perp = np.array([-vec[1], vec[0]], dtype=np.float32)
                perp /= np.linalg.norm(perp)
            
            # Apply random offset
            offset_mag = random.uniform(20*SCALING_FACTOR, 40*SCALING_FACTOR)
            offset = perp * offset_mag * random.choice([-1, 1])
            candidate = base_point + offset
            
            # Keep within bounds
            candidate[0] = np.clip(candidate[0], margin, FULL_VIEW_SIZE[0] - margin)
            candidate[1] = np.clip(candidate[1], margin, FULL_VIEW_SIZE[1] - margin)
            
            # Check validity
            valid = True
            if distance(candidate, START_POS) < (DOT_RADIUS + OBSTACLE_RADIUS + 10):
                valid = False
            if distance(candidate, goal) < (TARGET_RADIUS + OBSTACLE_RADIUS + 20):
                valid = False
            for obs in new_obstacles:
                if distance(candidate, obs) < (2*OBSTACLE_RADIUS + 10):
                    valid = False
            
            if valid:
                new_obstacles.append(candidate)
        
        self.obstacles = new_obstacles
    
    def step(self, action):
        """
        Execute one environment step.
        
        Args:
            action (numpy.ndarray): Action to take (gamma value)
            
        Returns:
            tuple: (observation, reward, done, truncated, info)
        """
        # Map the action from [-1,1] to gamma in [0,1]
        raw_a = float(action[0])
        gamma_val = 0.5 * (raw_a + 1.0)
        
        self.step_count += 1
        
        # Compute the world (expert) direction using the potential field
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        
        # Add noise for the human direction (simulating imperfect human input)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        
        # Update Bayesian model with latest human input
        goal_probs = self.bayesian_model.update(self.dot_pos, h_dir, self.obstacles)
        
        # Get expert recommendation from Bayesian model
        if self.use_joint_optimization:
            # For joint optimization, use the weighted expert recommendation
            w_dir = self.bayesian_model.compute_expert_recommendation(self.dot_pos, self.obstacles)
        # else we use the potential_field_dir directly (already computed above)
        
        # Combine the directions using gamma as the blending parameter
        c_dir = gamma_val * w_dir + (1 - gamma_val) * h_dir
        cm = np.hypot(c_dir[0], c_dir[1])
        if cm > 1e-6:
            c_dir /= cm
        
        # Compute the new position
        move_vec = c_dir * MAX_SPEED
        new_pos = self.dot_pos + move_vec
        
        # Update the dot position if no collision occurs along the move
        if not line_collision(self.dot_pos, new_pos, self.obstacles):
            new_pos[0] = np.clip(new_pos[0], 0, FULL_VIEW_SIZE[0])
            new_pos[1] = np.clip(new_pos[1], 0, FULL_VIEW_SIZE[1])
            self.dot_pos = new_pos
        
        # Check if the dot is inside an obstacle
        collided = inside_obstacle(self.dot_pos, self.obstacles)
        info = {}
        
        if collided:
            reward = -2.0
            done = True
            info["terminal_reason"] = "collision"
        else:
            # Check if reached the goal
            if distance(self.dot_pos, self.goal_pos) <= (TARGET_RADIUS + DOT_RADIUS):
                reward = 5.0
                done = True
                info["terminal_reason"] = "goal_reached"
            else:
                reward = 0.0
                done = False
                info["terminal_reason"] = None
        
        # Check for timeout
        truncated = (self.step_count >= self.max_steps)
        if truncated and not done:
            info["terminal_reason"] = "timeout"
        
        # ------------------ Reward Shaping ------------------
        # Compute distance to goal
        d_goal = distance(self.dot_pos, self.goal_pos)
        
        # Compute distance to closest obstacle
        if self.obstacles:
            d_obs = min(distance(self.dot_pos, obs) for obs in self.obstacles)
        else:
            d_obs = 999999.0
        
        # Progress reward
        prev_dist = distance(self.dot_pos - move_vec, self.goal_pos)
        progress_reward = (prev_dist - d_goal) * 0.1
        
        # Get goal probability for the true goal
        if len(goal_probs) > self.true_goal_idx:
            p_true = goal_probs[self.true_goal_idx]
        else:
            p_true = 0.0
        
        # Get maximum probability
        p_max = np.max(goal_probs) if len(goal_probs) > 0 else 0.0
        
        # Determine optimal gamma based on context
        # Compute the contextual reward based on the environment situation
        if d_goal < self.goal_threshold and d_obs < self.obs_threshold:
            # Near both goal and obstacles: higher gamma for safety
            desired_gamma = 0.6
        elif d_goal < self.goal_threshold:
            # Near goal: lower gamma (more human control for fine adjustment)
            desired_gamma = 0.3
        elif d_obs < self.obs_threshold:
            # Near obstacle: higher gamma (more AI control for safety)
            desired_gamma = 0.8
        else:
            # Open space: gamma depends on goal certainty
            # Higher uncertainty = lower gamma (more human control)
            goal_entropy = entropy(goal_probs) if len(goal_probs) > 0 else np.log(len(self.goals))
            max_entropy = np.log(len(self.goals))
            uncertainty = goal_entropy / max_entropy if max_entropy > 0 else 0.5
            desired_gamma = 0.3 + 0.5 * (1 - uncertainty)  # Maps [0,1] uncertainty to [0.8,0.3] gamma
        
        # Compute reward components based on manuscript
        # -c_collision · 𝟙_collision: penalty for collisions (already handled above)
        # +c_near · γ · p_max · 𝟙_near: reward for being near the predicted target with appropriate assistance
        # -c_far · γ · 𝟙_far: penalty for high assistance when far from all targets
        # +c_progress · p_max · (d_{t-1} - d_t): reward for making progress toward the goal
        # -c_γ · γ^2: quadratic penalty to minimize intervention
        # +c_goal · log(p_true): reward for correctly identifying the true goal
        
        c_near = 0.2
        c_far = 0.1
        c_progress = 0.5
        c_gamma = 0.5
        c_goal = 0.3
        
        near_flag = d_goal < self.goal_threshold
        far_flag = d_goal > 2 * self.goal_threshold
        
        near_reward = c_near * gamma_val * p_max * float(near_flag)
        far_penalty = c_far * gamma_val * float(far_flag)
        progress_term = c_progress * p_max * progress_reward
        gamma_penalty = c_gamma * (gamma_val - desired_gamma) ** 2
        goal_reward = c_goal * np.log(p_true + 1e-6)  # Add small epsilon to avoid log(0)
        
        # Combine all reward components
        shaped_reward = near_reward - far_penalty + progress_term - gamma_penalty + goal_reward
        
        # Total reward
        reward += shaped_reward
        
        self.episode_reward += reward
        # ------------------ End Reward Shaping ------------------
        
        # Get observation
        obs = self._get_obs()
        
        return obs, float(reward), done, truncated, info
    
    def _get_obs(self):
        """
        Get the current observation.
        
        Returns:
            numpy.ndarray: Current observation
        """
        # Basic state information
        to_g = self.goal_pos - self.dot_pos
        d = math.hypot(to_g[0], to_g[1])
        dist_ratio = d / self.max_dist if self.max_dist > 1e-6 else 0.0
        
        # Compute directions
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        
        # Distance to closest obstacle
        if self.obstacles:
            min_obs_distance = min(distance(self.dot_pos, obs) for obs in self.obstacles)
        else:
            min_obs_distance = self.max_dist
        obs_dist_ratio = min_obs_distance / self.max_dist
        
        # Get goal probabilities and entropy
        goal_probs = self.bayesian_model.get_goal_probs()
        goal_ent = entropy(goal_probs) if len(goal_probs) > 0 else np.log(len(self.goals))
        
        # Combine all features
        base_obs = np.array([
            self.dot_pos[0], self.dot_pos[1],
            h_dir[0], h_dir[1],
            self.goal_pos[0], self.goal_pos[1],
            w_dir[0], w_dir[1],
            dist_ratio, obs_dist_ratio,
            goal_ent
        ], dtype=np.float32)
        
        # Append goal probabilities
        obs = np.concatenate([base_obs, goal_probs]).astype(np.float32)
        
        return obs
    
    def render(self):
        """Render the environment."""
        if not self.visualize or (pygame is None):
            return
        
        # Handle user events (e.g., window close)
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return
        
        # Clear the screen
        self.window.fill(WHITE)
        
        # Draw obstacles
        for obs in self.obstacles:
            pygame.draw.circle(self.window, GRAY, (int(obs[0]), int(obs[1])), OBSTACLE_RADIUS)
        
        # Draw all potential targets
        for i, gpos in enumerate(self.goals):
            # Color intensity based on goal probability
            goal_probs = self.bayesian_model.get_goal_probs()
            if i < len(goal_probs):
                prob = goal_probs[i]
                # Interpolate color from yellow to green based on probability
                color = (int(255 * (1 - prob)), int(255), 0)
            else:
                color = YELLOW
            
            # Draw the target
            pygame.draw.circle(self.window, color, (int(gpos[0]), int(gpos[1])), TARGET_RADIUS)
            
            # Label the target
            font = pygame.font.SysFont(None, 24)
            label = font.render(str(i), True, BLACK)
            self.window.blit(label, (int(gpos[0]) - 5, int(gpos[1]) - 8))
        
        # Draw the true goal with a black outline
        pygame.draw.circle(self.window, BLACK, 
                          (int(self.goal_pos[0]), int(self.goal_pos[1])), 
                          TARGET_RADIUS+2, width=2)
        
        # Draw the agent
        pygame.draw.circle(self.window, BLACK, 
                          (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                          DOT_RADIUS, width=2)
        
        # Display gamma value
        if hasattr(self, 'last_gamma'):
            font = pygame.font.SysFont(None, 30)
            gamma_text = font.render(f"γ: {self.last_gamma:.2f}", True, BLACK)
            self.window.blit(gamma_text, (10, 10))
        
        # Update the display
        pygame.display.flip()
        self.clock.tick(RENDER_FPS)
    
    def close(self):
        """Close the environment."""
        if self.visualize and pygame is not None:
            pygame.quit()
        super().close()

###############################################################################
# CURRICULUM CALLBACK
###############################################################################
class CurriculumCallback(BaseCallback):
    """
    Callback for curriculum learning during training.
    
    Progressively increases the difficulty of the environment:
    1. Basic goal-directed behavior
    2. Basic collision avoidance
    3. Challenging obstacle configurations
    4. Goal ambiguity with multiple potential targets
    5. Full complexity
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.stage = 1
        self.success_count = 0
        self.required_successes = 100
        self.max_stage = 5
        self.stage_rewards = []
        self.stage_lengths = []
        self.current_ep_reward = 0
        self.stage_start_time = time.time()
    
    def _on_training_start(self):
        """Set initial curriculum stage."""
        print(f"Starting curriculum stage {self.stage}: Basic goal-directed behavior")
        self.stage_start_time = time.time()
    
    def _on_step(self) -> bool:
        """Check if curriculum stage should be updated."""
        # Get the environment
        env = self.model.env.envs[0]
        
        # Track reward
        rewards = self.locals.get('rewards', [0])
        self.current_ep_reward += rewards[0]
        
        # Check episode termination
        dones = self.locals.get('dones', [False])
        infos = self.locals.get('infos', [{}])
        
        if dones[0]:
            self.stage_rewards.append(self.current_ep_reward)
            self.stage_lengths.append(self.locals.get('n_steps', 0))
            self.current_ep_reward = 0
            
            # Check success (reached goal without collision)
            info = infos[0]
            terminal_reason = info.get('terminal_reason', None)
            
            if terminal_reason == 'goal_reached':
                self.success_count += 1
            else:
                # Reset success counter if failure occurs
                self.success_count = max(0, self.success_count - 1)
            
            # Check for stage progression
            if self.success_count >= self.required_successes:
                self._progress_curriculum()
        
        return True
    
    def _progress_curriculum(self):
        """Progress to the next curriculum stage."""
        elapsed_time = time.time() - self.stage_start_time
        avg_reward = np.mean(self.stage_rewards[-100:]) if self.stage_rewards else 0
        
        print(f"\nCompleted curriculum stage {self.stage}:")
        print(f"  - Time taken: {elapsed_time:.2f} seconds")
        print(f"  - Average reward (last 100 episodes): {avg_reward:.2f}")
        
        self.stage += 1
        if self.stage > self.max_stage:
            print("\nCurriculum completed! Using full complexity environment.")
            return
        
        # Reset counters
        self.success_count = 0
        self.stage_rewards = []
        self.stage_lengths = []
        self.stage_start_time = time.time()
        
        # Update environment parameters based on curriculum stage
        env = self.model.env.envs[0]
        
        if self.stage == 2:
            print("\nProgressing to stage 2: Basic collision avoidance")
            # Increase obstacle count but keep goal ambiguity low
            env.random_seed_probability = 0.4
        
        elif self.stage == 3:
            print("\nProgressing to stage 3: Challenging obstacle configurations")
            # More complex obstacle patterns
            env.random_seed_probability = 0.6
        
        elif self.stage == 4:
            print("\nProgressing to stage 4: Goal ambiguity with multiple targets")
            # Introduce more goal ambiguity
            env.random_seed_probability = 0.8
        
        elif self.stage == 5:
            print("\nProgressing to stage 5: Full complexity")
            # Full complexity - random scenarios
            env.random_seed_probability = 1.0
    
    def on_training_end(self):
        """Log final curriculum statistics."""
        print("\nCurriculum learning completed.")
        print(f"Final stage reached: {self.stage}/{self.max_stage}")

###############################################################################
# RENDER CALLBACK FOR LIVE VIEW
###############################################################################
class RenderCallback(BaseCallback):
    """Simple callback to render the environment during training."""
    def __init__(self, render_freq=1, verbose=0):
        super().__init__(verbose)
        self.render_freq = render_freq
    
    def _on_step(self) -> bool:
        """Render the environment periodically."""
        if self.n_calls % self.render_freq == 0:
            self.model.env.envs[0].render()
        return True

###############################################################################
# BAYESIAN MODEL TRAINING FUNCTION
###############################################################################
def train_bayesian_model(dataset_path=None, epochs=5, visualize=False):
    """
    Pre-train the Bayesian goal inference model on a dataset.
    
    Args:
        dataset_path (str): Path to trajectory dataset
        epochs (int): Number of epochs for training
        visualize (bool): Whether to visualize the training process
        
    Returns:
        BayesianGoalInference: Trained model
    """
    print("Training Bayesian goal inference model...")
    
    # Create a model with default parameters
    model = BayesianGoalInference(beta=10.0, w_theta=0.7, w_d=0.3)
    
    # If no dataset is provided, create a synthetic one
    if dataset_path is None:
        print("No dataset provided. Creating synthetic data...")
        
        # Create a synthetic dataset
        n_trajectories = 100
        trajectories = []
        env = SharedAutonomyEnv(visualize=False)
        
        for i in range(n_trajectories):
            print(f"Generating trajectory {i+1}/{n_trajectories}", end="\r")
            
            # Reset environment
            obs, _ = env.reset()
            
            # Extract relevant information
            goal_pos = env.goal_pos
            true_goal_idx = env.true_goal_idx
            goals = env.goals
            
            # Generate a trajectory
            trajectory = {
                "positions": [env.dot_pos.copy()],
                "inputs": [],
                "goal_pos": goal_pos.copy(),
                "true_goal_idx": true_goal_idx,
                "goals": [g.copy() for g in goals],
            }
            
            # Simple policy to generate human-like trajectories
            done = False
            while not done:
                # Compute direction to goal with noise
                to_goal = goal_pos - env.dot_pos
                dist = np.linalg.norm(to_goal)
                if dist > 0:
                    direction = to_goal / dist
                else:
                    direction = np.zeros(2)
                
                # Add noise to direction
                noise_scale = 0.2
                noise = np.random.normal(0, noise_scale, size=2)
                noisy_direction = direction + noise
                
                # Normalize
                norm = np.linalg.norm(noisy_direction)
                if norm > 0:
                    noisy_direction = noisy_direction / norm
                
                # Store human input
                trajectory["inputs"].append(noisy_direction.copy())
                
                # Step environment with full human control (gamma=0)
                obs, _, done, _, _ = env.step([-1.0])  # -1.0 maps to gamma=0
                
                # Store position
                trajectory["positions"].append(env.dot_pos.copy())
                
                if len(trajectory["positions"]) > 300:  # Prevent infinite loops
                    break
            
            trajectories.append(trajectory)
        
        env.close()
        print("\nGenerated synthetic dataset with", len(trajectories), "trajectories")
        
        # Save dataset
        os.makedirs("datasets", exist_ok=True)
        with open("datasets/synthetic_trajectories.pkl", "wb") as f:
            pickle.dump(trajectories, f)
        print("Saved synthetic dataset to datasets/synthetic_trajectories.pkl")
        
        dataset = trajectories
    else:
        # Load real dataset
        try:
            with open(dataset_path, "rb") as f:
                dataset = pickle.load(f)
            print(f"Loaded dataset from {dataset_path} with {len(dataset)} trajectories")
        except:
            raise ValueError(f"Could not load dataset from {dataset_path}")
    
    # Perform grid search for hyperparameter optimization
    print("\nPerforming hyperparameter grid search...")
    best_params = None
    best_accuracy = 0.0
    
    # Define parameter search space
    beta_values = [5.0, 10.0, 15.0]
    w_theta_values = [0.5, 0.7, 0.9]
    decay_values = [0.8, 0.85, 0.9]
    
    for beta in beta_values:
        for w_theta in w_theta_values:
            for decay in decay_values:
                w_d = 1.0 - w_theta  # Ensure weights sum to 1
                
                # Set model parameters
                model.beta = beta
                model.w_theta = w_theta
                model.w_d = w_d
                model.decay_rate = decay
                
                # Evaluate on validation set
                total_acc = 0.0
                n_valid = min(30, len(dataset))  # Use a subset for validation
                
                for i in range(n_valid):
                    traj = dataset[i]
                    goals = traj["goals"]
                    true_idx = traj["true_goal_idx"]
                    
                    # Initialize model with trajectory goals
                    model.initialize_goals(goals)
                    
                    # Process trajectory
                    positions = traj["positions"]
                    inputs = traj["inputs"]
                    n_steps = min(len(inputs), len(positions) - 1)
                    n_correct = 0
                    
                    for t in range(n_steps):
                        pos = positions[t]
                        h_input = inputs[t]
                        
                        # Update model
                        goal_probs = model.update(pos, h_input)
                        
                        # Check if prediction is correct
                        if len(goal_probs) > 0:
                            pred_idx = np.argmax(goal_probs)
                            if pred_idx == true_idx:
                                n_correct += 1
                    
                    # Calculate accuracy for this trajectory
                    accuracy = n_correct / n_steps if n_steps > 0 else 0
                    total_acc += accuracy
                
                # Average accuracy across trajectories
                avg_acc = total_acc / n_valid
                
                print(f"  beta={beta}, w_theta={w_theta}, decay={decay}: accuracy={avg_acc:.3f}")
                
                # Update best parameters
                if avg_acc > best_accuracy:
                    best_accuracy = avg_acc
                    best_params = (beta, w_theta, 1.0-w_theta, decay)
    
    if best_params:
        print(f"\nBest parameters: beta={best_params[0]}, w_theta={best_params[1]}, w_d={best_params[2]}, decay={best_params[3]}")
        print(f"Best accuracy: {best_accuracy:.3f}")
        
        # Set model to best parameters
        model.beta = best_params[0]
        model.w_theta = best_params[1]
        model.w_d = best_params[2]
        model.decay_rate = best_params[3]
    
    # Train confidence calibration
    print("\nTraining confidence calibration...")
    confidences = []
    accuracies = []
    
    n_calib = min(50, len(dataset))
    for i in range(n_calib):
        traj = dataset[i]
        goals = traj["goals"]
        true_idx = traj["true_goal_idx"]
        
        # Initialize model with trajectory goals
        model.initialize_goals(goals)
        
        # Process trajectory
        positions = traj["positions"]
        inputs = traj["inputs"]
        n_steps = min(len(inputs), len(positions) - 1)
        
        for t in range(n_steps):
            pos = positions[t]
            h_input = inputs[t]
            
            # Update model
            goal_probs = model.update(pos, h_input)
            
            # Record confidence and accuracy
            if len(goal_probs) > 0:
                max_prob = np.max(goal_probs)
                pred_idx = np.argmax(goal_probs)
                is_correct = int(pred_idx == true_idx)
                
                confidences.append(max_prob)
                accuracies.append(is_correct)
    
    # Train calibrator
    model.train_calibrator(confidences, accuracies)
    
    # Save the trained model
    os.makedirs("models", exist_ok=True)
    with open("models/bayesian_model.pkl", "wb") as f:
        pickle.dump(model, f)
    print("Saved trained Bayesian model to models/bayesian_model.pkl")
    
    # Evaluate on test set
    print("\nEvaluating model on test set...")
    eval_results = {}
    completion_points = [0.25, 0.5, 0.75, 1.0]
    for cp in completion_points:
        eval_results[cp] = []
    
    n_test = min(100, len(dataset))
    for i in range(n_test):
        traj = dataset[i]
        goals = traj["goals"]
        true_idx = traj["true_goal_idx"]
        
        # Initialize model with trajectory goals
        model.initialize_goals(goals)
        
        # Process trajectory
        positions = traj["positions"]
        inputs = traj["inputs"]
        n_steps = min(len(inputs), len(positions) - 1)
        
        for cp in completion_points:
            step_idx = int(cp * n_steps) - 1
            if step_idx < 0:
                continue
            
            # Reset model
            model.reset()
            
            # Process up to completion point
            for t in range(step_idx + 1):
                pos = positions[t]
                h_input = inputs[t]
                goal_probs = model.update(pos, h_input)
            
            # Check prediction
            if len(goal_probs) > 0:
                pred_idx = np.argmax(goal_probs)
                is_correct = int(pred_idx == true_idx)
                eval_results[cp].append(is_correct)
    
    # Print evaluation results
    print("\nModel performance at different path completion percentages:")
    for cp in completion_points:
        if eval_results[cp]:
            accuracy = np.mean(eval_results[cp])
            print(f"  {int(cp*100)}% completion: {accuracy:.3f} accuracy")
    
    if visualize:
        # Visualize model performance on a random trajectory
        print("\nVisualizing model performance...")
        
        # Create a visualization environment
        vis_env = SharedAutonomyEnv(visualize=True)
        
        for _ in range(3):  # Show 3 trajectories
            # Reset environment
            obs, _ = vis_env.reset()
            
            done = False
            while not done:
                # Get noisy human input
                to_goal = vis_env.goal_pos - vis_env.dot_pos
                dist = np.linalg.norm(to_goal)
                if dist > 0:
                    direction = to_goal / dist
                else:
                    direction = np.zeros(2)
                
                # Add noise
                noise_scale = 0.2
                noise = np.random.normal(0, noise_scale, size=2)
                h_input = direction + noise
                
                # Normalize
                norm = np.linalg.norm(h_input)
                if norm > 0:
                    h_input = h_input / norm
                
                # Step environment with full human control
                obs, _, done, _, _ = vis_env.step([-1.0])  # -1.0 maps to gamma=0
                
                # Visualize
                vis_env.render()
                time.sleep(0.1)
        
        vis_env.close()
    
    return model

###############################################################################
# TRAINING FUNCTION
###############################################################################
def train_model(total_timesteps=500_000, visualize=False, use_joint_optimization=True,
               bayesian_model_path=None, load_pretrained=False):
    """
    Train the shared autonomy model.
    
    Args:
        total_timesteps (int): Total number of timesteps for training
        visualize (bool): Whether to visualize training
        use_joint_optimization (bool): Whether to use joint optimization
        bayesian_model_path (str): Path to pre-trained Bayesian model
        load_pretrained (bool): Whether to load a pre-trained model
        
    Returns:
        stable_baselines3.PPO: Trained model
    """
    print("Initializing shared autonomy environment...")
    env = SharedAutonomyEnv(visualize=visualize, use_joint_optimization=use_joint_optimization)
    env = DummyVecEnv([lambda: env])
    
    # Set up callbacks
    metrics_callback = EnhancedMetricsCallback()
    curriculum_callback = CurriculumCallback()
    callbacks = [metrics_callback, curriculum_callback]
    
    if visualize:
        render_callback = RenderCallback(render_freq=1)
        callbacks.append(render_callback)
    
    # Load pre-trained Bayesian model if provided
    if bayesian_model_path:
        try:
            with open(bayesian_model_path, "rb") as f:
                bayesian_model = pickle.load(f)
            print(f"Loaded Bayesian model from {bayesian_model_path}")
            
            # Set the model in the environment
            env.envs[0].bayesian_model = bayesian_model
        except:
            print(f"Could not load Bayesian model from {bayesian_model_path}")
    
    print("Initializing PPO model with dual-head architecture...")
    
    if load_pretrained:
        # Load pre-trained model
        try:
            model = PPO.load("models/shared_autonomy_model")
            model.set_env(env)
            print("Loaded pre-trained model from models/shared_autonomy_model")
        except:
            print("Could not load pre-trained model. Training from scratch.")
            model = create_new_model(env)
    else:
        model = create_new_model(env)
    
    print(f"Starting training for {total_timesteps} timesteps...")
    model.learn(total_timesteps=total_timesteps, callback=CallbackList(callbacks), log_interval=1)
    
    # Save the trained model
    os.makedirs("models", exist_ok=True)
    model_path = os.path.join("models", "shared_autonomy_model")
    model.save(model_path)
    print(f"Model saved to {model_path}")
    
    # Save metrics
    metrics_callback.save_metrics(save_dir="training_metrics")
    print("Metrics saved to 'training_metrics/'")
    
    env.close()
    return model

def create_new_model(env):
    """Create a new PPO model with the appropriate architecture."""
    # Create PPO model with custom network architecture
    model = PPO(
        "MlpPolicy",  # We'll use the default policy structure
        env,
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=256,
        n_epochs=5,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        clip_range_vf=0.2,
        normalize_advantage=True,
        ent_coef=0.01,
        verbose=1,
        tensorboard_log="./tensorboard_logs/",
        policy_kwargs={
            "net_arch": [{"pi": [256, 256], "vf": [256, 256]}],
            "activation_fn": nn.ReLU,
            "ortho_init": True
        }
    )
    return model

###############################################################################
# EVALUATION FUNCTIONS
###############################################################################
def evaluate_model(model_path="models/shared_autonomy_model", n_episodes=50, visualize=True):
    """
    Evaluate a trained model.
    
    Args:
        model_path (str): Path to the trained model
        n_episodes (int): Number of episodes for evaluation
        visualize (bool): Whether to visualize evaluation
        
    Returns:
        dict: Evaluation metrics
    """
    print(f"Evaluating model {model_path}...")
    
    # Create evaluation environment
    env = SharedAutonomyEnv(visualize=visualize)
    
    try:
        # Load the model
        model = PPO.load(model_path)
        print(f"Loaded model from {model_path}")
    except:
        print(f"Could not load model from {model_path}")
        return None
    
    # Initialize metrics
    metrics = {
        "success_rate": 0.0,
        "collision_rate": 0.0,
        "timeout_rate": 0.0,
        "avg_reward": 0.0,
        "avg_steps": 0.0,
        "avg_gamma": 0.0,
        "goal_inference_accuracy": 0.0,
        "path_efficiency": 0.0,
        "completion_accuracy": {0.25: 0.0, 0.5: 0.0, 0.75: 0.0}
    }
    
    # Run evaluation episodes
    for episode in range(n_episodes):
        print(f"Episode {episode+1}/{n_episodes}", end="\r")
        
        # Reset environment
        obs, _ = env.reset()
        
        # Episode tracking
        done = False
        truncated = False
        episode_reward = 0.0
        episode_steps = 0
        episode_gammas = []
        trajectory = []
        goal_predictions = []
        
        # Track initial position and goal for path efficiency
        start_pos = env.dot_pos.copy()
        goal_pos = env.goal_pos.copy()
        
        # Track path completion
        completion_recorded = {0.25: False, 0.5: False, 0.75: False}
        completion_correct = {0.25: False, 0.5: False, 0.75: False}
        
        while not (done or truncated):
            # Predict action
            action, _ = model.predict(obs, deterministic=True)
            
            # Map action to gamma
            gamma = 0.5 * (action[0] + 1.0)
            episode_gammas.append(gamma)
            
            # Save for rendering
            env.last_gamma = gamma
            
            # Step environment
            obs, reward, done, truncated, info = env.step(action)
            
            # Update tracking
            episode_reward += reward
            episode_steps += 1
            trajectory.append(env.dot_pos.copy())
            
            # Track goal inference
            if hasattr(env.bayesian_model, "get_goal_probs"):
                probs = env.bayesian_model.get_goal_probs()
                if len(probs) > 0:
                    pred_goal = np.argmax(probs)
                    goal_predictions.append(pred_goal)
                    
                    # Check path completion accuracy
                    if len(trajectory) > 1:
                        # Calculate path completion
                        total_dist = distance(start_pos, goal_pos)
                        current_dist = distance(env.dot_pos, goal_pos)
                        completion = 1.0 - (current_dist / total_dist) if total_dist > 0 else 0.0
                        
                        # Check completion thresholds
                        for threshold in [0.25, 0.5, 0.75]:
                            if completion >= threshold and not completion_recorded[threshold]:
                                completion_recorded[threshold] = True
                                is_correct = (pred_goal == env.true_goal_idx)
                                completion_correct[threshold] = is_correct
            
            # Visualize
            if visualize:
                env.render()
                time.sleep(0.01)
        
        # Update metrics
        terminal_reason = info.get("terminal_reason", None)
        
        if terminal_reason == "goal_reached":
            metrics["success_rate"] += 1.0
        elif terminal_reason == "collision":
            metrics["collision_rate"] += 1.0
        elif terminal_reason == "timeout":
            metrics["timeout_rate"] += 1.0
        
        metrics["avg_reward"] += episode_reward
        metrics["avg_steps"] += episode_steps
        
        if episode_gammas:
            metrics["avg_gamma"] += np.mean(episode_gammas)
        
        # Calculate path efficiency (ratio of direct distance to actual path length)
        direct_distance = distance(start_pos, goal_pos)
        path_length = 0.0
        for i in range(1, len(trajectory)):
            path_length += distance(trajectory[i-1], trajectory[i])
        
        if path_length > 0:
            path_efficiency = direct_distance / path_length
            metrics["path_efficiency"] += path_efficiency
        
        # Goal inference accuracy
        if goal_predictions and hasattr(env, "true_goal_idx"):
            accuracy = np.mean([1.0 if p == env.true_goal_idx else 0.0 for p in goal_predictions])
            metrics["goal_inference_accuracy"] += accuracy
        
        # Path completion accuracy
        for threshold in [0.25, 0.5, 0.75]:
            if completion_recorded[threshold]:
                metrics["completion_accuracy"][threshold] += float(completion_correct[threshold])
    
    # Normalize metrics
    metrics["success_rate"] /= n_episodes
    metrics["collision_rate"] /= n_episodes
    metrics["timeout_rate"] /= n_episodes
    metrics["avg_reward"] /= n_episodes
    metrics["avg_steps"] /= n_episodes
    metrics["avg_gamma"] /= n_episodes
    metrics["goal_inference_accuracy"] /= n_episodes
    metrics["path_efficiency"] /= n_episodes
    
    for threshold in [0.25, 0.5, 0.75]:
        recorded_count = sum(1 for e in range(n_episodes) if completion_recorded[threshold])
        if recorded_count > 0:
            metrics["completion_accuracy"][threshold] /= recorded_count
    
    # Print evaluation results
    print("\nEvaluation Results:")
    print(f"  Success Rate: {metrics['success_rate']:.3f}")
    print(f"  Collision Rate: {metrics['collision_rate']:.3f}")
    print(f"  Timeout Rate: {metrics['timeout_rate']:.3f}")
    print(f"  Average Reward: {metrics['avg_reward']:.3f}")
    print(f"  Average Steps: {metrics['avg_steps']:.1f}")
    print(f"  Average Gamma: {metrics['avg_gamma']:.3f}")
    print(f"  Goal Inference Accuracy: {metrics['goal_inference_accuracy']:.3f}")
    print(f"  Path Efficiency: {metrics['path_efficiency']:.3f}")
    print("  Goal Accuracy at Path Completion:")
    for threshold in [0.25, 0.5, 0.75]:
        print(f"    {int(threshold*100)}%: {metrics['completion_accuracy'][threshold]:.3f}")
    
    # Save metrics to file
    os.makedirs("evaluation_results", exist_ok=True)
    with open("evaluation_results/metrics.json", "w") as f:
        import json
        json.dump(metrics, f, indent=2)
    print("Saved evaluation metrics to evaluation_results/metrics.json")
    
    env.close()
    return metrics

def compare_models(baseline_path="models/baseline_model", 
                  adaptive_path="models/shared_autonomy_model",
                  n_episodes=30, visualize=True):
    """
    Compare baseline and adaptive models.
    
    Args:
        baseline_path (str): Path to the baseline model
        adaptive_path (str): Path to the adaptive model
        n_episodes (int): Number of episodes for evaluation
        visualize (bool): Whether to visualize comparison
    """
    print("Comparing models...")
    
    # Evaluate baseline model
    print("\nEvaluating baseline model:")
    baseline_metrics = evaluate_model(baseline_path, n_episodes, visualize)
    
    # Evaluate adaptive model
    print("\nEvaluating adaptive model:")
    adaptive_metrics = evaluate_model(adaptive_path, n_episodes, visualize)
    
    if baseline_metrics is None or adaptive_metrics is None:
        print("Could not compare models.")
        return
    
    # Calculate improvements
    improvements = {}
    for key in baseline_metrics:
        if key == "completion_accuracy":
            improvements[key] = {}
            for threshold in baseline_metrics[key]:
                baseline_val = baseline_metrics[key][threshold]
                adaptive_val = adaptive_metrics[key][threshold]
                if baseline_val > 0:
                    rel_improvement = (adaptive_val - baseline_val) / baseline_val
                    improvements[key][threshold] = rel_improvement
        else:
            baseline_val = baseline_metrics[key]
            adaptive_val = adaptive_metrics[key]
            if baseline_val > 0:
                rel_improvement = (adaptive_val - baseline_val) / baseline_val
                improvements[key] = rel_improvement
    
    # Print comparison
    print("\nModel Comparison (Relative Improvement):")
    for key in improvements:
        if key == "completion_accuracy":
            print("  Goal Accuracy at Path Completion:")
            for threshold in improvements[key]:
                print(f"    {int(threshold*100)}%: {improvements[key][threshold]:.2%}")
        else:
            # For metrics where higher is better
            if key in ["success_rate", "avg_reward", "goal_inference_accuracy", "path_efficiency"]:
                print(f"  {key}: {improvements[key]:.2%}")
            # For metrics where lower is better (show negative improvement as positive)
            elif key in ["collision_rate", "timeout_rate"]:
                print(f"  {key}: {-improvements[key]:.2%}")
            else:
                print(f"  {key}: {improvements[key]:.2%}")
    
    # Plot comparison
    if visualize:
        plt.figure(figsize=(12, 8))
        
        # Metrics to compare (excluding completion_accuracy)
        metrics = ["success_rate", "collision_rate", "goal_inference_accuracy", "path_efficiency"]
        
        x = np.arange(len(metrics))
        width = 0.35
        
        baseline_vals = [baseline_metrics[m] for m in metrics]
        adaptive_vals = [adaptive_metrics[m] for m in metrics]
        
        plt.bar(x - width/2, baseline_vals, width, label='Baseline')
        plt.bar(x + width/2, adaptive_vals, width, label='Adaptive')
        
        plt.xlabel('Metrics')
        plt.ylabel('Values')
        plt.title('Comparison of Baseline vs. Adaptive Models')
        plt.xticks(x, metrics)
        plt.legend()
        plt.grid(True, axis='y')
        
        plt.tight_layout()
        os.makedirs("evaluation_results", exist_ok=True)
        plt.savefig("evaluation_results/model_comparison.png")
        plt.close()
        
        # Plot completion accuracy
        plt.figure(figsize=(8, 6))
        
        thresholds = [0.25, 0.5, 0.75]
        x = np.arange(len(thresholds))
        width = 0.35
        
        baseline_vals = [baseline_metrics["completion_accuracy"][t] for t in thresholds]
        adaptive_vals = [adaptive_metrics["completion_accuracy"][t] for t in thresholds]
        
        plt.bar(x - width/2, baseline_vals, width, label='Baseline')
        plt.bar(x + width/2, adaptive_vals, width, label='Adaptive')
        
        plt.xlabel('Path Completion')
        plt.ylabel('Goal Inference Accuracy')
        plt.title('Goal Inference Accuracy at Different Path Completions')
        plt.xticks(x, [f"{int(t*100)}%" for t in thresholds])
        plt.ylim(0, 1.1)
        plt.legend()
        plt.grid(True, axis='y')
        
        plt.tight_layout()
        plt.savefig("evaluation_results/completion_accuracy_comparison.png")
        plt.close()
    
    print("Saved comparison plots to evaluation_results/")

###############################################################################
# SEQUENTIAL BASELINES EVALUATION
###############################################################################
def evaluate_sequential_approach():
    """
    Evaluate a sequential approach (separate goal inference and control).
    This implements the sequential baseline described in the manuscript
    for validating the joint optimization theorem.
    """
    print("Evaluating sequential approach...")
    
    # Create environment
    env = SharedAutonomyEnv(visualize=False, use_joint_optimization=False)
    
    # Train sequential approach
    # Step 1: Use pre-trained Bayesian model
    bayesian_model = None
    if os.path.exists("models/bayesian_model.pkl"):
        try:
            with open("models/bayesian_model.pkl", "rb") as f:
                bayesian_model = pickle.load(f)
            print("Loaded pre-trained Bayesian model")
        except:
            print("Could not load pre-trained Bayesian model")
    
    if bayesian_model is None:
        # Train Bayesian model
        bayesian_model = train_bayesian_model(epochs=3)
    
    # Step 2: Train PPO with fixed Bayesian model
    env.bayesian_model = bayesian_model
    
    # Create sequential PPO model
    sequential_model = PPO(
        "MlpPolicy",
        DummyVecEnv([lambda: env]),
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=256,
        n_epochs=5,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1,
        tensorboard_log="./tensorboard_logs/sequential/",
    )
    
    # Train
    print("Training sequential PPO model...")
    sequential_model.learn(total_timesteps=300000)
    
    # Save model
    os.makedirs("models", exist_ok=True)
    sequential_model.save("models/sequential_model")
    print("Saved sequential model to models/sequential_model")
    
    # Evaluate
    metrics = evaluate_model("models/sequential_model", n_episodes=30, visualize=True)
    
    return metrics

###############################################################################
# ABLATION STUDIES
###############################################################################
def run_ablation_studies(n_episodes=20):
    """
    Run ablation studies to evaluate each component's contribution.
    
    This evaluates:
    1. MAP Selection: Using only the most likely goal
    2. Fixed Threshold: Using fixed blending parameter thresholds
    3. Independent Networks: Separate networks for goal inference and control
    4. Partial Info: Limited goal information for the policy
    5. Joint Optimization (Full model)
    """
    print("Running ablation studies...")
    
    # Define ablation conditions
    ablations = [
        ("MAP Selection", "models/map_selection_model"),
        ("Fixed Threshold", "models/fixed_threshold_model"),
        ("Independent Networks", "models/independent_networks_model"),
        ("Partial Info", "models/partial_info_model"),
        ("Joint Optimization", "models/shared_autonomy_model")
    ]
    
    # Track metrics
    ablation_metrics = {}
    
    for name, model_path in ablations:
        print(f"\nEvaluating {name}:")
        
        # Train model if it doesn't exist (simplified here)
        if not os.path.exists(f"{model_path}.zip") and name != "Joint Optimization":
            print(f"Training {name} model...")
            
            # Create appropriate environment based on ablation
            if name == "MAP Selection":
                # Modified environment that only uses the most likely goal
                env = SharedAutonomyEnv(visualize=False)
                env.use_joint_optimization = False
                
                # Override compute_expert_recommendation to use only MAP
                original_recommend = env.bayesian_model.compute_expert_recommendation
                def map_recommendation(self, agent_pos, obstacles):
                    most_likely_goal, _ = self.get_most_likely_goal()
                    if most_likely_goal is None:
                        return np.zeros(2, dtype=np.float32)
                    return potential_field_dir(agent_pos, most_likely_goal, obstacles)
                
                env.bayesian_model.compute_expert_recommendation = map_recommendation.__get__(
                    env.bayesian_model, type(env.bayesian_model))
                
            elif name == "Fixed Threshold":
                # Modified environment that uses fixed gamma thresholds
                env = SharedAutonomyEnv(visualize=False)
                
                # Override step method to use fixed thresholds
                original_step = env.step
                def fixed_threshold_step(self, action):
                    # Get goal probability
                    goal_probs = self.bayesian_model.get_goal_probs()
                    max_prob = np.max(goal_probs) if len(goal_probs) > 0 else 0.0
                    
                    # Simple fixed threshold policy:
                    # High certainty (>0.8): gamma = 0.8
                    # Medium certainty (0.5-0.8): gamma = 0.5
                    # Low certainty (<0.5): gamma = 0.2
                    if max_prob > 0.8:
                        gamma = 0.8
                    elif max_prob > 0.5:
                        gamma = 0.5
                    else:
                        gamma = 0.2
                    
                    # Convert to action space
                    action_fixed = np.array([2.0 * gamma - 1.0])
                    
                    # Call original step with fixed action
                    return original_step(self, action_fixed)
                
                env.step = fixed_threshold_step.__get__(env, type(env))
                
            elif name == "Independent Networks":
                # Use standard environment but with separate networks
                env = SharedAutonomyEnv(visualize=False)
                env.use_joint_optimization = False
                
            elif name == "Partial Info":
                # Modified environment with limited goal information
                env = SharedAutonomyEnv(visualize=False)
                
                # Override _get_obs to provide only maximum goal probability
                original_get_obs = env._get_obs
                def limited_obs(self):
                    # Get full observation
                    full_obs = original_get_obs(self)
                    
                    # Extract basic features (first 11 elements)
                    basic_features = full_obs[:11]
                    
                    # Replace goal probabilities with just the maximum probability
                    goal_probs = self.bayesian_model.get_goal_probs()
                    max_prob = np.max(goal_probs) if len(goal_probs) > 0 else 0.0
                    
                    # Return limited observation
                    return np.append(basic_features, [max_prob])
                
                env._get_obs = limited_obs.__get__(env, type(env))
                
                # Update observation space
                env.observation_space = spaces.Box(
                    low=np.array([0, 0, -1, -1, 0, 0, -1, -1, 0, 0, 0, 0], dtype=np.float32),
                    high=np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 
                                  FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 
                                  1, 1, np.log(8), 1], dtype=np.float32),
                    shape=(12,), dtype=np.float32
                )
            
            # Create model
            model = PPO(
                "MlpPolicy",
                DummyVecEnv([lambda: env]),
                learning_rate=3e-4,
                n_steps=1024,
                batch_size=256,
                n_epochs=5,
                gamma=0.99,
                gae_lambda=0.95,
                clip_range=0.2,
                verbose=1,
                tensorboard_log=f"./tensorboard_logs/{name.lower().replace(' ', '_')}/",
            )
            
            # Train model (with simplified settings for ablation)
            model.learn(total_timesteps=200000)
            
            # Save model
            model.save(model_path)
            print(f"Saved {name} model to {model_path}")
        
        # Evaluate model
        metrics = evaluate_model(model_path, n_episodes=n_episodes, visualize=True)
        
        if metrics is not None:
            ablation_metrics[name] = metrics
    
    # Compare ablation results
    print("\nAblation Study Results:")
    
    # Success Rate
    print("\nSuccess Rate:")
    for name in ablation_metrics:
        print(f"  {name}: {ablation_metrics[name]['success_rate']:.3f}")
    
    # Goal Inference Accuracy
    print("\nGoal Inference Accuracy:")
    for name in ablation_metrics:
        print(f"  {name}: {ablation_metrics[name]['goal_inference_accuracy']:.3f}")
    
    # Path Efficiency
    print("\nPath Efficiency:")
    for name in ablation_metrics:
        print(f"  {name}: {ablation_metrics[name]['path_efficiency']:.3f}")
    
    # Create comparison plot
    if len(ablation_metrics) > 1:
        plt.figure(figsize=(14, 8))
        
        metrics_to_plot = ["success_rate", "collision_rate", "goal_inference_accuracy", "path_efficiency"]
        x = np.arange(len(metrics_to_plot))
        width = 0.8 / len(ablation_metrics)
        
        for i, name in enumerate(ablation_metrics):
            values = [ablation_metrics[name][m] for m in metrics_to_plot]
            plt.bar(x + i*width - 0.4 + width/2, values, width, label=name)
        
        plt.xlabel('Metrics')
        plt.ylabel('Values')
        plt.title('Ablation Study Results')
        plt.xticks(x, metrics_to_plot)
        plt.legend()
        plt.grid(True, axis='y')
        
        plt.tight_layout()
        os.makedirs("evaluation_results", exist_ok=True)
        plt.savefig("evaluation_results/ablation_study.png")
        plt.close()
        
        print("Saved ablation study plot to evaluation_results/ablation_study.png")
    
    return ablation_metrics

###############################################################################
# GENERATE GAMMA HEATMAP
###############################################################################
def generate_gamma_heatmap(model_path="models/shared_autonomy_model", resolution=20):
    """
    Generate a heatmap showing gamma values across the environment space.
    
    Args:
        model_path (str): Path to the trained model
        resolution (int): Resolution of the heatmap grid
    """
    print("Generating gamma heatmap...")
    
    # Create environment
    env = SharedAutonomyEnv(visualize=False)
    
    try:
        # Load model
        model = PPO.load(model_path)
        print(f"Loaded model from {model_path}")
    except:
        print(f"Could not load model from {model_path}")
        return
    
    # Create grid
    x_grid = np.linspace(0, FULL_VIEW_SIZE[0], resolution)
    y_grid = np.linspace(0, FULL_VIEW_SIZE[1], resolution)
    gamma_values = np.zeros((resolution, resolution))
    
    # Reset environment to get obstacles and goals
    obs, _ = env.reset()
    
    # Cache some data
    goal_pos = env.goal_pos
    obstacles = env.obstacles
    original_pos = env.dot_pos.copy()
    
    # Compute gamma values at each grid point
    for i, x in enumerate(x_grid):
        for j, y in enumerate(y_grid):
            # Temporarily set agent position
            env.dot_pos = np.array([x, y], dtype=np.float32)
            
            # Get observation at this position
            obs = env._get_obs()
            
            # Predict action
            action, _ = model.predict(obs, deterministic=True)
            
            # Map action to gamma
            gamma = 0.5 * (action[0] + 1.0)
            
            # Store gamma value
            gamma_values[j, i] = gamma  # Note: j, i for correct orientation
    
    # Restore original position
    env.dot_pos = original_pos
    
    # Create heatmap
    plt.figure(figsize=(12, 8))
    plt.imshow(gamma_values, origin='lower', extent=[0, FULL_VIEW_SIZE[0], 0, FULL_VIEW_SIZE[1]], 
              cmap='viridis', vmin=0, vmax=1)
    plt.colorbar(label='Gamma')
    
    # Plot obstacles and goals
    for obs_pos in obstacles:
        circle = plt.Circle((obs_pos[0], obs_pos[1]), OBSTACLE_RADIUS, color='gray')
        plt.gcf().gca().add_artist(circle)
    
    for i, g_pos in enumerate(env.goals):
        circle = plt.Circle((g_pos[0], g_pos[1]), TARGET_RADIUS, color='yellow')
        plt.gcf().gca().add_artist(circle)
    
    # Mark true goal
    circle = plt.Circle((goal_pos[0], goal_pos[1]), TARGET_RADIUS+2, color='black', fill=False)
    plt.gcf().gca().add_artist(circle)
    
    plt.title('Gamma Values Across Environment')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.grid(False)
    
    # Save heatmap
    os.makedirs("evaluation_results", exist_ok=True)
    plt.savefig("evaluation_results/gamma_heatmap.png")
    plt.close()
    
    print("Saved gamma heatmap to evaluation_results/gamma_heatmap.png")

###############################################################################
# APPLY TO ROBOTIC ARM DOMAIN
###############################################################################
def apply_to_robotic_arm():
    """
    Demonstrate domain transfer to a simulated robotic arm environment.
    This is a simplified demonstration of the transfer described in the paper.
    """
    print("This function would implement the robotic arm domain transfer.")
    print("Currently a placeholder for the concept described in the manuscript.")
    
    # This would require additional robotics libraries and a more complex environment
    # The key elements would be:
    # 1. Create a robotic arm environment (e.g., using PyBullet, MuJoCo, or Isaac Gym)
    # 2. Adapt the observation space to include joint angles, end-effector positions, etc.
    # 3. Modify the reward function for robotic manipulation tasks
    # 4. Adapt the Bayesian goal inference for 3D space
    # 5. Re-train or fine-tune the model for the new domain
    
    print("Refer to the manuscript for details on how this transfer would be implemented.")

###############################################################################
# SEQUENTIAL EXPERIMENT RUNNER - USE THIS IN JUPYTER NOTEBOOKS
###############################################################################
def run_full_experiment(visualize=False, timesteps=300000):
    """
    Run the complete experiment sequence:
    1. Train Bayesian model
    2. Train full adaptive shared autonomy model 
    3. Train baseline (sequential) model
    4. Run ablation studies
    5. Generate evaluation visualizations
    
    Args:
        visualize (bool): Whether to visualize during training/evaluation
        timesteps (int): Number of timesteps for training
    """
    print("="*70)
    print("CONTEXT-ADAPTIVE SHARED AUTONOMY EXPERIMENT")
    print("="*70)
    
    # Step 1: Train Bayesian model
    print("\n[1/5] Training Bayesian goal inference model...")
    bayesian_model = train_bayesian_model(epochs=3, visualize=visualize)
    
    # Step 2: Train adaptive model with joint optimization
    print("\n[2/5] Training adaptive shared autonomy model...")
    adaptive_model = train_model(
        total_timesteps=timesteps, 
        visualize=visualize,
        use_joint_optimization=True,
        bayesian_model_path="models/bayesian_model.pkl"
    )
    
    # Step 3: Train baseline model (sequential approach)
    print("\n[3/5] Training baseline sequential model...")
    baseline_model = evaluate_sequential_approach()
    
    # Step 4: Run ablation studies
    print("\n[4/5] Running ablation studies...")
    ablation_results = run_ablation_studies(n_episodes=20)
    
    # Step 5: Generate visualizations
    print("\n[5/5] Generating comparative visualizations...")
    # Compare models
    compare_models(
        baseline_path="models/sequential_model",
        adaptive_path="models/shared_autonomy_model",
        n_episodes=30,
        visualize=True
    )
    
    # Generate gamma heatmap
    generate_gamma_heatmap("models/shared_autonomy_model", resolution=20)
    
    print("\nExperiment complete! Results are saved in:")
    print("- training_metrics/ (training curves)")
    print("- models/ (trained models)")
    print("- evaluation_results/ (comparisons and heatmaps)")
    
    return {
        "bayesian_model": bayesian_model,
        "adaptive_model": adaptive_model,
        "baseline_model": baseline_model,
        "ablation_results": ablation_results
    }

###############################################################################
# MAIN - DON'T USE THIS IN JUPYTER NOTEBOOKS (use run_full_experiment instead)
###############################################################################
###############################################################################
# MAIN - WORKS IN BOTH SCRIPT AND JUPYTER NOTEBOOK ENVIRONMENTS
###############################################################################
if __name__ == "__main__":
    # Check if running in Jupyter
    try:
        # This will only exist in IPython/Jupyter environments
        get_ipython
        in_jupyter = True
    except NameError:
        in_jupyter = False
    
    if in_jupyter:
        print("Running in Jupyter environment")
        print("To run the experiment, use the following in a new cell:")
        print("results = run_full_experiment(visualize=False, timesteps=50000)")
    else:
        # Only use argument parser in script mode
        import argparse
        
        parser = argparse.ArgumentParser(description="Context-Adaptive Shared Autonomy")
        parser.add_argument("--train", action="store_true", help="Train the model")
        parser.add_argument("--evaluate", action="store_true", help="Evaluate the model")
        parser.add_argument("--compare", action="store_true", help="Compare models")
        parser.add_argument("--ablation", action="store_true", help="Run ablation studies")
        parser.add_argument("--heatmap", action="store_true", help="Generate gamma heatmap")
        parser.add_argument("--timesteps", type=int, default=500000, help="Training timesteps")
        parser.add_argument("--visualize", action="store_true", help="Visualize training/evaluation")
        parser.add_argument("--ui", action="store_true", help="Launch experiment UI")
        
        args = parser.parse_args()
        
        # Run the full experiment
        run_full_experiment(visualize=args.visualize, timesteps=args.timesteps)

Running in Jupyter environment
To run the experiment, use the following in a new cell:
results = run_full_experiment(visualize=False, timesteps=50000)


In [5]:
# Run the complete experiment with your desired settings
results = run_full_experiment(visualize=False, timesteps=50000)

CONTEXT-ADAPTIVE SHARED AUTONOMY EXPERIMENT

[1/5] Training Bayesian goal inference model...
Training Bayesian goal inference model...
No dataset provided. Creating synthetic data...
Generating trajectory 100/100
Generated synthetic dataset with 100 trajectories
Saved synthetic dataset to datasets/synthetic_trajectories.pkl

Performing hyperparameter grid search...
  beta=5.0, w_theta=0.5, decay=0.8: accuracy=0.781
  beta=5.0, w_theta=0.5, decay=0.85: accuracy=0.786
  beta=5.0, w_theta=0.5, decay=0.9: accuracy=0.793
  beta=5.0, w_theta=0.7, decay=0.8: accuracy=0.895
  beta=5.0, w_theta=0.7, decay=0.85: accuracy=0.896
  beta=5.0, w_theta=0.7, decay=0.9: accuracy=0.896
  beta=5.0, w_theta=0.9, decay=0.8: accuracy=0.972
  beta=5.0, w_theta=0.9, decay=0.85: accuracy=0.973
  beta=5.0, w_theta=0.9, decay=0.9: accuracy=0.979
  beta=10.0, w_theta=0.5, decay=0.8: accuracy=0.739
  beta=10.0, w_theta=0.5, decay=0.85: accuracy=0.743
  beta=10.0, w_theta=0.5, decay=0.9: accuracy=0.753
  beta=10.0, 

AttributeError: 'SharedAutonomyEnv' object has no attribute 'observation_space'