In [1]:
import os
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import entropy
from sklearn.isotonic import IsotonicRegression
import pickle
import gymnasium as gym
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv

import torch
import torch.nn as nn

# Optional visualization
try:
    import pygame
except ImportError:
    pygame = None

###############################################################################
# CONSTANTS & UTILS
###############################################################################
FULL_VIEW_SIZE = (800, 600)
DOT_RADIUS = 15
TARGET_RADIUS = 10
OBSTACLE_RADIUS = 10
COLLISION_BUFFER = 5
MAX_SPEED = 3
NOISE_MAGNITUDE = 2.5
RENDER_FPS = 30

START_POS = np.array([FULL_VIEW_SIZE[0]//2, FULL_VIEW_SIZE[1]//2], dtype=np.float32)

# Colors
WHITE = (255, 255, 255)
GRAY = (128, 128, 128)
YELLOW = (255, 255, 0)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)

def distance(a, b):
    """Calculate Euclidean distance between two points."""
    return math.hypot(a[0] - b[0], a[1] - b[1])

def check_line_collision(start, end, center, radius):
    """Check if a line segment intersects with a circle."""
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    fx = center[0] - start[0]
    fy = center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 < 1e-9:
        return distance(start, center) <= radius
    t = max(0, min(1, (fx*dx + fy*dy) / l2))
    px = start[0] + t*dx
    py = start[1] + t*dy
    return distance((px, py), center) <= radius

def line_collision(pos, new_pos, obstacles):
    """Check if a line segment collides with any obstacle."""
    for obs in obstacles:
        if check_line_collision(pos, new_pos, obs, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos, obstacles):
    """Check if a point is inside any obstacle."""
    for obs in obstacles:
        if distance(pos, obs) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def potential_field_dir(pos, goal, obstacles):
    """Get a normalized direction vector using potential field approach."""
    # Attractive force toward goal
    gx = goal[0] - pos[0]
    gy = goal[1] - pos[1]
    dg = math.hypot(gx, gy)
    if dg < 1e-6:
        return np.zeros(2, dtype=np.float32)
    att = np.array([gx / dg, gy / dg], dtype=np.float32)

    # Repulsive force from obstacles
    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 50.0
    repulsion_gain = 30000.0

    for obs in obstacles:
        dx = pos[0] - obs[0]
        dy = pos[1] - obs[1]
        dobs = math.hypot(dx, dy)
        if dobs < 1e-9:
            continue
        if dobs < repulsion_radius:
            pushx = dx / dobs
            pushy = dy / dobs
            strength = repulsion_gain / (dobs**2)
            repulse_x += pushx * strength
            repulse_y += pushy * strength

    # Combined direction
    px = att[0] + repulse_x
    py = att[1] + repulse_y
    mg = math.hypot(px, py)
    if mg < 1e-9:
        return np.zeros(2, dtype=np.float32)
    return np.array([px / mg, py / mg], dtype=np.float32)

###############################################################################
# BAYESIAN GOAL INFERENCE MODEL
###############################################################################
class BayesianGoalInference:
    """Bayesian goal inference model that infers goal from human inputs."""
    
    def __init__(self, beta=10.0, w_theta=0.7, w_d=0.3, decay_rate=0.85):
        """
        Initialize the Bayesian goal inference model.
        
        Args:
            beta: Rationality parameter for human behavior
            w_theta: Weight for angular deviation cost
            w_d: Weight for distance deviation cost
            decay_rate: Temporal smoothing parameter
        """
        self.beta = beta
        self.w_theta = w_theta
        self.w_d = w_d
        self.decay_rate = decay_rate
        self.goals = []
        self.priors = None
        self.goal_probs = None
        self.calibrator = None
        self.history = []
        self.max_hist_len = 30
    
    def initialize_goals(self, goals):
        """Set up potential goals and initialize uniform priors."""
        self.goals = goals
        n_goals = len(goals)
        # Uniform prior
        self.priors = np.ones(n_goals) / n_goals
        self.goal_probs = self.priors.copy()
        self.history = []
    
    def load_calibrator(self, calibrator_path):
        """Load a pre-trained confidence calibrator."""
        try:
            with open(calibrator_path, 'rb') as f:
                self.calibrator = pickle.load(f)
            print(f"Loaded calibrator from {calibrator_path}")
        except:
            print(f"Could not load calibrator from {calibrator_path}")
            self.calibrator = None
    
    def train_calibrator(self, confidences, accuracies):
        """Train a confidence calibrator using isotonic regression."""
        if len(confidences) < 10:
            print("Not enough data to train calibrator")
            return
        
        self.calibrator = IsotonicRegression(out_of_bounds='clip')
        self.calibrator.fit(confidences, accuracies)
        
        # Save the calibrator
        os.makedirs("models", exist_ok=True)
        with open("models/calibrator.pkl", 'wb') as f:
            pickle.dump(self.calibrator, f)
        print("Calibrator trained and saved")
    
    def compute_cost(self, human_input, agent_pos, goal_pos):
        """Compute cost of human input based on deviation from optimal."""
        # Optimal direction to goal
        goal_dir = goal_pos - agent_pos
        dist = np.linalg.norm(goal_dir)
        if dist < 1e-6:
            return 0.0
        
        goal_dir = goal_dir / dist  # Normalize
        
        # Angular deviation (radians)
        dot_product = np.clip(np.dot(human_input, goal_dir), -1.0, 1.0)
        theta_dev = abs(np.arccos(dot_product))
        
        # Distance magnitude deviation
        h_magnitude = np.linalg.norm(human_input)
        
        # Optimal magnitude (slow down when approaching target)
        min_dist = TARGET_RADIUS + DOT_RADIUS
        if dist < min_dist:
            opt_magnitude = 0.0
        else:
            dist_factor = min(1.0, (dist - min_dist) / 200)
            opt_magnitude = dist_factor * 1.0
        
        d_dev = abs(1.0 - h_magnitude / max(opt_magnitude, 1e-6))
        
        # Combined cost
        cost = self.w_theta * theta_dev + self.w_d * d_dev
        return cost
    
    def update(self, agent_pos, human_input, obstacles=None):
        """Update goal probabilities based on observed human input."""
        if len(self.goals) == 0:
            return np.array([])
        
        # Calculate likelihoods for each goal
        likelihoods = np.zeros(len(self.goals))
        costs = np.zeros(len(self.goals))
        
        for i, goal in enumerate(self.goals):
            cost = self.compute_cost(human_input, agent_pos, goal)
            costs[i] = cost
            # Noisy rational model: P(action|goal) ∝ exp(-β * cost)
            likelihoods[i] = np.exp(-self.beta * cost)
        
        # Normalize likelihoods
        if np.sum(likelihoods) > 0:
            likelihoods = likelihoods / np.sum(likelihoods)
        else:
            likelihoods = np.ones_like(likelihoods) / len(likelihoods)
        
        # Bayesian update: P(goal|action) ∝ P(action|goal) * P(goal)
        raw_posteriors = likelihoods * self.goal_probs
        
        # Normalize
        if np.sum(raw_posteriors) > 0:
            posteriors = raw_posteriors / np.sum(raw_posteriors)
        else:
            posteriors = self.priors.copy()
        
        # Apply temporal smoothing
        self.goal_probs = self.decay_rate * self.goal_probs + (1 - self.decay_rate) * posteriors
        
        # Normalize again after smoothing
        if np.sum(self.goal_probs) > 0:
            self.goal_probs = self.goal_probs / np.sum(self.goal_probs)
        
        # Apply calibration if available
        if self.calibrator is not None:
            max_prob = np.max(self.goal_probs)
            calibrated_max = self.calibrator.predict([max_prob])[0]
            
            # Adjust probabilities proportionally
            if max_prob > 0:
                scale_factor = calibrated_max / max_prob
                self.goal_probs = self.goal_probs * scale_factor
                remainder = 1.0 - np.sum(self.goal_probs)
                if remainder > 0:
                    # Distribute remainder proportionally
                    indices = np.arange(len(self.goal_probs))
                    max_idx = np.argmax(self.goal_probs)
                    other_indices = indices[indices != max_idx]
                    if len(other_indices) > 0:
                        other_sum = np.sum(self.goal_probs[other_indices])
                        if other_sum > 0:
                            for idx in other_indices:
                                self.goal_probs[idx] += remainder * (self.goal_probs[idx] / other_sum)
                        else:
                            for idx in other_indices:
                                self.goal_probs[idx] += remainder / len(other_indices)
        
        # Track history
        self.history.append(self.goal_probs.copy())
        if len(self.history) > self.max_hist_len:
            self.history.pop(0)
        
        return self.goal_probs
    
    def get_goal_probs(self):
        """Get current goal probabilities."""
        return self.goal_probs
    
    def get_most_likely_goal(self):
        """Get the most likely goal and its probability."""
        if len(self.goals) == 0:
            return None, 0.0
        
        max_idx = np.argmax(self.goal_probs)
        return self.goals[max_idx], self.goal_probs[max_idx]
    
    def get_entropy(self):
        """Calculate entropy of the goal distribution."""
        return entropy(self.goal_probs)
    
    def compute_expert_recommendation(self, agent_pos, obstacles):
        """Compute expert recommendation based on goal probabilities."""
        if len(self.goals) == 0:
            return np.zeros(2, dtype=np.float32)
        
        # Compute weighted direction based on goal probabilities
        weighted_dir = np.zeros(2, dtype=np.float32)
        
        for i, goal in enumerate(self.goals):
            # Get expert direction for this goal
            expert_dir = potential_field_dir(agent_pos, goal, obstacles)
            # Weight by goal probability
            weighted_dir += self.goal_probs[i] * expert_dir
        
        # Normalize
        magnitude = np.linalg.norm(weighted_dir)
        if magnitude > 1e-6:
            weighted_dir = weighted_dir / magnitude
        
        return weighted_dir
    
    def reset(self):
        """Reset to initial state with uniform priors."""
        if len(self.goals) > 0:
            self.goal_probs = self.priors.copy()
        self.history = []

###############################################################################
# METRICS CALLBACK
###############################################################################
class MetricsCallback(BaseCallback):
    """Callback to track training metrics."""
    
    def __init__(self, verbose=0):
        super().__init__(verbose)
        # Episode metrics
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_gammas = []
        self.goal_inference_accuracy = []
        self.goal_entropy = []
        
        # Current episode tracking
        self.total_reward = 0.0
        self.ep_length = 0
        self.current_gammas = []
        self.current_goal_probs = []
        self.current_true_goal_idx = None
        
        # Collision tracking
        self.n_collisions = 0
        self.n_episodes = 0

    def _on_step(self) -> bool:
        """Called at each step of training."""
        actions = self.locals['actions']
        rewards = self.locals['rewards']
        dones = self.locals['dones']
        infos = self.locals['infos']
        
        # Get environment
        env = self.model.env.envs[0]
        
        # Compute gamma from action (mapping [-1,1] -> [0,1])
        gamma_val = 0.5 * (actions[0][0] + 1.0)
        self.current_gammas.append(gamma_val)
        
        # Track reward and episode length
        r = float(rewards[0])
        self.total_reward += r
        self.ep_length += 1
        
        # Track goal inference
        if hasattr(env, 'bayesian_model') and hasattr(env, 'true_goal_idx'):
            goal_probs = env.bayesian_model.get_goal_probs()
            self.current_goal_probs.append(goal_probs.copy())
            
            if self.current_true_goal_idx is None:
                self.current_true_goal_idx = env.true_goal_idx
        
        if dones[0]:
            # Calculate goal inference accuracy
            if self.current_goal_probs and self.current_true_goal_idx is not None:
                accuracies = []
                entropies = []
                
                for probs in self.current_goal_probs:
                    pred_goal = np.argmax(probs)
                    accuracies.append(int(pred_goal == self.current_true_goal_idx))
                    entropies.append(entropy(probs))
                
                self.goal_inference_accuracy.append(np.mean(accuracies))
                self.goal_entropy.append(np.mean(entropies))
            
            # Store episode statistics
            self.episode_rewards.append(self.total_reward)
            self.episode_lengths.append(self.ep_length)
            
            if self.current_gammas:
                self.episode_gammas.append(np.mean(self.current_gammas))
            
            # Reset episode tracking
            self.total_reward = 0.0
            self.ep_length = 0
            self.current_gammas = []
            self.current_goal_probs = []
            self.current_true_goal_idx = None
            
            # Increment episode counter
            self.n_episodes += 1
            if 'terminal_reason' in infos[0] and infos[0]['terminal_reason'] == 'collision':
                self.n_collisions += 1
        
        return True

    def _moving_average(self, data, window=10):
        """Calculate moving average of data."""
        if len(data) < window:
            return np.array(data)
        return np.convolve(data, np.ones(window)/window, mode='valid')

    def save_metrics(self, save_dir="training_metrics"):
        """Save metrics plots to the specified directory."""
        os.makedirs(save_dir, exist_ok=True)
        
        # Episode Rewards
        if self.episode_rewards:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_rewards, label="Episode Reward", alpha=0.6)
            ma_rewards = self._moving_average(self.episode_rewards, 10)
            if len(ma_rewards):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_rewards)), 
                         ma_rewards, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Reward")
            plt.title("Episode Rewards")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "episode_rewards.png"))
            plt.close()
        
        # Average Gamma per Episode
        if self.episode_gammas:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_gammas, label="Average Gamma", alpha=0.6)
            ma_gamma = self._moving_average(self.episode_gammas, 10)
            if len(ma_gamma):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_gamma)), 
                         ma_gamma, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Gamma (avg)")
            plt.title("Average Gamma per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "average_gamma.png"))
            plt.close()
        
        # Goal Inference Accuracy
        if self.goal_inference_accuracy:
            plt.figure(figsize=(10, 6))
            plt.plot(self.goal_inference_accuracy, label="Goal Inference Accuracy", alpha=0.6)
            ma_acc = self._moving_average(self.goal_inference_accuracy, 10)
            if len(ma_acc):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_acc)), 
                         ma_acc, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Accuracy")
            plt.title("Goal Inference Accuracy per Episode")
            plt.ylim(0, 1.05)
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "goal_inference_accuracy.png"))
            plt.close()
        
        # Goal Distribution Entropy
        if self.goal_entropy:
            plt.figure(figsize=(10, 6))
            plt.plot(self.goal_entropy, label="Goal Distribution Entropy", alpha=0.6)
            ma_ent = self._moving_average(self.goal_entropy, 10)
            if len(ma_ent):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_ent)), 
                         ma_ent, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Entropy")
            plt.title("Goal Distribution Entropy per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "goal_entropy.png"))
            plt.close()
        
        # Summary text
        with open(os.path.join(save_dir, "summary.txt"), "w") as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if self.episode_rewards:
                f.write(f"Mean Episode Reward: {np.mean(self.episode_rewards):.3f}\n")
            f.write(f"Collisions Rate: {self.n_collisions/max(1, self.n_episodes):.3f}\n")
            if self.episode_gammas:
                f.write(f"Mean Gamma: {np.mean(self.episode_gammas):.3f}\n")
            if self.goal_inference_accuracy:
                f.write(f"Mean Goal Inference Accuracy: {np.mean(self.goal_inference_accuracy):.3f}\n")

###############################################################################
# SHARED AUTONOMY ENVIRONMENT
###############################################################################
class SharedAutonomyEnv(gym.Env):
    """Environment for training shared autonomy."""
    
    metadata = {"render_modes": ["human"], "render_fps": RENDER_FPS}
    
    def __init__(self, visualize=False, use_joint_optimization=True):
        """Initialize the environment."""
        super().__init__()
        self.visualize = visualize
        self.use_joint_optimization = use_joint_optimization
        
        # Action space: gamma ∈ [-1, 1] (mapped to [0, 1])
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        
        # Default observation space (will be updated in reset)
        # Assuming a default of 8 goals for initial observation space
        n_goals = 8
        low = np.array(
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0, 0] + [0] * n_goals,
            dtype=np.float32
        )
        high = np.array(
            [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
            FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 1, 1, np.log(n_goals)] + [1] * n_goals,
            dtype=np.float32
        )
        self.observation_space = spaces.Box(low=low, high=high, shape=(11 + n_goals,), dtype=np.float32)
        
        # Initialize state
        self.dot_pos = None
        self.goal_pos = None
        self.obstacles = []
        self.goals = []
        self.true_goal_idx = None
        self.step_count = 0
        self.max_steps = 200
        self.episode_reward = 0.0
        self.max_dist = math.hypot(FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1])
        
        # Bayesian model for goal inference
        self.bayesian_model = BayesianGoalInference(beta=10.0, w_theta=0.7, w_d=0.3)
        
        # Set up visualization
        if self.visualize and pygame is not None:
            pygame.init()
            self.window = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Shared Autonomy")
            self.clock = pygame.time.Clock()
        else:
            self.window = None
            self.clock = None
    
    def reset(self, seed=None, options=None):
        """Reset the environment."""
        super().reset(seed=seed)
        
        # Create the environment layout
        self.randomize_env(seed if seed is not None else random.randint(0, 9999))
        
        # Reset internal state
        self.step_count = 0
        self.episode_reward = 0.0
        self.dot_pos = START_POS.copy()
        
        # Choose a random goal
        if self.goals:
            self.true_goal_idx = random.randint(0, len(self.goals) - 1)
            self.goal_pos = self.goals[self.true_goal_idx].copy()
        else:
            self.goal_pos = np.array([random.uniform(0.2*FULL_VIEW_SIZE[0], 0.8*FULL_VIEW_SIZE[0]),
                                      random.uniform(0.2*FULL_VIEW_SIZE[1], 0.8*FULL_VIEW_SIZE[1])],
                                     dtype=np.float32)
            self.true_goal_idx = 0
        
        # Reset Bayesian model
        self.bayesian_model.initialize_goals(self.goals)
        self.bayesian_model.reset()
        
        # Update observation space based on number of goals
        n_goals = len(self.goals)
        low = np.array(
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0, 0] + [0] * n_goals,
            dtype=np.float32
        )
        high = np.array(
            [FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
             FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 1, 1, np.log(n_goals)] + [1] * n_goals,
            dtype=np.float32
        )
        self.observation_space = spaces.Box(low=low, high=high, shape=(11 + n_goals,), dtype=np.float32)
        
        return self._get_obs(), {}
    
    def randomize_env(self, seed):
        """Create a randomized environment layout."""
        random.seed(seed)
        np.random.seed(seed)
        
        margin = 50
        n_goals = 8
        n_obstacles = 5
        min_goal_distance = 200
        
        # Generate goals
        new_goals = []
        attempts = 0
        while len(new_goals) < n_goals and attempts < 1000:
            x = random.uniform(margin, FULL_VIEW_SIZE[0] - margin)
            y = random.uniform(margin, FULL_VIEW_SIZE[1] - margin)
            candidate = np.array([x, y], dtype=np.float32)
            
            # Ensure goal is far enough from start position
            if distance(candidate, START_POS) >= min_goal_distance:
                new_goals.append(candidate)
            attempts += 1
        
        self.goals = new_goals[:n_goals]
        
        # Generate obstacles
        new_obstacles = []
        for _ in range(n_obstacles):
            x = random.uniform(margin, FULL_VIEW_SIZE[0] - margin)
            y = random.uniform(margin, FULL_VIEW_SIZE[1] - margin)
            candidate = np.array([x, y], dtype=np.float32)
            
            # Check validity
            valid = True
            if distance(candidate, START_POS) < (DOT_RADIUS + OBSTACLE_RADIUS + 10):
                valid = False
            for goal in self.goals:
                if distance(candidate, goal) < (TARGET_RADIUS + OBSTACLE_RADIUS + 20):
                    valid = False
            for obs in new_obstacles:
                if distance(candidate, obs) < (2*OBSTACLE_RADIUS + 10):
                    valid = False
            
            if valid:
                new_obstacles.append(candidate)
        
        self.obstacles = new_obstacles
    
    def step(self, action):
        """Execute one environment step."""
        # Map action to gamma
        gamma_val = 0.5 * (action[0] + 1.0)
        
        self.step_count += 1
        
        # Compute world direction using potential field
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        
        # Add noise for human direction (simulating imperfect human input)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        
        # Update Bayesian model with human input
        goal_probs = self.bayesian_model.update(self.dot_pos, h_dir, self.obstacles)
        
        # Get expert recommendation
        if self.use_joint_optimization:
            # Use weighted expert recommendation
            w_dir = self.bayesian_model.compute_expert_recommendation(self.dot_pos, self.obstacles)
        
        # Combine directions using gamma
        c_dir = gamma_val * w_dir + (1 - gamma_val) * h_dir
        cm = np.hypot(c_dir[0], c_dir[1])
        if cm > 1e-6:
            c_dir /= cm
        
        # Compute new position
        move_vec = c_dir * MAX_SPEED
        new_pos = self.dot_pos + move_vec
        
        # Check for collision
        if not line_collision(self.dot_pos, new_pos, self.obstacles):
            new_pos[0] = np.clip(new_pos[0], 0, FULL_VIEW_SIZE[0])
            new_pos[1] = np.clip(new_pos[1], 0, FULL_VIEW_SIZE[1])
            self.dot_pos = new_pos
        
        # Check if inside obstacle
        collided = inside_obstacle(self.dot_pos, self.obstacles)
        info = {}
        
        if collided:
            reward = -2.0
            done = True
            info["terminal_reason"] = "collision"
        else:
            # Check if reached goal
            if distance(self.dot_pos, self.goal_pos) <= (TARGET_RADIUS + DOT_RADIUS):
                reward = 5.0
                done = True
                info["terminal_reason"] = "goal_reached"
            else:
                # Progress reward
                prev_pos = self.dot_pos - move_vec
                prev_dist = distance(prev_pos, self.goal_pos)
                curr_dist = distance(self.dot_pos, self.goal_pos)
                progress = prev_dist - curr_dist
                
                reward = 0.1 * progress
                done = False
                info["terminal_reason"] = None
        
        # Check for timeout
        truncated = (self.step_count >= self.max_steps)
        if truncated and not done:
            info["terminal_reason"] = "timeout"
        
        # Get observation
        obs = self._get_obs()
        
        return obs, float(reward), done, truncated, info
    
    def _get_obs(self):
        """Get the current observation."""
        # Distance to goal
        to_g = self.goal_pos - self.dot_pos
        d = math.hypot(to_g[0], to_g[1])
        dist_ratio = d / self.max_dist
        
        # Compute directions
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        
        # Distance to closest obstacle
        if self.obstacles:
            min_obs_distance = min(distance(self.dot_pos, obs) for obs in self.obstacles)
        else:
            min_obs_distance = self.max_dist
        obs_dist_ratio = min_obs_distance / self.max_dist
        
        # Get goal probabilities and entropy
        goal_probs = self.bayesian_model.get_goal_probs()
        goal_ent = entropy(goal_probs) if len(goal_probs) > 0 else np.log(len(self.goals))
        
        # Combine all features
        base_obs = np.array([
            self.dot_pos[0], self.dot_pos[1],
            h_dir[0], h_dir[1],
            self.goal_pos[0], self.goal_pos[1],
            w_dir[0], w_dir[1],
            dist_ratio, obs_dist_ratio,
            goal_ent
        ], dtype=np.float32)
        
        # Append goal probabilities
        obs = np.concatenate([base_obs, goal_probs]).astype(np.float32)
        
        return obs
    
    def render(self):
        """Render the environment."""
        if not self.visualize or (pygame is None):
            return
        
        # Handle user events
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return
        
        # Clear the screen
        self.window.fill(WHITE)
        
        # Draw obstacles
        for obs in self.obstacles:
            pygame.draw.circle(self.window, GRAY, (int(obs[0]), int(obs[1])), OBSTACLE_RADIUS)
        
        # Draw all potential targets
        for i, gpos in enumerate(self.goals):
            # Color intensity based on goal probability
            goal_probs = self.bayesian_model.get_goal_probs()
            if i < len(goal_probs):
                prob = goal_probs[i]
                # Interpolate color from yellow to green based on probability
                color = (int(255 * (1 - prob)), int(255), 0)
            else:
                color = YELLOW
            
            # Draw the target
            pygame.draw.circle(self.window, color, (int(gpos[0]), int(gpos[1])), TARGET_RADIUS)
            
            # Label the target
            font = pygame.font.SysFont(None, 24)
            label = font.render(str(i), True, BLACK)
            self.window.blit(label, (int(gpos[0]) - 5, int(gpos[1]) - 8))
        
        # Draw the true goal with a black outline
        pygame.draw.circle(self.window, BLACK, 
                          (int(self.goal_pos[0]), int(self.goal_pos[1])), 
                          TARGET_RADIUS+2, width=2)
        
        # Draw the agent
        pygame.draw.circle(self.window, RED, 
                          (int(self.dot_pos[0]), int(self.dot_pos[1])), 
                          DOT_RADIUS)
        
        # Display gamma value
        if hasattr(self, 'last_gamma'):
            font = pygame.font.SysFont(None, 30)
            gamma_text = font.render(f"γ: {self.last_gamma:.2f}", True, BLACK)
            self.window.blit(gamma_text, (10, 10))
        
        # Update the display
        pygame.display.flip()
        self.clock.tick(RENDER_FPS)
    
    def close(self):
        """Close the environment."""
        if self.visualize and pygame is not None:
            pygame.quit()
        super().close()

