In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Actor-Critic Methods: The Student-Teacher Framework

*Part 3 of the Vizuara series on Policy Gradient Methods*
*Estimated time: 55 minutes*

## 1. Why Does This Matter?

In REINFORCE with baseline, we used the mean return as a simple baseline. But we can do much better. What if the baseline could **learn** to predict how good each state is?

This is the Actor-Critic architecture — one of the most important ideas in modern reinforcement learning. The "actor" is the policy (it picks actions), and the "critic" is a value network (it evaluates how good each state is). Together, they form a powerful learning system.

Actor-Critic methods are the foundation of PPO, the algorithm used to fine-tune ChatGPT. They are also used in robotics (learning to walk, manipulate objects) and game playing.

By the end of this notebook, you will:
- Build an Actor-Critic agent with separate policy and value networks
- Train both networks simultaneously on CartPole
- Compare Actor-Critic against vanilla REINFORCE
- Visualize the learned value function
- See training converge significantly faster

## 2. Building Intuition

Think of a student and a teacher. The student (actor) takes an exam and gives answers. The teacher (critic) grades the exam — but not just right/wrong. The teacher says "this answer was 3 points above average" or "this was 2 points below average."

This relative feedback is far more useful than raw scores. If every student scores between 80-100, saying "you got 90" does not tell you much. But saying "you were 5 points above the class average" tells you exactly how well you did relative to expectations.

The advantage function is this relative feedback:

$A(s, a) = Q(s, a) - V(s)$

"How much better (or worse) was action $a$ compared to what we expected in state $s$?"

The critic learns $V(s)$ (what to expect). The actor uses the advantage (actual outcome minus expectation) to improve.

### Think About This

In CartPole, a state where the pole is nearly vertical is "good" (high V). A state where the pole is almost horizontal is "bad" (low V). If the actor takes an action that keeps the pole balanced from a bad state, what sign would the advantage have? What would this tell the actor?

## 3. The Mathematics

### 3.1 The Actor-Critic Update

The actor update uses the advantage:

$$\theta_{\text{actor}} \leftarrow \theta_{\text{actor}} + \alpha_a \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t$$

where $A_t = G_t - V_\phi(s_t)$.

The critic update minimizes the prediction error:

$$\theta_{\text{critic}} \leftarrow \theta_{\text{critic}} - \alpha_c \nabla_\phi \sum_t (G_t - V_\phi(s_t))^2$$

Let us plug in numbers. Suppose at time step $t$:
- The actual return is $G_t = 15.0$
- The critic predicts $V(s_t) = 12.0$
- So the advantage is $A_t = 15.0 - 12.0 = 3.0$

This positive advantage tells the actor: "This action was 3 units better than expected. Increase its probability."

The critic sees its error: it predicted 12 but the truth was 15. So it adjusts to predict higher values for similar states.

### 3.2 Why Two Networks?

The actor and critic have different jobs:
- **Actor** outputs a probability distribution over actions (softmax output)
- **Critic** outputs a single scalar value estimate

They share the same input (state) but produce fundamentally different outputs. Using two networks lets each specialize.

## 4. Let's Build It — Component by Component

### 4.1 The Actor Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym

torch.manual_seed(42)
np.random.seed(42)

class Actor(nn.Module):
    """The policy network — selects actions."""
    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

    def forward(self, state):
        logits = self.net(state)
        return F.softmax(logits, dim=-1)

    def sample_action(self, state):
        probs = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

# Test actor
actor = Actor(state_dim=4, n_actions=2)
test_state = torch.randn(4)
probs = actor(test_state)
print(f"Actor output (action probs): {probs.detach().numpy().round(4)}")
print(f"Sum: {probs.sum().item():.4f}")

### 4.2 The Critic Network

In [None]:
class Critic(nn.Module):
    """The value network — evaluates states."""
    def __init__(self, state_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)

# Test critic
critic = Critic(state_dim=4)
test_state = torch.randn(4)
value = critic(test_state)
print(f"Critic output (state value): {value.item():.4f}")

