# Lab 09 – Policy Gradient & Actor-Critic Methods Starter Notebook

## Overview
Introduce policy gradient techniques for problems where value-based methods struggle. Students will implement REINFORCE and add baselines or simple actor-critic variants to reduce variance.

## Objectives
- Derive and implement the REINFORCE algorithm with softmax policies.
- Incorporate baseline functions to reduce gradient variance.
- Extend to an actor-critic architecture with shared parameters.

## Pre-Lab Review
- Study [`old content/RL_Section8_pdf.pdf`](../../old%20content/RL_Section8_pdf.pdf) for policy gradient theorem derivations.
- Review any policy gradient code snippets archived in [`old content/ALL_WEEKS_V5 - Student.ipynb`](../../old%20content/ALL_WEEKS_V5%20-%20Student.ipynb).

## In-Lab Exercises
1. Implement the REINFORCE algorithm for a discrete-action environment (e.g., CartPole with softmax policy).
2. Add a learned baseline or value function to create a variance-reduced estimator.
3. Explore an actor-critic setup with simultaneous policy and value updates.
4. Compare learning curves with and without baselines.

## Deliverables
- Notebook containing REINFORCE and actor-critic implementations, plus comparison plots.
- Reflection on variance sources and strategies for stabilization.

## Resources
- [`old content/optimal.png`](../../old%20content/optimal.png) for visual discussion on optimal policies.
- Suggested reading: Sutton & Barto Chapter 13 or equivalent policy gradient tutorials.

### Policy Gradient (REINFORCE) Skeleton
Connect this template with the policy-gradient reading in `old content/RL_Section8_pdf.pdf`.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

try:
    import gymnasium as gym
except ImportError:
    raise ImportError("Install gymnasium to run REINFORCE.")

class PolicyNetwork(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        return self.net(x)

def reinforce(env_name='CartPole-v1', episodes=500, gamma=0.99):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    policy = PolicyNetwork(obs_dim, act_dim)
    optimizer = optim.Adam(policy.parameters(), lr=1e-3)

    for episode in range(episodes):
        log_probs = []
        rewards = []
        state, _ = env.reset()
        done = False
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32)
            action_probs = policy(state_tensor)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            done = terminated or truncated
            state = next_state

        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        loss = 0
        for log_prob, Gt in zip(log_probs, returns):
            loss -= log_prob * Gt
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (episode + 1) % 20 == 0:
            print(f"Episode {episode + 1}, total reward: {sum(rewards):.1f}")

    env.close()
    return policy

# trained_policy = reinforce(episodes=200)
