In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Categorical
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
from matplotlib import animation
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import random
from collections import deque

# Set seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

# Visualization and training parameters
VISUALIZE = False  # Set to False to disable visualization
VISUALIZE_INTERVAL = 2000  # Show visualization every n steps
SAVE_VIDEOS = False  # Save videos of the trained agent
NUM_EPISODES = 2000  # Increased from 1000 to allow more learning
EVAL_INTERVAL = 100  # Evaluate every n episodes

# Environment parameters
GOAL_RADIUS = 0.05
NUM_GOALS = 3  # Number of potential goals

# PPO parameters - IMPROVED
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
ENTROPY_COEF = 0.02  # Increased for more exploration
VALUE_COEF = 0.5
LR = 3e-4
BATCH_SIZE = 256  # Increased from 64 for better stability
STEPS_PER_UPDATE = 512  # Fixed update schedule instead of variable
PPO_EPOCHS = 5  # Reduced from 10 to prevent overfitting
TARGET_KL = 0.02  # Increased slightly

# Bayesian inference parameters - IMPROVED
BETA = 5.0  # Increased for sharper belief updates
ANGULAR_WEIGHT = 0.6  # Adjusted for better balance
DISTANCE_WEIGHT = 0.4  # Adjusted for better balance
BELIEF_SMOOTH_ALPHA = 0.5  # Reduced further for faster belief updates

# Reward function weights - CONSISTENT USE
COLLISION_WEIGHT = 10.0
PROXIMITY_WEIGHT = 2.5
FAR_WEIGHT = 1.5
PROGRESS_WEIGHT = 3.0
AUTONOMY_WEIGHT = 1.5
GOAL_WEIGHT = 15.0  # Increased to prioritize goal achievement

# Pretraining parameters - IMPROVED
PRETRAIN = True
PRETRAIN_EPOCHS = 20  # Increased for better pretraining
PRETRAIN_BATCH_SIZE = 256
PRETRAIN_TRAJECTORIES = 200  # Increased for better generalization
PRETRAIN_TRAJECTORY_LENGTH = 50
PRETRAIN_LR = 1e-3

# Early stopping patience - INCREASED
EARLY_STOP_PATIENCE = 10  # Increased from 2 for better chance at converging

# Create results directory
os.makedirs("results", exist_ok=True)

# Check if CUDA is available
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print(f"Using device: {DEVICE}")

# Custom Reacher 2D Environment (No MuJoCo dependency)
class CustomReacher2D(gym.Env):
    """Custom 2D Reacher Environment that doesn't use MuJoCo"""
    
    def __init__(self, render_mode=None):
        super(CustomReacher2D, self).__init__()
        
        # Action space: 2D continuous actions for controlling the arm
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        
        # Observation space: joint angles, joint velocities, end-effector position
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32)
        
        # Arm parameters
        self.link_lengths = [0.1, 0.11]  # Length of arm segments
        self.max_velocity = 1.0  # Maximum joint velocity
        self.dt = 0.05  # Time step
        
        # State
        self.joint_angles = np.array([0.0, 0.0])  # Two joints, in radians
        self.joint_velocities = np.array([0.0, 0.0])
        
        # Rendering setup
        self.render_mode = render_mode
        self.screen = None
        self.clock = None
        self.window_size = 500  # pixels
        self.pygame = None
        
        # For visualization
        if render_mode == "rgb_array":
            # Import pygame only if rendering is needed
            try:
                import pygame
                self.pygame = pygame
                self.screen = pygame.Surface((self.window_size, self.window_size))
                self.clock = pygame.time.Clock()
            except ImportError:
                self.render_mode = None
                print("Warning: Pygame not available. Rendering disabled.")
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Reset joint state
        self.joint_angles = np.array([np.random.uniform(-np.pi/2, np.pi/2), 
                                       np.random.uniform(-np.pi/2, np.pi/2)])
        self.joint_velocities = np.array([0.0, 0.0])
        
        # Get current observation
        observation = self._get_obs()
        info = {}
        
        return observation, info
    
    def step(self, action):
        # Clip action to action space
        action = np.clip(action, -1, 1)
        
        # Convert actions to joint accelerations
        accelerations = action * 8.0  # Scale factor for accelerations
        
        # Update velocities using accelerations
        self.joint_velocities += accelerations * self.dt
        
        # Clip velocities
        self.joint_velocities = np.clip(self.joint_velocities, -self.max_velocity, self.max_velocity)
        
        # Update joint angles using velocities
        self.joint_angles += self.joint_velocities * self.dt
        
        # Wrap angles to [-pi, pi]
        self.joint_angles = np.mod(self.joint_angles + np.pi, 2 * np.pi) - np.pi
        
        # Get new observation
        observation = self._get_obs()
        
        # Default reward and done (to be overridden by wrapper)
        reward = 0.0
        terminated = False
        truncated = False
        info = {}
        
        return observation, reward, terminated, truncated, info
    
    def _get_obs(self):
        """Get current observation (joint angles, velocities, and end-effector position)"""
        ee_pos = self._get_end_effector_position()
        obs = np.concatenate([
            np.cos(self.joint_angles),
            np.sin(self.joint_angles),
            ee_pos
        ])
        return obs
    
    def _get_end_effector_position(self):
        """Compute the position of the end effector using forward kinematics"""
        theta1, theta2 = self.joint_angles
        l1, l2 = self.link_lengths
        
        # Position of first joint is at origin (0,0)
        # Position of second joint
        x1 = l1 * np.cos(theta1)
        y1 = l1 * np.sin(theta1)
        
        # Position of end effector
        x2 = x1 + l2 * np.cos(theta1 + theta2)
        y2 = y1 + l2 * np.sin(theta1 + theta2)
        
        return np.array([x2, y2])
    
    def render(self):
        if self.render_mode != "rgb_array" or self.pygame is None:
            return None
        
        # Clear the screen
        self.screen.fill((255, 255, 255))
        
        # Convert from world coordinates to screen coordinates
        def world_to_screen(point):
            scale = self.window_size / 2.5  # Scale to make the arm visible
            screen_x = int(point[0] * scale + self.window_size / 2)
            screen_y = int(-point[1] * scale + self.window_size / 2)  # Negative because screen y is inverted
            return (screen_x, screen_y)
        
        # Draw the arm
        # Base position (origin)
        base_pos = world_to_screen((0, 0))
        
        # First joint position
        theta1 = self.joint_angles[0]
        l1 = self.link_lengths[0]
        joint1_x = l1 * np.cos(theta1)
        joint1_y = l1 * np.sin(theta1)
        joint1_pos = world_to_screen((joint1_x, joint1_y))
        
        # End effector position
        ee_pos = self._get_end_effector_position()
        ee_screen_pos = world_to_screen(ee_pos)
        
        # Draw the links
        self.pygame.draw.line(self.screen, (0, 0, 0), base_pos, joint1_pos, 6)
        self.pygame.draw.line(self.screen, (0, 0, 0), joint1_pos, ee_screen_pos, 6)
        
        # Draw the joints
        self.pygame.draw.circle(self.screen, (255, 0, 0), base_pos, 10)
        self.pygame.draw.circle(self.screen, (0, 255, 0), joint1_pos, 8)
        self.pygame.draw.circle(self.screen, (0, 0, 255), ee_screen_pos, 8)
        
        return np.transpose(np.array(self.pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2))
    
    def close(self):
        if self.screen is not None:
            self.screen = None