###############################################################################
# TRAINING FUNCTIONS
###############################################################################
def train_bayesian_model(visualize=False):
    """Train and evaluate the Bayesian goal inference model."""
    print("Training Bayesian goal inference model...")
    
    # Create a model with default parameters
    model = BayesianGoalInference(beta=10.0, w_theta=0.7, w_d=0.3)
    
    # Create a synthetic dataset
    n_trajectories = 50
    trajectories = []
    env = SharedAutonomyEnv(visualize=False)
    
    for i in range(n_trajectories):
        print(f"Generating trajectory {i+1}/{n_trajectories}", end="\r")
        
        # Reset environment
        obs, _ = env.reset()
        
        # Extract information
        goal_pos = env.goal_pos
        true_goal_idx = env.true_goal_idx
        goals = env.goals
        
        # Generate a trajectory
        trajectory = {
            "positions": [env.dot_pos.copy()],
            "inputs": [],
            "goal_pos": goal_pos.copy(),
            "true_goal_idx": true_goal_idx,
            "goals": [g.copy() for g in goals],
        }
        
        # Simple policy to generate trajectories
        done = False
        truncated = False
        while not (done or truncated):
            # Direction to goal with noise
            to_goal = goal_pos - env.dot_pos
            dist = np.linalg.norm(to_goal)
            if dist > 0:
                direction = to_goal / dist
            else:
                direction = np.zeros(2)
            
            # Add noise
            noise_scale = 0.2
            noise = np.random.normal(0, noise_scale, size=2)
            noisy_direction = direction + noise
            
            # Normalize
            norm = np.linalg.norm(noisy_direction)
            if norm > 0:
                noisy_direction = noisy_direction / norm
            
            # Store human input
            trajectory["inputs"].append(noisy_direction.copy())
            
            # Step environment with full human control
            obs, _, done, truncated, _ = env.step([-1.0])  # -1.0 maps to gamma=0
            
            # Store position
            trajectory["positions"].append(env.dot_pos.copy())
            
            if len(trajectory["positions"]) > 200:  # Prevent infinite loops
                break
        
        trajectories.append(trajectory)
    
    env.close()
    print("\nGenerated synthetic dataset with", len(trajectories), "trajectories")
    
    # Save dataset
    os.makedirs("datasets", exist_ok=True)
    with open("datasets/synthetic_trajectories.pkl", "wb") as f:
        pickle.dump(trajectories, f)
    print("Saved synthetic dataset to datasets/synthetic_trajectories.pkl")
    
    # Train confidence calibration
    print("\nTraining confidence calibration...")
    confidences = []
    accuracies = []
    
    n_calib = min(40, len(trajectories))
    for i in range(n_calib):
        traj = trajectories[i]
        goals = traj["goals"]
        true_idx = traj["true_goal_idx"]
        
        # Initialize model with goals
        model.initialize_goals(goals)
        
        # Process trajectory
        positions = traj["positions"]
        inputs = traj["inputs"]
        n_steps = min(len(inputs), len(positions) - 1)
        
        for t in range(n_steps):
            pos = positions[t]
            h_input = inputs[t]
            
            # Update model
            goal_probs = model.update(pos, h_input)
            
            # Record confidence and accuracy
            if len(goal_probs) > 0:
                max_prob = np.max(goal_probs)
                pred_idx = np.argmax(goal_probs)
                is_correct = int(pred_idx == true_idx)
                
                confidences.append(max_prob)
                accuracies.append(is_correct)
    
    # Train calibrator
    model.train_calibrator(confidences, accuracies)
    
    # Save the trained model
    os.makedirs("models", exist_ok=True)
    with open("models/bayesian_model.pkl", "wb") as f:
        pickle.dump(model, f)
    print("Saved trained Bayesian model to models/bayesian_model.pkl")
    
    # Evaluate on test set
    print("\nEvaluating model on test set...")
    eval_results = {}
    completion_points = [0.25, 0.5, 0.75, 1.0]
    for cp in completion_points:
        eval_results[cp] = []
    
    n_test = min(10, len(trajectories))
    for i in range(n_test):
        traj = trajectories[i]
        goals = traj["goals"]
        true_idx = traj["true_goal_idx"]
        
        # Initialize model with goals
        model.initialize_goals(goals)
        
        # Process trajectory
        positions = traj["positions"]
        inputs = traj["inputs"]
        n_steps = min(len(inputs), len(positions) - 1)
        
        for cp in completion_points:
            step_idx = int(cp * n_steps) - 1
            if step_idx < 0:
                continue
            
            # Reset model
            model.reset()
            
            # Process up to completion point
            for t in range(step_idx + 1):
                pos = positions[t]
                h_input = inputs[t]
                goal_probs = model.update(pos, h_input)
            
            # Check prediction
            if len(goal_probs) > 0:
                pred_idx = np.argmax(goal_probs)
                is_correct = int(pred_idx == true_idx)
                eval_results[cp].append(is_correct)
    
    # Print evaluation results
    print("\nModel performance at different path completion percentages:")
    for cp in completion_points:
        if eval_results[cp]:
            accuracy = np.mean(eval_results[cp])
            print(f"  {int(cp*100)}% completion: {accuracy:.3f} accuracy")
    
    # Plot accuracy by path completion
    plt.figure(figsize=(10, 6))
    x = []
    y = []
    for cp in completion_points:
        if eval_results[cp]:
            x.append(int(cp*100))
            y.append(np.mean(eval_results[cp]))
    
    if x and y:
        plt.bar(x, y)
        plt.xlabel('Path Completion (%)')
        plt.ylabel('Accuracy')
        plt.title('Goal Inference Accuracy vs. Path Completion')
        plt.ylim(0, 1.1)
        plt.grid(True, axis='y')
        os.makedirs("bayesian_results", exist_ok=True)
        plt.savefig("bayesian_results/accuracy_vs_completion.png")
        plt.close()
    
    # Visualize model if requested
    if visualize and pygame is not None:
        print("\nVisualizing model performance...")
        
        # Create visualization environment
        vis_env = SharedAutonomyEnv(visualize=True)
        
        for _ in range(2):  # Show 2 trajectories
            # Reset environment
            obs, _ = vis_env.reset()
            
            done = False
            truncated = False
            while not (done or truncated):
                # Get noisy human input
                to_goal = vis_env.goal_pos - vis_env.dot_pos
                dist = np.linalg.norm(to_goal)
                if dist > 0:
                    direction = to_goal / dist
                else:
                    direction = np.zeros(2)
                
                # Add noise
                noise_scale = 0.2
                noise = np.random.normal(0, noise_scale, size=2)
                h_input = direction + noise
                
                # Normalize
                norm = np.linalg.norm(h_input)
                if norm > 0:
                    h_input = h_input / norm
                
                # Step environment with full human control
                obs, _, done, truncated, _ = vis_env.step([-1.0])
                
                # Visualize
                vis_env.render()
                import time
                time.sleep(0.05)
        
        vis_env.close()
    
    return model

