In [None]:
import pygame
import numpy as np
import torch
import math
from collections import deque
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import os
from datetime import datetime

##############################
# 1. Select device (GPU or CPU)
##############################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

class SharedControlEnv:
    """Simplified shared-control environment without LIDAR or obstacles, single goal."""
    def __init__(self, window_size=(1200, 800), render_mode=None):
        self.window_size = window_size
        self.render_mode = render_mode

        # Environment parameters
        self.max_speed = 3
        self.dot_radius = 30
        self.target_radius = 10
        self.goal_detection_radius = self.dot_radius + self.target_radius

        # Initialize pygame if rendering
        if self.render_mode == 'human':
            try:
                pygame.init()
                self.screen = pygame.display.set_mode(window_size)
                self.clock = pygame.time.Clock()
            except pygame.error as e:
                print(f"Failed to initialize pygame: {e}")
                self.render_mode = None

        # We'll keep some history of states if needed
        self.state_history_len = 5
        self.state_history = deque(maxlen=self.state_history_len)

        # The state is: (dot_x, dot_y, goal_x, goal_y, human_dx, human_dy, gamma)
        self.observation_dim = 7

        # One-dimensional action: gamma in [0, 1]
        self.action_dim = 1

        # Environment state
        self.dot_pos = None
        self.goal_pos = None
        self.reached_goal = False
        self.current_gamma = 0.2

        self.reset()

    def get_state(self, human_input):
        """
        Construct the state vector:
        [dot_x, dot_y, goal_x, goal_y, h_in_x, h_in_y, gamma].
        Positions normalized to [0,1].
        Human input normalized by max_speed.
        """
        # Normalize positions
        norm_dot_pos = [
            self.dot_pos[0] / self.window_size[0],
            self.dot_pos[1] / self.window_size[1]
        ]
        norm_goal_pos = [
            self.goal_pos[0] / self.window_size[0],
            self.goal_pos[1] / self.window_size[1]
        ]

        # Normalize human input
        norm_human_input = [
            np.clip(human_input[0] / self.max_speed, -1, 1),
            np.clip(human_input[1] / self.max_speed, -1, 1)
        ]

        state = np.array([
            norm_dot_pos[0],
            norm_dot_pos[1],
            norm_goal_pos[0],
            norm_goal_pos[1],
            norm_human_input[0],
            norm_human_input[1],
            self.current_gamma
        ], dtype=np.float32)

        return state

    def step(self, action, human_input):
        """
        Step the environment forward by one.
        action = gamma in [0, 1].
        human_input = [dx, dy] with noise.
        """
        # Clip gamma to [0, 1]
        action = np.clip(float(action), 0.0, 1.0)
        self.current_gamma = action

        # Decompose human input
        h_dx, h_dy = human_input
        h_mag = math.hypot(h_dx, h_dy)
        h_dir = [h_dx / h_mag, h_dy / h_mag] if h_mag > 0 else [0, 0]

        # Direction to goal
        w_dx = self.goal_pos[0] - self.dot_pos[0]
        w_dy = self.goal_pos[1] - self.dot_pos[1]
        w_mag = math.hypot(w_dx, w_dy)
        w_dir = [w_dx / w_mag, w_dy / w_mag] if w_mag > 0 else [0, 0]

        # Scale movement by max_speed
        step_size = self.max_speed * min(max(h_mag / self.max_speed, 0), 1)

        # Weighted movement: gamma for autopilot, (1 - gamma) for human
        w_move = [
            self.current_gamma * w_dir[0] * step_size,
            self.current_gamma * w_dir[1] * step_size
        ]
        h_move = [
            (1 - self.current_gamma) * h_dir[0] * step_size,
            (1 - self.current_gamma) * h_dir[1] * step_size
        ]

        new_pos = [
            self.dot_pos[0] + w_move[0] + h_move[0],
            self.dot_pos[1] + w_move[1] + h_move[1]
        ]

        # Clip to window boundaries
        self.dot_pos = [
            max(0, min(self.window_size[0], new_pos[0])),
            max(0, min(self.window_size[1], new_pos[1]))
        ]

        # Check if goal is reached
        dist_to_goal = math.hypot(
            self.dot_pos[0] - self.goal_pos[0],
            self.dot_pos[1] - self.goal_pos[1]
        )
        self.reached_goal = (dist_to_goal < self.goal_detection_radius)

        # Calculate reward
        reward = self._compute_reward(dist_to_goal)

        # Build next state
        state = self.get_state(human_input)
        self.state_history.append(state)

        done = self.reached_goal

        info = {
            'distance_to_goal': dist_to_goal,
            'reached_goal': self.reached_goal,
            'gamma': self.current_gamma
        }

        return state, reward, done, info

    def _compute_reward(self, dist_to_goal):
        """
        Simple reward function:
        - +1 if goal is reached
        - -0.01 * dist_to_goal if not reached
        - small penalty for deviation from gamma=0.5
        """
        reward = 0.0
        if self.reached_goal:
            reward += 1.0
        else:
            reward -= 0.01 * dist_to_goal
        reward -= 0.05 * abs(self.current_gamma - 0.5)
        return reward

    def reset(self):
        """Reset the environment: dot to center, new random goal."""
        self.dot_pos = [
            self.window_size[0] // 2,
            self.window_size[1] // 2
        ]
        margin = 100
        self.goal_pos = [
            random.randint(margin, self.window_size[0] - margin),
            random.randint(margin, self.window_size[1] - margin)
        ]
        self.reached_goal = False
        self.current_gamma = 0.2
        self.state_history.clear()

        init_state = self.get_state([0, 0])
        self.state_history.append(init_state)
        return init_state

    def render(self):
        """Render environment if render_mode == 'human'."""
        if self.render_mode != 'human':
            return
        try:
            self.screen.fill((255, 255, 255))

            # Draw goal
            pygame.draw.circle(
                self.screen,
                (255, 255, 0),
                (int(self.goal_pos[0]), int(self.goal_pos[1])),
                self.target_radius
            )
            pygame.draw.circle(
                self.screen,
                (0, 0, 0),
                (int(self.goal_pos[0]), int(self.goal_pos[1])),
                self.target_radius + 2, 2
            )

            # Draw dot
            pygame.draw.circle(
                self.screen,
                (0, 0, 0),
                (int(self.dot_pos[0]), int(self.dot_pos[1])),
                self.dot_radius, 2
            )

            # Draw gamma
            font = pygame.font.Font(None, 36)
            gamma_text = font.render(f'γ: {self.current_gamma:.2f}', True, (0, 0, 0))
            self.screen.blit(gamma_text, (10, 10))

            pygame.display.flip()
            self.clock.tick(60)
        except pygame.error as e:
            print(f"Render error: {e}")
            self.render_mode = None

    def close(self):
        if self.render_mode == 'human':
            pygame.quit()


