In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque

# --- Hyperparameters ---
LEARNING_RATE = 3e-4
GAMMA = 0.99  # Discount factor for future rewards
GAE_LAMBDA = 0.95  # Lambda for Generalized Advantage Estimation (GAE)
PPO_EPOCHS = 10  # How many times to iterate over the collected batch
CLIP_EPSILON = 0.2  # Clipping parameter for PPO loss
VALUE_COEF = 0.5  # Coefficient for the value function loss
ENTROPY_COEF = 0.01  # Coefficient for the entropy term (encourages exploration)
BATCH_SIZE = 64
TIMESTEPS_PER_BATCH = 2048  # How many environment steps to collect before a PPO update
MAX_TIMESTEPS = 1000000

# --- 1. Actor-Critic Network Architecture ---
class ActorCritic(nn.Module):
    """
    A single neural network that outputs both the policy (Actor)
    and the value function (Critic).
    """
    def __init__(self, obs_dim, action_dim):
        super(ActorCritic, self).__init__()

        # Shared Layer
        self.shared_layer = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )

        # Actor Head (Policy - outputs probabilities for actions)
        self.actor_head = nn.Sequential(
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1) # Softmax converts logits to probabilities
        )

        # Critic Head (Value - outputs V(s), the expected return from state s)
        self.critic_head = nn.Sequential(
            nn.Linear(64, 1) # Outputs a single scalar value
        )

    def forward(self, x):
        shared_output = self.shared_layer(x)
        action_probs = self.actor_head(shared_output)
        value = self.critic_head(shared_output)
        return action_probs, value

# --- 2. PPO Agent Class ---
class PPOAgent:
    def __init__(self, env):
        self.env = env
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n

        # Initialize the Actor-Critic model and optimizer
        self.model = ActorCritic(self.obs_dim, self.action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)

        # Store data for the training batch
        self.batch_data = {
            'states': [], 'actions': [], 'rewards': [],
            'log_probs': [], 'dones': [], 'values': []
        }

    def select_action(self, state):
        """Selects an action based on the current policy."""
        # Convert NumPy state array to PyTorch tensor
        state_tensor = torch.from_numpy(state).float().unsqueeze(0)

        # Get action probabilities and state value from the model
        with torch.no_grad():
            action_probs, value = self.model(state_tensor)

        # Sample an action from the categorical distribution defined by action_probs
        action_dist = torch.distributions.Categorical(action_probs)
        action = action_dist.sample()

        # Calculate the log probability of the sampled action
        log_prob = action_dist.log_prob(action)

        # Return action (as int), log_prob (scalar), and value (scalar)
        return action.item(), log_prob.item(), value.item()

    def store_transition(self, state, action, reward, log_prob, done, value):
        """Stores a step's experience for later batch training."""
        self.batch_data['states'].append(state)
        self.batch_data['actions'].append(action)
        self.batch_data['rewards'].append(reward)
        self.batch_data['log_probs'].append(log_prob)
        self.batch_data['dones'].append(done)
        self.batch_data['values'].append(value)

    def _compute_advantages_and_returns(self, next_value):
        """
        Computes the General Advantage Estimate (GAE) and the
        target returns for the value function loss.
        """
        rewards = np.array(self.batch_data['rewards'])
        values = np.array(self.batch_data['values'])
        dones = np.array(self.batch_data['dones'])

        # Append the value of the final state (or 0 if terminal)
        # Note: In CartPole, the episode ends when 'done' is True.
        # However, PPO needs the value of the *next* state, even if it's the last.
        # We handle this by setting the final advantage/return target correctly.

        # Compute discounted returns (targets for the critic)
        returns = np.zeros_like(rewards)
        # Compute advantages (used to scale the actor's policy updates)
        advantages = np.zeros_like(rewards)

        last_gae_lambda = 0

        # Iterate backwards through the trajectory to compute GAE and returns
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                # Delta at the last step uses the value of the next state (or 0 if done)
                next_non_terminal = 1.0 - dones[t]
                delta = rewards[t] + GAMMA * next_value * next_non_terminal - values[t]
                returns[t] = rewards[t] + GAMMA * next_value * next_non_terminal
            else:
                next_non_terminal = 1.0 - dones[t+1]
                # Delta is the temporal difference error
                delta = rewards[t] + GAMMA * values[t+1] * next_non_terminal - values[t]
                returns[t] = rewards[t] + GAMMA * returns[t+1] * next_non_terminal # Simple discounted return

            # GAE is an exponentially weighted average of the TD errors
            advantages[t] = last_gae_lambda = delta + GAMMA * GAE_LAMBDA * next_non_terminal * last_gae_lambda

        # Advantages are normalized to stabilize training
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return advantages, returns

    def learn(self, next_state):
        """
        Performs the PPO update using the collected batch data.
        """
        print("--- PPO Update Cycle Started ---")

        # Get the value of the final state in the batch
        with torch.no_grad():
            next_state_tensor = torch.from_numpy(next_state).float().unsqueeze(0)
            _, next_value = self.model(next_state_tensor)
            next_value = next_value.item()

        # 1. Calculate Advantages (GAE) and Returns (Targets)
        advantages, returns = self._compute_advantages_and_returns(next_value)

        # 2. Convert collected data to Tensors
        states = torch.tensor(np.array(self.batch_data['states']), dtype=torch.float)
        actions = torch.tensor(np.array(self.batch_data['actions']), dtype=torch.long)
        old_log_probs = torch.tensor(np.array(self.batch_data['log_probs']), dtype=torch.float)
        returns = torch.tensor(returns, dtype=torch.float).unsqueeze(1)
        advantages = torch.tensor(advantages, dtype=torch.float).unsqueeze(1)

        # 3. PPO Epochs
        data_size = len(states)
        indices = np.arange(data_size)

        for _ in range(PPO_EPOCHS):
            # Shuffle indices for stochastic gradient descent
            np.random.shuffle(indices)

            # Iterate through mini-batches
            for start in range(0, data_size, BATCH_SIZE):
                end = start + BATCH_SIZE
                batch_indices = indices[start:end]

                # Extract mini-batch
                b_states = states[batch_indices]
                b_actions = actions[batch_indices]
                b_old_log_probs = old_log_probs[batch_indices].unsqueeze(1)
                b_returns = returns[batch_indices]
                b_advantages = advantages[batch_indices]

                # Get new action probabilities and state values
                action_probs, new_values = self.model(b_states)

                # Calculate new log probabilities and entropy
                dist = torch.distributions.Categorical(action_probs)
                new_log_probs = dist.log_prob(b_actions).unsqueeze(1)
                entropy = dist.entropy().mean()

                # --- PPO Loss Calculation ---

                # Calculate the probability ratio (r_t)
                ratio = torch.exp(new_log_probs - b_old_log_probs)

                # 1. Actor Loss (Policy Loss)
                # Unclipped loss
                surr1 = ratio * b_advantages
                # Clipped loss
                surr2 = torch.clamp(ratio, 1.0 - CLIP_EPSILON, 1.0 + CLIP_EPSILON) * b_advantages
                # Policy loss takes the minimum of the two to penalize large policy updates
                policy_loss = -torch.min(surr1, surr2).mean()

                # 2. Critic Loss (Value Loss)
                # This is the Mean Squared Error between the predicted value and the GAE returns
                value_loss = (new_values - b_returns).pow(2).mean()

                # 3. Total Loss
                # The entropy term is added to encourage exploration and prevent premature convergence.
                total_loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

                # Optimization step
                self.optimizer.zero_grad()
                total_loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) # Gradient clipping for stability
                self.optimizer.step()

        # 4. Clear the batch data after the update
        self.batch_data = {
            'states': [], 'actions': [], 'rewards': [],
            'log_probs': [], 'dones': [], 'values': []
        }

        print("--- PPO Update Cycle Finished ---")