def train_ppo_model(total_timesteps=10000, visualize=False, use_joint_optimization=True):
    """Train the PPO shared autonomy model."""
    print("Initializing shared autonomy environment...")
    env = SharedAutonomyEnv(visualize=visualize, use_joint_optimization=use_joint_optimization)
    env = DummyVecEnv([lambda: env])
    
    # Set up callback
    metrics_callback = MetricsCallback()
    
    # Load pre-trained Bayesian model if available
    bayesian_model_path = "models/bayesian_model.pkl"
    if os.path.exists(bayesian_model_path):
        try:
            with open(bayesian_model_path, "rb") as f:
                bayesian_model = pickle.load(f)
            print(f"Loaded Bayesian model from {bayesian_model_path}")
            
            # Set the model in the environment
            env.envs[0].bayesian_model = bayesian_model
        except:
            print(f"Could not load Bayesian model from {bayesian_model_path}")
    
    print("Initializing PPO model...")
    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=64,
        n_epochs=5,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1,
        tensorboard_log="./tensorboard_logs/",
        policy_kwargs={
            "net_arch": [dict(pi=[128, 128], vf=[128, 128])],
            "activation_fn": nn.ReLU
        }
    )
    
    print(f"Starting training for {total_timesteps} timesteps...")
    model.learn(total_timesteps=total_timesteps, callback=metrics_callback)
    
    # Save the trained model
    os.makedirs("models", exist_ok=True)
    model_path = os.path.join("models", "shared_autonomy_model")
    model.save(model_path)
    print(f"Model saved to {model_path}")
    
    # Save metrics
    metrics_callback.save_metrics(save_dir="training_metrics")
    print("Metrics saved to 'training_metrics/'")
    
    env.close()
    return model