In [None]:
# Visualization: Actor and Critic architecture
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Actor diagram
layers_actor = ['State\n(4 dims)', 'Hidden\n(128 ReLU)', 'Action Probs\n(2 softmax)']
colors_actor = ['#f97316', '#fb923c', '#fdba74']
for i, (label, color) in enumerate(zip(layers_actor, colors_actor)):
    rect = plt.Rectangle((i*2, 0), 1.5, 1, facecolor=color, edgecolor='black', linewidth=1.5)
    ax1.add_patch(rect)
    ax1.text(i*2 + 0.75, 0.5, label, ha='center', va='center', fontsize=10, fontweight='bold')
    if i < len(layers_actor) - 1:
        ax1.annotate('', xy=(i*2+2, 0.5), xytext=(i*2+1.5, 0.5),
                     arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
ax1.set_xlim(-0.5, 6.5)
ax1.set_ylim(-0.5, 1.5)
ax1.set_title('Actor Network (Policy)', fontsize=14, fontweight='bold')
ax1.axis('off')

# Critic diagram
layers_critic = ['State\n(4 dims)', 'Hidden\n(128 ReLU)', 'Value\n(1 scalar)']
colors_critic = ['#3b82f6', '#60a5fa', '#93c5fd']
for i, (label, color) in enumerate(zip(layers_critic, colors_critic)):
    rect = plt.Rectangle((i*2, 0), 1.5, 1, facecolor=color, edgecolor='black', linewidth=1.5)
    ax2.add_patch(rect)
    ax2.text(i*2 + 0.75, 0.5, label, ha='center', va='center', fontsize=10, fontweight='bold')
    if i < len(layers_critic) - 1:
        ax2.annotate('', xy=(i*2+2, 0.5), xytext=(i*2+1.5, 0.5),
                     arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
ax2.set_xlim(-0.5, 6.5)
ax2.set_ylim(-0.5, 1.5)
ax2.set_title('Critic Network (Value)', fontsize=14, fontweight='bold')
ax2.axis('off')

plt.tight_layout()
plt.show()
print("The Actor selects actions. The Critic evaluates states.")

### 4.3 Computing Returns and Advantages

In [None]:
GAMMA = 0.99

def compute_returns(rewards, gamma=GAMMA):
    """Compute discounted returns G_t for each timestep."""
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    return returns

def compute_advantages(returns, states, critic):
    """Compute advantages A_t = G_t - V(s_t)."""
    returns_t = torch.tensor(returns, dtype=torch.float32)
    states_t = torch.tensor(np.array(states), dtype=torch.float32)

    with torch.no_grad():
        values = critic(states_t)

    advantages = returns_t - values
    return advantages, returns_t, values

# Example
dummy_returns = [5.0, 4.0, 3.0]
dummy_states = [np.zeros(4) for _ in range(3)]
advs, rets, vals = compute_advantages(dummy_returns, dummy_states, critic)
print(f"Returns:     {rets.numpy().round(2)}")
print(f"Values V(s): {vals.numpy().round(2)}")
print(f"Advantages:  {advs.numpy().round(2)}")

## 5. Your Turn

### TODO: Implement the Actor-Critic Training Step

In [None]:
def actor_critic_step(actor, critic, actor_optimizer, critic_optimizer,
                      states, log_probs, rewards):
    """
    Perform one Actor-Critic update step after collecting an episode.

    Args:
        actor: Actor network
        critic: Critic network
        actor_optimizer: Optimizer for actor
        critic_optimizer: Optimizer for critic
        states: list of state observations
        log_probs: list of log probabilities (from action sampling)
        rewards: list of rewards

    Returns:
        actor_loss: float, the actor loss value
        critic_loss: float, the critic loss value
    """
    # Step 1: Compute returns
    returns = compute_returns(rewards)
    returns_t = torch.tensor(returns, dtype=torch.float32)
    states_t = torch.tensor(np.array(states), dtype=torch.float32)

    # ============ TODO ============
    # Step 2: Get value predictions from critic
    # values = critic(states_t)
    #
    # Step 3: Compute advantages (returns - values.detach())
    # advantages = ???
    #
    # Step 4: Actor loss = -sum(log_probs * advantages)
    # log_probs_t = torch.stack(log_probs)
    # actor_loss = ???
    #
    # Step 5: Critic loss = MSE(returns, values)
    # critic_loss = ???
    #
    # Step 6: Update actor
    # actor_optimizer.zero_grad()
    # actor_loss.backward()
    # actor_optimizer.step()
    #
    # Step 7: Update critic
    # critic_optimizer.zero_grad()
    # critic_loss.backward()
    # critic_optimizer.step()
    # ==============================

    return actor_loss.item(), critic_loss.item()

In [None]:
# Verification
env = gym.make("CartPole-v1")
test_actor = Actor(4, 2)
test_critic = Critic(4)
test_a_opt = torch.optim.Adam(test_actor.parameters(), lr=0.01)
test_c_opt = torch.optim.Adam(test_critic.parameters(), lr=0.01)

# Collect a test episode
states, actions, rewards, log_probs = [], [], [], []
state, _ = env.reset()
for _ in range(10):
    st = torch.as_tensor(state, dtype=torch.float32)
    action, lp = test_actor.sample_action(st)
    ns, r, term, trunc, _ = env.step(action)
    states.append(state)
    log_probs.append(lp)
    rewards.append(r)
    state = ns
    if term or trunc:
        break

a_loss, c_loss = actor_critic_step(test_actor, test_critic, test_a_opt, test_c_opt,
                                    states, log_probs, rewards)
assert isinstance(a_loss, float), "Actor loss should be a float"
assert isinstance(c_loss, float), "Critic loss should be a float"
print(f"Actor loss: {a_loss:.4f}")
print(f"Critic loss: {c_loss:.4f}")
print("Correct! Actor-Critic training step works.")
env.close()

### TODO: Implement Episode Collection with Value Tracking

In [None]:
def collect_episode_with_values(env, actor, critic):
    """
    Collect an episode and also record the critic's value estimates.

    Returns:
        states, actions, rewards, log_probs, values
    """
    states, actions, rewards, log_probs, values = [], [], [], [], []
    state, _ = env.reset()
    done = False

    while not done:
        state_t = torch.as_tensor(state, dtype=torch.float32)

        # ============ TODO ============
        # Step 1: Get action and log_prob from actor
        # Step 2: Get value estimate from critic (detached)
        # Step 3: Step the environment
        # Step 4: Store everything
        # ==============================

        pass  # YOUR CODE HERE

    return states, actions, rewards, log_probs, values

In [None]:
# Verification
env = gym.make("CartPole-v1")
actor = Actor(4, 2)
critic = Critic(4)

s, a, r, lp, v = collect_episode_with_values(env, actor, critic)
assert len(s) == len(a) == len(r) == len(lp) == len(v)
assert all(isinstance(vi, float) for vi in v)
print(f"Collected episode with {len(s)} steps")
print(f"First 5 values: {[f'{vi:.2f}' for vi in v[:5]]}")
print("Correct! Episode collection with value tracking works.")
env.close()

## 6. Putting It All Together

In [None]:
def train_actor_critic(env_name="CartPole-v1", num_episodes=500, lr_actor=0.01, lr_critic=0.01):
    """Full Actor-Critic training loop."""
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    actor = Actor(state_dim, n_actions)
    critic = Critic(state_dim)
    actor_opt = torch.optim.Adam(actor.parameters(), lr=lr_actor)
    critic_opt = torch.optim.Adam(critic.parameters(), lr=lr_critic)

    reward_history = []
    actor_losses = []
    critic_losses = []

    for episode in range(num_episodes):
        # Collect episode
        states, actions, rewards, log_probs = [], [], [], []
        state, _ = env.reset()
        done = False

        while not done:
            state_t = torch.as_tensor(state, dtype=torch.float32)
            action, log_prob = actor.sample_action(state_t)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            log_probs.append(log_prob)
            state = next_state

        # Compute returns and update
        returns = compute_returns(rewards)
        returns_t = torch.tensor(returns, dtype=torch.float32)
        states_t = torch.tensor(np.array(states), dtype=torch.float32)

        # Critic values and advantages
        values = critic(states_t)
        advantages = (returns_t - values.detach())

        # Normalize advantages
        if len(advantages) > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Actor loss
        log_probs_t = torch.stack(log_probs)
        actor_loss = -(log_probs_t * advantages).sum()

        # Critic loss
        critic_loss = F.mse_loss(values, returns_t)

        # Update actor
        actor_opt.zero_grad()
        actor_loss.backward()
        actor_opt.step()

        # Update critic
        critic_opt.zero_grad()
        critic_loss.backward()
        critic_opt.step()

        ep_reward = sum(rewards)
        reward_history.append(ep_reward)
        actor_losses.append(actor_loss.item())
        critic_losses.append(critic_loss.item())

        if (episode + 1) % 100 == 0:
            avg = np.mean(reward_history[-100:])
            print(f"Ep {episode+1:4d} | Avg Reward: {avg:.1f} | "
                  f"Actor Loss: {np.mean(actor_losses[-100:]):.2f} | "
                  f"Critic Loss: {np.mean(critic_losses[-100:]):.2f}")

    env.close()
    return reward_history, actor_losses, critic_losses, actor, critic

# Train Actor-Critic
print("=" * 60)
print("Training: Actor-Critic")
print("=" * 60)
ac_rewards, ac_actor_losses, ac_critic_losses, trained_actor, trained_critic = \
    train_actor_critic(num_episodes=500)

## 7. Training and Results

In [None]:
# Compare all three methods
# First, train REINFORCE variants for comparison
def train_reinforce_simple(num_episodes=500, use_baseline=False):
    env = gym.make("CartPole-v1")
    policy = Actor(4, 2)  # Reuse Actor class
    opt = torch.optim.Adam(policy.parameters(), lr=0.01)
    rewards = []

    for ep in range(num_episodes):
        states, actions, rews, log_probs = [], [], [], []
        state, _ = env.reset()
        done = False
        while not done:
            st = torch.as_tensor(state, dtype=torch.float32)
            action, lp = policy.sample_action(st)
            ns, r, term, trunc, _ = env.step(action)
            done = term or trunc
            states.append(state); rews.append(r); log_probs.append(lp)
            state = ns

        rets = compute_returns(rews)
        rets_t = torch.tensor(rets, dtype=torch.float32)
        if use_baseline:
            rets_t = rets_t - rets_t.mean()

        loss = -(torch.stack(log_probs) * rets_t).sum()
        opt.zero_grad(); loss.backward(); opt.step()
        rewards.append(sum(rews))
    env.close()
    return rewards

print("Training REINFORCE (no baseline)...")
reinforce_rewards = train_reinforce_simple(500, use_baseline=False)
print("Training REINFORCE + Baseline...")
baseline_rewards = train_reinforce_simple(500, use_baseline=True)

In [None]:
# Grand comparison plot
fig, ax = plt.subplots(figsize=(12, 6))
window = 30

for data, label, color, ls in [
    (reinforce_rewards, 'REINFORCE', '#ef4444', '-'),
    (baseline_rewards, 'REINFORCE + Baseline', '#f59e0b', '--'),
    (ac_rewards, 'Actor-Critic', '#3b82f6', '-'),
]:
    if len(data) >= window:
        smoothed = np.convolve(data, np.ones(window)/window, mode='valid')
        ax.plot(range(window-1, len(data)), smoothed, label=label,
                color=color, linewidth=2.5, linestyle=ls)

ax.axhline(y=500, color='gray', linestyle=':', alpha=0.5, label='Max Reward')
ax.set_xlabel('Episode', fontsize=13)
ax.set_ylabel('Episode Reward', fontsize=13)
ax.set_title('Policy Gradient Methods: Head-to-Head Comparison', fontsize=15)
ax.legend(fontsize=12, loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 550)
plt.tight_layout()
plt.show()

# Print statistics
print("\n--- Final 100-Episode Average ---")
for name, data in [("REINFORCE", reinforce_rewards),
                    ("+ Baseline", baseline_rewards),
                    ("Actor-Critic", ac_rewards)]:
    avg = np.mean(data[-100:])
    print(f"  {name:14s}: {avg:.1f}")

In [None]:
# Actor-Critic loss curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
window = 30

# Actor loss
smoothed_actor = np.convolve(ac_actor_losses, np.ones(window)/window, mode='valid')
ax1.plot(range(window-1, len(ac_actor_losses)), smoothed_actor, color='#f97316', linewidth=2)
ax1.set_xlabel('Episode', fontsize=12)
ax1.set_ylabel('Actor Loss', fontsize=12)
ax1.set_title('Actor Loss Over Training', fontsize=14)
ax1.grid(True, alpha=0.3)

# Critic loss
smoothed_critic = np.convolve(ac_critic_losses, np.ones(window)/window, mode='valid')
ax2.plot(range(window-1, len(ac_critic_losses)), smoothed_critic, color='#3b82f6', linewidth=2)
ax2.set_xlabel('Episode', fontsize=12)
ax2.set_ylabel('Critic Loss (MSE)', fontsize=12)
ax2.set_title('Critic Loss Over Training', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print("Watch how the critic loss decreases as V(s) gets more accurate,")
print("which in turn provides better advantage estimates for the actor.")

## 8. Final Output

In [None]:
# Visualize the learned value function
print("=" * 60)
print("VISUALIZING THE LEARNED VALUE FUNCTION")
print("=" * 60)

# Create a grid of states varying cart position and pole angle
cart_positions = np.linspace(-2.4, 2.4, 50)
pole_angles = np.linspace(-0.2, 0.2, 50)

value_grid = np.zeros((50, 50))
for i, cp in enumerate(cart_positions):
    for j, pa in enumerate(pole_angles):
        state = torch.tensor([cp, 0.0, pa, 0.0], dtype=torch.float32)
        with torch.no_grad():
            value_grid[j, i] = trained_critic(state).item()

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Value function heatmap
im = axes[0].imshow(value_grid, origin='lower', aspect='auto',
                     extent=[-2.4, 2.4, -0.2, 0.2],
                     cmap='RdYlGn')
axes[0].set_xlabel('Cart Position', fontsize=12)
axes[0].set_ylabel('Pole Angle (rad)', fontsize=12)
axes[0].set_title('Learned Value Function V(s)', fontsize=14)
plt.colorbar(im, ax=axes[0], label='Estimated Value')
axes[0].axhline(y=0, color='white', linestyle='--', alpha=0.5)
axes[0].axvline(x=0, color='white', linestyle='--', alpha=0.5)

# Policy heatmap (probability of pushing right)
policy_grid = np.zeros((50, 50))
for i, cp in enumerate(cart_positions):
    for j, pa in enumerate(pole_angles):
        state = torch.tensor([cp, 0.0, pa, 0.0], dtype=torch.float32)
        with torch.no_grad():
            probs = trained_actor(state)
            policy_grid[j, i] = probs[1].item()  # P(push right)

im2 = axes[1].imshow(policy_grid, origin='lower', aspect='auto',
                      extent=[-2.4, 2.4, -0.2, 0.2],
                      cmap='coolwarm', vmin=0, vmax=1)
axes[1].set_xlabel('Cart Position', fontsize=12)
axes[1].set_ylabel('Pole Angle (rad)', fontsize=12)
axes[1].set_title('Learned Policy: P(push right)', fontsize=14)
plt.colorbar(im2, ax=axes[1], label='P(right)')
axes[1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
axes[1].axvline(x=0, color='black', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print("\nLeft plot: States near center (cart at 0, pole upright) have highest value.")
print("Right plot: Policy pushes right when pole tilts right, left when pole tilts left.")
print("\nCongratulations! You have built Actor-Critic from scratch!")
print("This is the foundation of PPO, the algorithm used to fine-tune ChatGPT.")

## 9. Reflection and Next Steps

### Reflection Questions
1. Why do we normalize advantages before computing the actor loss? What happens if advantages are all large and positive?
2. The critic loss is MSE between predicted values and actual returns. Why is this a good loss function? What alternatives exist?
3. In Actor-Critic, the actor and critic share no parameters. What would happen if they shared the early layers of the network? What are the trade-offs?

### Optional Challenges
1. Implement a shared backbone: one network with shared hidden layers and two separate output heads (policy and value).
2. Add entropy regularization to encourage exploration: subtract $\beta H(\pi)$ from the actor loss.
3. Try the LunarLander-v3 environment (4 discrete actions, 8-dimensional state). Does Actor-Critic still outperform REINFORCE?