In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
import matplotlib.pyplot as plt
from datetime import datetime

# Actor Network for continuous actions
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        
        self.mean_layer = nn.Linear(256, action_dim)
        self.log_std_layer = nn.Linear(256, action_dim)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, 1.0)
            module.bias.data.zero_()
    
    def forward(self, state):
        x = self.net(state)
        mean = torch.tanh(self.mean_layer(x))  # Bound mean to [-1,1]
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)  # Prevent too small or large std
        return mean, log_std.exp()

# Critic Network
class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, 1.0)
            module.bias.data.zero_()
    
    def forward(self, state):
        return self.net(state)

class PPOAgent:
    def __init__(self, state_dim, action_dim):
        self.actor = Actor(state_dim, action_dim)
        self.critic = Critic(state_dim)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        
        # PPO hyperparameters
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_epsilon = 0.2
        self.entropy_coef = 0.01
        self.value_clip_range = 0.2
        self.max_grad_norm = 0.5
        
        # Initialize buffers
        self.reset_buffers()
        
        # Training metrics
        self.rewards_history = []
        self.avg_rewards_history = []
    
    def reset_buffers(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []
    
    def get_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state)
            mean, std = self.actor(state)
            dist = Normal(mean, std)
            action = dist.sample()
            action = torch.clamp(action, -1.0, 1.0)
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = self.critic(state)
            return action.numpy(), value.item(), log_prob.item()
    
    def train_episode(self, env, render=False):
        state, _ = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            # Get action
            action, value, log_prob = self.get_action(state)
            
            # Take step in environment
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Store transition
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.values.append(value)
            self.log_probs.append(log_prob)
            self.dones.append(done)
            
            state = next_state
            total_reward += reward
        
        return total_reward
    
    def update(self, batch_size=64):
        # Convert buffers to tensors
        states = torch.FloatTensor(np.array(self.states))
        actions = torch.FloatTensor(np.array(self.actions))
        old_log_probs = torch.FloatTensor(self.log_probs)
        
        # Compute GAE and returns
        advantages = torch.zeros_like(torch.FloatTensor(self.rewards))
        returns = torch.zeros_like(torch.FloatTensor(self.rewards))
        
        running_return = 0
        running_advantage = 0
        
        for t in reversed(range(len(self.rewards))):
            if t == len(self.rewards) - 1:
                next_value = 0
            else:
                next_value = self.values[t + 1]
            
            running_return = self.rewards[t] + self.gamma * running_return * (1 - self.dones[t])
            delta = self.rewards[t] + self.gamma * next_value * (1 - self.dones[t]) - self.values[t]
            running_advantage = delta + self.gamma * self.gae_lambda * running_advantage * (1 - self.dones[t])
            
            returns[t] = running_return
            advantages[t] = running_advantage
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO update
        for _ in range(10):  # Number of epochs
            # Generate random indices
            indices = torch.randperm(len(states))
            
            # Mini-batch update
            for start in range(0, len(states), batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]
                
                # Get batch data
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]
                
                # Get current policy distribution
                means, stds = self.actor(batch_states)
                dist = Normal(means, stds)
                current_log_probs = dist.log_prob(batch_actions).sum(dim=-1)
                entropy = dist.entropy().mean()
                
                # Compute policy loss
                ratios = torch.exp(current_log_probs - batch_log_probs)
                surr1 = ratios * batch_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Compute value loss
                values = self.critic(batch_states).squeeze()
                value_loss = 0.5 * ((values - batch_returns) ** 2).mean()
                
                # Total loss
                loss = policy_loss + 0.5 * value_loss - self.entropy_coef * entropy
                
                # Update networks
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                
                self.actor_optimizer.step()
                self.critic_optimizer.step()
        
        self.reset_buffers()

def plot_training_progress(rewards, avg_rewards, title="Training Progress"):
    plt.figure(figsize=(10, 5))
    plt.plot(rewards, alpha=0.5, label='Rewards')
    plt.plot(avg_rewards, label='Average Rewards')
    plt.title(title)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.legend()
    plt.grid(True)
    plt.savefig('bipedal_training_progress.png')
    plt.close()

def train(env_name='BipedalWalker-v3', max_episodes=1000, render_freq=10):
    env = gym.make(env_name)
    eval_env = gym.make(env_name, render_mode='human')
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    agent = PPOAgent(state_dim, action_dim)
    
    # Training loop
    best_reward = float('-inf')
    
    try:
        for episode in range(max_episodes):
            # Training episode
            total_reward = agent.train_episode(env)
            agent.rewards_history.append(total_reward)
            
            # Calculate average reward
            avg_reward = np.mean(agent.rewards_history[-100:])
            agent.avg_rewards_history.append(avg_reward)
            
            # Update policy
            agent.update()
            
            # Print progress
            print(f"Episode {episode + 1}")
            print(f"Reward: {total_reward:.2f}")
            print(f"Average Reward: {avg_reward:.2f}")
            print("-" * 50)
            
            # Plot progress
            if episode % 10 == 0:
                plot_training_progress(agent.rewards_history, agent.avg_rewards_history)
            
            # Render episode
            if episode % render_freq == 0:
                eval_reward = agent.train_episode(eval_env)
                print(f"Evaluation Reward: {eval_reward:.2f}")
            
            # Save best model
            if avg_reward > best_reward:
                best_reward = avg_reward
                torch.save({
                    'actor_state_dict': agent.actor.state_dict(),
                    'critic_state_dict': agent.critic.state_dict(),
                    'reward': best_reward
                }, 'best_bipedal_model.pth')
            
            # Early stopping criterion
            if avg_reward > 300:  # BipedalWalker is considered solved at 300
                print(f"Environment solved in {episode + 1} episodes!")
                break
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    env.close()
    eval_env.close()
    return agent

def evaluate(agent, env_name='BipedalWalker-v3', episodes=5):
    env = gym.make(env_name, render_mode='human')
    
    for episode in range(episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            action, _, _ = agent.get_action(state)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
        
        print(f"Evaluation Episode {episode + 1}: Reward = {total_reward:.2f}")
    
    env.close()

if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(0)
    np.random.seed(0)
    
    # Train agent
    agent = train(render_freq=10)  # Render every 10 episodes
    
    # Evaluate trained agent
    evaluate(agent)
    

Episode 1
Reward: -107.81
Average Reward: -107.81
--------------------------------------------------
Evaluation Reward: -116.40
Episode 2
Reward: -127.53
Average Reward: -117.67
--------------------------------------------------
Episode 3
Reward: -113.95
Average Reward: -116.43
--------------------------------------------------
Episode 4
Reward: -106.04
Average Reward: -113.83
--------------------------------------------------
Episode 5
Reward: -107.97
Average Reward: -112.66
--------------------------------------------------
Episode 6
Reward: -100.51
Average Reward: -110.63
--------------------------------------------------
Episode 7
Reward: -129.39
Average Reward: -113.31
--------------------------------------------------
Episode 8
Reward: -113.95
Average Reward: -113.39
--------------------------------------------------
Episode 9
Reward: -104.02
Average Reward: -112.35
--------------------------------------------------
Episode 10
Reward: -99.52
Average Reward: -111.07
--------------

: 