def generate_gamma_heatmap(env, model, output_path="gamma_heatmap.png", resolution=20):
    """Generate a heatmap showing gamma values across the environment space."""
    print("Generating gamma heatmap...")
    
    # Create grid
    x_grid = np.linspace(0, FULL_VIEW_SIZE[0], resolution)
    y_grid = np.linspace(0, FULL_VIEW_SIZE[1], resolution)
    gamma_values = np.zeros((resolution, resolution))
    
    # Cache some data
    goal_pos = env.goal_pos
    obstacles = env.obstacles
    original_pos = env.dot_pos.copy()
    
    # Compute gamma values at each grid point
    for i, x in enumerate(x_grid):
        for j, y in enumerate(y_grid):
            # Temporarily set agent position
            env.dot_pos = np.array([x, y], dtype=np.float32)
            
            # Get observation at this position
            obs = env._get_obs()
            
            # Predict action
            action, _ = model.predict(obs, deterministic=True)
            
            # Map action to gamma
            gamma = 0.5 * (action[0] + 1.0)
            
            # Store gamma value
            gamma_values[j, i] = gamma  # Note: j, i for correct orientation
    
    # Restore original position
    env.dot_pos = original_pos
    
    # Create heatmap
    plt.figure(figsize=(12, 8))
    plt.imshow(gamma_values, origin='lower', extent=[0, FULL_VIEW_SIZE[0], 0, FULL_VIEW_SIZE[1]], 
              cmap='viridis', vmin=0, vmax=1)
    plt.colorbar(label='Gamma')
    
    # Plot obstacles and goals
    for obs_pos in obstacles:
        circle = plt.Circle((obs_pos[0], obs_pos[1]), OBSTACLE_RADIUS, color='gray')
        plt.gcf().gca().add_artist(circle)
    
    for i, g_pos in enumerate(env.goals):
        circle = plt.Circle((g_pos[0], g_pos[1]), TARGET_RADIUS, color='yellow')
        plt.gcf().gca().add_artist(circle)
    
    # Mark true goal
    circle = plt.Circle((goal_pos[0], goal_pos[1]), TARGET_RADIUS+2, color='black', fill=False)
    plt.gcf().gca().add_artist(circle)
    
    plt.title('Gamma Values Across Environment')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.grid(False)
    
    # Save heatmap
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    plt.savefig(output_path)
    plt.close()
    
    print(f"Saved gamma heatmap to {output_path}")

