In [1]:
import pygame
import numpy as np
import torch
import math
from collections import deque
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import os
from datetime import datetime

class SharedControlEnv:
    """Wrapper around the pygame environment for RL training"""
    
    def __init__(self, window_size=(1200, 800), render_mode=None):
        self.window_size = window_size
        self.render_mode = render_mode
        
        # Constants from original environment
        self.max_speed = 3
        self.dot_radius = 30
        self.target_radius = 10
        self.goal_detection_radius = self.dot_radius + self.target_radius
        
        # Initialize pygame if rendering
        if self.render_mode == 'human':
            try:
                pygame.init()
                self.screen = pygame.display.set_mode(window_size)
                self.clock = pygame.time.Clock()
            except pygame.error as e:
                print(f"Failed to initialize pygame: {e}")
                self.render_mode = None
        
        # State space components
        self.num_lidar_rays = 16  # Number of rays for simplified lidar
        self.state_history_len = 5  # Number of past states to keep
        
        # Environment state
        self.dot_pos = None
        self.targets = None
        self.current_target_idx = None
        self.reached_goal = False
        self.state_history = deque(maxlen=self.state_history_len)
        
        # Define observation and action spaces
        self.observation_dim = (self.num_lidar_rays +  # Lidar readings
                              2 +  # dot position
                              2 +  # current target position
                              2 +  # human input
                              1)   # previous gamma
        self.action_dim = 1  # gamma value
        
        self.reset()

    def get_state(self, human_input):
        """Convert environment state to RL observation"""
        # Validate human input
        if not isinstance(human_input, (list, np.ndarray)) or len(human_input) != 2:
            raise ValueError("Human input must be a list or array of length 2")
            
        # Simulate lidar readings
        lidar_readings = self._get_lidar_readings()
        
        # Normalize positions to [0,1]
        norm_dot_pos = [
            self.dot_pos[0] / self.window_size[0],
            self.dot_pos[1] / self.window_size[1]
        ]
        
        curr_target = self.targets[self.current_target_idx]
        norm_target_pos = [
            curr_target[0] / self.window_size[0],
            curr_target[1] / self.window_size[1]
        ]
        
        # Normalize and clip human input
        norm_human_input = [
            np.clip(human_input[0] / self.max_speed, -1, 1),
            np.clip(human_input[1] / self.max_speed, -1, 1)
        ]
        
        # Combine all state components
        state = np.concatenate([
            lidar_readings,
            norm_dot_pos,
            norm_target_pos,
            norm_human_input,
            [self.current_gamma]
        ])
        
        return state.astype(np.float32)

    def _get_lidar_readings(self):
        """Simulate simplified lidar readings"""
        readings = []
        for i in range(self.num_lidar_rays):
            angle = 2 * math.pi * i / self.num_lidar_rays
            reading = self._cast_ray(angle)
            readings.append(reading / max(self.window_size))  # Normalize
        return np.array(readings, dtype=np.float32)

    def _cast_ray(self, angle, max_dist=None):
        """Cast a ray and return distance to nearest obstacle/wall"""
        if max_dist is None:
            max_dist = math.hypot(*self.window_size)
            
        dir_x = math.cos(angle)
        dir_y = math.sin(angle)
        
        # Start from dot position
        start_x, start_y = self.dot_pos
        
        # Check wall intersections
        wall_dist = max_dist
        
        # Check intersection with each wall
        # Left wall
        if dir_x < 0:
            t = -start_x / dir_x
            y = start_y + t * dir_y
            if 0 <= y <= self.window_size[1]:
                wall_dist = min(wall_dist, abs(t))
                
        # Right wall
        elif dir_x > 0:
            t = (self.window_size[0] - start_x) / dir_x
            y = start_y + t * dir_y
            if 0 <= y <= self.window_size[1]:
                wall_dist = min(wall_dist, abs(t))
                
        # Top wall
        if dir_y < 0:
            t = -start_y / dir_y
            x = start_x + t * dir_x
            if 0 <= x <= self.window_size[0]:
                wall_dist = min(wall_dist, abs(t))
                
        # Bottom wall
        elif dir_y > 0:
            t = (self.window_size[1] - start_y) / dir_y
            x = start_x + t * dir_x
            if 0 <= x <= self.window_size[0]:
                wall_dist = min(wall_dist, abs(t))
        
        return wall_dist

    def step(self, action, human_input):
        """
        Execute one environment step
        
        Args:
            action: gamma value between 0 and 1
            human_input: [dx, dy] from keyboard/joystick
        """
        # Validate inputs
        action = np.clip(float(action), 0.0, 1.0)
        human_input = np.array(human_input, dtype=np.float32)
        if human_input.shape != (2,):
            raise ValueError("Human input must be a 2D vector")
            
        self.current_gamma = action
        
        # Move dot using original move_dot logic
        h_dx, h_dy = human_input
        h_mag = math.hypot(h_dx, h_dy)
        h_dir = [h_dx / h_mag, h_dy / h_mag] if h_mag > 0 else [0, 0]

        target_pos = self.targets[self.current_target_idx]
        w_dx = target_pos[0] - self.dot_pos[0]
        w_dy = target_pos[1] - self.dot_pos[1]
        w_mag = math.hypot(w_dx, w_dy)
        w_dir = [w_dx / w_mag, w_dy / w_mag] if w_mag > 0 else [0, 0]

        # Scale movement
        step_size = self.max_speed * min(max(h_mag / self.max_speed, 0), 1)
        
        # Calculate movement components
        w_move = [
            self.current_gamma * w_dir[0] * step_size,
            self.current_gamma * w_dir[1] * step_size
        ]
        
        h_move = [
            (1 - self.current_gamma) * h_dir[0] * step_size,
            (1 - self.current_gamma) * h_dir[1] * step_size
        ]

        # Update position
        new_pos = [
            self.dot_pos[0] + w_move[0] + h_move[0],
            self.dot_pos[1] + w_move[1] + h_move[1]
        ]
        
        # Clip to window boundaries
        self.dot_pos = [
            max(0, min(self.window_size[0], new_pos[0])),
            max(0, min(self.window_size[1], new_pos[1]))
        ]

        # Check if goal reached
        dist_to_goal = math.hypot(
            self.dot_pos[0] - target_pos[0],
            self.dot_pos[1] - target_pos[1]
        )
        self.reached_goal = dist_to_goal < self.goal_detection_radius

        # Calculate reward
        reward = self._compute_reward(dist_to_goal)
        
        # Get new state
        state = self.get_state(human_input)
        self.state_history.append(state)
        
        # Check if episode is done
        done = self.reached_goal
        
        # Additional info
        info = {
            'distance_to_goal': dist_to_goal,
            'reached_goal': self.reached_goal,
            'gamma': self.current_gamma
        }
        
        return state, reward, done, info

    def _compute_reward(self, dist_to_goal):
        """Compute reward based on paper's reward function with better scaling"""
        # Safety reward (normalized to [-1, 0])
        if dist_to_goal >= 0.8:
            rsafe = 0
        elif 0.5 < dist_to_goal < 0.8:
            rsafe = -5 * (0.8 - dist_to_goal)  # Scales linearly from 0 to -1.5
        else:
            rsafe = -1  # Cap at -1 instead of -500

        # Smoothness reward (normalized to [-1, 0])
        if len(self.state_history) > 1:
            prev_gamma = self.state_history[-1][-1]
            rsm = -abs(self.current_gamma - prev_gamma)  # Already in [0,1] range
        else:
            rsm = 0

        # Goal reward (normalized to [-0.1, 1])
        rgoal = 1.0 if self.reached_goal else -0.1
        
        # Weight the components
        return 0.4 * rsafe + 0.2 * rsm + 0.4 * rgoal

    def reset(self):
        """Reset environment state"""
        # Reset dot position to center
        self.dot_pos = [self.window_size[0] // 2, self.window_size[1] // 2]
        
        # Generate new targets
        self.targets = []
        margin = 100  # Keep targets away from edges
        for _ in range(3):  # 3 targets like original environment
            self.targets.append([
                random.randint(margin, self.window_size[0] - margin),
                random.randint(margin, self.window_size[1] - margin)
            ])
        
        self.current_target_idx = 0
        self.reached_goal = False
        self.current_gamma = 0.2  # Initial gamma value
        
        # Clear history
        self.state_history.clear()
        
        # Get initial state
        state = self.get_state([0, 0])  # No initial human input
        self.state_history.append(state)
        
        return state

    def render(self):
        """Render environment if render_mode is 'human'"""
        if self.render_mode != 'human':
            return
            
        try:
            self.screen.fill((255, 255, 255))
            
            # Draw targets
            for i, target in enumerate(self.targets):
                color = (255, 255, 0) if i > self.current_target_idx else (200, 200, 200)
                pygame.draw.circle(self.screen, color, 
                                (int(target[0]), int(target[1])), 
                                self.target_radius)
                
            # Highlight current target
            current_target = self.targets[self.current_target_idx]
            pygame.draw.circle(self.screen, (0, 0, 0),
                             (int(current_target[0]), int(current_target[1])),
                             self.target_radius + 2, 2)
            
            # Draw dot
            pygame.draw.circle(self.screen, (0, 0, 0),
                             (int(self.dot_pos[0]), int(self.dot_pos[1])),
                             self.dot_radius, 2)
            
            # Draw gamma value
            font = pygame.font.Font(None, 36)
            gamma_text = font.render(f'γ: {self.current_gamma:.2f}', True, (0, 0, 0))
            self.screen.blit(gamma_text, (10, 10))
            
            pygame.display.flip()
            self.clock.tick(60)
            
        except pygame.error as e:
            print(f"Render error: {e}")
            self.render_mode = None

    def close(self):
        if self.render_mode == 'human':
            pygame.quit()

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(ActorCritic, self).__init__()
        
        # Shared feature extractor
        self.feature_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Actor network (policy)
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(1, action_dim))
        
        # Critic network (value function)
        self.critic = nn.Linear(hidden_dim, 1)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            if module.bias is not None:
                module.bias.data.zero_()
    
    def forward(self, state):
        features = self.feature_net(state)
        
        # Actor: get mean and std of action distribution
        action_mean = torch.sigmoid(self.actor_mean(features))  # Ensure gamma is between 0 and 1
        action_std = torch.exp(self.actor_log_std)
        
        # Critic: get state value
        value = self.critic(features)
        
        return action_mean, action_std, value
    
    def get_action_distribution(self, state):
        action_mean, action_std, _ = self(state)
        return torch.distributions.Normal(action_mean, action_std)

