In [1]:
import gym
import torch
import numpy as np
from torch import nn, optim

In [2]:
# Initialize the CartPole environment
env = gym.make("CartPole-v1", render_mode="rgb_array")

state, info = env.reset()
action = env.action_space.sample()

state, reward, done, truncated, info = env.step(action)
print(f"State: {state}, Reward: {reward}, Done: {done}, Truncated: {truncated}, Info: {info}")  


State: [ 0.01330669 -0.19223367  0.0312921   0.25859132], Reward: 1.0, Done: False, Truncated: False, Info: {}


  if not isinstance(terminated, (bool, np.bool8)):


In [3]:
# Define the PPO agent's neural network
class PPOAgent(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PPOAgent, self).__init__()
        
        # Common hidden layers
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        # Policy output: generates probabilities for actions
        self.policy = nn.Linear(64, output_dim)
        
        # Value output: predicts the expected return for a state
        self.value = nn.Linear(64, 1)

    def forward(self, x):
        x = self.fc(x)
        
        # Use softmax to ensure the policy output is a valid probability distribution
        policy_probs = torch.softmax(self.policy(x), dim=-1)
        
        # Value output is a single scalar
        value = self.value(x)
        
        return policy_probs, value

# Set the input and output dimensions based on CartPole
input_dim = 4
output_dim = 2

# Initialize the PPO agent
agent = PPOAgent(input_dim, output_dim)

# Test the network with a sample input
sample_state = torch.tensor(env.reset()[0], dtype=torch.float32)
policy_probs, value = agent(sample_state)
print("Policy probabilities:", policy_probs)
print("State value:", value)

# Define the optimizer
optimizer = optim.Adam(agent.parameters(), lr=1e-4)

# Define the PPO hyperparameters
clip_epsilon = 0.15  # Controls how much the policy can change
value_coef = 0.5  # Weight for value loss
entropy_coef = 0.01  # Weight for entropy bonus
batch_epochs = 3  # Number of update epochs per batch
batch_size = 128  # Size of mini-batches


Policy probabilities: tensor([0.5139, 0.4861], grad_fn=<SoftmaxBackward0>)
State value: tensor([0.1304], grad_fn=<ViewBackward0>)


In [4]:
def compute_advantages(rewards, values, gamma=0.99, lam=0.95):
    advantages = []
    gae = 0
    # Add a final value of 0 if episode ended (or use last state's value if it didn't)
    values = values + [0]
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + gamma * values[i + 1] - values[i]
        gae = delta + gamma * lam * gae
        advantages.insert(0, gae)
    
    return torch.tensor(advantages, dtype=torch.float32)



In [5]:
# Hyperparameters
STEPS_PER_EPOCH = 2048
EPOCHS = 1000

# Training loop
for epoch in range(EPOCHS):
    batch_data = []
    avg_life = []
    
    # Collect data for a batch
    while len(batch_data) < STEPS_PER_EPOCH:
        state = env.reset()[0]
        done = False
        
        # Collect data for a single episode
        lifetime = 0
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32)
            policy_probs, value = agent(state_tensor)
            
            action_dist = torch.distributions.Categorical(policy_probs)
            action = action_dist.sample().item()
            next_state, reward, done, _, _ = env.step(action)
            
            batch_data.append((state, action, reward, next_state, value, done))
            state = next_state
            
            lifetime += 1
        
        avg_life.append(lifetime)
    
    # Process the batch data
    print(f"Processing epoch {epoch} data, avg life: {np.mean(avg_life)}")
    
    # Extract data
    states = torch.tensor(np.array([data[0] for data in batch_data]), dtype=torch.float32)
    actions = torch.tensor([data[1] for data in batch_data], dtype=torch.int64)
    rewards = [data[2] for data in batch_data]
    values = [data[4] for data in batch_data]

    # Compute advantages
    advantages = compute_advantages(rewards, values)
    #print("Advantages:", advantages)
    
    # Convert values and advantages to tensor
    values = torch.tensor(values, dtype=torch.float32)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    #print("Normalized advantages:", advantages)
    
    for _ in range(batch_epochs):
        # Randomly shuffle the data
        indices = np.arange(len(states))
        np.random.shuffle(indices)
        
        states = states[indices]
        actions = actions[indices]
        advantages = advantages[indices]
        values = values[indices]
        
        for idx_start in range(0, len(states), batch_size):
            idx_end = idx_start + batch_size
            
            # Create a mini-batch
            batch_states = states[idx_start:idx_end]
            batch_actions = actions[idx_start:idx_end]
            batch_advantages = advantages[idx_start:idx_end]
            batch_values = values[idx_start:idx_end]
            
            # Get new policy and value predictions
            new_policy_probs, new_values = agent(batch_states)
            new_action_dist = torch.distributions.Categorical(new_policy_probs)
            
            # Calculate the log probabilities of the actions in the batch
            log_probs = new_action_dist.log_prob(batch_actions)
            old_log_probs = new_action_dist.log_prob(batch_actions).detach()
            
            # Policy ratio for PPO's clipped objective
            policy_ratio = torch.exp(log_probs - old_log_probs)
            surr1 = policy_ratio * batch_advantages
            surr2 = torch.clamp(policy_ratio, 1 - clip_epsilon, 1 + clip_epsilon) * batch_advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_loss = value_coef * (new_values.squeeze() - batch_values).pow(2).mean()

            # Entropy bonus to encourage exploration
            entropy = -torch.sum(new_policy_probs * torch.log(new_policy_probs + 1e-8), dim=-1).mean()
            entropy_loss = entropy_coef * entropy

            # Combine losses
            loss = policy_loss + value_loss + entropy_loss
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Processing epoch 0 data, avg life: 22.445652173913043
Processing epoch 1 data, avg life: 20.349514563106798
Processing epoch 2 data, avg life: 20.78787878787879
Processing epoch 3 data, avg life: 22.33695652173913
Processing epoch 4 data, avg life: 23.247191011235955
Processing epoch 5 data, avg life: 23.930232558139537
Processing epoch 6 data, avg life: 24.746987951807228
Processing epoch 7 data, avg life: 21.78723404255319
Processing epoch 8 data, avg life: 24.5
Processing epoch 9 data, avg life: 24.738095238095237
Processing epoch 10 data, avg life: 24.97590361445783
Processing epoch 11 data, avg life: 23.06741573033708
Processing epoch 12 data, avg life: 26.075949367088608
Processing epoch 13 data, avg life: 26.1125
Processing epoch 14 data, avg life: 25.132530120481928
Processing epoch 15 data, avg life: 26.31645569620253
Processing epoch 16 data, avg life: 26.037974683544302
Processing epoch 17 data, avg life: 26.911392405063292
Processing epoch 18 data, avg life: 26.701298701298

KeyboardInterrupt: 