def evaluate_model(model_path="models/shared_autonomy_model", n_episodes=10, visualize=True):
    """Evaluate a trained model."""
    print(f"Evaluating model {model_path}...")
    
    # Create evaluation environment
    env = SharedAutonomyEnv(visualize=visualize)
    
    try:
        # Load the model
        model = PPO.load(model_path)
        print(f"Loaded model from {model_path}")
    except:
        print(f"Could not load model from {model_path}")
        return None
    
    # Initialize metrics
    metrics = {
        "success_rate": 0.0,
        "collision_rate": 0.0,
        "timeout_rate": 0.0,
        "avg_reward": 0.0,
        "avg_steps": 0.0,
        "avg_gamma": 0.0,
        "goal_inference_accuracy": 0.0
    }
    
    # Run evaluation episodes
    for episode in range(n_episodes):
        print(f"Episode {episode+1}/{n_episodes}", end="\r")
        
        # Reset environment
        obs, _ = env.reset()
        
        # Episode tracking
        done = False
        truncated = False
        episode_reward = 0.0
        episode_steps = 0
        episode_gammas = []
        goal_predictions = []
        
        while not (done or truncated):
            # Predict action
            action, _ = model.predict(obs, deterministic=True)
            
            # Map action to gamma
            gamma = 0.5 * (action[0] + 1.0)
            episode_gammas.append(gamma)
            
            # Save for rendering
            env.last_gamma = gamma
            
            # Step environment
            obs, reward, done, truncated, info = env.step(action)
            
            # Update tracking
            episode_reward += reward
            episode_steps += 1
            
            # Track goal inference
            if hasattr(env.bayesian_model, "get_goal_probs"):
                probs = env.bayesian_model.get_goal_probs()
                if len(probs) > 0:
                    pred_goal = np.argmax(probs)
                    goal_predictions.append(pred_goal)
            
            # Visualize
            if visualize:
                env.render()
                import time
                time.sleep(0.01)
        
        # Update metrics
        terminal_reason = info.get("terminal_reason", None)
        
        if terminal_reason == "goal_reached":
            metrics["success_rate"] += 1.0
        elif terminal_reason == "collision":
            metrics["collision_rate"] += 1.0
        elif terminal_reason == "timeout":
            metrics["timeout_rate"] += 1.0
        
        metrics["avg_reward"] += episode_reward
        metrics["avg_steps"] += episode_steps
        
        if episode_gammas:
            metrics["avg_gamma"] += np.mean(episode_gammas)
        
        # Goal inference accuracy
        if goal_predictions and hasattr(env, "true_goal_idx"):
            accuracy = np.mean([1.0 if p == env.true_goal_idx else 0.0 for p in goal_predictions])
            metrics["goal_inference_accuracy"] += accuracy
    
    # Normalize metrics
    for key in metrics:
        metrics[key] /= n_episodes
    
    # Print evaluation results
    print("\nEvaluation Results:")
    print(f"  Success Rate: {metrics['success_rate']:.3f}")
    print(f"  Collision Rate: {metrics['collision_rate']:.3f}")
    print(f"  Timeout Rate: {metrics['timeout_rate']:.3f}")
    print(f"  Average Reward: {metrics['avg_reward']:.3f}")
    print(f"  Average Steps: {metrics['avg_steps']:.1f}")
    print(f"  Average Gamma: {metrics['avg_gamma']:.3f}")
    print(f"  Goal Inference Accuracy: {metrics['goal_inference_accuracy']:.3f}")
    
    # Save metrics to file
    os.makedirs("evaluation_results", exist_ok=True)
    with open("evaluation_results/metrics.json", "w") as f:
        import json
        json.dump(metrics, f, indent=2)
    print("Saved evaluation metrics to evaluation_results/metrics.json")
    
    # Generate gamma heatmap
    generate_gamma_heatmap(env, model, "evaluation_results/gamma_heatmap.png")
    
    env.close()
    return metrics

