In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
import matplotlib.pyplot as plt
import time
import os

# Actor and Critic Network definitions remain the same
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.LayerNorm(128)  # Layer normalization for stable learning
        )
        self.mean_layer = nn.Linear(128, action_dim)
        self.log_std_layer = nn.Linear(128, action_dim)
        
    def forward(self, state):
        x = self.net(state)
        mean = torch.tanh(self.mean_layer(x))
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)
        return mean, log_std.exp()

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.LayerNorm(128)  # Layer normalization
        )
        self.value_layer = nn.Linear(128, 1)
        
    def forward(self, state):
        x = self.net(state)
        value = self.value_layer(x)
        return value


class PPO:
    def __init__(self, state_dim, action_dim):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.actor = Actor(state_dim, action_dim).to(self.device)
        self.critic = Critic(state_dim).to(self.device)
        
        # Optimizers with initial learning rates
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3)

        # PPO hyperparameters with adjustments
        self.gamma = 0.98
        self.gae_lambda = 0.95
        self.clip_ratio = 0.1  # Reduce to encourage policy stability
        self.entropy_coef = 0.005  # Decreased entropy coefficient for later stages
        self.max_grad_norm = 0.5

        # Memory management
        self.max_memory_size = 2048

        # Learning rate scheduler
        self.scheduler_actor = torch.optim.lr_scheduler.ExponentialLR(self.actor_optimizer, gamma=0.99)
        self.scheduler_critic = torch.optim.lr_scheduler.ExponentialLR(self.critic_optimizer, gamma=0.99)

        
    def get_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device)
            mean, std = self.actor(state)
            dist = Normal(mean, std)
            action = dist.sample()
            action = torch.clamp(action, -1.0, 1.0)
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = self.critic(state)
            return action.cpu().numpy(), value.item(), log_prob.item()

    def update(self, memory, batch_size=64):
        states = torch.FloatTensor(memory['states']).to(self.device)
        actions = torch.FloatTensor(memory['actions']).to(self.device)
        old_log_probs = torch.FloatTensor(memory['log_probs']).to(self.device)
        returns = torch.FloatTensor(memory['returns']).to(self.device)
        advantages = torch.FloatTensor(memory['advantages']).to(self.device)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Mini-batch updates
        indices = torch.randperm(states.size(0))
        
        for start_idx in range(0, states.size(0), batch_size):
            idx = indices[start_idx:start_idx + batch_size]
            
            batch_states = states[idx]
            batch_actions = actions[idx]
            batch_old_log_probs = old_log_probs[idx]
            batch_returns = returns[idx]
            batch_advantages = advantages[idx]
            
            # Get current policy distribution
            means, stds = self.actor(batch_states)
            dist = Normal(means, stds)
            new_log_probs = dist.log_prob(batch_actions).sum(dim=-1)
            entropy = dist.entropy().mean()
            
            # Calculate ratios and losses
            ratios = torch.exp(new_log_probs - batch_old_log_probs)
            surr1 = ratios * batch_advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
            
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_values = self.critic(batch_states).squeeze()
            critic_loss = 0.5 * ((critic_values - batch_returns) ** 2).mean()
            
            # Total loss
            loss = actor_loss + 0.5 * critic_loss - self.entropy_coef * entropy
            
            # Update networks
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
            
            self.actor_optimizer.step()
            self.critic_optimizer.step()

def train(render_every=20, total_episodes=10000):
    # Create environment
    env = gym.make('BipedalWalker-v3')
    
    # Initialize agent
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    agent = PPO(state_dim, action_dim)
    
    # Training variables
    best_reward = float('-inf')  # Initialize best_reward to a very low value
    episode_rewards = []
    display_env = None
    
    try:
        for episode in range(total_episodes):
            # Initialize episode memory
            memory = {
                'states': [],
                'actions': [],
                'rewards': [],
                'dones': [],
                'log_probs': [],
                'values': []
            }
            
            state, _ = env.reset()
            episode_reward = 0
            
            # Create display environment if needed
            if episode % render_every == 0:
                if display_env is not None:
                    display_env.close()
                display_env = gym.make('BipedalWalker-v3', render_mode='human')
                display_state, _ = display_env.reset()
            
            # Episode loop
            while True:
                action, value, log_prob = agent.get_action(state)
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                
                # Store transition
                memory['states'].append(state)
                memory['actions'].append(action)
                memory['rewards'].append(reward)
                memory['dones'].append(done)
                memory['log_probs'].append(log_prob)
                memory['values'].append(value)
                
                episode_reward += reward
                state = next_state
                
                # Display if needed
                if episode % render_every == 0 and display_env is not None:
                    display_action, _, _ = agent.get_action(display_state)
                    display_state, _, terminated, truncated, _ = display_env.step(display_action)
                    if terminated or truncated:
                        break
                
                if done:
                    break
                
                # Check memory limit
                if len(memory['states']) >= agent.max_memory_size:
                    break
            
            # Compute returns and advantages
            returns, advantages = compute_returns(
                memory['rewards'],
                memory['dones'],
                memory['values'],
                agent.gamma,
                agent.gae_lambda
            )
            
            # Update memory
            memory['returns'] = returns
            memory['advantages'] = advantages
            
            # Update policy
            agent.update(memory)
            
            # Store and print progress
            episode_rewards.append(episode_reward)
            avg_reward = np.mean(episode_rewards[-100:])
            
            print(f"Episode {episode + 1}")
            print(f"Reward: {episode_reward:.2f}")
            print(f"Average Reward (last 100): {avg_reward:.2f}")
            print("-" * 50)
            
            # Save best model
            if avg_reward > best_reward:
                best_reward = avg_reward
                torch.save({
                    'actor_state_dict': agent.actor.state_dict(),
                    'critic_state_dict': agent.critic.state_dict(),
                    'reward': best_reward
                }, 'best_model.pth')
            
            # Clear memory
            del memory
            
            # Early stopping
            if avg_reward >= 300:
                print("Environment solved!")
                break
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    finally:
        env.close()
        if display_env is not None:
            display_env.close()
    
    return agent



def compute_returns(rewards, dones, values, gamma, gae_lambda):
    returns = []
    advantages = []
    advantage = 0
    next_value = 0
    
    for r, d, v in zip(reversed(rewards), reversed(dones), reversed(values)):
        td_error = r + gamma * next_value * (1 - d) - v
        advantage = td_error + gamma * gae_lambda * (1 - d) * advantage
        next_value = v
        
        returns.insert(0, advantage + v)
        advantages.insert(0, advantage)
    
    return returns, advantages

if __name__ == "__main__":
    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Train the agent
    print("Starting training...")
    agent = train(render_every=20)

Starting training...


  states = torch.FloatTensor(memory['states']).to(self.device)


Episode 1
Reward: -4.64
Average Reward (last 100): -4.64
--------------------------------------------------
Episode 2
Reward: -112.80
Average Reward (last 100): -58.72
--------------------------------------------------
Episode 3
Reward: -113.42
Average Reward (last 100): -76.95
--------------------------------------------------
Episode 4
Reward: -152.56
Average Reward (last 100): -95.86
--------------------------------------------------
Episode 5
Reward: -111.26
Average Reward (last 100): -98.94
--------------------------------------------------
Episode 6
Reward: -113.68
Average Reward (last 100): -101.39
--------------------------------------------------
Episode 7
Reward: -108.10
Average Reward (last 100): -102.35
--------------------------------------------------
Episode 8
Reward: -176.60
Average Reward (last 100): -111.63
--------------------------------------------------
Episode 9
Reward: -184.56
Average Reward (last 100): -119.74
--------------------------------------------------