# Environment wrapper for multiple goals
class MultiGoalReacherEnv(gym.Wrapper):
    def __init__(self, num_goals=3, goal_radius=0.05, render_mode=None):
        # Create the base environment
        self.env = CustomReacher2D(render_mode=render_mode)
        super().__init__(self.env)
        
        # Potential goal positions (normalized to [-1, 1])
        self.num_goals = num_goals
        self.goal_radius = goal_radius
        self.goals = []
        self.true_goal = None
        self.true_goal_idx = None
        self.obstacles = []
        
        # Override observation space to include goal information
        obs_dim = self.env.observation_space.shape[0]
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(obs_dim + num_goals * 2,)
        )
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        
        # Generate random goal positions - IMPROVED goal placement
        self.goals = []
        # Start with a wider distribution of goals
        angles = np.linspace(0, 2*np.pi, self.num_goals, endpoint=False)
        angles += np.random.uniform(0, 2*np.pi/self.num_goals)  # Add randomness
        
        for i in range(self.num_goals):
            # Place goals in a circular pattern with some randomness
            radius = np.random.uniform(0.12, 0.18)
            goal_x = radius * np.cos(angles[i])
            goal_y = radius * np.sin(angles[i])
            self.goals.append(np.array([goal_x, goal_y]))
        
        # Choose one goal as the true goal
        self.true_goal_idx = np.random.randint(0, self.num_goals)
        self.true_goal = self.goals[self.true_goal_idx]
        
        # Generate random obstacles (simplified as positions to avoid)
        self.obstacles = []
        for _ in range(2):  # 2 obstacles
            obs_x = np.random.uniform(-0.15, 0.15)
            obs_y = np.random.uniform(-0.15, 0.15)
            radius = np.random.uniform(0.02, 0.04)
            
            # Ensure obstacles don't overlap with goals
            valid = True
            for goal in self.goals:
                if np.linalg.norm(goal - np.array([obs_x, obs_y])) < radius + self.goal_radius:
                    valid = False
                    break
            
            if valid:
                self.obstacles.append((np.array([obs_x, obs_y]), radius))
        
        # Track previous distance for progress reward
        self.prev_distance = np.linalg.norm(self._get_end_effector_position() - self.true_goal)
        
        # Augment observation with goal positions
        augmented_obs = self._augment_observation(obs)
        return augmented_obs, info
    
    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        
        # Get current end effector position
        ee_pos = self._get_end_effector_position()
        
        # Check if the end effector reached the true goal
        distance_to_goal = np.linalg.norm(ee_pos - self.true_goal)
        goal_reached = distance_to_goal < self.goal_radius
        
        # Check for collision with obstacles
        collision = False
        for obs_pos, obs_radius in self.obstacles:
            if np.linalg.norm(ee_pos - obs_pos) < obs_radius:
                collision = True
                break
        
        # Custom reward function
        reward = self._compute_reward(ee_pos, action, goal_reached, collision)
        
        # Override termination conditions
        done = goal_reached or collision
        
        # Augment observation with goal positions
        augmented_obs = self._augment_observation(obs)
        
        # Update info dictionary
        info['true_goal'] = self.true_goal
        info['goal_reached'] = goal_reached
        info['collision'] = collision
        info['distance_to_goal'] = distance_to_goal
        
        return augmented_obs, reward, done, truncated, info
    
    def _get_end_effector_position(self):
        """Get the end effector position from the environment"""
        return self.env._get_end_effector_position()
    
    def _compute_reward(self, ee_pos, action, goal_reached, collision):
        """Compute reward based on components - IMPROVED"""
        # Initialize reward
        reward = 0
        
        # 1. Collision penalty
        if collision:
            reward -= COLLISION_WEIGHT
            return reward  # Early return on collision
        
        # 2. Distance to true goal
        distance_to_goal = np.linalg.norm(ee_pos - self.true_goal)
        
        # 3. Goal reached bonus - now using GOAL_WEIGHT
        if goal_reached:
            reward += GOAL_WEIGHT
        
        # 4. Progress reward (encourage moving toward the goal)
        progress = self.prev_distance - distance_to_goal
        reward += PROGRESS_WEIGHT * progress
        
        # Store current distance for next step progress calculation
        self.prev_distance = distance_to_goal
        
        # 5. Action penalty (to encourage smooth actions)
        action_penalty = 0.1 * np.square(action).sum()
        reward -= action_penalty
        
        # 6. Proximity reward (higher reward when closer to goal)
        proximity_reward = PROXIMITY_WEIGHT * np.exp(-5.0 * distance_to_goal)
        reward += proximity_reward
        
        # 7. Far penalty (discourage being too far from goal)
        if distance_to_goal > 0.2:
            far_penalty = FAR_WEIGHT * (distance_to_goal - 0.2)
            reward -= far_penalty
        
        # 8. Autonomy reward (consistent with defined constants)
        if distance_to_goal < 0.15:  # More autonomy as we get closer to goal
            reward += AUTONOMY_WEIGHT * (0.15 - distance_to_goal) / 0.15
        
        # Normalize reward to avoid extreme values
        reward = np.clip(reward, -20.0, 20.0)
        
        return reward
    
    def _augment_observation(self, obs):
        """Concatenate goal positions to the observation"""
        goal_info = np.concatenate([goal for goal in self.goals])
        return np.concatenate([obs, goal_info])
    
    def render(self):
        frame = self.env.render()
        if frame is None:
            return None
        
        if self.env.pygame is None:
            return frame
        
        # Draw goals and obstacles on the frame
        def world_to_screen(point):
            scale = self.env.window_size / 2.5
            screen_x = int(point[0] * scale + self.env.window_size / 2)
            screen_y = int(-point[1] * scale + self.env.window_size / 2)
            return (screen_x, screen_y)
        
        # Draw goals
        for i, goal in enumerate(self.goals):
            goal_pos = world_to_screen(goal)
            color = (255, 215, 0) if i == self.true_goal_idx else (200, 200, 200)
            self.env.pygame.draw.circle(self.env.screen, color, goal_pos, int(self.goal_radius * self.env.window_size / 2.5))
        
        # Draw obstacles
        for obs_pos, obs_radius in self.obstacles:
            obs_screen_pos = world_to_screen(obs_pos)
            self.env.pygame.draw.circle(
                self.env.screen,
                (100, 100, 100),
                obs_screen_pos,
                int(obs_radius * self.env.window_size / 2.5)
            )
        
        return np.transpose(np.array(self.env.pygame.surfarray.pixels3d(self.env.screen)), axes=(1, 0, 2))


# IMPROVED: Completely redesigned Bayesian Inference Module with better numerical stability
class BayesianInferenceModule(nn.Module):
    def __init__(self, num_goals=3, beta=5.0, angular_weight=0.6, distance_weight=0.4, smooth_alpha=0.5):
        super(BayesianInferenceModule, self).__init__()
        self.num_goals = num_goals
        
        # Define trainable parameters with more stable initialization
        self.beta = nn.Parameter(torch.tensor(float(beta), dtype=torch.float32))
        self.angular_weight = nn.Parameter(torch.tensor(float(angular_weight), dtype=torch.float32))
        self.distance_weight = nn.Parameter(torch.tensor(float(distance_weight), dtype=torch.float32))
        self.smooth_alpha = nn.Parameter(torch.tensor(float(smooth_alpha), dtype=torch.float32))
        
        # Prior probabilities (initialized as uniform)
        self.register_buffer('prior', torch.ones(num_goals, dtype=torch.float32) / num_goals)
        
        # Store belief history
        self.beliefs = None
        
        # Add confidence network for better belief updates
        self.confidence_net = nn.Sequential(
            nn.Linear(4, 16),  # Increased network capacity
            nn.ReLU(),
            nn.LayerNorm(16),  # Added layer normalization for stability
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.LayerNorm(8),
            nn.Linear(8, 1),
            nn.Sigmoid()  # Output between 0 and 1
        )
        
        # Initialize weights properly
        for m in self.confidence_net.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)  # Better initialization
                nn.init.zeros_(m.bias)
    
    def forward(self, state, action, goals, prev_belief=None):
        """
        Update belief over goals based on observed state and action.
        
        Args:
            state: Current state (end effector position)
            action: Human action
            goals: List of potential goal positions
            prev_belief: Previous belief distribution
            
        Returns:
            Updated belief over goals
        """
        # Ensure all input tensors are on the same device
        device = state.device
        
        if prev_belief is None:
            belief = self.prior.clone().to(device)
        else:
            belief = prev_belief.clone().to(device)
        
        # Constrain parameters to valid ranges to avoid numerical issues
        constrained_beta = torch.sigmoid(self.beta) * 10.0 + 1.0  # Range [1.0, 11.0]
        constrained_angular_weight = torch.sigmoid(self.angular_weight)
        constrained_distance_weight = 1.0 - constrained_angular_weight  # Ensure weights sum to 1
        constrained_smooth_alpha = torch.sigmoid(self.smooth_alpha) * 0.8  # Range [0, 0.8]
        
        # Calculate action magnitude for confidence computation
        action_magnitude = torch.norm(action)
        
        # Calculate likelihoods for each goal
        likelihoods = torch.zeros(self.num_goals, device=device)
        
        for i in range(self.num_goals):
            goal = goals[i].to(device)
            
            # Calculate optimal action toward this goal
            goal_direction = goal - state
            goal_direction_norm = torch.norm(goal_direction) + 1e-8  # Avoid division by zero
            
            # Calculate confidence based on distance and action magnitude
            distance_to_goal = goal_direction_norm
            confidence_input = torch.tensor([
                distance_to_goal.item(), 
                action_magnitude.item(),
                distance_to_goal.item() * action_magnitude.item(),
                1.0 / (distance_to_goal.item() + 0.1)  # Inverse distance feature
            ], device=device).float()
            
            # Use confidence network
            confidence = self.confidence_net(confidence_input).squeeze()
            
            # Normalize goal direction
            optimal_action = goal_direction / goal_direction_norm
            
            # Calculate cost components (deviation from optimal action)
            angular_cost = self._angular_deviation(action, optimal_action)
            distance_cost = self._distance_deviation(action, optimal_action, goal_direction_norm)
            
            # Use constrained weights
            total_cost = (constrained_angular_weight * angular_cost + 
                         constrained_distance_weight * distance_cost)
            
            # Scale by confidence - more confidence means cost matters more
            scaled_cost = total_cost * confidence
            
            # Calculate likelihood using noisy-rational model with constrained beta
            # Clamp the cost to avoid numerical issues
            clamped_cost = torch.clamp(-constrained_beta * scaled_cost, min=-20.0, max=20.0)
            likelihoods[i] = torch.exp(clamped_cost)
        
        # Apply softmax for better numerical stability
        likelihoods = F.softmax(likelihoods, dim=0)
        
        # Bayesian update
        posterior = belief * likelihoods
        posterior_sum = torch.sum(posterior)
        
        if posterior_sum > 1e-10:
            posterior = posterior / posterior_sum
        else:
            # If update fails, use a weighted combination instead
            posterior = 0.5 * belief + 0.5 * likelihoods
        
        # Apply temporal smoothing with constrained alpha
        if prev_belief is not None:
            # Calculate adaptive smoothing factor based on action magnitude
            # Less smoothing (lower alpha) when action is strong
            adaptive_alpha = constrained_smooth_alpha * (1.0 - torch.tanh(action_magnitude * 0.5))
            posterior = (1.0 - adaptive_alpha) * posterior + adaptive_alpha * prev_belief
        
        # Check for NaN values and replace with prior if needed
        if torch.isnan(posterior).any() or torch.isinf(posterior).any():
            posterior = self.prior.clone().to(device)
        
        return posterior
    
    def _angular_deviation(self, actual, optimal):
        """Calculate angular deviation between actual and optimal action."""
        actual_norm = torch.norm(actual) + 1e-8  # Avoid division by zero
        optimal_norm = torch.norm(optimal) + 1e-8
        
        # Normalize vectors
        actual_norm_vec = actual / actual_norm
        optimal_norm_vec = optimal / optimal_norm
        
        # Calculate cosine similarity
        cos_sim = torch.dot(actual_norm_vec, optimal_norm_vec)
        # Clamp to avoid numerical errors
        cos_sim = torch.clamp(cos_sim, -1.0 + 1e-6, 1.0 - 1e-6)
        
        # Convert to angle
        angle = torch.acos(cos_sim)
        
        # Normalize to [0, 1]
        return angle / np.pi
    
    def _distance_deviation(self, actual, optimal, distance_to_goal):
        """Calculate deviation in action magnitude."""
        # Optimal magnitude decreases as we get closer to the goal
        # IMPROVED: Better scaling of optimal magnitude
        optimal_magnitude = torch.clamp(distance_to_goal * 2.0, 0.1, 1.0)
        
        # Calculate actual magnitude
        actual_magnitude = torch.norm(actual)
        
        # Normalized deviation - IMPROVED
        # Scale by factor relative to distance
        scale_factor = 1.0 / (optimal_magnitude + 0.1)
        deviation = torch.abs(actual_magnitude - optimal_magnitude) * scale_factor
        
        return torch.clamp(deviation, 0.0, 1.0)
    
    def reset(self):
        """Reset beliefs to prior."""
        self.beliefs = [self.prior.clone()]
    
    def update_belief(self, state, action, goals):
        """Update belief and store history."""
        if self.beliefs is None:
            self.reset()
        
        # Convert inputs to torch tensors if they're not already
        if not isinstance(state, torch.Tensor):
            state = torch.tensor(state, dtype=torch.float32)
        if not isinstance(action, torch.Tensor):
            action = torch.tensor(action, dtype=torch.float32)
        if not isinstance(goals[0], torch.Tensor):
            goals = [torch.tensor(g, dtype=torch.float32) for g in goals]
        
        # Ensure all tensors are on the same device
        device = self.beta.device  # Use a parameter's device
        state = state.to(device)
        action = action.to(device)
        goals = [g.to(device) for g in goals]
        
        # Update belief
        new_belief = self.forward(state, action, goals, self.beliefs[-1].to(device))
        self.beliefs.append(new_belief)
        
        return new_belief