###############################################################################
# MAIN EXPERIMENT FUNCTION
###############################################################################
def run_experiment(visualize=False):
    """Run the complete shared autonomy experiment."""
    print("=" * 70)
    print("SHARED AUTONOMY EXPERIMENT")
    print("=" * 70)
    
    # Set up directories
    os.makedirs("models", exist_ok=True)
    os.makedirs("training_metrics", exist_ok=True)
    os.makedirs("evaluation_results", exist_ok=True)
    os.makedirs("bayesian_results", exist_ok=True)
    
    # Step 1: Train Bayesian model
    print("\n[1/3] Training Bayesian goal inference model...")
    bayesian_model = train_bayesian_model(visualize=visualize)
    
    # Step 2: Train PPO model
    print("\n[2/3] Training shared autonomy PPO model...")
    # Smaller timesteps for faster execution, increase for better performance
    ppo_model = train_ppo_model(
        total_timesteps=10000,
        visualize=visualize,
        use_joint_optimization=True
    )
    
    # Step 3: Evaluate model
    print("\n[3/3] Evaluating model...")
    metrics = evaluate_model(
        model_path="models/shared_autonomy_model",
        n_episodes=10,
        visualize=visualize
    )
    
    print("\nExperiment complete!")
    print("Results are saved in:")
    print("- bayesian_results/ (Bayesian model performance)")
    print("- training_metrics/ (training curves)")
    print("- models/ (trained models)")
    print("- evaluation_results/ (evaluation results and heatmaps)")
    
    return {
        "bayesian_model": bayesian_model,
        "ppo_model": ppo_model,
        "evaluation_metrics": metrics
    }