# --- 3. Main Training Loop ---
def main():
    # Use a standard Gym environment
    env = gym.make("CartPole-v1")
    agent = PPOAgent(env)

    # Track episode scores for monitoring
    scores = deque(maxlen=100)

    total_timesteps = 0
    i_episode = 0

    print(f"Starting PPO training on CartPole-v1...")

    while total_timesteps < MAX_TIMESTEPS:
        state, _ = env.reset()
        episode_reward = 0
        i_episode += 1

        for t in range(TIMESTEPS_PER_BATCH):
            # 1. Interact with the environment
            action, log_prob, value = agent.select_action(state)

            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            # 2. Store the transition
            agent.store_transition(state, action, reward, log_prob, done, value)

            state = next_state
            episode_reward += reward
            total_timesteps += 1

            # If episode ends, reset the environment and record the score
            if done:
                scores.append(episode_reward)
                mean_score = np.mean(scores)

                print(f"Episode {i_episode:5d} | Timestep {total_timesteps:7d} | Score: {episode_reward:5.1f} | 100-Avg: {mean_score:5.2f}")

                # Check for "solving" the environment (CartPole-v1 is solved when avg score >= 195 over 100 episodes)
                if mean_score >= 195.0:
                    print("\n*** ENVIRONMENT SOLVED! Average score >= 195.0 over 100 episodes. ***\n")
                    # Break the inner loop to trigger the PPO update one last time
                    break

                # Reset environment for the next episode
                state, _ = env.reset()
                episode_reward = 0

        # 3. Perform the PPO Update after collecting the batch of data
        agent.learn(state) # state is the starting state of the next interaction batch

    env.close()

if __name__ == '__main__':
    # Set the seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Ensure all collected data is used before training
    # The batch size is set large enough to collect a full trajectory
    main()

Starting PPO training on CartPole-v1...
Episode     1 | Timestep      20 | Score:  20.0 | 100-Avg: 20.00
Episode     1 | Timestep      34 | Score:  14.0 | 100-Avg: 17.00
Episode     1 | Timestep      50 | Score:  16.0 | 100-Avg: 16.67
Episode     1 | Timestep      86 | Score:  36.0 | 100-Avg: 21.50
Episode     1 | Timestep     133 | Score:  47.0 | 100-Avg: 26.60
Episode     1 | Timestep     147 | Score:  14.0 | 100-Avg: 24.50
Episode     1 | Timestep     167 | Score:  20.0 | 100-Avg: 23.86
Episode     1 | Timestep     180 | Score:  13.0 | 100-Avg: 22.50
Episode     1 | Timestep     229 | Score:  49.0 | 100-Avg: 25.44
Episode     1 | Timestep     241 | Score:  12.0 | 100-Avg: 24.10
Episode     1 | Timestep     267 | Score:  26.0 | 100-Avg: 24.27
Episode     1 | Timestep     287 | Score:  20.0 | 100-Avg: 23.92
Episode     1 | Timestep     299 | Score:  12.0 | 100-Avg: 23.00
Episode     1 | Timestep     336 | Score:  37.0 | 100-Avg: 24.00
Episode     1 | Timestep     376 | Score:  40.0 | 