class ActorCritic(nn.Module):
    """Actor-Critic network (PPO-compatible)."""

    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(ActorCritic, self).__init__()

        # Shared feature extractor
        self.feature_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Actor network
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(1, action_dim))

        # Critic network
        self.critic = nn.Linear(hidden_dim, 1)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(self, state):
        features = self.feature_net(state)
        action_mean = torch.sigmoid(self.actor_mean(features))  # ensures [0,1]
        action_std = torch.exp(self.actor_log_std)
        value = self.critic(features)
        return action_mean, action_std, value

    def get_action_distribution(self, state):
        action_mean, action_std, _ = self(state)
        return torch.distributions.Normal(action_mean, action_std)


class PPOSharedControl:
    """Proximal Policy Optimization for the shared-control environment."""

    def __init__(
        self, 
        state_dim,
        action_dim,
        hidden_dim=128,
        lr=1e-4,        # lower LR
        gamma=0.99,
        epsilon=0.2,
        c1=1.0,
        c2=0.01
    ):
        self.actor_critic = ActorCritic(state_dim, action_dim, hidden_dim).to(device)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr)

        self.gamma = gamma
        self.epsilon = epsilon
        self.c1 = c1
        self.c2 = c2

    def get_action(self, state):
        """Sample an action (gamma) from the current policy."""
        # Move state to GPU if available
        state = state.to(device)

        with torch.no_grad():
            dist = self.actor_critic.get_action_distribution(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            # Clip action to [0,1]
            action = torch.clamp(action, 0.0, 1.0)
        return action, log_prob

    def get_value(self, state):
        """Get the critic's value estimate for a state."""
        state = state.to(device)
        with torch.no_grad():
            _, _, value = self.actor_critic(state)
        return value

    def update(self, states, actions, old_log_probs, returns, advantages,
               epochs=10, batch_size=64):
        # Move data to GPU
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        old_log_probs = torch.FloatTensor(old_log_probs).to(device)
        returns = torch.FloatTensor(returns).to(device)
        advantages = torch.FloatTensor(advantages).to(device)

        # Normalize advantages safely
        adv_std = advantages.std()
        if adv_std < 1e-8:
            advantages = advantages - advantages.mean()
        else:
            advantages = (advantages - advantages.mean()) / (adv_std + 1e-8)

        for _ in range(epochs):
            indices = torch.randperm(len(states))
            for start_idx in range(0, len(states), batch_size):
                idx = indices[start_idx : start_idx + batch_size]

                batch_states = states[idx]
                batch_actions = actions[idx]
                batch_old_log_probs = old_log_probs[idx]
                batch_returns = returns[idx]
                batch_advantages = advantages[idx]

                dist = self.actor_critic.get_action_distribution(batch_states)
                _, _, values = self.actor_critic(batch_states)

                new_log_probs = dist.log_prob(batch_actions)
                ratio = torch.exp(new_log_probs - batch_old_log_probs)

                # PPO objective
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(
                    ratio,
                    1.0 - self.epsilon,
                    1.0 + self.epsilon
                ) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value function loss, flatten shapes
                value_loss = F.mse_loss(values.view(-1), batch_returns.view(-1))

                # Entropy bonus
                entropy = dist.entropy().mean()

                # Total loss
                total_loss = policy_loss + self.c1 * value_loss - self.c2 * entropy

                self.optimizer.zero_grad()
                total_loss.backward()
                # Stronger grad clipping
                torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_norm=0.1)
                self.optimizer.step()

    def save(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'actor_critic_state_dict': self.actor_critic.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def load(self, path):
        checkpoint = torch.load(path, map_location=device)
        self.actor_critic.load_state_dict(checkpoint['actor_critic_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


def simulate_human_input(env):
    """
    Simulate much noisier human input directed (roughly) toward the goal, 
    but not so large as to cause extreme instability.
    """
    dx = env.goal_pos[0] - env.dot_pos[0]
    dy = env.goal_pos[1] - env.dot_pos[1]

    # Use a smaller std dev than 1000 to reduce instability
    dx += np.random.normal(0, 10)
    dy += np.random.normal(0, 10)

    mag = math.hypot(dx, dy)
    if mag > 0:
        dx = dx / mag * env.max_speed
        dy = dy / mag * env.max_speed

    return np.array([dx, dy], dtype=np.float32)


def compute_returns(rewards, gamma):
    """
    Compute discounted returns, then safely normalize them.
    """
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns, dtype=torch.float32)

    # Safe normalization
    returns_std = returns.std()
    if returns_std < 1e-8:
        returns = returns - returns.mean()
    else:
        returns = (returns - returns.mean()) / (returns_std + 1e-5)

    return returns


def compute_advantages(returns, states, agent):
    """
    Advantages = returns - value estimates, done carefully to avoid NaNs.
    """
    states_tensor = torch.FloatTensor(states)
    values = agent.get_value(states_tensor).detach().squeeze()

    # Move values back to CPU if needed
    values = values.cpu()
    advantages = returns - values

    # If you want to do safe normalization here as well, you can,
    # but it's also done inside agent.update(...) after concatenation.
    return advantages


def train_ppo(env, episodes=500, steps_per_episode=300, checkpoint_freq=50):
    """
    Train a PPO agent on the simplified SharedControlEnv.
    """
    state_dim = env.observation_dim
    action_dim = env.action_dim

    agent = PPOSharedControl(state_dim, action_dim)

    # Create directory for checkpoints
    checkpoint_dir = f'checkpoints_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    os.makedirs(checkpoint_dir, exist_ok=True)

    best_reward = float('-inf')
    episode_rewards = []

    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0

        # Collect episode experience
        states = []
        actions = []
        rewards = []
        log_probs = []

        for step in range(steps_per_episode):
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            human_input = simulate_human_input(env)
            action, log_prob = agent.get_action(state_tensor)

            # Convert action to a float for environment step
            next_state, reward, done, info = env.step(action.item(), human_input)

            states.append(state)
            actions.append(action.squeeze().cpu().numpy())
            rewards.append(reward)
            log_probs.append(log_prob.squeeze().cpu().numpy())

            episode_reward += reward
            state = next_state

            if env.render_mode == 'human':
                env.render()

            if done:
                break

        episode_rewards.append(episode_reward)

        # Convert to arrays
        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.float32)
        rewards = np.array(rewards, dtype=np.float32)
        log_probs = np.array(log_probs, dtype=np.float32)

        # Compute returns and advantages
        returns = compute_returns(rewards, agent.gamma)
        advantages = compute_advantages(returns, states, agent)

        # Update PPO
        agent.update(states, actions, log_probs, returns, advantages)

        # Checkpoint if best
        if episode_reward > best_reward:
            best_reward = episode_reward
            agent.save(os.path.join(checkpoint_dir, 'best_model.pth'))

        # Regular checkpoint
        if (episode + 1) % checkpoint_freq == 0:
            agent.save(os.path.join(checkpoint_dir, f'checkpoint_{episode+1}.pth'))

        # Print progress
        avg_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 \
            else np.mean(episode_rewards)
        print(f"Episode {episode+1}/{episodes}, Reward: {episode_reward:.2f}, "
              f"Avg Reward (last 100): {avg_reward:.2f}")

        # Early stop if solved
        if avg_reward > 200 and len(episode_rewards) >= 100:
            print("Environment solved!")
            agent.save(os.path.join(checkpoint_dir, 'solved_model.pth'))
            break

    return agent, episode_rewards


if __name__ == "__main__":
    env = SharedControlEnv(render_mode='human')
    try:
        agent, rewards_history = train_ppo(env)
        agent.save('final_model.pth')
        np.save('training_rewards.npy', np.array(rewards_history))
    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    except Exception as e:
        print(f"Error during training: {e}")
    finally:
        env.close()


pygame 2.6.1 (SDL 2.28.4, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html
Using device: cuda
Episode 1/500, Reward: -191.15, Avg Reward (last 100): -191.15
Episode 2/500, Reward: -382.16, Avg Reward (last 100): -286.66
Episode 3/500, Reward: -63.07, Avg Reward (last 100): -212.13
Episode 4/500, Reward: -79.46, Avg Reward (last 100): -178.96
Episode 5/500, Reward: -180.01, Avg Reward (last 100): -179.17
Episode 6/500, Reward: -254.00, Avg Reward (last 100): -191.64
Episode 7/500, Reward: -85.35, Avg Reward (last 100): -176.46
Episode 8/500, Reward: -153.69, Avg Reward (last 100): -173.61
Episode 9/500, Reward: -365.54, Avg Reward (last 100): -194.94
Episode 10/500, Reward: -338.55, Avg Reward (last 100): -209.30
Episode 11/500, Reward: -481.17, Avg Reward (last 100): -234.01
Episode 12/500, Reward: -200.40, Avg Reward (last 100): -231.21
Episode 13/500, Reward: -118.10, Avg Reward (last 100): -222.51
Episode 14/500, Reward: -469.54, Avg Reward (las