###############################################################################
# MAIN SCRIPT
###############################################################################
if __name__ == "__main__":
    # Run the experiment with visualization (set to False to run headless)
    run_experiment(visualize=True)

SHARED AUTONOMY EXPERIMENT

[1/3] Training Bayesian goal inference model...
Training Bayesian goal inference model...
Generating trajectory 50/50
Generated synthetic dataset with 50 trajectories
Saved synthetic dataset to datasets/synthetic_trajectories.pkl

Training confidence calibration...
Calibrator trained and saved
Saved trained Bayesian model to models/bayesian_model.pkl

Evaluating model on test set...

Model performance at different path completion percentages:
  25% completion: 0.800 accuracy
  50% completion: 0.500 accuracy
  75% completion: 0.600 accuracy
  100% completion: 0.600 accuracy

Visualizing model performance...

[2/3] Training shared autonomy PPO model...
Initializing shared autonomy environment...
Loaded Bayesian model from models/bayesian_model.pkl
Initializing PPO model...
Using cuda device




Starting training for 10000 timesteps...
Logging to ./tensorboard_logs/PPO_1
-----------------------------
| time/              |      |
|    fps             | 522  |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 1024 |
-----------------------------
--------------------------------------
| time/                   |          |
|    fps                  | 485      |
|    iterations           | 2        |
|    time_elapsed         | 4        |
|    total_timesteps      | 2048     |
| train/                  |          |
|    approx_kl            | 90.83863 |
|    clip_fraction        | 0.969    |
|    clip_range           | 0.2      |
|    entropy_loss         | -1.42    |
|    explained_variance   | 0.93     |
|    learning_rate        | 0.0003   |
|    loss                 | 1.13     |
|    n_updates            | 5        |
|    policy_gradient_loss | 0.301    |
|    std                  | 1        |
|    value_loss           | 7.5      |
-----------