In [5]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import random
from collections import deque

# Set seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

# Visualization and training parameters
VISUALIZE = False  # Set to False to disable visualization
VISUALIZE_INTERVAL = 2000  # Show visualization every n steps
SAVE_VIDEOS = False  # Save videos of the trained agent
NUM_EPISODES = 2000  # Number of training episodes

# Environment parameters
GOAL_RADIUS = 0.05
NUM_GOALS = 3  # Number of potential goals

# PPO parameters - Tuned for stable learning
GAMMA = 0.995  # Increased for better long-term rewards
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
ENTROPY_COEF = 0.015  # Slightly increased for better exploration
VALUE_COEF = 0.5
LR = 2e-4  # Slightly lower for more stable learning
BATCH_SIZE = 512  # Larger batch size for better statistics
STEPS_PER_UPDATE = 256  # Frequent updates
PPO_EPOCHS = 12  # More epochs per update for better learning
TARGET_KL = 0.015  # Slightly increased to allow more policy change

# Beliefs will be added in Phase 2
USE_BELIEF_MODULE = False  # Will be toggled to True in Phase 2
BETA = 2.0  # Higher temperature for sharper belief distributions

# Reward function parameters - carefully balanced
COLLISION_PENALTY = -10.0
GOAL_REWARD = 25.0  # Increased goal reward for stronger signal
PROGRESS_REWARD = 1.5  # Increased to encourage moving toward goal
ACTION_PENALTY = 0.05  # Reduced to allow more movement
PROXIMITY_REWARD = 1.5  # Increased to create better gradient near goal

# Training parameters
EVAL_INTERVAL = 20  # Evaluate less frequently but more thoroughly
EARLY_STOP_PATIENCE = 25  # More patience to find good solutions
IMPROVEMENT_THRESHOLD = 0.03  # Lower threshold - 3% improvement is significant
MIN_TRAINING_EPISODES = 200  # Ensure minimum training before stopping

# Create results directory
os.makedirs("results", exist_ok=True)

# Check if CUDA is available
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print(f"Using device: {DEVICE}")


# Custom Reacher 2D Environment
class CustomReacher2D(gym.Env):
    """Custom 2D Reacher Environment with realistic physics"""
    
    def __init__(self, render_mode=None):
        super(CustomReacher2D, self).__init__()
        
        # Action space: 2D continuous actions for controlling the arm
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        
        # Observation space: joint angles, joint velocities, end-effector position
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32)
        
        # Arm parameters
        self.link_lengths = [0.1, 0.11]  # Length of arm segments
        self.max_velocity = 1.0  # Maximum joint velocity
        self.dt = 0.05  # Time step
        
        # State
        self.joint_angles = np.array([0.0, 0.0])  # Two joints, in radians
        self.joint_velocities = np.array([0.0, 0.0])
        
        # Rendering setup
        self.render_mode = render_mode
        self.screen = None
        self.clock = None
        self.window_size = 500  # pixels
        self.pygame = None
        
        # For visualization
        if render_mode == "rgb_array":
            # Import pygame only if rendering is needed
            try:
                import pygame
                self.pygame = pygame
                self.screen = pygame.Surface((self.window_size, self.window_size))
                self.clock = pygame.time.Clock()
            except ImportError:
                self.render_mode = None
                print("Warning: Pygame not available. Rendering disabled.")
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Reset joint state
        self.joint_angles = np.array([np.random.uniform(-np.pi/2, np.pi/2), 
                                       np.random.uniform(-np.pi/2, np.pi/2)])
        self.joint_velocities = np.array([0.0, 0.0])
        
        # Get current observation
        observation = self._get_obs()
        info = {}
        
        return observation, info
    
    def step(self, action):
        # Clip action to action space
        action = np.clip(action, -1, 1)
        
        # Convert actions to joint accelerations - more realistic physics
        accelerations = action * 8.0  # Scale factor for accelerations
        
        # Update velocities using accelerations
        self.joint_velocities += accelerations * self.dt
        
        # Clip velocities
        self.joint_velocities = np.clip(self.joint_velocities, -self.max_velocity, self.max_velocity)
        
        # Update joint angles using velocities
        self.joint_angles += self.joint_velocities * self.dt
        
        # Wrap angles to [-pi, pi]
        self.joint_angles = np.mod(self.joint_angles + np.pi, 2 * np.pi) - np.pi
        
        # Get new observation
        observation = self._get_obs()
        
        # Default reward and done (to be overridden by wrapper)
        reward = 0.0
        terminated = False
        truncated = False
        info = {}
        
        return observation, reward, terminated, truncated, info
    
    def _get_obs(self):
        """Get current observation (joint angles, velocities, and end-effector position)"""
        ee_pos = self._get_end_effector_position()
        obs = np.concatenate([
            np.cos(self.joint_angles),
            np.sin(self.joint_angles),
            ee_pos
        ])
        return obs
    
    def _get_end_effector_position(self):
        """Compute the position of the end effector using forward kinematics"""
        theta1, theta2 = self.joint_angles
        l1, l2 = self.link_lengths
        
        # Position of first joint is at origin (0,0)
        # Position of second joint
        x1 = l1 * np.cos(theta1)
        y1 = l1 * np.sin(theta1)
        
        # Position of end effector
        x2 = x1 + l2 * np.cos(theta1 + theta2)
        y2 = y1 + l2 * np.sin(theta1 + theta2)
        
        return np.array([x2, y2])
    
    def render(self):
        if self.render_mode != "rgb_array" or self.pygame is None:
            return None
        
        # Clear the screen
        self.screen.fill((255, 255, 255))
        
        # Convert from world coordinates to screen coordinates
        def world_to_screen(point):
            scale = self.window_size / 2.5  # Scale to make the arm visible
            screen_x = int(point[0] * scale + self.window_size / 2)
            screen_y = int(-point[1] * scale + self.window_size / 2)  # Negative because screen y is inverted
            return (screen_x, screen_y)
        
        # Draw the arm
        # Base position (origin)
        base_pos = world_to_screen((0, 0))
        
        # First joint position
        theta1 = self.joint_angles[0]
        l1 = self.link_lengths[0]
        joint1_x = l1 * np.cos(theta1)
        joint1_y = l1 * np.sin(theta1)
        joint1_pos = world_to_screen((joint1_x, joint1_y))
        
        # End effector position
        ee_pos = self._get_end_effector_position()
        ee_screen_pos = world_to_screen(ee_pos)
        
        # Draw the links
        self.pygame.draw.line(self.screen, (0, 0, 0), base_pos, joint1_pos, 6)
        self.pygame.draw.line(self.screen, (0, 0, 0), joint1_pos, ee_screen_pos, 6)
        
        # Draw the joints
        self.pygame.draw.circle(self.screen, (255, 0, 0), base_pos, 10)
        self.pygame.draw.circle(self.screen, (0, 255, 0), joint1_pos, 8)
        self.pygame.draw.circle(self.screen, (0, 0, 255), ee_screen_pos, 8)
        
        return np.transpose(np.array(self.pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2))
    
    def close(self):
        if self.screen is not None:
            self.screen = None


