In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.actor = nn.Linear(256, action_dim)
        self.critic = nn.Linear(256, 1)
        self.log_std = nn.Parameter(torch.zeros(1, action_dim))
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.actor(x)
        std = self.log_std.exp().expand_as(mean)
        value = self.critic(x)
        return mean, std, value

def run_ppo(episodes, is_training, render=False):
    env = gym.make('Pendulum-v1', render_mode='human' if render else None)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_bound = float(env.action_space.high[0])
    
    # Hyperparameters
    gamma = 0.99
    lr = 3e-4
    epsilon = 0.2
    epochs = 10
    batch_size = 64
    max_steps = 200
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ActorCritic(state_dim, action_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    if not is_training:
        model.load_state_dict(torch.load('ppo_pendulum.pth'))
        model.eval()
    
    rewards = []
    
    for episode in range(episodes):
        states = []
        actions = []
        rewards_ep = []
        log_probs = []
        values = []
        dones = []
        
        state, _ = env.reset()
        ep_reward = 0
        
        # Collect trajectory
        for _ in range(max_steps):
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            with torch.set_grad_enabled(is_training):
                mean, std, value = model(state_tensor)
                dist = Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).sum(-1)
            
            action_np = action.cpu().numpy().flatten()
            action_np = np.clip(action_np, -action_bound, action_bound)
            
            next_state, reward, terminated, truncated, _ = env.step(action_np)
            done = terminated or truncated
            
            states.append(state)
            actions.append(action_np)
            rewards_ep.append(reward)
            log_probs.append(log_prob)
            values.append(value)
            dones.append(done)
            
            state = next_state
            ep_reward += reward
            
            if done:
                break
        
        rewards.append(ep_reward)
        
        if is_training:
            # Convert to tensors
            states = torch.FloatTensor(np.array(states)).to(device)
            actions = torch.FloatTensor(np.array(actions)).to(device)
            old_log_probs = torch.cat(log_probs).detach()
            rewards_ep = torch.FloatTensor(np.array(rewards_ep)).to(device)
            dones = torch.FloatTensor(np.array(dones)).to(device)
            
            # Calculate returns
            returns = []
            R = 0
            for r, done in zip(reversed(rewards_ep), reversed(dones)):
                R = r + gamma * R * (1 - done)
                returns.insert(0, R)
            returns = torch.FloatTensor(returns).to(device)
            
            # Normalize returns
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            
            # PPO update
            for _ in range(epochs):
                # Shuffle indices
                indices = torch.randperm(len(states))
                
                for start in range(0, len(states), batch_size):
                    end = start + batch_size
                    idx = indices[start:end]
                    
                    # Get minibatch
                    batch_states = states[idx]
                    batch_actions = actions[idx]
                    batch_old_log_probs = old_log_probs[idx]
                    batch_returns = returns[idx]
                    
                    # Get current policy
                    mean, std, values_pred = model(batch_states)
                    dist = Normal(mean, std)
                    log_probs = dist.log_prob(batch_actions).sum(-1)
                    entropy = dist.entropy().mean()
                    
                    # Calculate ratios
                    ratios = torch.exp(log_probs - batch_old_log_probs)
                    
                    # Calculate losses
                    advantages = batch_returns - values_pred.squeeze().detach()
                    surr1 = ratios * advantages
                    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
                    actor_loss = -torch.min(surr1, surr2).mean()
                    critic_loss = F.mse_loss(values_pred.squeeze(), batch_returns)
                    
                    # Total loss
                    loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
                    
                    # Update
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()
        
        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(rewards[-10:])
            print(f'Episode {episode+1}/{episodes}, Reward: {ep_reward:.2f}, Avg Reward: {avg_reward:.2f}')
            
            if is_training:
                torch.save(model.state_dict(), 'ppo_pendulum.pth')
    
    # Plot results
    plt.plot(rewards)
    plt.title('PPO Training - Pendulum-v1')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.savefig('ppo_pendulum.png')
    plt.close()
    
    env.close()



In [3]:
# Training
run_ppo(episodes=500, is_training=True, render=False)

# Testing
run_ppo(episodes=5, is_training=False, render=True)

Episode 10/500, Reward: -1342.64, Avg Reward: -1265.56
Episode 20/500, Reward: -1101.37, Avg Reward: -1430.97
Episode 30/500, Reward: -1200.13, Avg Reward: -1308.93
Episode 40/500, Reward: -1509.21, Avg Reward: -1305.01
Episode 50/500, Reward: -981.37, Avg Reward: -1379.17
Episode 60/500, Reward: -880.88, Avg Reward: -1298.77
Episode 70/500, Reward: -1555.10, Avg Reward: -1360.12
Episode 80/500, Reward: -950.37, Avg Reward: -1163.48
Episode 90/500, Reward: -1176.81, Avg Reward: -1172.25
Episode 100/500, Reward: -1170.19, Avg Reward: -1366.35
Episode 110/500, Reward: -1513.01, Avg Reward: -1264.77
Episode 120/500, Reward: -1192.74, Avg Reward: -1293.14
Episode 130/500, Reward: -1213.30, Avg Reward: -1328.06
Episode 140/500, Reward: -1474.82, Avg Reward: -1351.27
Episode 150/500, Reward: -1326.70, Avg Reward: -1356.41
Episode 160/500, Reward: -1360.80, Avg Reward: -1389.85
Episode 170/500, Reward: -1354.80, Avg Reward: -1355.64
Episode 180/500, Reward: -1337.48, Avg Reward: -1432.70
Epis

In [4]:
# Testing
run_ppo(episodes=5, is_training=False, render=True)