In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
import matplotlib.pyplot as plt
import time
import os

# Set environment variables to handle PyGame display better
os.environ['SDL_VIDEODRIVER'] = 'windib'  # Use Windows driver
os.environ['SDL_WINDOW_CENTERED'] = '1'

# Actor Network
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.mean_layer = nn.Linear(256, action_dim)
        self.log_std_layer = nn.Linear(256, action_dim)
        
    def forward(self, state):
        x = self.net(state)
        mean = torch.tanh(self.mean_layer(x))
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)
        return mean, log_std.exp()

# Critic Network
class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self, state):
        return self.net(state)

class PPO:
    def __init__(self, state_dim, action_dim):
        self.actor = Actor(state_dim, action_dim)
        self.critic = Critic(state_dim)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        
        # PPO hyperparameters
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_ratio = 0.2
        self.clip_rewards = True
        self.entropy_coef = 0.01
        self.max_grad_norm = 0.5
        
        # Training history
        self.rewards_history = []
        self.avg_rewards_history = []
        
    def get_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state)
            mean, std = self.actor(state)
            dist = Normal(mean, std)
            action = dist.sample()
            action = torch.clamp(action, -1.0, 1.0)
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = self.critic(state)
            return action.numpy(), value.item(), log_prob.item()
    
    def compute_returns(self, rewards, dones, values):
        returns = []
        advantages = []
        advantage = 0
        next_value = 0
        
        for r, d, v in zip(reversed(rewards), reversed(dones), reversed(values)):
            td_error = r + self.gamma * next_value * (1 - d) - v
            advantage = td_error + self.gamma * self.gae_lambda * (1 - d) * advantage
            next_value = v
            
            returns.insert(0, advantage + v)
            advantages.insert(0, advantage)
            
        return torch.FloatTensor(returns), torch.FloatTensor(advantages)
    
    def update(self, states, actions, log_probs, returns, advantages, batch_size=64):
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        old_log_probs = torch.FloatTensor(log_probs)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        for _ in range(10):  # PPO epochs
            # Generate random indices
            indices = torch.randperm(states.size(0))
            
            for start_idx in range(0, states.size(0), batch_size):
                idx = indices[start_idx:start_idx + batch_size]
                
                batch_states = states[idx]
                batch_actions = actions[idx]
                batch_old_log_probs = old_log_probs[idx]
                batch_returns = returns[idx]
                batch_advantages = advantages[idx]
                
                # Get current policy distribution
                means, stds = self.actor(batch_states)
                dist = Normal(means, stds)
                new_log_probs = dist.log_prob(batch_actions).sum(dim=-1)
                entropy = dist.entropy().mean()
                
                # Calculate ratios
                ratios = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratios * batch_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
                
                # Calculate losses
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_values = self.critic(batch_states).squeeze()
                critic_loss = 0.5 * ((critic_values - batch_returns) ** 2).mean()
                
                # Total loss
                loss = actor_loss + 0.5 * critic_loss - self.entropy_coef * entropy
                
                # Update networks
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                
                self.actor_optimizer.step()
                self.critic_optimizer.step()

def train(render_every=20, total_episodes=1000):
    # Create environments
    env = gym.make('BipedalWalker-v3')
    
    # Initialize agent
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    agent = PPO(state_dim, action_dim)
    
    # Training loop
    best_reward = float('-inf')
    episode_rewards = []
    
    try:
        for episode in range(total_episodes):
            states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
            
            # Reset environment
            state, _ = env.reset()
            episode_reward = 0
            
            # Create display environment if needed
            if episode % render_every == 0:
                display_env = gym.make('BipedalWalker-v3', render_mode='human')
                display_state, _ = display_env.reset()
            
            while True:
                # Get action from agent
                action, value, log_prob = agent.get_action(state)
                
                # Take step in environment
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                
                # Store transition
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                log_probs.append(log_prob)
                values.append(value)
                
                episode_reward += reward
                state = next_state
                
                # Display if needed
                if episode % render_every == 0:
                    display_action, _, _ = agent.get_action(display_state)
                    display_state, _, terminated, truncated, _ = display_env.step(display_action)
                    if terminated or truncated:
                        display_env.close()
                        break
                
                if done:
                    break
            
            # Update policy
            returns, advantages = agent.compute_returns(rewards, dones, values)
            agent.update(states, actions, log_probs, returns, advantages)
            
            # Store reward
            episode_rewards.append(episode_reward)
            avg_reward = np.mean(episode_rewards[-100:])
            
            # Print progress
            print(f"Episode {episode + 1}")
            print(f"Reward: {episode_reward:.2f}")
            print(f"Average Reward (last 100): {avg_reward:.2f}")
            print("-" * 50)
            
            # Plot progress
            if (episode + 1) % 10 == 0:
                plt.figure(figsize=(10, 5))
                plt.plot(episode_rewards)
                plt.title("Training Progress")
                plt.xlabel("Episode")
                plt.ylabel("Reward")
                plt.savefig("training_progress.png")
                plt.close()
            
            # Save best model
            if avg_reward > best_reward:
                best_reward = avg_reward
                torch.save({
                    'actor_state_dict': agent.actor.state_dict(),
                    'critic_state_dict': agent.critic.state_dict(),
                    'reward': best_reward
                }, 'best_model.pth')
            
            # Early stopping
            if avg_reward >= 300:
                print("Environment solved!")
                break
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    finally:
        env.close()
        if 'display_env' in locals():
            display_env.close()
    
    return agent

def evaluate(agent, episodes=5):
    env = gym.make('BipedalWalker-v3', render_mode='human')
    
    try:
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            
            while not done:
                action, _, _ = agent.get_action(state)
                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                total_reward += reward
                time.sleep(0.01)  # Slow down visualization
            
            print(f"Episode {episode + 1} Reward: {total_reward:.2f}")
    
    finally:
        env.close()

if __name__ == "__main__":
    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Train agent
    print("Starting training...")
    agent = train(render_every=20)  # Show every 20 episodes
    
    # Evaluate agent
    print("\nEvaluating trained agent...")
    evaluate(agent)

Starting training...


  states = torch.FloatTensor(states)


Episode 1
Reward: -124.50
Average Reward (last 100): -124.50
--------------------------------------------------
Episode 2
Reward: -116.26
Average Reward (last 100): -120.38
--------------------------------------------------
Episode 3
Reward: -100.28
Average Reward (last 100): -113.68
--------------------------------------------------
Episode 4
Reward: -111.77
Average Reward (last 100): -113.20
--------------------------------------------------
Episode 5
Reward: -109.13
Average Reward (last 100): -112.39
--------------------------------------------------
Episode 6
Reward: -111.60
Average Reward (last 100): -112.26
--------------------------------------------------
Episode 7
Reward: -112.87
Average Reward (last 100): -112.34
--------------------------------------------------
Episode 8
Reward: -109.17
Average Reward (last 100): -111.95
--------------------------------------------------
Episode 9
Reward: -111.70
Average Reward (last 100): -111.92
-------------------------------------------

: 