# Environment wrapper for multiple goals
class MultiGoalReacherEnv(gym.Wrapper):
    def __init__(self, num_goals=3, goal_radius=0.05, render_mode=None):
        # Create the base environment
        self.env = CustomReacher2D(render_mode=render_mode)
        super().__init__(self.env)
        
        # Potential goal positions (normalized to [-1, 1])
        self.num_goals = num_goals
        self.goal_radius = goal_radius
        self.goals = []
        self.true_goal = None
        self.true_goal_idx = None
        self.obstacles = []
        
        # Override observation space to include goal information
        obs_dim = self.env.observation_space.shape[0]
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(obs_dim + num_goals * 2,)
        )
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        
        # Generate random goal positions - better distribution
        self.goals = []
        angles = np.linspace(0, 2*np.pi, self.num_goals, endpoint=False)
        angles += np.random.uniform(0, 2*np.pi/self.num_goals)  # Add randomness
        
        for i in range(self.num_goals):
            # Place goals in a circular pattern with some randomness
            radius = np.random.uniform(0.12, 0.18)
            goal_x = radius * np.cos(angles[i])
            goal_y = radius * np.sin(angles[i])
            self.goals.append(np.array([goal_x, goal_y]))
        
        # Choose one goal as the true goal
        self.true_goal_idx = np.random.randint(0, self.num_goals)
        self.true_goal = self.goals[self.true_goal_idx]
        
        # Generate random obstacles (simplified as positions to avoid)
        self.obstacles = []
        for _ in range(2):  # 2 obstacles
            obs_x = np.random.uniform(-0.15, 0.15)
            obs_y = np.random.uniform(-0.15, 0.15)
            radius = np.random.uniform(0.02, 0.04)
            
            # Ensure obstacles don't overlap with goals
            valid = True
            for goal in self.goals:
                if np.linalg.norm(goal - np.array([obs_x, obs_y])) < radius + self.goal_radius:
                    valid = False
                    break
            
            if valid:
                self.obstacles.append((np.array([obs_x, obs_y]), radius))
        
        # Track previous distance for progress reward
        self.prev_distance = np.linalg.norm(self._get_end_effector_position() - self.true_goal)
        
        # Augment observation with goal positions
        augmented_obs = self._augment_observation(obs)
        return augmented_obs, info
    
    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        
        # Get current end effector position
        ee_pos = self._get_end_effector_position()
        
        # Check if the end effector reached the true goal
        distance_to_goal = np.linalg.norm(ee_pos - self.true_goal)
        goal_reached = distance_to_goal < self.goal_radius
        
        # Check for collision with obstacles
        collision = False
        for obs_pos, obs_radius in self.obstacles:
            if np.linalg.norm(ee_pos - obs_pos) < obs_radius:
                collision = True
                break
        
        # Simplified reward function
        reward = self._compute_reward(ee_pos, action, goal_reached, collision)
        
        # Override termination conditions
        done = goal_reached or collision
        
        # Augment observation with goal positions
        augmented_obs = self._augment_observation(obs)
        
        # Update info dictionary
        info['true_goal'] = self.true_goal
        info['goal_reached'] = goal_reached
        info['collision'] = collision
        info['distance_to_goal'] = distance_to_goal
        
        return augmented_obs, reward, done, truncated, info
    
    def _get_end_effector_position(self):
        """Get the end effector position from the environment"""
        return self.env._get_end_effector_position()
    
    def _compute_reward(self, ee_pos, action, goal_reached, collision):
        """Improved reward function with better shaped rewards"""
        # Distance to true goal
        distance_to_goal = np.linalg.norm(ee_pos - self.true_goal)
        
        # Base reward starts at 0
        reward = 0.0
        
        # Terminal states
        if collision:
            return -COLLISION_PENALTY
        
        if goal_reached:
            # Bonus for faster completion - encourages efficiency
            reward += GOAL_REWARD + GOAL_REWARD * 0.5 * (1.0 - min(1.0, distance_to_goal))
            return reward
        
        # Progress reward (encourage moving toward the goal)
        progress = self.prev_distance - distance_to_goal
        
        # Enhanced progress reward - higher gradient when making good progress
        if progress > 0:
            # Reward more for making progress when closer to goal
            progress_scaling = 1.0 + 2.0 * (1.0 - min(1.0, distance_to_goal / 0.2))
            reward += PROGRESS_REWARD * progress * progress_scaling
        else:
            # Small penalty for moving away from goal
            reward += 0.5 * PROGRESS_REWARD * progress
        
        # Store current distance for next step
        self.prev_distance = distance_to_goal
        
        # Action smoothness reward - quadratic penalty scaled by distance
        # Less penalty when far, more when close (encourages precision near goal)
        action_smoothness = ACTION_PENALTY * np.square(action).sum()
        action_scale = min(1.0, distance_to_goal / 0.1)  # Scale penalty by distance
        reward -= action_smoothness * (0.5 + 0.5 * (1.0 - action_scale))
        
        # Proximity reward (higher when closer to goal)
        # Exponential scaling provides better gradient
        proximity_reward = PROXIMITY_REWARD * np.exp(-5.0 * distance_to_goal)
        reward += proximity_reward
        
        # Add time penalty to encourage faster completion
        reward -= 0.01
        
        # Check if near obstacles and add avoidance reward
        for obs_pos, obs_radius in self.obstacles:
            dist_to_obs = np.linalg.norm(ee_pos - obs_pos)
            if dist_to_obs < obs_radius * 3:
                # Encourage staying away from obstacles
                obstacle_margin = dist_to_obs - obs_radius
                if obstacle_margin > 0:
                    # Reward for maintaining safe distance
                    reward += 0.5 * np.exp(-5.0 * obstacle_margin)
        
        return reward
    
    def _augment_observation(self, obs):
        """Concatenate goal positions to the observation"""
        goal_info = np.concatenate([goal for goal in self.goals])
        return np.concatenate([obs, goal_info])
    
    def render(self):
        frame = self.env.render()
        if frame is None:
            return None
        
        if self.env.pygame is None:
            return frame
        
        # Draw goals and obstacles on the frame
        def world_to_screen(point):
            scale = self.env.window_size / 2.5
            screen_x = int(point[0] * scale + self.env.window_size / 2)
            screen_y = int(-point[1] * scale + self.env.window_size / 2)
            return (screen_x, screen_y)
        
        # Draw goals
        for i, goal in enumerate(self.goals):
            goal_pos = world_to_screen(goal)
            color = (255, 215, 0) if i == self.true_goal_idx else (200, 200, 200)
            self.env.pygame.draw.circle(self.env.screen, color, goal_pos, int(self.goal_radius * self.env.window_size / 2.5))
        
        # Draw obstacles
        for obs_pos, obs_radius in self.obstacles:
            obs_screen_pos = world_to_screen(obs_pos)
            self.env.pygame.draw.circle(
                self.env.screen,
                (100, 100, 100),
                obs_screen_pos,
                int(obs_radius * self.env.window_size / 2.5)
            )
        
        return np.transpose(np.array(self.env.pygame.surfarray.pixels3d(self.env.screen)), axes=(1, 0, 2))


# Simple Belief Module - To be used in Phase 2
class SimpleBeliefModule(nn.Module):
    def __init__(self, num_goals=3, state_dim=2, action_dim=2, beta=1.0):
        super(SimpleBeliefModule, self).__init__()
        self.num_goals = num_goals
        
        # Just one temperature parameter for simplicity
        self.beta = nn.Parameter(torch.tensor(float(beta), dtype=torch.float32))
        
        # Prior probabilities (initialized as uniform)
        self.register_buffer('prior', torch.ones(num_goals, dtype=torch.float32) / num_goals)
        
        # Feature extractor for better state representation
        in_dim = state_dim + action_dim
        self.feature_net = nn.Sequential(
            nn.Linear(in_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU()
        )

        
        # Goal comparator network
        self.comparison_net = nn.Sequential(
            nn.Linear(32 + 2, 32),  # Features + goal
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=1.0)
            nn.init.zeros_(m.bias)
    
    def forward(self, state, action, goals, prev_belief=None):
        """
        Update belief over goals based on observed state and action.
        
        Args:
            state: Current state (end effector position) [2]
            action: Agent action [2]
            goals: List of potential goal positions [num_goals x 2]
            prev_belief: Previous belief distribution [num_goals]
            
        Returns:
            Updated belief over goals [num_goals]
        """
        # Ensure all input tensors are on the same device
        device = self.beta.device
        
        if prev_belief is None:
            prev_belief = self.prior.clone().to(device)
        
        # Ensure state and action are correctly shaped
        if isinstance(state, np.ndarray):
            state = torch.tensor(state, dtype=torch.float32, device=device)
        if isinstance(action, np.ndarray):
            action = torch.tensor(action, dtype=torch.float32, device=device)
        
        # Ensure we have flat vectors (no batch dimension)
        if state.dim() > 1:
            state = state.reshape(-1)
        if action.dim() > 1:
            action = action.reshape(-1)
        
        # Concatenate state and action to get a tensor of shape (4,)
        state_action = torch.cat([state, action])
        
        # Add batch dimension: (4,) -> (1, 4)
        state_action = state_action.unsqueeze(0)
        
        # Extract features from state-action pair
        features = self.feature_net(state_action)
        
        # Calculate likelihoods for each goal
        likelihoods = torch.zeros(self.num_goals, device=device)
        
        for i in range(self.num_goals):
            goal = goals[i]
            if isinstance(goal, np.ndarray):
                goal = torch.tensor(goal, dtype=torch.float32, device=device)
            
            # Ensure goal is a flat vector (no batch dimension)
            if goal.dim() > 1:
                goal = goal.reshape(-1)
            
            # Add batch dimension to goal: (2,) -> (1, 2)
            goal = goal.unsqueeze(0)
            
            # Concatenate features with goal
            # features has shape (1, 32) and goal has shape (1, 2)
            feature_goal = torch.cat([features, goal], dim=1)
            
            # Score this goal based on state-action
            score = self.comparison_net(feature_goal).squeeze()
            
            # Store the score
            likelihoods[i] = score
        
        # Apply softmax with temperature
        beta = F.softplus(self.beta) + 0.1  # Ensure positive
        likelihoods = F.softmax(beta * likelihoods, dim=0)
        
        # Bayesian update
        posterior = prev_belief * likelihoods
        posterior_sum = torch.sum(posterior)
        
        # Check for numerical stability
        if posterior_sum > 1e-8:
            posterior = posterior / posterior_sum
        else:
            # If update fails, use a mixture
            posterior = 0.5 * prev_belief + 0.5 * likelihoods
        
        # Check for NaN values
        if torch.isnan(posterior).any():
            posterior = prev_belief.clone()
        
        return posterior
    
    def reset(self):
        """Reset belief to prior."""
        return self.prior.clone()


# Improved Actor-Critic Networks for PPO
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, belief_dim=None, use_belief=False):
        super(ActorCritic, self).__init__()
        
        self.use_belief = use_belief
        
        # Feature extraction layers
        self.feature_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        
        # Separate belief processing if provided and enabled
        self.belief_network = None
        if use_belief and belief_dim is not None:
            self.belief_network = nn.Sequential(
                nn.Linear(belief_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.ReLU()
            )
            
            # Fusion layer for state and belief features
            self.fusion_network = nn.Sequential(
                nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            )
        
        # Actor network (policy)
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
        
        # Critic network (value function)
        self.critic = nn.Linear(hidden_dim, 1)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, state, belief=None):
        # Extract features from state
        features = self.feature_network(state)
        
        # Process belief and fuse with state features if available
        if self.use_belief and belief is not None and self.belief_network is not None:
            belief_features = self.belief_network(belief)
            features = torch.cat([features, belief_features], dim=-1)
            features = self.fusion_network(features)
        
        # Actor: Get action distribution
        action_mean = self.actor_mean(features)
        action_std = torch.exp(torch.clamp(self.actor_log_std, min=-20, max=2))
        
        # Critic: Get state value
        value = self.critic(features)
        
        return action_mean, action_std, value
    
    def get_action(self, state, belief=None, deterministic=False):
        # Move data to the model's device
        device = next(self.parameters()).device
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        if belief is not None and self.use_belief:
            belief = torch.FloatTensor(belief).unsqueeze(0).to(device)
        
        with torch.no_grad():
            action_mean, action_std, _ = self.forward(state, belief)
            
            if deterministic:
                action = action_mean
            else:
                normal = Normal(action_mean, action_std)
                action = normal.sample()
                
            return action.squeeze(0).cpu().numpy()
    
    def evaluate_actions(self, states, actions, beliefs=None):
        action_mean, action_std, values = self.forward(states, beliefs)
        
        # Create normal distribution
        dist = Normal(action_mean, action_std)
        
        # Get log probabilities
        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        
        # Get entropy
        entropy = dist.entropy().sum(dim=-1, keepdim=True)
        
        return log_probs, entropy, values