class PPOSharedControl:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=3e-4, gamma=0.99, 
                 epsilon=0.2, c1=1.0, c2=0.01):
        """
        Initialize PPO agent for shared control
        
        Args:
            state_dim: Dimension of state space
            action_dim: Dimension of action space
            hidden_dim: Hidden layer dimension
            lr: Learning rate
            gamma: Discount factor
            epsilon: PPO clipping parameter
            c1: Value function loss coefficient
            c2: Entropy bonus coefficient
        """
        self.actor_critic = ActorCritic(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr)
        
        self.gamma = gamma
        self.epsilon = epsilon
        self.c1 = c1
        self.c2 = c2
    
    def get_action(self, state):
        """Sample action from current policy"""
        with torch.no_grad():
            dist = self.actor_critic.get_action_distribution(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            
            # Clip action to [0, 1] since it represents gamma
            action = torch.clamp(action, 0.0, 1.0)
            
        return action, log_prob
    
    def get_value(self, state):
        """Get value estimate for state"""
        with torch.no_grad():
            _, _, value = self.actor_critic(state)
        return value
    
    def update(self, states, actions, old_log_probs, returns, advantages, 
              epochs=10, batch_size=64):
        """Update policy using PPO"""
        # Convert to tensors if they aren't already
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        old_log_probs = torch.FloatTensor(old_log_probs)
        returns = torch.FloatTensor(returns)
        advantages = torch.FloatTensor(advantages)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO update over multiple epochs
        for _ in range(epochs):
            # Generate random mini-batches
            indices = torch.randperm(len(states))
            for start_idx in range(0, len(states), batch_size):
                # Get mini-batch
                idx = indices[start_idx:start_idx + batch_size]
                batch_states = states[idx]
                batch_actions = actions[idx]
                batch_old_log_probs = old_log_probs[idx]
                batch_returns = returns[idx]
                batch_advantages = advantages[idx]
                
                # Get current action distribution and value
                dist = self.actor_critic.get_action_distribution(batch_states)
                _, _, values = self.actor_critic(batch_states)
                
                # Calculate ratios and surrogate losses
                new_log_probs = dist.log_prob(batch_actions)
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                
                # PPO policy loss
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value function loss
                value_loss = F.mse_loss(values.squeeze(), batch_returns)
                
                # Entropy bonus for exploration
                entropy = dist.entropy().mean()
                
                # Total loss
                total_loss = policy_loss + self.c1 * value_loss - self.c2 * entropy
                
                # Update network
                self.optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)
                self.optimizer.step()
    
    def save(self, path):
        """Save model"""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'actor_critic_state_dict': self.actor_critic.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
    
    def load(self, path):
        """Load model"""
        checkpoint = torch.load(path)
        self.actor_critic.load_state_dict(checkpoint['actor_critic_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

def train_ppo(env, episodes=1000, steps_per_episode=1000, checkpoint_freq=100):
    """
    Train PPO agent on shared control environment
    
    Args:
        env: SharedControlEnv instance
        episodes: Number of training episodes
        steps_per_episode: Maximum steps per episode
        checkpoint_freq: Save checkpoint every n episodes
    """
    state_dim = env.observation_dim
    action_dim = env.action_dim
    
    # Initialize PPO agent
    agent = PPOSharedControl(state_dim, action_dim)
    
    # Create directory for checkpoints
    checkpoint_dir = f'checkpoints_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training metrics
    best_reward = float('-inf')
    episode_rewards = []
    
    # Training loop
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        
        # Storage for episode data
        states = []
        actions = []
        rewards = []
        log_probs = []
        
        for step in range(steps_per_episode):
            # Simulate human input (can be replaced with real human data)
            human_input = simulate_human_input(env)
            
            # Get action from policy
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action, log_prob = agent.get_action(state_tensor)
            
            # Step environment
            next_state, reward, done, info = env.step(action.item(), human_input)
            
            # Store transition
            states.append(state)
            actions.append(action.squeeze().cpu().numpy())
            rewards.append(reward)
            log_probs.append(log_prob.squeeze().cpu().numpy())
            
            episode_reward += reward
            state = next_state
            
            if env.render_mode == 'human':
                env.render()
            
            if done:
                break
        
        # Store episode reward
        episode_rewards.append(episode_reward)
        
        # Convert lists to arrays for batch processing
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.float32)
        rewards = np.array(rewards, dtype=np.float32)
        log_probs = np.array(log_probs, dtype=np.float32)
        
        # Calculate returns and advantages
        returns = compute_returns(rewards, agent.gamma)
        advantages = compute_advantages(returns, states, agent)
        
        # Update policy
        agent.update(states, actions, log_probs, returns, advantages)
        
        # Save checkpoint if best performance
        if episode_reward > best_reward:
            best_reward = episode_reward
            agent.save(os.path.join(checkpoint_dir, 'best_model.pth'))
            
        # Regular checkpoint saving
        if (episode + 1) % checkpoint_freq == 0:
            agent.save(os.path.join(checkpoint_dir, f'checkpoint_{episode+1}.pth'))
        
        # Print progress
        avg_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
        print(f"Episode {episode+1}/{episodes}, Reward: {episode_reward:.2f}, Avg Reward (100): {avg_reward:.2f}")
        
        # Early stopping if solved
        if avg_reward > 800 and len(episode_rewards) >= 100:
            print("Environment solved!")
            agent.save(os.path.join(checkpoint_dir, 'solved_model.pth'))
            break
    
    return agent, episode_rewards

def simulate_human_input(env):
    """Simulate human input for training"""
    # Get vector to current target
    target = env.targets[env.current_target_idx]
    dx = target[0] - env.dot_pos[0]
    dy = target[1] - env.dot_pos[1]
    
    # Add some noise to simulate imperfect human input
    dx += np.random.normal(0, 0.2)
    dy += np.random.normal(0, 0.2)
    
    # Normalize and scale
    mag = math.hypot(dx, dy)
    if mag > 0:
        dx = dx / mag * env.max_speed
        dy = dy / mag * env.max_speed
    
    return np.array([dx, dy], dtype=np.float32)

def compute_returns(rewards, gamma):
    """Compute discounted returns"""
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns, dtype=torch.float32)
    returns = (returns - returns.mean()) / (returns.std() + 1e-5)
    return returns

def compute_advantages(returns, states, agent):
    """Compute advantages (returns - value estimates)"""
    states = torch.FloatTensor(states)
    values = agent.get_value(states)
    advantages = returns - values.detach().squeeze()
    return advantages

if __name__ == "__main__":
    # Create environment
    env = SharedControlEnv(render_mode='human')
    
    try:
        # Train agent
        trained_agent, rewards_history = train_ppo(env)
        
        # Final save
        trained_agent.save('final_model.pth')
        
        # Save training history
        np.save('training_rewards.npy', np.array(rewards_history))
        
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    except Exception as e:
        print(f"Error during training: {e}")
    finally:
        env.close()

pygame 2.6.0 (SDL 2.28.4, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
Episode 1/1000, Reward: -19.29, Avg Reward (100): -19.29
Episode 2/1000, Reward: -12.90, Avg Reward (100): -16.10
Episode 3/1000, Reward: -19.15, Avg Reward (100): -17.11
Episode 4/1000, Reward: -16.77, Avg Reward (100): -17.03
Episode 5/1000, Reward: -10.19, Avg Reward (100): -15.66

Training interrupted by user