# IMPROVED: Data generator for pretraining the Bayesian inference module
class TrajectoryGenerator:
    def __init__(self, num_goals=3, noise_levels=[0.1, 0.2, 0.3, 0.4, 0.5]):
        self.num_goals = num_goals
        self.noise_levels = noise_levels
    
    def generate_trajectories(self, num_trajectories, trajectory_length):
        """Generate trajectories for pretraining the Bayesian inference module."""
        trajectories = []
        
        # Generate trajectories with varying noise levels for better generalization
        noise_batch_size = num_trajectories // len(self.noise_levels)
        
        for noise_idx, noise_level in enumerate(self.noise_levels):
            for _ in range(noise_batch_size):
                # Generate random goals with increased spread
                goals = []
                angles = np.linspace(0, 2*np.pi, self.num_goals, endpoint=False)
                angles += np.random.uniform(0, 2*np.pi/self.num_goals)
                
                for i in range(self.num_goals):
                    # Place goals in a circular pattern with randomness
                    radius = np.random.uniform(0.12, 0.18)
                    goal_x = radius * np.cos(angles[i])
                    goal_y = radius * np.sin(angles[i])
                    goals.append(np.array([goal_x, goal_y]))
                
                # Choose true goal
                true_goal_idx = np.random.randint(0, self.num_goals)
                true_goal = goals[true_goal_idx]
                
                # Generate random starting position
                pos = np.random.uniform(-0.15, 0.15, size=2)
                
                # Generate trajectory
                states = []
                actions = []
                
                # Add waypoints for more realistic paths
                num_waypoints = np.random.randint(1, 3)
                waypoints = [true_goal]
                if num_waypoints > 1:
                    for _ in range(num_waypoints - 1):
                        # Random waypoint near the path to the goal
                        wp = pos + np.random.uniform(0.3, 0.7) * (true_goal - pos)
                        wp += np.random.normal(0, 0.05, size=2)  # Add small randomness
                        waypoints.insert(0, wp)  # Insert at beginning
                
                current_waypoint_idx = 0
                
                for _ in range(trajectory_length):
                    # Get current target waypoint
                    if current_waypoint_idx < len(waypoints):
                        target = waypoints[current_waypoint_idx]
                    else:
                        target = true_goal
                    
                    # Generate optimal action toward current waypoint
                    goal_dir = target - pos
                    goal_dist = np.linalg.norm(goal_dir)
                    
                    if goal_dist > 0:
                        optimal_action = goal_dir / goal_dist
                        
                        # Scale action magnitude based on distance
                        magnitude = min(1.0, goal_dist * 2)
                        optimal_action *= magnitude
                    else:
                        optimal_action = np.zeros_like(goal_dir)
                        current_waypoint_idx += 1
                    
                    # Add noise to simulate human control
                    # Add directional bias to noise
                    bias_dir = np.random.normal(0, 1, size=2)
                    bias_dir = bias_dir / np.linalg.norm(bias_dir) if np.linalg.norm(bias_dir) > 0 else bias_dir
                    
                    # Blend random noise with directional bias
                    noise = np.random.normal(0, noise_level, size=2) + bias_dir * noise_level * 0.5
                    
                    # Action has a small probability of being completely random
                    if np.random.random() < 0.05:  # 5% chance of random action
                        action = np.random.uniform(-1, 1, size=2)
                    else:
                        action = optimal_action + noise
                    
                    # Clip to action space
                    action = np.clip(action, -1, 1)
                    
                    # Store state and action
                    states.append(pos)
                    actions.append(action)
                    
                    # Update position with some damping
                    new_pos = pos + 0.05 * action
                    pos = 0.9 * new_pos + 0.1 * pos  # Add inertia for more realistic movement
                    
                    # Clip position to domain bounds
                    pos = np.clip(pos, -0.3, 0.3)
                    
                    # Check if we've reached the current waypoint
                    if current_waypoint_idx < len(waypoints):
                        if np.linalg.norm(pos - waypoints[current_waypoint_idx]) < 0.05:
                            current_waypoint_idx += 1
                
                trajectories.append({
                    'states': states,
                    'actions': actions,
                    'goals': goals,
                    'true_goal_idx': true_goal_idx
                })
        
        # Fill in any remaining trajectories due to integer division
        remaining = num_trajectories - (noise_batch_size * len(self.noise_levels))
        if remaining > 0:
            # Use medium noise level for remaining trajectories
            noise_level = self.noise_levels[len(self.noise_levels) // 2]
            
            # Generate additional trajectories with medium noise
            for _ in range(remaining):
                # Same trajectory generation logic as above...
                # Generate random goals with increased spread
                goals = []
                angles = np.linspace(0, 2*np.pi, self.num_goals, endpoint=False)
                angles += np.random.uniform(0, 2*np.pi/self.num_goals)
                
                for i in range(self.num_goals):
                    radius = np.random.uniform(0.12, 0.18)
                    goal_x = radius * np.cos(angles[i])
                    goal_y = radius * np.sin(angles[i])
                    goals.append(np.array([goal_x, goal_y]))
                
                true_goal_idx = np.random.randint(0, self.num_goals)
                true_goal = goals[true_goal_idx]
                pos = np.random.uniform(-0.15, 0.15, size=2)
                
                states = []
                actions = []
                
                # Add waypoints
                num_waypoints = np.random.randint(1, 3)
                waypoints = [true_goal]
                if num_waypoints > 1:
                    for _ in range(num_waypoints - 1):
                        wp = pos + np.random.uniform(0.3, 0.7) * (true_goal - pos)
                        wp += np.random.normal(0, 0.05, size=2)
                        waypoints.insert(0, wp)
                
                current_waypoint_idx = 0
                
                for _ in range(trajectory_length):
                    if current_waypoint_idx < len(waypoints):
                        target = waypoints[current_waypoint_idx]
                    else:
                        target = true_goal
                    
                    goal_dir = target - pos
                    goal_dist = np.linalg.norm(goal_dir)
                    
                    if goal_dist > 0:
                        optimal_action = goal_dir / goal_dist
                        magnitude = min(1.0, goal_dist * 2)
                        optimal_action *= magnitude
                    else:
                        optimal_action = np.zeros_like(goal_dir)
                        current_waypoint_idx += 1
                    
                    bias_dir = np.random.normal(0, 1, size=2)
                    bias_dir = bias_dir / np.linalg.norm(bias_dir) if np.linalg.norm(bias_dir) > 0 else bias_dir
                    
                    noise = np.random.normal(0, noise_level, size=2) + bias_dir * noise_level * 0.5
                    
                    if np.random.random() < 0.05:
                        action = np.random.uniform(-1, 1, size=2)
                    else:
                        action = optimal_action + noise
                    
                    action = np.clip(action, -1, 1)
                    
                    states.append(pos)
                    actions.append(action)
                    
                    new_pos = pos + 0.05 * action
                    pos = 0.9 * new_pos + 0.1 * pos
                    pos = np.clip(pos, -0.3, 0.3)
                    
                    if current_waypoint_idx < len(waypoints):
                        if np.linalg.norm(pos - waypoints[current_waypoint_idx]) < 0.05:
                            current_waypoint_idx += 1
                
                trajectories.append({
                    'states': states,
                    'actions': actions,
                    'goals': goals,
                    'true_goal_idx': true_goal_idx
                })
            
        return trajectories


# Improved pretraining function with curriculum learning
def pretrain_belief_module(module, num_trajectories=1000, trajectory_length=50, 
                           batch_size=128, num_epochs=100, lr=1e-3):
    """Pretrain the Bayesian inference module."""
    # Generate pretraining data
    print("Generating pretraining data...")
    generator = TrajectoryGenerator(num_goals=module.num_goals, 
                                   noise_levels=[0.1, 0.2, 0.3, 0.4])
    trajectories = generator.generate_trajectories(num_trajectories, trajectory_length)
    
    # Move module to device
    device = DEVICE  # Use the global device
    module = module.to(device)
    
    # Prepare optimizer with weight decay to prevent parameter explosion
    optimizer = optim.Adam(module.parameters(), lr=lr, weight_decay=1e-4)
    
    # Learning rate scheduler to reduce learning rate over time
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
    
    # Training loop with curriculum learning
    print(f"Pretraining Bayesian inference module for {num_epochs} epochs...")
    
    # Split trajectories into difficulty levels based on noise
    easy_trajectories = trajectories[:num_trajectories//3]
    medium_trajectories = trajectories[num_trajectories//3:2*num_trajectories//3]
    hard_trajectories = trajectories[2*num_trajectories//3:]
    
    # Curriculum schedule
    curriculum_trajectories = [
        easy_trajectories,  # First third of epochs: train on easy trajectories
        easy_trajectories + medium_trajectories,  # Second third: add medium 
        trajectories  # Final third: all trajectories
    ]
    
    # Define epoch ranges for curriculum stages
    epoch_stage1 = num_epochs // 3
    epoch_stage2 = 2 * num_epochs // 3
    
    best_loss = float('inf')
    best_state_dict = None
    
    for epoch in range(num_epochs):
        # Select trajectories based on curriculum stage
        if epoch < epoch_stage1:
            current_trajectories = curriculum_trajectories[0]
            stage_name = "easy"
        elif epoch < epoch_stage2:
            current_trajectories = curriculum_trajectories[1]
            stage_name = "medium"
        else:
            current_trajectories = curriculum_trajectories[2]
            stage_name = "hard"
        
        # Shuffle trajectories
        random.shuffle(current_trajectories)
        
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        valid_losses = 0  # To avoid division by zero if all losses are invalid
        
        # Process each trajectory
        for traj in tqdm(current_trajectories, desc=f"Epoch {epoch+1}/{num_epochs} ({stage_name})", leave=False):
            # Reset belief to prior
            belief = module.prior.clone().to(device)
            
            # Convert goals to tensors
            goals = [torch.tensor(g, dtype=torch.float32, device=device) for g in traj['goals']]
            
            # Process each step in the trajectory
            for t in range(len(traj['states'])):
                state = torch.tensor(traj['states'][t], dtype=torch.float32, device=device)
                action = torch.tensor(traj['actions'][t], dtype=torch.float32, device=device)
                
                # Update belief using the module
                belief = module.forward(state, action, goals, belief)
                
                # Skip initial steps for loss calculation (let beliefs stabilize)
                if t < 5:
                    continue
                
                # Occasionally introduce random goals to improve robustness
                true_goal_idx = traj['true_goal_idx']
                if random.random() < 0.1:  # 10% of the time
                    # Randomly swap the true goal
                    true_goal_idx = random.randint(0, module.num_goals-1)
                
                # Ensure belief is valid
                if torch.isnan(belief).any() or torch.isinf(belief).any():
                    belief = module.prior.clone().to(device)
                    continue
                
                # Smooth target distribution (label smoothing)
                target = torch.ones_like(belief) * 0.01 / (module.num_goals - 1)
                target[true_goal_idx] = 0.99
                
                # Cross entropy loss
                epsilon = 1e-10  # Small epsilon to avoid log(0)
                loss = -torch.sum(target * torch.log(belief + epsilon))
                
                # Skip backpropagation if loss is NaN or Inf
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                # Backpropagate
                optimizer.zero_grad()
                loss.backward()
                
                # Gradient clipping for stability
                torch.nn.utils.clip_grad_norm_(module.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                total_loss += loss.item()
                valid_losses += 1
                
                # Track accuracy
                pred_goal_idx = torch.argmax(belief).item()
                correct_predictions += (pred_goal_idx == true_goal_idx)
                total_predictions += 1
                
                # Update belief for next step (detach from computation graph)
                belief = belief.detach()
        
        # Update learning rate based on validation loss
        avg_loss = total_loss / valid_losses if valid_losses > 0 else float('inf')
        scheduler.step(avg_loss)
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_state_dict = module.state_dict().copy()
        
        # Print progress
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    
    # Load best model
    if best_state_dict:
        module.load_state_dict(best_state_dict)
    
    print("Pretraining complete!")
    return module


# Actor-Critic Networks for PPO - IMPROVED
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, belief_dim=None):
        super(ActorCritic, self).__init__()
        
        # Feature extraction layers
        self.feature_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),  # Added normalization for stability
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),  # Added normalization for stability
            nn.ReLU()
        )
        
        # Separate belief processing if provided
        self.belief_network = None
        if belief_dim is not None:
            self.belief_network = nn.Sequential(
                nn.Linear(belief_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),  # Added normalization
                nn.ReLU()
            )
            
            # Fusion layer for state and belief features
            self.fusion_network = nn.Sequential(
                nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
                nn.LayerNorm(hidden_dim),  # Added normalization
                nn.ReLU()
            )
        
        # Actor network (policy)
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
        
        # Critic network (value function)
        self.critic = nn.Linear(hidden_dim, 1)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, state, belief=None):
        # Extract features from state
        features = self.feature_network(state)
        
        # Process belief and fuse with state features if available
        if belief is not None and self.belief_network is not None:
            belief_features = self.belief_network(belief)
            features = torch.cat([features, belief_features], dim=-1)
            features = self.fusion_network(features)
        
        # Actor: Get action distribution
        action_mean = self.actor_mean(features)
        
        # Variable action std based on training progress
        action_std = torch.exp(torch.clamp(self.actor_log_std, min=-20, max=2))
        
        # Critic: Get state value
        value = self.critic(features)
        
        return action_mean, action_std, value
    
    def get_action(self, state, belief=None, deterministic=False):
        # Move data to the model's device
        device = next(self.parameters()).device
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        if belief is not None:
            belief = torch.FloatTensor(belief).unsqueeze(0).to(device)
        
        with torch.no_grad():
            action_mean, action_std, _ = self.forward(state, belief)
            
            if deterministic:
                action = action_mean
            else:
                normal = Normal(action_mean, action_std)
                action = normal.sample()
                
            return action.squeeze(0).cpu().numpy()
    
    def evaluate_actions(self, states, actions, beliefs=None):
        action_mean, action_std, values = self.forward(states, beliefs)
        
        # Create normal distribution
        dist = Normal(action_mean, action_std)
        
        # Get log probabilities
        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        
        # Get entropy
        entropy = dist.entropy().sum(dim=-1, keepdim=True)
        
        return log_probs, entropy, values