# Improved PPO Agent
class PPOAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=256, num_goals=3, lr=3e-4, 
                 gamma=0.99, gae_lambda=0.95, clip_epsilon=0.2, value_coef=0.5, 
                 entropy_coef=0.01, target_kl=0.01, use_belief=False):
        
        # Initialize actor-critic
        self.actor_critic = ActorCritic(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            belief_dim=num_goals if use_belief else None,
            use_belief=use_belief
        )
        
        # Initialize belief module if needed
        self.belief_module = None
        if use_belief:
            self.belief_module = SimpleBeliefModule(num_goals=num_goals, beta=BETA)
        
        self.use_belief = use_belief
        self.num_goals = num_goals
        
        # Move models to the appropriate device
        self.device = DEVICE
        self.actor_critic.to(self.device)
        if self.belief_module:
            self.belief_module.to(self.device)
        
        # Initialize optimizer - joint optimization of both networks
        if self.belief_module:
            self.optimizer = optim.Adam(
                list(self.actor_critic.parameters()) + list(self.belief_module.parameters()), 
                lr=lr
            )
        else:
            self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr)
            
        # Add learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=1000, eta_min=lr/10)
        
        # Store hyperparameters
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.target_kl = target_kl
        
        # Memory buffer for experience
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.beliefs = []
        self.dones = []
        
        # Current belief and previous action
        self.current_belief = None
        self.prev_action = None
        
        # Tracking variables
        self.steps_since_update = 0
        self.epochs_trained = 0
        self.update_count = 0
    
    def select_action(self, state, goals=None, deterministic=False):
        """Select action based on current state and belief."""
        # Extract goals from state if not provided
        if goals is None:
            # Assume the state includes goal information as per the environment wrapper
            goals_data = state[state.shape[0] - 2*self.num_goals:]
            goals = [goals_data[i:i+2] for i in range(0, len(goals_data), 2)]
        
                # Get end effector position# Get end effector position - corrected to account for goal information
        ee_pos = state[-2-2*self.num_goals:-2*self.num_goals]  # Position right before goal info
        
        # Update belief if using belief module
        belief = None
        if self.use_belief and self.belief_module:
            # Initialize belief if needed
            if self.current_belief is None:
                self.current_belief = self.belief_module.reset().to(self.device)
            
            # Handle case where we don't have previous action yet
            if self.prev_action is None:
                self.prev_action = torch.zeros(2, device=self.device)
            
            # Update belief based on observed state and previous action
            with torch.no_grad():
                updated_belief = self.belief_module.forward(
                    state=torch.tensor(ee_pos, dtype=torch.float32, device=self.device),
                    action=self.prev_action,
                    goals=[torch.tensor(g, dtype=torch.float32, device=self.device) for g in goals],
                    prev_belief=self.current_belief
                )
                
                # Safety check
                if torch.isnan(updated_belief).any():
                    updated_belief = self.belief_module.reset().to(self.device)
                
                self.current_belief = updated_belief
                belief = self.current_belief.cpu().numpy()
        
        # Select action using actor network
        action = self.actor_critic.get_action(
            state=state,
            belief=belief if self.use_belief else None,
            deterministic=deterministic
        )
        
        # Store action for next belief update
        if self.use_belief:
            self.prev_action = torch.tensor(action, device=self.device)
        
        # Safety check for NaN actions
        if np.isnan(action).any():
            action = np.zeros_like(action)
            if self.use_belief:
                self.prev_action = torch.zeros(2, device=self.device)
        
        return action, belief if self.use_belief else None
    
    def remember(self, state, action, reward, value, log_prob, belief, done):
        """Store experience in memory."""
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        if belief is not None:
            self.beliefs.append(belief)
        self.dones.append(done)
        
        # Track steps since last update
        self.steps_since_update += 1
    
    def compute_gae(self, next_value):
        """Compute Generalized Advantage Estimation."""
        values = self.values + [next_value]
        gae = 0
        returns = []
        
        # GAE calculation
        for step in reversed(range(len(self.rewards))):
            delta = (self.rewards[step] + 
                     self.gamma * values[step + 1] * (1 - self.dones[step]) - 
                     values[step])
            gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[step]) * gae
            returns.insert(0, gae + values[step])
        
        return returns
    
    def update(self):
        """Update policy using PPO algorithm."""
        # Skip update if not enough data
        if len(self.states) < BATCH_SIZE:
            return 0.0
            
        # Get next value for GAE
        with torch.no_grad():
            if len(self.states) > 0:
                state = torch.FloatTensor(self.states[-1]).unsqueeze(0).to(self.device)
                belief = None
                if self.use_belief and len(self.beliefs) > 0:
                    belief = torch.FloatTensor(self.beliefs[-1]).unsqueeze(0).to(self.device)
                _, _, next_value = self.actor_critic(state, belief)
                next_value = next_value.squeeze().cpu().item()
            else:
                next_value = 0
        
        # Compute returns using GAE
        returns = self.compute_gae(next_value)
        
        # Convert lists to tensors
        states = torch.FloatTensor(np.array(self.states)).to(self.device)
        actions = torch.FloatTensor(np.array(self.actions)).to(self.device)
        returns = torch.FloatTensor(returns).unsqueeze(1).to(self.device)
        old_log_probs = torch.FloatTensor(np.array(self.log_probs)).unsqueeze(1).to(self.device)
        values = torch.FloatTensor(np.array(self.values)).unsqueeze(1).to(self.device)
        
        # Prepare beliefs if used
        beliefs = None
        if self.use_belief and len(self.beliefs) > 0:
            beliefs = torch.FloatTensor(np.array(self.beliefs)).to(self.device)
        
        # Calculate advantages
        advantages = returns - values
        
        # Normalize advantages
        if len(advantages) > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        total_loss = 0
        policy_losses = []
        value_losses = []
        entropy_losses = []
        kl_divs = []
        
        # Mini-batch training
        batch_size = min(BATCH_SIZE, len(self.states))
        indices = np.arange(len(self.states))
        
        # Randomize and shuffle batches
        np.random.shuffle(indices)
        batches = [indices[i:i + batch_size] for i in range(0, len(indices), batch_size)]
        
        for _ in range(PPO_EPOCHS):
            # Shuffle batches
            np.random.shuffle(batches)
            
            for batch_indices in batches:
                # Get batch data
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                
                batch_beliefs = None
                if self.use_belief and beliefs is not None:
                    batch_beliefs = beliefs[batch_indices]
                
                # Evaluate actions again
                new_log_probs, entropy, values = self.actor_critic.evaluate_actions(
                    batch_states,
                    batch_actions,
                    batch_beliefs
                )
                
                # Compute ratio
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                
                # Clamp ratio to avoid numerical instability
                ratio = torch.clamp(ratio, 0.0, 10.0)
                
                # Compute surrogate losses
                surrogate1 = ratio * batch_advantages
                surrogate2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_advantages
                
                # Value loss
                value_loss = F.mse_loss(values, batch_returns)
                
                # Policy loss
                policy_loss = -torch.min(surrogate1, surrogate2).mean()
                
                # Entropy bonus
                entropy_loss = -entropy.mean()
                
                # Total loss
                loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
                
                # Store losses for logging
                policy_losses.append(policy_loss.item())
                value_losses.append(value_loss.item())
                entropy_losses.append(entropy_loss.item())
                
                # Update parameters
                self.optimizer.zero_grad()
                loss.backward()
                # Gradient clipping for stability
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_norm=0.5)
                if self.use_belief and self.belief_module:
                    nn.utils.clip_grad_norm_(self.belief_module.parameters(), max_norm=0.5)
                self.optimizer.step()
                
                total_loss += loss.item()
                
                # Check KL divergence for early stopping
                with torch.no_grad():
                    # Compute approx KL for early stopping
                    log_ratio = new_log_probs - batch_old_log_probs
                    approx_kl = ((log_ratio.exp() - 1) - log_ratio).mean().item()
                    kl_divs.append(approx_kl)
                
                if approx_kl > self.target_kl:
                    break
            
            # Early stopping based on average KL divergence
            if np.mean(kl_divs) > self.target_kl:
                break
        
        # Update learning rate periodically
        self.scheduler.step()
        
        # Increment epoch count
        self.epochs_trained += 1
        self.update_count += 1
        
        # Compute average losses for logging
        avg_policy_loss = np.mean(policy_losses) if policy_losses else 0
        avg_value_loss = np.mean(value_losses) if value_losses else 0
        avg_entropy_loss = np.mean(entropy_losses) if entropy_losses else 0
        avg_kl = np.mean(kl_divs) if kl_divs else 0
        
        # Reset memory buffer
        self.clear_memory()
        
        # Return loss components for logging
        loss_info = {
            'total_loss': total_loss / (len(batches) * PPO_EPOCHS),
            'policy_loss': avg_policy_loss,
            'value_loss': avg_value_loss,
            'entropy_loss': avg_entropy_loss,
            'approx_kl': avg_kl
        }
        
        return loss_info
    
    def clear_memory(self):
        """Clear memory buffer after updates."""
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.beliefs = []
        self.dones = []
        self.steps_since_update = 0
    
    def reset(self):
        """Reset agent state between episodes."""
        self.current_belief = None if self.use_belief else None
        self.prev_action = None if self.use_belief else None
    
    def save(self, path):
        """Save model parameters."""
        save_dict = {
            'actor_critic_state_dict': self.actor_critic.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'epochs_trained': self.epochs_trained
        }
        
        if self.use_belief and self.belief_module:
            save_dict['belief_module_state_dict'] = self.belief_module.state_dict()
            
        torch.save(save_dict, path)
    
    def load(self, path):
        """Load model parameters."""
        checkpoint = torch.load(path, map_location=self.device)
        self.actor_critic.load_state_dict(checkpoint['actor_critic_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epochs_trained = checkpoint.get('epochs_trained', 0)
        
        if self.use_belief and self.belief_module and 'belief_module_state_dict' in checkpoint:
            self.belief_module.load_state_dict(checkpoint['belief_module_state_dict'])


# Performance tracking
class PerformanceTracker:
    def __init__(self, log_dir="results"):
        self.episode_rewards = []
        self.episode_lengths = []
        self.success_rates = []
        self.collision_rates = []
        self.belief_accuracy = []
        self.losses = []
        
        # Training metrics
        self.policy_losses = []
        self.value_losses = []
        self.entropy_losses = []
        self.kl_divs = []
        
        # For visualization
        self.frames = []
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)
    
    def add_episode_metrics(self, episode_reward, episode_length, success, collision, belief_accuracy=None):
        self.episode_rewards.append(episode_reward)
        self.episode_lengths.append(episode_length)
        self.success_rates.append(1 if success else 0)
        self.collision_rates.append(1 if collision else 0)
        if belief_accuracy is not None:
            self.belief_accuracy.append(belief_accuracy)
    
    def add_training_metrics(self, loss_dict):
        """Add training metrics, handling both dictionary and float loss formats."""
        if loss_dict is None:
            return
            
        # Handle the case where loss_dict is actually a float (total loss only)
        if isinstance(loss_dict, (int, float)):
            self.losses.append(loss_dict)
            self.policy_losses.append(0)
            self.value_losses.append(0)
            self.entropy_losses.append(0)
            self.kl_divs.append(0)
        else:
            # Normal case where we have a dictionary with detailed metrics
            self.losses.append(loss_dict.get('total_loss', 0))
            self.policy_losses.append(loss_dict.get('policy_loss', 0))
            self.value_losses.append(loss_dict.get('value_loss', 0))
            self.entropy_losses.append(loss_dict.get('entropy_loss', 0))
            self.kl_divs.append(loss_dict.get('approx_kl', 0))
    
    def add_frame(self, frame):
        self.frames.append(frame)
    
    def get_recent_metrics(self, window=100):
        """Get metrics from recent episodes."""
        reward = np.mean(self.episode_rewards[-window:]) if self.episode_rewards else 0
        success = np.mean(self.success_rates[-window:]) if self.success_rates else 0
        belief = np.mean(self.belief_accuracy[-window:]) if self.belief_accuracy else 0
        collision = np.mean(self.collision_rates[-window:]) if self.collision_rates else 0
        length = np.mean(self.episode_lengths[-window:]) if self.episode_lengths else 0
        
        return {
            'reward': reward,
            'success_rate': success,
            'belief_accuracy': belief,
            'collision_rate': collision,
            'episode_length': length
        }
    
    def plot_learning_curves(self, window=15):
        """Plot learning curves with moving averages."""
        sns.set(style="darkgrid")
        
        # Create a figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Episode rewards
        self._plot_smoothed_curve(axes[0, 0], self.episode_rewards, window, 
                                 "Episode Rewards", "Episode", "Reward")
        
        # Success rate
        self._plot_smoothed_curve(axes[0, 1], self.success_rates, window,
                                 "Success Rate", "Episode", "Success Rate", 
                                 is_rate=True)
        
        # Collision rate
        self._plot_smoothed_curve(axes[0, 2], self.collision_rates, window,
                                 "Collision Rate", "Episode", "Collision Rate", 
                                 is_rate=True)
        
        # Episode length
        self._plot_smoothed_curve(axes[1, 0], self.episode_lengths, window,
                                 "Episode Length", "Episode", "Steps")
        
        # Training loss
        self._plot_smoothed_curve(axes[1, 1], self.losses, window,
                                 "Training Loss", "Update", "Loss")
        
        # Belief accuracy (if available)
        if self.belief_accuracy:
            self._plot_smoothed_curve(axes[1, 2], self.belief_accuracy, window,
                                     "Belief Accuracy", "Episode", "Accuracy")
        else:
            axes[1, 2].set_title("Belief Accuracy (N/A)")
            axes[1, 2].set_xlabel("Episode")
            axes[1, 2].set_ylabel("Accuracy")
        
        # Adjust layout
        plt.tight_layout()
        plt.savefig(f"{self.log_dir}/learning_curves.png", dpi=300)
        plt.close()
        
        # Plot detailed training metrics
        self._plot_training_metrics(window)
    
    def _smooth_data(self, data, window):
        """Apply moving average smoothing to data."""
        if len(data) < window:
            return np.array(data)
        
        smoothed = np.convolve(data, np.ones(window)/window, mode='valid')
        return smoothed
    
    def _plot_smoothed_curve(self, ax, data, window, title, xlabel, ylabel, is_rate=False):
        """Helper to plot smoothed curve with confidence interval."""
        if not data:
            return
            
        x = np.arange(len(data))
        y = np.array(data)
        
        # Plot raw data with low alpha
        ax.plot(x, y, alpha=0.2, color='blue', label='Raw')
        
        # Apply smoothing for moving average if sufficient data
        if len(x) > window:
            smoothed_y = self._smooth_data(y, window)
            smoothed_x = np.arange(window-1, len(x))
            
            # Plot moving average
            ax.plot(smoothed_x, smoothed_y, alpha=1.0, color='blue', linewidth=2, label=f'Moving Avg (window={window})')
            
            # Compute rolling std for confidence interval
            rolling_std = [np.std(y[max(0, i-window):i+1]) for i in range(window-1, len(y))]
            rolling_std = np.array(rolling_std)
            
            # Plot confidence intervals
            ax.fill_between(smoothed_x, 
                           np.maximum(0, smoothed_y - rolling_std), 
                           np.minimum(1 if is_rate else float('inf'), smoothed_y + rolling_std), 
                           alpha=0.2, color='blue')
        
        # Add title and labels
        ax.set_title(title, fontsize=14)
        ax.set_xlabel(xlabel, fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        
        # Set y-axis limits for rate plots
        if is_rate:
            ax.set_ylim([-0.05, 1.05])
        
        # Add legend if we have both raw and smoothed data
        if len(x) > window:
            ax.legend(loc='best')
    
    def _plot_training_metrics(self, window=15):
        """Plot additional training metrics."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Policy loss
        self._plot_smoothed_curve(axes[0, 0], self.policy_losses, window,
                                 "Policy Loss", "Update", "Loss")
        
        # Value loss
        self._plot_smoothed_curve(axes[0, 1], self.value_losses, window,
                                 "Value Loss", "Update", "Loss")
        
        # Entropy loss
        self._plot_smoothed_curve(axes[1, 0], self.entropy_losses, window,
                                 "Entropy Loss", "Update", "Loss")
        
        # KL divergence
        self._plot_smoothed_curve(axes[1, 1], self.kl_divs, window,
                                 "Approx. KL Divergence", "Update", "KL")
        
        plt.tight_layout()
        plt.savefig(f"{self.log_dir}/training_metrics.png", dpi=300)
        plt.close()
    
    def plot_comparison(self, results_dict):
        """Plot comparison of different methods."""
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Process data for plotting
        methods = list(results_dict.keys())
        success_rates = [results_dict[m]['success_rate'] for m in methods]
        episode_rewards = [results_dict[m]['episode_reward'] for m in methods]
        completion_times = [results_dict[m]['completion_time'] for m in methods]
        
        # Plot success rates
        axes[0].bar(methods, success_rates, color='skyblue')
        axes[0].set_title('Success Rate by Method', fontsize=14)
        axes[0].set_ylabel('Success Rate', fontsize=12)
        axes[0].set_ylim([0, 1.0])
        
        # Plot episode rewards
        axes[1].bar(methods, episode_rewards, color='lightgreen')
        axes[1].set_title('Average Episode Reward by Method', fontsize=14)
        axes[1].set_ylabel('Reward', fontsize=12)
        
        # Plot completion times
        axes[2].bar(methods, completion_times, color='salmon')
        axes[2].set_title('Average Completion Time by Method', fontsize=14)
        axes[2].set_ylabel('Steps', fontsize=12)
        
        # Add values on top of bars
        for ax, data in zip(axes, [success_rates, episode_rewards, completion_times]):
            for i, v in enumerate(data):
                ax.text(i, v, f'{v:.2f}', ha='center', va='bottom', fontsize=10)
        
        plt.tight_layout()
        plt.savefig(f"{self.log_dir}/method_comparison.png", dpi=300)
        plt.close()
    
    def save_video(self, filename="episode", fps=30):
        """Save frames as a video."""
        if not self.frames:
            return
            
        # Create writer
        frames = [frame for frame in self.frames]
        
        # Check if frames are valid
        if not frames or frames[0] is None:
            print("No valid frames to save.")
            return
            
        height, width, _ = frames[0].shape
        
        # Create video writer
        video_path = f"{self.log_dir}/{filename}.mp4"
        
        # Save using matplotlib animation
        fig = plt.figure(figsize=(width/100, height/100), dpi=100)
        ax = fig.add_subplot(111)
        ax.set_axis_off()
        
        images = [[ax.imshow(frame)] for frame in frames]
        anim = animation.ArtistAnimation(fig, images, interval=1000/fps, blit=True)
        
        try:
            # Use Pillow writer
            anim.save(video_path, writer='pillow', fps=fps)
            print(f"Video saved to {video_path}")
        except Exception as e:
            print(f"Error saving video: {e}")
            # Try to save individual frames instead
            try:
                os.makedirs(f"{self.log_dir}/{filename}_frames", exist_ok=True)
                for i, frame in enumerate(frames):
                    plt.imsave(f"{self.log_dir}/{filename}_frames/frame_{i:04d}.png", frame)
                print(f"Saved {len(frames)} individual frames to {self.log_dir}/{filename}_frames folder")
            except Exception as e2:
                print(f"Error saving frames: {e2}")
        
        plt.close()
    
    def clear_frames(self):
        """Clear saved frames to free memory."""
        self.frames = []


# Evaluation function
def evaluate(env, agent, num_episodes=20, max_steps=1000, visualize=False, render_mode=None):
    """Evaluate the agent's performance."""
    eval_rewards = []
    eval_lengths = []
    eval_successes = []
    eval_collisions = []
    eval_belief_accuracy = []
    
    tracker = PerformanceTracker(log_dir="results/eval")
    
    for episode in range(num_episodes):
        state, _ = env.reset()
        agent.reset()
        
        episode_reward = 0
        episode_length = 0
        episode_beliefs = []
        
        # Extract goals from the environment
        goals = []
        for i in range(NUM_GOALS):
            goal_idx = state.shape[0] - 2 * NUM_GOALS + 2 * i
            goals.append(state[goal_idx:goal_idx + 2])
        
        # Episode loop
        for step in range(max_steps):
            # Record video if requested
            if visualize:
                frame = env.render()
                if frame is not None:
                    tracker.add_frame(frame)
            
            # Select action (deterministic for evaluation)
            action, belief = agent.select_action(state, goals, deterministic=True)
            
            # Safety check
            if np.isnan(action).any():
                action = np.zeros_like(action)
            
            # Take action
            next_state, reward, done, truncated, info = env.step(action)
            
            # Update state and counters
            state = next_state
            episode_reward += reward
            episode_length += 1
            
            # Track belief accuracy if using beliefs
            if belief is not None:
                true_goal_idx = env.true_goal_idx
                predicted_goal_idx = np.argmax(belief)
                episode_beliefs.append(predicted_goal_idx == true_goal_idx)
            
            if done or truncated:
                break
        
        # Calculate episode metrics
        success = info.get('goal_reached', False)
        collision = info.get('collision', False)
        belief_accuracy = np.mean(episode_beliefs) if episode_beliefs else None
        
        # Store metrics
        eval_rewards.append(episode_reward)
        eval_lengths.append(episode_length)
        eval_successes.append(1 if success else 0)
        eval_collisions.append(1 if collision else 0)
        if belief_accuracy is not None:
            eval_belief_accuracy.append(belief_accuracy)
    
    # Save video of the evaluation
    if visualize and tracker.frames:
        tracker.save_video("evaluation")
        tracker.clear_frames()
    
    # Calculate final metrics
    success_rate = np.mean(eval_successes)
    successful_episode_lengths = [l for l, s in zip(eval_lengths, eval_successes) if s]
    
    results = {
        'episode_reward': np.mean(eval_rewards),
        'episode_length': np.mean(eval_lengths),
        'success_rate': success_rate,
        'collision_rate': np.mean(eval_collisions),
        'completion_time': np.mean(successful_episode_lengths) if successful_episode_lengths else float('inf'),
    }
    
    if eval_belief_accuracy:
        results['belief_accuracy'] = np.mean(eval_belief_accuracy)
    
    return results


# Human simulation for assistance evaluation
class HumanSimulator:
    def __init__(self, noise_level=0.2, random_action_prob=0.05):
        self.noise_level = noise_level  # Noise in human control
        self.random_action_prob = random_action_prob  # Probability of random action
        
    def get_action(self, state, true_goal):
        """Generate realistic simulated human action toward the goal."""
                # Extract end effector position# Get end effector position - corrected to account for goal information
        ee_pos = state[-2-2*self.num_goals:-2*self.num_goals]  # Position right before goal info
        
        # Calculate direction to goal
        goal_dir = true_goal - ee_pos
        goal_dist = np.linalg.norm(goal_dir)
        
        # Random action with small probability
        if np.random.random() < self.random_action_prob:
            return np.random.uniform(-1, 1, size=2)
        
        # Normalize direction if non-zero
        if goal_dist > 0:
            optimal_dir = goal_dir / goal_dist
        else:
            optimal_dir = np.zeros(2)
        
        # Scale action magnitude based on distance
        optimal_magnitude = min(1.0, goal_dist * 2.0)
        optimal_action = optimal_dir * optimal_magnitude
        
        # Add noise to simulate human variability
        noise = np.random.normal(0, self.noise_level, size=2)
        
        # Add some tremor when close to goal (simulating precision difficulty)
        if goal_dist < 0.1:
            tremor = np.random.normal(0, 0.1 + 0.3 * (1 - goal_dist/0.1), size=2)
            noise += tremor
        
        # Final action with noise
        action = optimal_action + noise
        
        # Clip to action space
        return np.clip(action, -1, 1)


# Assistance blending methods
class AssistanceMethod:
    def __init__(self, agent, human):
        self.agent = agent
        self.human = human
        self.device = agent.device
    
    def reset(self):
        self.agent.reset()
    
    def select_action(self, state, goals, true_goal, deterministic=False):
        # This method should be overridden by subclasses
        raise NotImplementedError


class NoAssistance(AssistanceMethod):
    def select_action(self, state, goals, true_goal, deterministic=False):
        # Just return human action, ignore agent
        human_action = self.human.get_action(state, true_goal)
        
        # Get belief for tracking purposes only
        _, belief = self.agent.select_action(state, goals, deterministic)
        
        return human_action, belief


class FixedBlending(AssistanceMethod):
    def __init__(self, agent, human, blend_ratio=0.5):
        super().__init__(agent, human)
        self.blend_ratio = blend_ratio  # Fixed blend ratio
    
    def select_action(self, state, goals, true_goal, deterministic=False):
        # Get agent action
        agent_action, belief = self.agent.select_action(state, goals, deterministic)
        
        # Get human action
        human_action = self.human.get_action(state, true_goal)
        
        # Blend with fixed ratio
        gamma = self.blend_ratio
        action = (1 - gamma) * human_action + gamma * agent_action
        
        # Safety check
        if np.isnan(action).any():
            action = human_action
        
        return action, belief


class AdaptiveBlending(AssistanceMethod):
    def __init__(self, agent, human, min_gamma=0.1, max_gamma=0.8):
        super().__init__(agent, human)
        self.min_gamma = min_gamma
        self.max_gamma = max_gamma
        
        # Maintain a history of actions and goal distances
        self.action_history = []
        self.max_history = 10
        
        # Performance metrics for adaptation
        self.success_counter = 0
        self.failure_counter = 0
    
    def select_action(self, state, goals, true_goal, deterministic=False):
        # Get agent action
        agent_action, belief = self.agent.select_action(state, goals, deterministic)
        
        # Get human action
        human_action = self.human.get_action(state, true_goal)
        
        # Extract end effector position
                # Get end effector position - corrected to account for goal information
        ee_pos = state[-2-2*self.num_goals:-2*self.num_goals]  # Position right before goal info
                
        # Calculate distance to true goal
        dist_to_goal = np.linalg.norm(ee_pos - true_goal)
        
        # Store history for action consistency
        self.action_history.append((human_action, agent_action, dist_to_goal))
        if len(self.action_history) > self.max_history:
            self.action_history.pop(0)
        
        # Calculate adaptive blending ratio based on multiple factors
        if belief is not None:
            # 1. Confidence-based component
            max_belief_idx = np.argmax(belief)
            confidence = belief[max_belief_idx]
            confidence_factor = confidence
            
            # 2. Correct goal component (if agent believes in correct goal)
            goal_correctness = 1.0 if max_belief_idx == np.where(goals == true_goal)[0][0] else 0.3
            
            # 3. Distance-based component (more assistance when closer to goal)
            proximity_factor = np.clip(0.3 / (dist_to_goal + 0.3), 0, 1.0)
            
            # 4. Performance-based component
            success_rate = self.success_counter / max(1, self.success_counter + self.failure_counter)
            performance_factor = np.clip(success_rate, 0.3, 1.0)
            
            # 5. Action consistency component (prevent sudden changes)
            consistency_factor = 1.0
            if len(self.action_history) > 2:
                # Check if human actions are consistent
                recent_human = np.array([h for h, _, _ in self.action_history[-3:]])
                human_std = np.std(recent_human, axis=0).mean()
                if human_std < 0.2:  # Human is consistent, reduce AI intervention
                    consistency_factor = 0.7
            
            # Combine all factors
            gamma_factors = [
                confidence_factor * 0.3,
                goal_correctness * 0.2,
                proximity_factor * 0.2,
                performance_factor * 0.2,
                consistency_factor * 0.1
            ]
            
            gamma_raw = sum(gamma_factors)
            
            # Scale to desired range
            gamma = self.min_gamma + (self.max_gamma - self.min_gamma) * gamma_raw
            
            # Special case: detect if stuck or making no progress
            if len(self.action_history) > 5:
                recent_distances = [d for _, _, d in self.action_history[-5:]]
                if all(abs(d - recent_distances[0]) < 0.01 for d in recent_distances):
                    # If stuck, alternate between more human and more AI
                    if np.random.random() < 0.7:
                        gamma = max(0.05, gamma * 0.5)  # Let human try more
                    else:
                        gamma = min(0.95, gamma * 1.5)  # Let AI try more
        else:
            # Default to minimum assistance if no belief
            gamma = self.min_gamma
        
        # Safety checks before blending
        agent_magnitude = np.linalg.norm(agent_action)
        human_magnitude = np.linalg.norm(human_action)
        
        # If agent action is much larger than human, reduce its influence
        if agent_magnitude > 2.0 * human_magnitude and human_magnitude > 0.1:
            gamma *= 0.7
        
        # If agent and human actions point in opposite directions, reduce AI influence
        dot_product = np.dot(agent_action, human_action)
        agent_norm = np.linalg.norm(agent_action)
        human_norm = np.linalg.norm(human_action)
        if agent_norm > 0.1 and human_norm > 0.1:
            cosine = dot_product / (agent_norm * human_norm)
            if cosine < -0.5:  # Actions disagree significantly
                gamma *= 0.5
        
        # Blend with adaptive ratio
        action = (1 - gamma) * human_action + gamma * agent_action
        
        # Safety check
        if np.isnan(action).any():
            action = human_action
        
        return action, belief
    
    def reset(self):
        super().reset()
        self.action_history = []
        
    def record_outcome(self, success):
        """Record outcome for adaptation."""
        if success:
            self.success_counter += 1
        else:
            self.failure_counter += 1


class PhasedTraining:
    def __init__(self, env, state_dim, action_dim, log_dir="results"):
        self.env = env
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.log_dir = log_dir
        
        # Create separate directories for phase logs
        self.phase1_dir = f"{log_dir}/phase1"
        self.phase2_dir = f"{log_dir}/phase2"
        self.phase3_dir = f"{log_dir}/phase3"
        self.comparison_dir = f"{log_dir}/comparison"
        
        for dir_path in [self.phase1_dir, self.phase2_dir, self.phase3_dir, self.comparison_dir]:
            os.makedirs(dir_path, exist_ok=True)
    
    def train_phase1(self, num_episodes=2000):
        """Phase 1: Train pure PPO without beliefs to establish a baseline."""
        print("\n=== Phase 1: Training Pure PPO ===")
        
        # Create pure PPO agent without belief module
        agent = PPOAgent(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            hidden_dim=256,
            num_goals=NUM_GOALS,
            lr=LR,
            gamma=GAMMA,
            gae_lambda=GAE_LAMBDA,
            clip_epsilon=CLIP_EPSILON,
            value_coef=VALUE_COEF,
            entropy_coef=ENTROPY_COEF,
            target_kl=TARGET_KL,
            use_belief=False
        )
        
        # Train the pure PPO agent - ensure it learns the basic task first
        tracker = self._train_agent(
            agent=agent,
            num_episodes=num_episodes,
            early_stop_target=0.75,  # Target 75% success rate
            log_dir=self.phase1_dir
        )
        
        # Final evaluation to verify performance
        print("Performing detailed evaluation of Phase 1 agent...")
        eval_results = evaluate(self.env, agent, num_episodes=100, max_steps=500)
        print(f"Pure PPO Success Rate: {eval_results['success_rate']:.4f}")
        
        # Save the final model
        agent.save(f"{self.phase1_dir}/final_model.pt")
        
        return agent
    
    def train_phase2(self, base_agent, num_episodes=2000):
        """Phase 2: Add belief module jointly trained with PPO."""
        print("\n=== Phase 2: Training with Belief Module ===")
        
        # Create agent with belief module (transfer learning from base agent)
        belief_agent = PPOAgent(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            hidden_dim=256,
            num_goals=NUM_GOALS,
            lr=LR * 0.5,  # Lower learning rate for fine-tuning
            gamma=GAMMA,
            gae_lambda=GAE_LAMBDA,
            clip_epsilon=CLIP_EPSILON * 0.8,  # Tighter clips for fine-tuning
            value_coef=VALUE_COEF,
            entropy_coef=ENTROPY_COEF * 0.7,  # Lower entropy to exploit learned policy
            target_kl=TARGET_KL,
            use_belief=True  # Now with belief module
        )
        
        # Copy parameters from base agent to new agent
        belief_agent.actor_critic.feature_network.load_state_dict(
            base_agent.actor_critic.feature_network.state_dict()
        )
        belief_agent.actor_critic.actor_mean.load_state_dict(
            base_agent.actor_critic.actor_mean.state_dict()
        )
        belief_agent.actor_critic.actor_log_std.data = base_agent.actor_critic.actor_log_std.data.clone()
        belief_agent.actor_critic.critic.load_state_dict(
            base_agent.actor_critic.critic.state_dict()
        )
        
        # Phase 2A: Train belief module first with a higher learning rate
        print("Phase 2A: Training belief module...")
        for param in belief_agent.actor_critic.parameters():
            param.requires_grad = False  # Freeze actor-critic
            
        # Create a separate optimizer for the belief module
        belief_optimizer = optim.Adam(belief_agent.belief_module.parameters(), lr=LR*2)
        
        # Train for a few hundred episodes with only belief module learning
        belief_pretraining_episodes = 300
        
        # Custom training loop for belief module only
        tracker = self._train_agent(
            agent=belief_agent,
            num_episodes=belief_pretraining_episodes,
            early_stop_target=0.5,  # Lower goal for this phase
            log_dir=f"{self.phase2_dir}/belief_only",
            custom_optimizer=belief_optimizer
        )
        
        # Phase 2B: Unfreeze actor-critic and train jointly
        print("Phase 2B: Joint training of belief and policy...")
        for param in belief_agent.actor_critic.parameters():
            param.requires_grad = True  # Unfreeze actor-critic
            
        # Reset the regular optimizer
        belief_agent.optimizer = optim.Adam(
            list(belief_agent.actor_critic.parameters()) + list(belief_agent.belief_module.parameters()),
            lr=LR * 0.5
        )
        
        # Continue training with everything unfrozen
        tracker = self._train_agent(
            agent=belief_agent,
            num_episodes=num_episodes - belief_pretraining_episodes,
            early_stop_target=0.65,  # Target for full agent
            log_dir=self.phase2_dir
        )
        
        # Final detailed evaluation
        print("Performing detailed evaluation of Phase 2 agent...")
        eval_results = evaluate(self.env, belief_agent, num_episodes=100, max_steps=500)
        print(f"Belief-Enabled Agent Success Rate: {eval_results['success_rate']:.4f}")
        print(f"Belief Accuracy: {eval_results.get('belief_accuracy', 'N/A')}")
        
        # Save the final model
        belief_agent.save(f"{self.phase2_dir}/final_model.pt")
        
        return belief_agent
    
    def evaluate_assistance_methods(self, agent, num_episodes=50):
        """Compare different assistance methods using the trained agent."""
        print("\n=== Evaluating Assistance Methods ===")
        
        # Create human simulator
        human = HumanSimulator(noise_level=0.2)
        
        # Create evaluation environment
        eval_env = MultiGoalReacherEnv(num_goals=NUM_GOALS)
        
        # Define assistance methods to evaluate
        methods = {
            'No Assistance': NoAssistance(agent, human),
            'Fixed 50-50': FixedBlending(agent, human, blend_ratio=0.5),
            'Adaptive': AdaptiveBlending(agent, human, min_gamma=0.1, max_gamma=0.8)
        }
        
        # Evaluate each method
        results = {}
        for name, method in methods.items():
            print(f"\nEvaluating: {name}")
            method_results = self._evaluate_assistance(
                eval_env, method, num_episodes=num_episodes
            )
            results[name] = method_results
            
            # Print key metrics
            print(f"  Success rate: {method_results['success_rate']:.4f}")
            print(f"  Episode reward: {method_results['episode_reward']:.4f}")
            print(f"  Completion time: {method_results['completion_time']:.4f}")
            if 'belief_accuracy' in method_results:
                print(f"  Belief accuracy: {method_results['belief_accuracy']:.4f}")
        
        # Plot comparison
        tracker = PerformanceTracker(log_dir=self.comparison_dir)
        tracker.plot_comparison(results)
        
        return results
    
    def _train_agent(self, agent, num_episodes, early_stop_target, log_dir, custom_optimizer=None):
        """Train an agent with early stopping based on success rate."""
        tracker = PerformanceTracker(log_dir=log_dir)
        
        # Define parameters
        max_steps = 500
        eval_interval = EVAL_INTERVAL
        patience = EARLY_STOP_PATIENCE
        
        best_success_rate = 0
        best_reward = -float('inf')
        episodes_without_improvement = 0
        
        # For tracking progress
        total_steps = 0
        
        # Minimum episodes to train before considering early stopping
        min_episodes = MIN_TRAINING_EPISODES  # Using constant from parameters
        # Minimum number of evaluation episodes to trust the results
        min_eval_episodes = 50  # Increased for more reliable evaluation
        # Number of consecutive evaluations above target to consider it reached
        consecutive_target_reached = 0
        required_consecutive = 3  # Require meeting target 3 times in a row
        
        # Use custom optimizer if provided (for specialized training phases)
        if custom_optimizer is not None:
            optimizer = custom_optimizer
        else:
            optimizer = agent.optimizer
        
        # Training loop
        for episode in tqdm(range(1, num_episodes + 1), desc="Training"):
            state, _ = self.env.reset()
            agent.reset()
            
            episode_reward = 0
            episode_length = 0
            episode_beliefs = []
            
            # Extract goals from the environment
            goals = []
            for i in range(NUM_GOALS):
                goal_idx = state.shape[0] - 2 * NUM_GOALS + 2 * i
                goals.append(state[goal_idx:goal_idx + 2])
            
            # Episode loop
            for step in range(max_steps):
                # Select action
                action, belief = agent.select_action(state, goals)
                
                # Check for NaN values in action
                if np.isnan(action).any():
                    action = np.zeros_like(action)
                
                # Get action log probability and value
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
                    belief_tensor = None
                    if belief is not None:
                        belief_tensor = torch.FloatTensor(belief).unsqueeze(0).to(agent.device)
                    action_tensor = torch.FloatTensor(action).unsqueeze(0).to(agent.device)
                    
                    log_prob, _, value = agent.actor_critic.evaluate_actions(
                        state_tensor, action_tensor, belief_tensor
                    )
                    log_prob = log_prob.squeeze().cpu().item()
                    value = value.squeeze().cpu().item()
                
                # Take action in environment
                next_state, reward, done, truncated, info = self.env.step(action)
                
                # Store experience for learning
                agent.remember(state, action, reward, value, log_prob, belief, done)
                
                # Update state and counters
                state = next_state
                episode_reward += reward
                episode_length += 1
                total_steps += 1
                
                # Store belief accuracy
                if belief is not None:
                    true_goal_idx = self.env.true_goal_idx
                    predicted_goal_idx = np.argmax(belief)
                    episode_beliefs.append(predicted_goal_idx == true_goal_idx)
                
                # Update policy if enough steps have been taken
                if agent.steps_since_update >= STEPS_PER_UPDATE:
                    # If using custom optimizer
                    if custom_optimizer is not None:
                        # Manual update for specialized training (e.g., belief module only)
                        with torch.no_grad():
                            if len(agent.states) > 0:
                                state = torch.FloatTensor(agent.states[-1]).unsqueeze(0).to(agent.device)
                                belief = None
                                if agent.use_belief and len(agent.beliefs) > 0:
                                    belief = torch.FloatTensor(agent.beliefs[-1]).unsqueeze(0).to(agent.device)
                                _, _, next_value = agent.actor_critic(state, belief)
                                next_value = next_value.squeeze().cpu().item()
                            else:
                                next_value = 0
                        
                        # Compute returns
                        returns = agent.compute_gae(next_value)
                        
                        # Convert lists to tensors
                        states = torch.FloatTensor(np.array(agent.states)).to(agent.device)
                        actions = torch.FloatTensor(np.array(agent.actions)).to(agent.device)
                        returns = torch.FloatTensor(returns).unsqueeze(1).to(agent.device)
                        
                        # Prepare beliefs if used
                        beliefs = None
                        if agent.use_belief and len(agent.beliefs) > 0:
                            beliefs = torch.FloatTensor(np.array(agent.beliefs)).to(agent.device)
                        
                        # Custom update for belief module
                        if agent.use_belief and agent.belief_module:
                            # Batch size for updates
                            batch_size = min(BATCH_SIZE, len(agent.states))
                            indices = np.arange(len(agent.states))
                            np.random.shuffle(indices)
                            
                            # Mini-batch training for belief module
                            total_loss = 0
                            policy_loss = 0
                            belief_loss = 0
                            
                            for idx in range(0, len(indices), batch_size):
                                batch_indices = indices[idx:idx + batch_size]
                                batch_states = states[batch_indices]
                                batch_actions = actions[batch_indices]
                                
                                if beliefs is not None:
                                    batch_beliefs = beliefs[batch_indices]
                                    
                                    # Extract end effector positions and goals
                                    ee_positions = batch_states[:, -4:-2]
                                    goal_indices = []
                                    for i in range(NUM_GOALS):
                                        goal_idx = batch_states.shape[1] - 2 * NUM_GOALS + 2 * i
                                        goal_indices.append(goal_idx)
                                    
                                    # Collect belief module losses
                                    belief_loss_total = 0
                                    
                                    # Update belief module here
                                    custom_optimizer.zero_grad()
                                    
                                    # Special belief training code
                                    batch_size = batch_states.shape[0]
                                    for i in range(batch_size):
                                        state = ee_positions[i].view(-1)
                                        action = batch_actions[i].view(-1)
                                        
                                        # Extract goals for this sample
                                        goals = []
                                        for g_idx in range(NUM_GOALS):
                                            goal_start = goal_indices[g_idx]
                                            goals.append(batch_states[i, goal_start:goal_start+2])
                                        
                                        # Assuming the belief network always outputs higher probability for correct goal
                                        # This is a simple auxiliary loss that may not be optimal
                                        if random.random() < 0.5:  # Only use half the samples for efficiency
                                            updated_belief = agent.belief_module(state, action, goals, batch_beliefs[i])
                                            
                                            # Create target distribution (one-hot for true goal)
                                            # In a real scenario, this would come from human feedback
                                            # For now, we'll use a heuristic based on rewards or progress
                                            target_belief = torch.zeros_like(updated_belief)
                                            
                                            # Find goal closest to the action direction
                                            action_dir = action / (torch.norm(action) + 1e-6)
                                            goal_costs = []
                                            
                                            for goal in goals:
                                                goal_dir = goal - state
                                                goal_dist = torch.norm(goal_dir) + 1e-6
                                                goal_dir = goal_dir / goal_dist
                                                alignment = torch.dot(action_dir, goal_dir)
                                                goal_costs.append(-alignment)  # Negative cost for better alignment
                                            
                                            # Select most likely goal based on action alignment
                                            best_goal_idx = torch.argmin(torch.tensor(goal_costs)).item()
                                            target_belief[best_goal_idx] = 1.0
                                            
                                            # Cross entropy loss
                                            loss = F.kl_div(updated_belief.log(), target_belief, reduction='batchmean')
                                            belief_loss_total += loss
                                    
                                    if belief_loss_total > 0:
                                        belief_loss_total.backward()
                                        torch.nn.utils.clip_grad_norm_(agent.belief_module.parameters(), 1.0)
                                        custom_optimizer.step()
                                        
                                        belief_loss += belief_loss_total.item()
                                    
                            # Report metrics
                            loss_info = {
                                'total_loss': belief_loss,
                                'belief_loss': belief_loss,
                                'policy_loss': 0,
                                'value_loss': 0,
                                'entropy_loss': 0,
                                'approx_kl': 0
                            }
                            
                            tracker.add_training_metrics(loss_info)
                        
                        # Clear memory
                        agent.clear_memory()
                    else:
                        # Normal PPO update
                        loss_info = agent.update()
                        tracker.add_training_metrics(loss_info)
                
                # Check if episode is done
                if done or truncated:
                    break
            
            # Calculate episode metrics
            success = info.get('goal_reached', False)
            collision = info.get('collision', False)
            belief_accuracy = np.mean(episode_beliefs) if episode_beliefs else None
            
            # Add episode metrics to tracker
            tracker.add_episode_metrics(
                episode_reward, episode_length, success, collision, belief_accuracy
            )
            
            # Log progress
            if episode % 50 == 0:
                metrics = tracker.get_recent_metrics(window=min(50, episode))
                print(f"Episode {episode}, Steps: {total_steps}, Reward: {metrics['reward']:.2f}, "
                      f"Success: {metrics['success_rate']:.2f}, "
                      f"Length: {metrics['episode_length']:.2f}")
                
                # Plot learning curves
                tracker.plot_learning_curves()
            
            # Evaluate agent periodically - but only after minimum training
            if episode % eval_interval == 0 and episode >= min_episodes:
                eval_results = evaluate(self.env, agent, num_episodes=min_eval_episodes, max_steps=max_steps)
                
                # Check if target is reached consistently
                if eval_results['success_rate'] >= early_stop_target:
                    consecutive_target_reached += 1
                    print(f"Target success rate reached ({consecutive_target_reached}/{required_consecutive})")
                else:
                    consecutive_target_reached = 0
                
                # Check for improvement
                improved = False
                if eval_results['success_rate'] > best_success_rate + IMPROVEMENT_THRESHOLD:
                    best_success_rate = eval_results['success_rate']
                    agent.save(f"{log_dir}/best_model_success.pt")
                    print(f"New best success rate: {best_success_rate:.4f}")
                    improved = True
                
                if eval_results['episode_reward'] > best_reward + IMPROVEMENT_THRESHOLD * 100:
                    best_reward = eval_results['episode_reward']
                    if not improved:  # Only save if not already saved for success rate
                        agent.save(f"{log_dir}/best_model_reward.pt")
                        print(f"New best reward: {best_reward:.4f}")
                    improved = True
                
                if not improved:
                    episodes_without_improvement += 1
                    print(f"No improvement for {episodes_without_improvement} evaluations")
                else:
                    episodes_without_improvement = 0
                
                # Early stopping based on patience
                if episodes_without_improvement >= patience:
                    print(f"Early stopping after {episode} episodes due to no improvement")
                    break
                
                # Early stopping based on consistently reaching target
                if consecutive_target_reached >= required_consecutive:
                    print(f"Target success rate {early_stop_target} reached consistently!")
                    break
        
        # Final evaluation with more episodes for reliable results
        final_eval = evaluate(self.env, agent, num_episodes=50, max_steps=max_steps)
        
        # Save final model
        agent.save(f"{log_dir}/final_model.pt")
        
        # Print final evaluation results
        print("\nFinal Evaluation Results:")
        for key, value in final_eval.items():
            print(f"{key}: {value:.4f}")
        
        return tracker
    
    def _evaluate_assistance(self, env, method, num_episodes=50, max_steps=500):
        """Evaluate an assistance method."""
        rewards = []
        lengths = []
        successes = []
        collisions = []
        belief_accuracies = []
        
        for episode in range(num_episodes):
            state, _ = env.reset()
            method.reset()
            
            episode_reward = 0
            episode_length = 0
            episode_beliefs = []
            
            # Extract goals
            goals = []
            for i in range(NUM_GOALS):
                goal_idx = state.shape[0] - 2 * NUM_GOALS + 2 * i
                goals.append(state[goal_idx:goal_idx + 2])
            
            # Episode loop
            for step in range(max_steps):
                # Get assisted action
                action, belief = method.select_action(
                    state, goals, env.true_goal, deterministic=True
                )
                
                # Take action
                next_state, reward, done, truncated, info = env.step(action)
                
                # Update state and counters
                state = next_state
                episode_reward += reward
                episode_length += 1
                
                # Track belief accuracy
                if belief is not None:
                    true_goal_idx = env.true_goal_idx
                    predicted_goal_idx = np.argmax(belief)
                    episode_beliefs.append(predicted_goal_idx == true_goal_idx)
                
                if done or truncated:
                    break
            
            # Calculate episode metrics
            success = info.get('goal_reached', False)
            collision = info.get('collision', False)
            
            # Store metrics
            rewards.append(episode_reward)
            lengths.append(episode_length)
            successes.append(1 if success else 0)
            collisions.append(1 if collision else 0)
            
            if episode_beliefs:
                belief_accuracies.append(np.mean(episode_beliefs))
        
        # Calculate aggregate metrics
        success_rate = np.mean(successes)
        successful_episode_lengths = [l for l, s in zip(lengths, successes) if s]
        
        results = {
            'episode_reward': np.mean(rewards),
            'episode_length': np.mean(lengths),
            'success_rate': success_rate,
            'collision_rate': np.mean(collisions),
            'completion_time': np.mean(successful_episode_lengths) if successful_episode_lengths else float('inf')
        }
        
        if belief_accuracies:
            results['belief_accuracy'] = np.mean(belief_accuracies)
        
        return results


def main():
    # Check if visualization is possible
    render_mode = None
    visualization_enabled = VISUALIZE
    
    if visualization_enabled:
        try:
            import pygame
            pygame.init()
            render_mode = "rgb_array"
        except ImportError:
            print("Warning: Pygame not found. Visualization will be disabled.")
            visualization_enabled = False
    
    # Create environment with reproducible seeds
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    random.seed(SEED)
    
    env = MultiGoalReacherEnv(num_goals=NUM_GOALS, render_mode=render_mode)
    
    # Get dimensions
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    # Run phased training
    trainer = PhasedTraining(env, state_dim, action_dim)
    
    # Phase 1: Train pure PPO without beliefs - longer training
    print("\n" + "="*50)
    print("PHASE 1: Establishing solid PPO baseline")
    print("="*50)
    ppo_agent = trainer.train_phase1(num_episodes=2000)
    
    # Verify baseline performance before proceeding
    eval_results = evaluate(env, ppo_agent, num_episodes=100)
    if eval_results['success_rate'] < 0.6:
        print("\nWARNING: Pure PPO agent performance is below target.")
        print("Consider running Phase 1 again with different hyperparameters.")
        print("Proceeding anyway, but results may be suboptimal.")
    else:
        print(f"\nPure PPO baseline established successfully!")
        print(f"Success rate: {eval_results['success_rate']:.2f}")
    
    # Phase 2: Add belief module jointly trained with PPO
    print("\n" + "="*50)
    print("PHASE 2: Training with belief module")
    print("="*50)
    belief_agent = trainer.train_phase2(ppo_agent, num_episodes=2000)
    
    # Evaluate assistance methods with more episodes for reliable results
    print("\n" + "="*50)
    print("FINAL PHASE: Evaluating assistance methods")
    print("="*50)
    results = trainer.evaluate_assistance_methods(belief_agent, num_episodes=100)
    
    # Print summary of results
    print("\n" + "="*50)
    print("SUMMARY OF RESULTS")
    print("="*50)
    
    # Compare assistance methods
    methods = sorted(results.keys(), key=lambda x: results[x]['success_rate'], reverse=True)
    print(f"Best method: {methods[0]} with {results[methods[0]]['success_rate']:.2f} success rate")
    
    # Print all methods ordered by success rate
    print("\nAll methods ranked by success rate:")
    for i, method in enumerate(methods, 1):
        print(f"{i}. {method}: {results[method]['success_rate']:.2f} success, " + 
              f"{results[method]['episode_reward']:.1f} reward, " +
              f"{results[method]['completion_time']:.1f} completion time")
    
    # Report belief accuracy if available
    if 'belief_accuracy' in results[methods[0]]:
        print(f"\nBelief accuracy: {results[methods[0]]['belief_accuracy']:.2f}")
    
    # Close environment
    env.close()
    
    print("\nTraining and evaluation completed!")


if __name__ == "__main__":
    main()

Using device: cuda

PHASE 1: Establishing solid PPO baseline

=== Phase 1: Training Pure PPO ===


Training:   2%|▏         | 49/2000 [00:07<04:43,  6.87it/s]

Episode 50, Steps: 3680, Reward: 73.53, Success: 0.68, Length: 73.60


Training:   5%|▍         | 99/2000 [00:20<15:18,  2.07it/s]

Episode 100, Steps: 8623, Reward: 80.50, Success: 0.68, Length: 98.86


Training:   7%|▋         | 148/2000 [00:31<04:26,  6.95it/s]

Episode 150, Steps: 12849, Reward: 73.29, Success: 0.72, Length: 84.52


Training:  10%|▉         | 198/2000 [00:42<04:24,  6.81it/s]

Episode 200, Steps: 17256, Reward: 76.88, Success: 0.64, Length: 88.14


Training:  10%|█         | 201/2000 [00:50<26:29,  1.13it/s]

New best success rate: 0.5000


Training:  11%|█         | 220/2000 [00:56<24:46,  1.20it/s]

No improvement for 1 evaluations


Training:  12%|█▏        | 240/2000 [01:02<21:01,  1.39it/s]

No improvement for 2 evaluations


Training:  12%|█▏        | 249/2000 [01:03<07:26,  3.92it/s]

Episode 250, Steps: 21484, Reward: 73.67, Success: 0.60, Length: 84.56


Training:  13%|█▎        | 261/2000 [01:12<18:05,  1.60it/s]

New best success rate: 0.6600


Training:  14%|█▍        | 280/2000 [01:19<19:52,  1.44it/s]

New best reward: 111.0942


Training:  15%|█▍        | 297/2000 [01:22<04:33,  6.22it/s]

Episode 300, Steps: 26020, Reward: 77.77, Success: 0.66, Length: 90.72


Training:  15%|█▌        | 300/2000 [01:28<31:00,  1.09s/it]

No improvement for 1 evaluations


Training:  16%|█▌        | 320/2000 [01:36<26:27,  1.06it/s]

New best reward: 129.1350


Training:  17%|█▋        | 341/2000 [01:45<21:34,  1.28it/s]

No improvement for 1 evaluations


Training:  17%|█▋        | 349/2000 [01:47<08:12,  3.35it/s]

Episode 350, Steps: 31561, Reward: 86.48, Success: 0.72, Length: 110.82


Training:  18%|█▊        | 360/2000 [01:56<29:08,  1.07s/it]

No improvement for 2 evaluations


Training:  19%|█▉        | 381/2000 [02:05<22:45,  1.19it/s]

New best reward: 134.9475


Training:  20%|█▉        | 399/2000 [02:08<05:47,  4.61it/s]

Episode 400, Steps: 36592, Reward: 83.51, Success: 0.70, Length: 100.62


Training:  20%|██        | 400/2000 [02:18<57:03,  2.14s/it]

New best reward: 138.3319


Training:  21%|██        | 420/2000 [02:29<31:10,  1.18s/it]

New best reward: 156.6859


Training:  22%|██▏       | 440/2000 [02:39<35:06,  1.35s/it]

No improvement for 1 evaluations


Training:  22%|██▏       | 448/2000 [02:44<21:32,  1.20it/s]

Episode 450, Steps: 43817, Reward: 98.04, Success: 0.52, Length: 144.50


Training:  23%|██▎       | 460/2000 [02:54<28:47,  1.12s/it]

No improvement for 2 evaluations


Training:  24%|██▍       | 480/2000 [03:04<37:40,  1.49s/it]

No improvement for 3 evaluations


Training:  25%|██▍       | 499/2000 [03:11<12:09,  2.06it/s]

Episode 500, Steps: 50627, Reward: 103.11, Success: 0.66, Length: 136.20


Training:  25%|██▌       | 500/2000 [03:21<1:24:54,  3.40s/it]

No improvement for 4 evaluations


Training:  26%|██▌       | 520/2000 [03:37<50:50,  2.06s/it]  

New best reward: 160.8415


Training:  27%|██▋       | 540/2000 [03:45<20:56,  1.16it/s]

No improvement for 1 evaluations


Training:  27%|██▋       | 548/2000 [03:46<06:34,  3.68it/s]

Episode 550, Steps: 58796, Reward: 108.14, Success: 0.48, Length: 163.38


Training:  28%|██▊       | 560/2000 [03:59<33:48,  1.41s/it]

No improvement for 2 evaluations


Training:  29%|██▉       | 580/2000 [04:06<23:26,  1.01it/s]

No improvement for 3 evaluations


Training:  30%|██▉       | 599/2000 [04:12<08:02,  2.90it/s]

Episode 600, Steps: 65116, Reward: 93.36, Success: 0.56, Length: 126.40


Training:  30%|███       | 601/2000 [04:21<40:25,  1.73s/it]

No improvement for 4 evaluations


Training:  31%|███       | 620/2000 [04:28<21:39,  1.06it/s]

No improvement for 5 evaluations


Training:  32%|███▏      | 640/2000 [04:38<19:56,  1.14it/s]

No improvement for 6 evaluations


Training:  32%|███▏      | 649/2000 [04:42<11:49,  1.90it/s]

Episode 650, Steps: 71864, Reward: 94.59, Success: 0.48, Length: 134.96


Training:  33%|███▎      | 660/2000 [04:57<37:29,  1.68s/it]

No improvement for 7 evaluations


Training:  34%|███▍      | 680/2000 [05:10<47:00,  2.14s/it]

New best reward: 173.8065


Training:  35%|███▍      | 698/2000 [05:17<06:37,  3.27it/s]

Episode 700, Steps: 79820, Reward: 105.31, Success: 0.52, Length: 159.12


Training:  35%|███▌      | 700/2000 [05:30<51:25,  2.37s/it]

New best reward: 177.3420


Training:  36%|███▌      | 720/2000 [05:43<44:31,  2.09s/it]

No improvement for 1 evaluations


Training:  37%|███▋      | 740/2000 [05:57<22:35,  1.08s/it]

No improvement for 2 evaluations


Training:  37%|███▋      | 749/2000 [06:00<07:55,  2.63it/s]

Episode 750, Steps: 89818, Reward: 127.98, Success: 0.46, Length: 199.96


Training:  38%|███▊      | 761/2000 [06:16<25:57,  1.26s/it]

No improvement for 3 evaluations


Training:  39%|███▉      | 780/2000 [06:26<19:12,  1.06it/s]

No improvement for 4 evaluations


Training:  40%|███▉      | 798/2000 [06:31<05:32,  3.61it/s]

Episode 800, Steps: 97210, Reward: 113.08, Success: 0.64, Length: 147.84


Training:  40%|████      | 800/2000 [06:43<46:26,  2.32s/it]

No improvement for 5 evaluations


Training:  41%|████      | 820/2000 [06:52<17:28,  1.12it/s]

No improvement for 6 evaluations


Training:  42%|████▏     | 840/2000 [07:07<38:51,  2.01s/it]

No improvement for 7 evaluations


Training:  42%|████▏     | 849/2000 [07:11<08:35,  2.23it/s]

Episode 850, Steps: 105648, Reward: 109.60, Success: 0.38, Length: 168.76


Training:  43%|████▎     | 860/2000 [07:28<30:14,  1.59s/it]

No improvement for 8 evaluations


Training:  44%|████▍     | 880/2000 [07:42<37:26,  2.01s/it]

No improvement for 9 evaluations


Training:  45%|████▍     | 899/2000 [07:51<09:37,  1.91it/s]

Episode 900, Steps: 115517, Reward: 128.14, Success: 0.44, Length: 197.38


Training:  45%|████▌     | 900/2000 [08:04<1:16:40,  4.18s/it]

No improvement for 10 evaluations


Training:  46%|████▌     | 920/2000 [08:18<35:14,  1.96s/it]  

No improvement for 11 evaluations


Training:  47%|████▋     | 940/2000 [08:29<19:56,  1.13s/it]

No improvement for 12 evaluations


Training:  47%|████▋     | 949/2000 [08:31<04:10,  4.19it/s]

Episode 950, Steps: 124939, Reward: 123.01, Success: 0.40, Length: 188.44


Training:  48%|████▊     | 960/2000 [08:48<22:03,  1.27s/it]

No improvement for 13 evaluations


Training:  49%|████▉     | 981/2000 [09:00<13:18,  1.28it/s]

No improvement for 14 evaluations


Training:  50%|████▉     | 997/2000 [09:05<06:42,  2.49it/s]

Episode 1000, Steps: 133700, Reward: 121.15, Success: 0.56, Length: 175.22


Training:  50%|█████     | 1000/2000 [09:17<31:09,  1.87s/it]

No improvement for 15 evaluations


Training:  51%|█████     | 1020/2000 [09:26<25:04,  1.54s/it]

No improvement for 16 evaluations


Training:  52%|█████▏    | 1040/2000 [09:39<25:02,  1.57s/it]

No improvement for 17 evaluations


Training:  52%|█████▏    | 1048/2000 [09:43<09:39,  1.64it/s]

Episode 1050, Steps: 141567, Reward: 100.89, Success: 0.44, Length: 157.34


Training:  53%|█████▎    | 1060/2000 [10:00<32:30,  2.07s/it]

No improvement for 18 evaluations


Training:  54%|█████▍    | 1080/2000 [10:12<16:20,  1.07s/it]

No improvement for 19 evaluations


Training:  55%|█████▍    | 1099/2000 [10:19<04:48,  3.12it/s]

Episode 1100, Steps: 149529, Reward: 100.32, Success: 0.48, Length: 159.24


Training:  55%|█████▌    | 1103/2000 [10:33<28:07,  1.88s/it]

No improvement for 20 evaluations


Training:  56%|█████▌    | 1120/2000 [10:39<09:31,  1.54it/s]

No improvement for 21 evaluations


Training:  57%|█████▋    | 1141/2000 [10:54<15:33,  1.09s/it]

No improvement for 22 evaluations


Training:  57%|█████▋    | 1148/2000 [10:56<05:38,  2.52it/s]

Episode 1150, Steps: 158459, Reward: 115.96, Success: 0.48, Length: 178.60


Training:  58%|█████▊    | 1160/2000 [11:14<19:28,  1.39s/it]

No improvement for 23 evaluations


Training:  59%|█████▉    | 1180/2000 [11:26<14:38,  1.07s/it]

No improvement for 24 evaluations


Training:  60%|█████▉    | 1199/2000 [11:29<02:28,  5.39it/s]

Episode 1200, Steps: 165530, Reward: 103.36, Success: 0.50, Length: 141.42


Training:  60%|█████▉    | 1199/2000 [11:44<07:50,  1.70it/s]

No improvement for 25 evaluations
Early stopping after 1200 episodes due to no improvement






Final Evaluation Results:
episode_reward: 155.0599
episode_length: 212.7400
success_rate: 0.3400
collision_rate: 0.3400
completion_time: 71.5294
Performing detailed evaluation of Phase 1 agent...
Pure PPO Success Rate: 0.3300

Consider running Phase 1 again with different hyperparameters.
Proceeding anyway, but results may be suboptimal.

PHASE 2: Training with belief module

=== Phase 2: Training with Belief Module ===
Phase 2A: Training belief module...


  state=torch.tensor(ee_pos, dtype=torch.float32, device=self.device),
  goals=[torch.tensor(g, dtype=torch.float32, device=self.device) for g in goals],
Training:   1%|          | 2/300 [00:01<03:52,  1.28it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x2 and 4x32)