# IMPROVED PPO Agent
class PPOAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=256, num_goals=3,
                 lr=3e-4, gamma=0.99, gae_lambda=0.95, clip_epsilon=0.2,
                 value_coef=0.5, entropy_coef=0.01, target_kl=0.01):
        
        # Initialize actor-critic
        self.actor_critic = ActorCritic(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            belief_dim=num_goals
        )
        
        # Initialize Bayesian inference module
        self.bayesian_module = BayesianInferenceModule(
            num_goals=num_goals,
            beta=BETA,
            angular_weight=ANGULAR_WEIGHT,
            distance_weight=DISTANCE_WEIGHT,
            smooth_alpha=BELIEF_SMOOTH_ALPHA
        )
        
        # Move models to the appropriate device
        self.device = DEVICE
        self.actor_critic.to(self.device)
        self.bayesian_module.to(self.device)
        
        # Initialize optimizer with L2 regularization
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr, weight_decay=1e-5)
        
        # Add learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=1000, eta_min=lr/10)
        
        # Store hyperparameters
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.target_kl = target_kl
        
        # Memory buffer for experience
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.beliefs = []
        self.dones = []
        
        # Tracking variables
        self.steps_since_update = 0
        self.epochs_trained = 0
        self.update_count = 0
    
    def select_action(self, state, goals, deterministic=False):
        """Select action based on current state and belief."""
        # Get end effector position
        ee_pos = state[-4:-2]  # Position is at the end of the observation before goal info
        
        # Extract goals from state if necessary
        if goals is None:
            # Assume the state includes goal information as per the environment wrapper
            goals_data = state[state.shape[0] - 2*NUM_GOALS:]
            goals = [goals_data[i:i+2] for i in range(0, len(goals_data), 2)]
        
        # Initialize belief if needed
        if not hasattr(self, 'current_belief') or self.current_belief is None:
            self.current_belief = self.bayesian_module.prior.clone().to(self.device)
        
        # Handle case where we don't have previous action yet (e.g., at start of episode)
        if not hasattr(self, 'prev_action') or self.prev_action is None:
            # Use a default no-movement action for the first update
            self.prev_action = torch.zeros(2, device=self.device)
        
        # Update belief based on observed state and previous action
        with torch.no_grad():  # No need to track gradients for inference
            updated_belief = self.bayesian_module.forward(
                state=torch.tensor(ee_pos, dtype=torch.float32, device=self.device),
                action=self.prev_action,
                goals=[torch.tensor(g, dtype=torch.float32, device=self.device) for g in goals],
                prev_belief=self.current_belief
            )
            
            # Check for NaN values in updated belief
            if torch.isnan(updated_belief).any() or torch.isinf(updated_belief).any():
                updated_belief = self.bayesian_module.prior.clone().to(self.device)
            
            self.current_belief = updated_belief
        
        # Select action using actor network
        action = self.actor_critic.get_action(
            state=state,
            belief=self.current_belief.cpu().numpy(),
            deterministic=deterministic
        )
        
        # Store action for next belief update
        self.prev_action = torch.tensor(action, device=self.device)
        
        # Check for NaN actions
        if np.isnan(action).any() or np.isinf(action).any():
            action = np.zeros_like(action)
            self.prev_action = torch.zeros(2, device=self.device)
        
        return action, self.current_belief.cpu().numpy()
    
    def remember(self, state, action, reward, value, log_prob, belief, done):
        """Store experience in memory."""
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.beliefs.append(belief)
        self.dones.append(done)
        
        # Track steps since last update
        self.steps_since_update += 1
    
    def compute_gae(self, next_value):
        """Compute Generalized Advantage Estimation."""
        values = self.values + [next_value]
        gae = 0
        returns = []
        
        # GAE calculation
        for step in reversed(range(len(self.rewards))):
            delta = (self.rewards[step] + 
                     self.gamma * values[step + 1] * (1 - self.dones[step]) - 
                     values[step])
            gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[step]) * gae
            returns.insert(0, gae + values[step])
        
        return returns
    
    def update(self):
        """Update policy using PPO algorithm."""
        # Skip update if not enough data
        if len(self.states) < BATCH_SIZE:
            return 0.0
            
        # Get next value for GAE
        with torch.no_grad():
            if len(self.states) > 0:
                state = torch.FloatTensor(self.states[-1]).unsqueeze(0).to(self.device)
                belief = torch.FloatTensor(self.beliefs[-1]).unsqueeze(0).to(self.device)
                _, _, next_value = self.actor_critic(state, belief)
                next_value = next_value.squeeze().cpu().item()
            else:
                next_value = 0
        
        # Compute returns using GAE
        returns = self.compute_gae(next_value)
        
        # Convert lists to tensors
        states = torch.FloatTensor(np.array(self.states)).to(self.device)
        actions = torch.FloatTensor(np.array(self.actions)).to(self.device)
        returns = torch.FloatTensor(returns).unsqueeze(1).to(self.device)
        old_log_probs = torch.FloatTensor(np.array(self.log_probs)).unsqueeze(1).to(self.device)
        beliefs = torch.FloatTensor(np.array(self.beliefs)).to(self.device)
        values = torch.FloatTensor(np.array(self.values)).unsqueeze(1).to(self.device)
        
        # Calculate advantages
        advantages = returns - values
        
        # Normalize advantages
        if len(advantages) > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        total_loss = 0
        kl_divs = []
        
        # Mini-batch training
        batch_size = min(BATCH_SIZE, len(self.states))
        indices = np.arange(len(self.states))
        
        # Calculate adaptive clip epsilon based on training progress
        # Start with higher clip (more exploration) and reduce over time
        adaptive_clip = max(0.1, self.clip_epsilon * (1.0 - self.epochs_trained / 2000.0))
        
        # IMPROVED: Better entropy scheduling
        # Start with higher entropy coefficient and decay over time
        adaptive_entropy_coef = max(0.001, self.entropy_coef * (1.0 - self.epochs_trained / 1500.0))
        
        # Randomize and shuffle batches
        np.random.shuffle(indices)
        batches = [indices[i:i + batch_size] for i in range(0, len(indices), batch_size)]
        
        for _ in range(PPO_EPOCHS):
            # Shuffle batches
            np.random.shuffle(batches)
            
            for batch_indices in batches:
                # Evaluate actions again
                new_log_probs, entropy, values = self.actor_critic.evaluate_actions(
                    states[batch_indices],
                    actions[batch_indices],
                    beliefs[batch_indices]
                )
                
                # Compute ratio
                ratio = torch.exp(new_log_probs - old_log_probs[batch_indices])
                
                # Clamp ratio to avoid numerical instability
                ratio = torch.clamp(ratio, 0.0, 10.0)
                
                # Compute surrogate losses
                surrogate1 = ratio * advantages[batch_indices]
                surrogate2 = torch.clamp(ratio, 1.0 - adaptive_clip, 1.0 + adaptive_clip) * advantages[batch_indices]
                
                # Value loss
                value_loss = F.mse_loss(values, returns[batch_indices])
                
                # Policy loss
                policy_loss = -torch.min(surrogate1, surrogate2).mean()
                
                # Entropy bonus
                entropy_loss = -entropy.mean()
                
                # Total loss with value weighting that starts lower and increases over time
                adaptive_value_coef = min(1.0, self.value_coef * (1.0 + self.epochs_trained / 1000.0))
                loss = policy_loss + adaptive_value_coef * value_loss + adaptive_entropy_coef * entropy_loss
                
                # Update parameters
                self.optimizer.zero_grad()
                loss.backward()
                # Gradient clipping for stability
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_norm=0.5)
                self.optimizer.step()
                
                total_loss += loss.item()
                
                # Check KL divergence for early stopping
                with torch.no_grad():
                    approx_kl = ((old_log_probs[batch_indices] - new_log_probs).exp() - 1 - 
                                (old_log_probs[batch_indices] - new_log_probs)).mean().item()
                    kl_divs.append(approx_kl)
                
                if approx_kl > self.target_kl:
                    break
            
            # Early stopping based on average KL divergence
            if np.mean(kl_divs) > self.target_kl:
                break
        
        # Update learning rate periodically
        self.scheduler.step()
        
        # Increment epoch count
        self.epochs_trained += 1
        self.update_count += 1
        
        # Reset memory buffer
        self.clear_memory()
        
        # Return average loss
        return total_loss / (len(batches) * PPO_EPOCHS)
    
    def clear_memory(self):
        """Clear memory buffer after updates."""
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.beliefs = []
        self.dones = []
        self.steps_since_update = 0
    
    def reset(self):
        """Reset agent state between episodes."""
        self.current_belief = None
        self.prev_action = None
        self.bayesian_module.reset()
    
    def save(self, path):
        """Save model parameters."""
        torch.save({
            'actor_critic_state_dict': self.actor_critic.state_dict(),
            'bayesian_module_state_dict': self.bayesian_module.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'epochs_trained': self.epochs_trained
        }, path)
    
    def load(self, path):
        """Load model parameters."""
        checkpoint = torch.load(path, map_location=self.device)
        self.actor_critic.load_state_dict(checkpoint['actor_critic_state_dict'])
        self.bayesian_module.load_state_dict(checkpoint['bayesian_module_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epochs_trained = checkpoint.get('epochs_trained', 0)


# IMPROVED: Performance tracking with moving averages
class PerformanceTracker:
    def __init__(self):
        self.episode_rewards = []
        self.episode_lengths = []
        self.success_rates = []
        self.collision_rates = []
        self.belief_accuracy = []
        self.assistance_levels = []
        self.losses = []
        self.avg_returns = []
        
        # For visualization
        self.frames = []
    
    def add_episode_metrics(self, episode_reward, episode_length, success, collision, belief_accuracy, avg_assistance):
        self.episode_rewards.append(episode_reward)
        self.episode_lengths.append(episode_length)
        self.success_rates.append(1 if success else 0)
        self.collision_rates.append(1 if collision else 0)
        self.belief_accuracy.append(belief_accuracy)
        self.assistance_levels.append(avg_assistance)
    
    def add_training_metrics(self, loss, avg_return):
        self.losses.append(loss)
        self.avg_returns.append(avg_return)
    
    def add_frame(self, frame):
        self.frames.append(frame)
    
    def plot_learning_curves(self, window=15):
        """Plot learning curves with moving averages."""
        sns.set(style="darkgrid")
        
        # Create a figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Episode rewards
        self._plot_smoothed_curve(axes[0, 0], self.episode_rewards, window, 
                                 "Episode Rewards", "Episode", "Reward")
        
        # Success rate
        self._plot_smoothed_curve(axes[0, 1], self.success_rates, window,
                                 "Success Rate", "Episode", "Success Rate", 
                                 is_rate=True)
        
        # Collision rate
        self._plot_smoothed_curve(axes[0, 2], self.collision_rates, window,
                                 "Collision Rate", "Episode", "Collision Rate", 
                                 is_rate=True)
        
        # Belief accuracy
        self._plot_smoothed_curve(axes[1, 0], self.belief_accuracy, window,
                                 "Belief Accuracy", "Episode", "Accuracy")
        
        # Assistance level
        self._plot_smoothed_curve(axes[1, 1], self.assistance_levels, window,
                                 "Average Assistance Level", "Episode", "Assistance γ")
        
        # Episode length
        self._plot_smoothed_curve(axes[1, 2], self.episode_lengths, window,
                                 "Episode Length", "Episode", "Steps")
        
        # Adjust layout
        plt.tight_layout()
        plt.savefig("results/learning_curves.png", dpi=300)
        plt.close()
        
        # Plot additional training metrics
        self._plot_training_metrics(window)
    
    def _smooth_data(self, data, window):
        """Apply moving average smoothing to data."""
        if len(data) < window:
            return np.array(data)
        
        smoothed = np.convolve(data, np.ones(window)/window, mode='valid')
        return smoothed
    
    def _plot_smoothed_curve(self, ax, data, window, title, xlabel, ylabel, is_rate=False):
        """Helper to plot smoothed curve with confidence interval."""
        if not data:
            return
            
        x = np.arange(len(data))
        y = np.array(data)
        
        # Plot raw data with low alpha
        ax.plot(x, y, alpha=0.2, color='blue', label='Raw')
        
        # Apply smoothing for moving average if sufficient data
        if len(x) > window:
            smoothed_y = self._smooth_data(y, window)
            smoothed_x = np.arange(window-1, len(x))
            
            # Plot moving average
            ax.plot(smoothed_x, smoothed_y, alpha=1.0, color='blue', linewidth=2, label=f'Moving Avg (window={window})')
            
            # Compute rolling std for confidence interval
            rolling_std = [np.std(y[max(0, i-window):i+1]) for i in range(window-1, len(y))]
            rolling_std = np.array(rolling_std)
            
            # Plot confidence intervals
            ax.fill_between(smoothed_x, 
                           np.maximum(0, smoothed_y - rolling_std), 
                           np.minimum(1 if is_rate else float('inf'), smoothed_y + rolling_std), 
                           alpha=0.2, color='blue')
        
        # Add title and labels
        ax.set_title(title, fontsize=14)
        ax.set_xlabel(xlabel, fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        
        # Set y-axis limits for rate plots
        if is_rate:
            ax.set_ylim([-0.05, 1.05])
        
        # Add legend if we have both raw and smoothed data
        if len(x) > window:
            ax.legend(loc='best')
    
    def _plot_training_metrics(self, window=15):
        """Plot additional training metrics."""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Losses
        self._plot_smoothed_curve(axes[0], self.losses, window,
                                 "Training Loss", "Update", "Loss")
        
        # Average returns
        self._plot_smoothed_curve(axes[1], self.avg_returns, window,
                                 "Average Returns", "Update", "Return")
        
        plt.tight_layout()
        plt.savefig("results/training_metrics.png", dpi=300)
        plt.close()
    
    def plot_ablation_study(self, results_dict):
        """Plot ablation study results."""
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Process data for plotting
        methods = list(results_dict.keys())
        success_rates = [results_dict[m]['success_rate'] for m in methods]
        episode_rewards = [results_dict[m]['episode_reward'] for m in methods]
        completion_times = [results_dict[m]['completion_time'] for m in methods]
        
        # Plot success rates
        axes[0].bar(methods, success_rates, color='skyblue')
        axes[0].set_title('Success Rate by Method', fontsize=14)
        axes[0].set_ylabel('Success Rate', fontsize=12)
        axes[0].set_ylim([0, 1.0])
        
        # Plot episode rewards
        axes[1].bar(methods, episode_rewards, color='lightgreen')
        axes[1].set_title('Average Episode Reward by Method', fontsize=14)
        axes[1].set_ylabel('Reward', fontsize=12)
        
        # Plot completion times
        axes[2].bar(methods, completion_times, color='salmon')
        axes[2].set_title('Average Completion Time by Method', fontsize=14)
        axes[2].set_ylabel('Steps', fontsize=12)
        
        # Add values on top of bars
        for ax, data in zip(axes, [success_rates, episode_rewards, completion_times]):
            for i, v in enumerate(data):
                ax.text(i, v, f'{v:.2f}', ha='center', va='bottom', fontsize=10)
        
        plt.tight_layout()
        plt.savefig("results/ablation_study.png", dpi=300)
        plt.close()
    
    def save_video(self, filename="episode", fps=30):
        """Save frames as a video."""
        if not self.frames:
            return
            
        # Create writer
        frames = [frame for frame in self.frames]
        
        # Check if frames are valid
        if not frames or frames[0] is None:
            print("No valid frames to save.")
            return
            
        height, width, _ = frames[0].shape
        
        # Create video writer
        video_path = f"results/{filename}.mp4"
        
        # Save using matplotlib animation
        fig = plt.figure(figsize=(width/100, height/100), dpi=100)
        ax = fig.add_subplot(111)
        ax.set_axis_off()
        
        images = [[ax.imshow(frame)] for frame in frames]
        anim = animation.ArtistAnimation(fig, images, interval=1000/fps, blit=True)
        
        try:
            # Use Pillow writer instead of ffmpeg
            anim.save(video_path, writer='pillow', fps=fps)
            print(f"Video saved to {video_path}")
        except Exception as e:
            print(f"Error saving video: {e}")
            # Try to save individual frames instead
            try:
                os.makedirs(f"results/{filename}_frames", exist_ok=True)
                for i, frame in enumerate(frames):
                    plt.imsave(f"results/{filename}_frames/frame_{i:04d}.png", frame)
                print(f"Saved {len(frames)} individual frames to results/{filename}_frames folder")
            except Exception as e2:
                print(f"Error saving frames: {e2}")
        
        plt.close()
    
    def clear_frames(self):
        """Clear saved frames to free memory."""
        self.frames = []


# IMPROVED: Training loop with better monitoring and stability
def train(env, agent, num_episodes=5000, max_steps=1000, eval_interval=200, visualize=False, visualize_interval=2000):
    """Train the agent using PPO."""
    tracker = PerformanceTracker()
    total_steps = 0
    update_counter = 0
    best_success_rate = 0
    best_reward = -float('inf')
    episodes_without_improvement = 0
    early_stop_patience = EARLY_STOP_PATIENCE  # Increased patience
    
    # Initialize evaluation metrics
    eval_metrics = {
        'success_rate': [],
        'episode_length': [],
        'episode_reward': [],
        'belief_accuracy': [],
        'collision_rate': []
    }
    
    # For visualization
    if visualize:
        print(f"Training with visualization every {visualize_interval} steps")
    
    # Training loop
    for episode in tqdm(range(1, num_episodes + 1), desc="Training"):
        state, _ = env.reset()
        agent.reset()
        
        episode_reward = 0
        episode_length = 0
        episode_beliefs = []
        episode_true_goals = []
        episode_assistance = []
        
        # Record frames for visualization
        record_video = visualize and (episode % eval_interval == 0 or 
                                     (total_steps % visualize_interval < max_steps))
        
        # Extract goals from the environment
        goals = []
        for i in range(NUM_GOALS):
            goal_idx = state.shape[0] - 2 * NUM_GOALS + 2 * i
            goals.append(state[goal_idx:goal_idx + 2])
        
        # Episode loop
        for step in range(max_steps):
            # Render for visualization
            if record_video:
                frame = env.render()
                if frame is not None:
                    tracker.add_frame(frame)
            
            # Select action
            action, belief = agent.select_action(state, goals)
            
            # Check for NaN values in action
            if np.isnan(action).any() or np.isinf(action).any():
                print(f"Warning: NaN or Inf detected in action! Using zero action instead.")
                action = np.zeros_like(action)
            
            # Get action log probability and value
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
                belief_tensor = torch.FloatTensor(belief).unsqueeze(0).to(agent.device)
                action_tensor = torch.FloatTensor(action).unsqueeze(0).to(agent.device)
                
                log_prob, _, value = agent.actor_critic.evaluate_actions(
                    state_tensor, action_tensor, belief_tensor
                )
                log_prob = log_prob.squeeze().cpu().item()
                value = value.squeeze().cpu().item()
            
            # Take action in environment
            next_state, reward, done, truncated, info = env.step(action)
            
            # Store experience
            agent.remember(state, action, reward, value, log_prob, belief, done)
            
            # Update state and counters
            state = next_state
            episode_reward += reward
            episode_length += 1
            total_steps += 1
            
            # Store belief accuracy and assistance level
            true_goal_idx = env.true_goal_idx
            
            # Compute belief accuracy
            predicted_goal_idx = np.argmax(belief)
            episode_beliefs.append(predicted_goal_idx == true_goal_idx)
            episode_true_goals.append(true_goal_idx)
            
            # Estimate assistance level (gamma) as max belief probability
            max_belief = np.max(belief)
            episode_assistance.append(max_belief)
            
            # Update policy if enough steps have been taken
            if agent.steps_since_update >= STEPS_PER_UPDATE:
                loss = agent.update()
                tracker.add_training_metrics(loss, np.mean(agent.rewards) if agent.rewards else 0)
                update_counter += 1
            
            # Check if episode is done
            if done or truncated:
                break
        
        # Calculate episode metrics
        success = info.get('goal_reached', False)
        collision = info.get('collision', False)
        belief_accuracy = np.mean(episode_beliefs) if episode_beliefs else 0
        avg_assistance = np.mean(episode_assistance) if episode_assistance else 0
        
        # Add episode metrics to tracker
        tracker.add_episode_metrics(
            episode_reward, episode_length, success, collision, 
            belief_accuracy, avg_assistance
        )
        
        # Log progress
        if episode % 100 == 0:
            mean_reward = np.mean(tracker.episode_rewards[-100:]) if tracker.episode_rewards else 0
            mean_success = np.mean(tracker.success_rates[-100:]) if tracker.success_rates else 0
            mean_belief = np.mean(tracker.belief_accuracy[-100:]) if tracker.belief_accuracy else 0
            
            print(f"Episode {episode}, Steps: {total_steps}, Avg reward: {mean_reward:.2f}, " 
                  f"Success rate: {mean_success:.2f}, Belief accuracy: {mean_belief:.2f}")
            
            # Plot current learning curves periodically
            tracker.plot_learning_curves()
        
        # Evaluate agent periodically
        if episode % eval_interval == 0:
            eval_results = evaluate(env, agent, num_episodes=20, max_steps=max_steps)
            
            for key in eval_metrics:
                eval_metrics[key].append(eval_results[key])
            
            # Check for improvement
            if eval_results['success_rate'] > best_success_rate:
                best_success_rate = eval_results['success_rate']
                agent.save("results/best_model_success.pt")
                episodes_without_improvement = 0
                print(f"New best success rate: {best_success_rate:.4f}")
            elif eval_results['episode_reward'] > best_reward:
                best_reward = eval_results['episode_reward']
                agent.save("results/best_model_reward.pt")
                episodes_without_improvement = 0
                print(f"New best reward: {best_reward:.4f}")
            else:
                episodes_without_improvement += 1
                print(f"No improvement for {episodes_without_improvement} evaluations")
                
            # Early stopping check
            if episodes_without_improvement >= early_stop_patience:
                print(f"No improvement for {early_stop_patience} evaluations. Stopping training.")
                break
                
            # Save video of the best evaluation episode
            if visualize and tracker.frames:
                tracker.save_video(f"episode_{episode}")
                tracker.clear_frames()
    
    # Final evaluation and visualization
    print("Training complete. Running final evaluation...")
    final_eval = evaluate(env, agent, num_episodes=30, max_steps=max_steps, visualize=visualize)
    
    # Save final model
    agent.save("results/final_model.pt")
    
    # Plot final learning curves
    tracker.plot_learning_curves()
    
    # Run ablation study
    run_ablation_study(env, agent, tracker, visualize=visualize)
    
    # Print final results
    print("\nFinal Evaluation Results:")
    for key, value in final_eval.items():
        print(f"{key}: {value:.4f}")
    
    return tracker, final_eval


# IMPROVED: Evaluation function with deterministic policy
def evaluate(env, agent, num_episodes=10, max_steps=1000, visualize=False):
    """Evaluate the agent's performance."""
    eval_rewards = []
    eval_lengths = []
    eval_successes = []
    eval_collisions = []
    eval_belief_accuracy = []
    
    tracker = PerformanceTracker()
    
    for episode in range(num_episodes):
        state, _ = env.reset()
        agent.reset()
        
        episode_reward = 0
        episode_length = 0
        episode_beliefs = []
        
        # Extract goals from the environment
        goals = []
        for i in range(NUM_GOALS):
            goal_idx = state.shape[0] - 2 * NUM_GOALS + 2 * i
            goals.append(state[goal_idx:goal_idx + 2])
        
        # Episode loop
        for step in range(max_steps):
            # Record video if requested
            if visualize:
                frame = env.render()
                if frame is not None:
                    tracker.add_frame(frame)
            
            # Select action (deterministic for evaluation)
            action, belief = agent.select_action(state, goals, deterministic=True)
            
            # Check for NaN values in action
            if np.isnan(action).any() or np.isinf(action).any():
                action = np.zeros_like(action)
            
            # Take action
            next_state, reward, done, truncated, info = env.step(action)
            
            # Update state and counters
            state = next_state
            episode_reward += reward
            episode_length += 1
            
            # Track belief accuracy
            true_goal_idx = env.true_goal_idx
            predicted_goal_idx = np.argmax(belief)
            episode_beliefs.append(predicted_goal_idx == true_goal_idx)
            
            if done or truncated:
                break
        
        # Calculate episode metrics
        success = info.get('goal_reached', False)
        collision = info.get('collision', False)
        belief_accuracy = np.mean(episode_beliefs) if episode_beliefs else 0
        
        # Store metrics
        eval_rewards.append(episode_reward)
        eval_lengths.append(episode_length)
        eval_successes.append(1 if success else 0)
        eval_collisions.append(1 if collision else 0)
        eval_belief_accuracy.append(belief_accuracy)
    
    # Save video of the evaluation
    if visualize and tracker.frames:
        tracker.save_video("evaluation")
        tracker.clear_frames()
    
    # Calculate final metrics
    results = {
        'episode_reward': np.mean(eval_rewards) if eval_rewards else 0,
        'episode_length': np.mean(eval_lengths) if eval_lengths else 0,
        'success_rate': np.mean(eval_successes) if eval_successes else 0,
        'collision_rate': np.mean(eval_collisions) if eval_collisions else 0,
        'belief_accuracy': np.mean(eval_belief_accuracy) if eval_belief_accuracy else 0,
        'completion_time': np.mean([l for l, s in zip(eval_lengths, eval_successes) if s]) if any(eval_successes) else float('inf')
    }
    
    return results


# IMPROVED: Ablation study with more realistic human simulation
def run_ablation_study(env, agent, tracker, visualize=False):
    """Run ablation studies to compare different approaches."""
    print("Running ablation studies...")
    
    # Define methods to compare
    methods = {
        'Full Model': agent,  # Our integrated approach
        'No Assistance': None,  # Manual control only
        'Fixed (0.5)': None,  # Fixed blending parameter
        'MAP Selection': None,  # Assist toward most likely goal
    }
    
    # Create human simulation agent for more realistic tests
    class HumanSimulation:
        def __init__(self, noise_level=0.2, optimal_ratio=0.8, random_action_prob=0.05):
            self.noise_level = noise_level  # Noise in human control
            self.optimal_ratio = optimal_ratio  # How closely human follows optimal path
            self.random_action_prob = random_action_prob  # Probability of random action
            
        def get_human_action(self, state, goal):
            """Generate realistic simulated human action toward the goal."""
            # Extract end effector position
            ee_pos = state[-4:-2]  # Position is at the end of the observation before goal info
            
            # Calculate direction to goal
            goal_dir = goal - ee_pos
            goal_dist = np.linalg.norm(goal_dir)
            
            # Random action with small probability
            if np.random.random() < self.random_action_prob:
                return np.random.uniform(-1, 1, size=2)
            
            # Normalize direction if non-zero
            if goal_dist > 0:
                optimal_dir = goal_dir / goal_dist
            else:
                optimal_dir = np.zeros(2)
            
            # Scale action magnitude based on distance
            optimal_magnitude = min(1.0, goal_dist * 2.0)
            optimal_action = optimal_dir * optimal_magnitude
            
            # Add noise to simulate human variability
            noise = np.random.normal(0, self.noise_level, size=2)
            
            # Add directional bias to simulate hand tremor
            bias_dir = np.random.normal(0, 1, size=2)
            bias_dir = bias_dir / np.linalg.norm(bias_dir) if np.linalg.norm(bias_dir) > 0 else bias_dir
            bias = bias_dir * self.noise_level * 0.3
            
            # Sometimes take suboptimal paths
            if np.random.random() > self.optimal_ratio:
                # Add deviation to simulate strategic errors
                angle = np.random.uniform(-np.pi/4, np.pi/4)
                cos_angle, sin_angle = np.cos(angle), np.sin(angle)
                rotation = np.array([[cos_angle, -sin_angle], [sin_angle, cos_angle]])
                optimal_dir = np.dot(rotation, optimal_dir)
                optimal_action = optimal_dir * optimal_magnitude
            
            # Combine optimal action with noise and bias
            action = optimal_action + noise + bias
            
            # Add some tremor when close to goal (nervousness)
            if goal_dist < 0.1:
                tremor = np.random.normal(0, 0.1 + 0.3 * (1 - goal_dist/0.1), size=2)
                action += tremor
            
            # Clip to action space
            return np.clip(action, -1, 1)
    
    # Create human simulation
    human_sim = HumanSimulation(noise_level=0.2)
    
    # 1. No Assistance: γ = 0
    class NoAssistanceAgent:
        def __init__(self, base_agent, human_sim):
            self.base_agent = base_agent
            self.human_sim = human_sim
            self.device = base_agent.device
        
        def select_action(self, state, goals, deterministic=False):
            # Update belief as normal
            _, belief = self.base_agent.select_action(state, goals, deterministic)
            
            # Generate simulated human action (noisy optimal action toward true goal)
            human_action = self.human_sim.get_human_action(state, env.true_goal)
            
            # Return pure human action (no assistance)
            return human_action, belief
        
        def reset(self):
            self.base_agent.reset()
    
    # 2. Fixed blending: γ = 0.5
    class FixedBlendingAgent:
        def __init__(self, base_agent, human_sim, gamma=0.5):
            self.base_agent = base_agent
            self.human_sim = human_sim
            self.gamma = gamma
            self.device = base_agent.device
        
        def select_action(self, state, goals, deterministic=False):
            # Update belief and get AI action
            ai_action, belief = self.base_agent.select_action(state, goals, deterministic)
            
            # Generate human action
            human_action = self.human_sim.get_human_action(state, env.true_goal)
            
            # Blend with fixed parameter
            action = (1 - self.gamma) * human_action + self.gamma * ai_action
            
            # Check for NaN values
            if np.isnan(action).any() or np.isinf(action).any():
                action = human_action  # Fall back to human action
            
            return action, belief
        
        def reset(self):
            self.base_agent.reset()
    
    # 3. MAP Selection: Assist toward most likely goal only
    class MAPSelectionAgent:
        def __init__(self, base_agent, human_sim):
            self.base_agent = base_agent
            self.human_sim = human_sim
            self.device = base_agent.device
        
        def select_action(self, state, goals, deterministic=False):
            # Get belief from base agent
            _, belief = self.base_agent.select_action(state, goals, deterministic)
            
            # Generate human action
            human_action = self.human_sim.get_human_action(state, env.true_goal)
            
            # Get most likely goal
            map_goal_idx = np.argmax(belief)
            map_goal = goals[map_goal_idx]
            
            # Generate expert action toward MAP goal
            ee_pos = state[-4:-2]  # Position is at the end of the observation before goal info
            goal_dir = map_goal - ee_pos
            if np.linalg.norm(goal_dir) > 0:
                goal_dist = np.linalg.norm(goal_dir)
                expert_dir = goal_dir / goal_dist
                # Scale magnitude based on distance
                magnitude = min(1.0, goal_dist * 2.0)
                expert_action = expert_dir * magnitude
            else:
                expert_action = np.zeros(2)
            
            # Adaptive blending based on maximum belief
            gamma = np.max(belief)
            action = (1 - gamma) * human_action + gamma * expert_action
            
            # Check for NaN values
            if np.isnan(action).any() or np.isinf(action).any():
                action = human_action  # Fall back to human action
            
            return action, belief
        
        def reset(self):
            self.base_agent.reset()
    
    # Improved Full Model: Better handles the human simulation
    class FullModelAgent:
        def __init__(self, base_agent, human_sim):
            self.base_agent = base_agent
            self.human_sim = human_sim
            self.device = base_agent.device
        
        def select_action(self, state, goals, deterministic=False):
            # Get AI action and belief from base agent
            ai_action, belief = self.base_agent.select_action(state, goals, deterministic)
            
            # Generate human action
            human_action = self.human_sim.get_human_action(state, env.true_goal)
            
            # Blend actions based on confidence and proximity to goal
            ee_pos = state[-4:-2]
            max_belief_idx = np.argmax(belief)
            confidence = belief[max_belief_idx]
            
            # Increase AI influence when close to goal or when confidence is high
            distance_to_goal = np.linalg.norm(env.true_goal - ee_pos)
            
            # Adaptive assistance based on belief and distance
            gamma = min(0.8, confidence * (1.0 + 0.5 * (1.0 - min(1.0, distance_to_goal / 0.2))))
            
            # Blend with adaptive gamma
            action = (1 - gamma) * human_action + gamma * ai_action
            
            # Check for NaN values
            if np.isnan(action).any() or np.isinf(action).any():
                action = human_action  # Fall back to human action
            
            return action, belief
        
        def reset(self):
            self.base_agent.reset()
    
    # Create instances of each ablated agent
    methods['No Assistance'] = NoAssistanceAgent(agent, human_sim)
    methods['Fixed (0.5)'] = FixedBlendingAgent(agent, human_sim, gamma=0.5)
    methods['MAP Selection'] = MAPSelectionAgent(agent, human_sim)
    methods['Full Model'] = FullModelAgent(agent, human_sim)  # Replace with improved version
    
    # Evaluate each method
    results = {}
    for method_name, method_agent in methods.items():
        print(f"Evaluating: {method_name}")
        
        # Skip if no agent is provided
        if method_agent is None:
            continue
        
        # Evaluate the method
        eval_results = evaluate(env, method_agent, num_episodes=20, visualize=visualize and method_name=='Full Model')
        
        # Store results
        results[method_name] = eval_results
        
        print(f"  Success rate: {eval_results['success_rate']:.4f}")
        print(f"  Episode reward: {eval_results['episode_reward']:.4f}")
        print(f"  Completion time: {eval_results['completion_time']:.4f}")
        print(f"  Belief accuracy: {eval_results['belief_accuracy']:.4f}")
    
    # Plot ablation study results
    tracker.plot_ablation_study(results)
    
    return results


def main():
    # Check if visualization is possible
    render_mode = None
    visualization_enabled = VISUALIZE
    
    if visualization_enabled:
        try:
            import pygame
            pygame.init()
            render_mode = "rgb_array"
        except ImportError:
            print("Warning: Pygame not found. Visualization will be disabled.")
            visualization_enabled = False
    
    # Create environment
    env = MultiGoalReacherEnv(num_goals=NUM_GOALS, render_mode=render_mode)
    
    # Get dimensions
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    # Create agent
    agent = PPOAgent(
        state_dim=state_dim,
        action_dim=action_dim,
        hidden_dim=256,
        num_goals=NUM_GOALS,
        lr=LR,
        gamma=GAMMA,
        gae_lambda=GAE_LAMBDA,
        clip_epsilon=CLIP_EPSILON,
        value_coef=VALUE_COEF,
        entropy_coef=ENTROPY_COEF,
        target_kl=TARGET_KL
    )
    
    # Pretrain the Bayesian inference module if enabled
    if PRETRAIN:
        print("Pretraining Bayesian inference module...")
        agent.bayesian_module = pretrain_belief_module(
            agent.bayesian_module,
            num_trajectories=PRETRAIN_TRAJECTORIES,
            trajectory_length=PRETRAIN_TRAJECTORY_LENGTH,
            batch_size=PRETRAIN_BATCH_SIZE,
            num_epochs=PRETRAIN_EPOCHS,
            lr=PRETRAIN_LR
        )
        print("Pretraining complete!")
    
    # Train agent
    tracker, final_eval = train(
        env=env,
        agent=agent,
        num_episodes=NUM_EPISODES,
        eval_interval=EVAL_INTERVAL,
        visualize=visualization_enabled,
        visualize_interval=VISUALIZE_INTERVAL
    )
    
    # Close environment
    env.close()
    
    print("Training and evaluation completed!")


if __name__ == "__main__":
    main()

Using device: cuda
Pretraining Bayesian inference module...
Generating pretraining data...




Pretraining Bayesian inference module for 20 epochs...


                                                                  

Epoch 1/20, Loss: 0.4847, Accuracy: 0.9037


                                                                  

Epoch 2/20, Loss: 0.4919, Accuracy: 0.9172


                                                                  

Epoch 3/20, Loss: 0.4591, Accuracy: 0.9182


                                                                  

Epoch 4/20, Loss: 0.4251, Accuracy: 0.9286


                                                                  

Epoch 5/20, Loss: 0.4676, Accuracy: 0.9205


                                                                  

Epoch 6/20, Loss: 0.4711, Accuracy: 0.9158


                                                                      

Epoch 7/20, Loss: 0.4974, Accuracy: 0.9061


                                                                      

Epoch 8/20, Loss: 0.4719, Accuracy: 0.9093


                                                                      

Epoch 9/20, Loss: 0.4847, Accuracy: 0.9069


                                                                       

Epoch 10/20, Loss: 0.4880, Accuracy: 0.9068


                                                                       

Epoch 11/20, Loss: 0.4698, Accuracy: 0.9165


                                                                       

Epoch 12/20, Loss: 0.4868, Accuracy: 0.9049


                                                                       

Epoch 13/20, Loss: 0.4856, Accuracy: 0.9054


                                                                     

Epoch 14/20, Loss: 0.4931, Accuracy: 0.8969


                                                                     

Epoch 15/20, Loss: 0.4784, Accuracy: 0.9003


                                                                     

Epoch 16/20, Loss: 0.4906, Accuracy: 0.8971


                                                                     

Epoch 17/20, Loss: 0.4801, Accuracy: 0.8992


                                                                     

Epoch 18/20, Loss: 0.4614, Accuracy: 0.9030


                                                                     

Epoch 19/20, Loss: 0.4889, Accuracy: 0.9000


                                                                     

Epoch 20/20, Loss: 0.4617, Accuracy: 0.9039
Pretraining complete!
Pretraining complete!


Training:   5%|▍         | 99/2000 [02:31<45:41,  1.44s/it]  

Episode 100, Steps: 22672, Avg reward: 172.44, Success rate: 0.61, Belief accuracy: 0.32


Training:   5%|▌         | 100/2000 [02:51<3:40:50,  6.97s/it]

New best success rate: 0.5500


Training:  10%|▉         | 199/2000 [04:34<52:06,  1.74s/it]  

Episode 200, Steps: 38297, Avg reward: 136.83, Success rate: 0.67, Belief accuracy: 0.38


Training:  10%|█         | 200/2000 [05:35<9:42:00, 19.40s/it]

New best reward: 451.4554


Training:  15%|█▍        | 298/2000 [07:38<59:17,  2.09s/it]  

Episode 300, Steps: 57092, Avg reward: 138.76, Success rate: 0.51, Belief accuracy: 0.33


Training:  15%|█▌        | 300/2000 [08:08<3:49:31,  8.10s/it]

No improvement for 1 evaluations


Training:  20%|█▉        | 399/2000 [10:08<26:08,  1.02it/s]  

Episode 400, Steps: 75501, Avg reward: 154.14, Success rate: 0.58, Belief accuracy: 0.30


Training:  20%|██        | 400/2000 [10:41<4:14:59,  9.56s/it]

No improvement for 2 evaluations


Training:  25%|██▍       | 499/2000 [13:00<26:02,  1.04s/it]  

Episode 500, Steps: 96705, Avg reward: 195.82, Success rate: 0.62, Belief accuracy: 0.36


Training:  25%|██▌       | 500/2000 [13:14<1:52:16,  4.49s/it]

No improvement for 3 evaluations


Training:  30%|██▉       | 599/2000 [15:57<53:33,  2.29s/it]  

Episode 600, Steps: 122492, Avg reward: 257.81, Success rate: 0.46, Belief accuracy: 0.39


Training:  30%|███       | 600/2000 [16:57<6:56:52, 17.87s/it]

New best reward: 612.7537


Training:  35%|███▍      | 699/2000 [22:47<1:01:53,  2.85s/it]

Episode 700, Steps: 157431, Avg reward: 367.17, Success rate: 0.43, Belief accuracy: 0.29


Training:  35%|███▌      | 700/2000 [24:54<14:02:43, 38.90s/it]

No improvement for 1 evaluations


Training:  40%|███▉      | 799/2000 [33:14<55:55,  2.79s/it]   

Episode 800, Steps: 187206, Avg reward: 302.08, Success rate: 0.47, Belief accuracy: 0.29


Training:  40%|████      | 800/2000 [35:02<11:27:28, 34.37s/it]

No improvement for 2 evaluations


Training:  45%|████▍     | 899/2000 [41:37<1:49:08,  5.95s/it] 

Episode 900, Steps: 210509, Avg reward: 245.16, Success rate: 0.34, Belief accuracy: 0.29


Training:  45%|████▌     | 900/2000 [42:50<7:45:56, 25.42s/it]

No improvement for 3 evaluations


Training:  50%|████▉     | 999/2000 [50:54<1:31:32,  5.49s/it]

Episode 1000, Steps: 241750, Avg reward: 319.05, Success rate: 0.37, Belief accuracy: 0.26


Training:  50%|█████     | 1000/2000 [52:16<7:03:22, 25.40s/it]

No improvement for 4 evaluations


Training:  55%|█████▍    | 1099/2000 [58:27<24:42,  1.65s/it]  

Episode 1100, Steps: 266135, Avg reward: 243.83, Success rate: 0.44, Belief accuracy: 0.33


Training:  55%|█████▌    | 1100/2000 [1:00:30<9:30:08, 38.01s/it]

No improvement for 5 evaluations


Training:  60%|█████▉    | 1199/2000 [1:06:31<10:14,  1.30it/s]  

Episode 1200, Steps: 287151, Avg reward: 216.61, Success rate: 0.44, Belief accuracy: 0.29


Training:  60%|██████    | 1200/2000 [1:07:34<4:17:14, 19.29s/it]

New best success rate: 0.6000


Training:  65%|██████▍   | 1299/2000 [1:15:34<47:59,  4.11s/it]  

Episode 1300, Steps: 316606, Avg reward: 306.38, Success rate: 0.42, Belief accuracy: 0.33


Training:  65%|██████▌   | 1300/2000 [1:16:06<2:19:54, 11.99s/it]

No improvement for 1 evaluations


Training:  70%|██████▉   | 1399/2000 [1:24:20<43:42,  4.36s/it]  

Episode 1400, Steps: 343939, Avg reward: 277.53, Success rate: 0.36, Belief accuracy: 0.31


Training:  70%|███████   | 1400/2000 [1:26:01<5:07:16, 30.73s/it]

No improvement for 2 evaluations


Training:  75%|███████▍  | 1499/2000 [1:33:14<1:12:42,  8.71s/it]

Episode 1500, Steps: 371479, Avg reward: 302.53, Success rate: 0.48, Belief accuracy: 0.36


Training:  75%|███████▌  | 1500/2000 [1:35:48<7:04:52, 50.99s/it]

New best reward: 678.8909


Training:  80%|███████▉  | 1599/2000 [1:43:03<10:26,  1.56s/it]  

Episode 1600, Steps: 398675, Avg reward: 293.63, Success rate: 0.50, Belief accuracy: 0.29


Training:  80%|████████  | 1601/2000 [1:44:00<1:24:39, 12.73s/it]

No improvement for 1 evaluations


Training:  85%|████████▍ | 1699/2000 [1:51:13<33:46,  6.73s/it]  

Episode 1700, Steps: 427648, Avg reward: 309.22, Success rate: 0.37, Belief accuracy: 0.30


Training:  85%|████████▌ | 1700/2000 [1:51:52<1:20:59, 16.20s/it]

No improvement for 2 evaluations


Training:  90%|████████▉ | 1799/2000 [1:54:36<05:17,  1.58s/it]  

Episode 1800, Steps: 453187, Avg reward: 270.28, Success rate: 0.40, Belief accuracy: 0.42


Training:  90%|█████████ | 1800/2000 [1:55:10<31:24,  9.42s/it]

No improvement for 3 evaluations


Training:  95%|█████████▍| 1899/2000 [2:00:03<12:39,  7.52s/it]

Episode 1900, Steps: 484178, Avg reward: 362.11, Success rate: 0.53, Belief accuracy: 0.32


Training:  95%|█████████▌| 1901/2000 [2:02:10<49:26, 29.97s/it]  

No improvement for 4 evaluations


Training: 100%|█████████▉| 1999/2000 [2:05:11<00:01,  1.53s/it]

Episode 2000, Steps: 508945, Avg reward: 270.37, Success rate: 0.41, Belief accuracy: 0.39


Training: 100%|██████████| 2000/2000 [2:05:55<00:00,  3.78s/it]

No improvement for 5 evaluations
Training complete. Running final evaluation...





Running ablation studies...
Evaluating: Full Model
  Success rate: 0.2000
  Episode reward: 370.3006
  Completion time: 10.0000
  Belief accuracy: 0.2242
Evaluating: No Assistance
  Success rate: 0.6500
  Episode reward: 226.5941
  Completion time: 169.5385
  Belief accuracy: 0.2924
Evaluating: Fixed (0.5)
  Success rate: 0.5000
  Episode reward: 336.4776
  Completion time: 181.9000
  Belief accuracy: 0.2739
Evaluating: MAP Selection
  Success rate: 0.3500
  Episode reward: 303.2553
  Completion time: 112.1429
  Belief accuracy: 0.3222

Final Evaluation Results:
episode_reward: 611.6992
episode_length: 490.6000
success_rate: 0.2667
collision_rate: 0.3333
belief_accuracy: 0.2851
completion_time: 22.2500
Training and evaluation completed!
