# Lab 08 – Function Approximation & Deep Q-Networks Starter Notebook

## Overview
Move beyond tabular methods by implementing a lightweight Deep Q-Network (DQN) for environments where state spaces are large or continuous. Students will focus on network design, replay buffers, and training stability.

## Objectives
- Configure a neural network to approximate action-value functions.
- Implement experience replay and target network updates.
- Monitor training curves and diagnose divergence or instability.

## Pre-Lab Review
- Watch a short DQN explainer video (instructor curated) to preview architecture choices.
- Review the comparative figure [`old content/DQN_vs_Q.png`](../../old%20content/DQN_vs_Q.png) to recall tabular vs. deep approaches.

## In-Lab Exercises
1. Set up the CartPole-v1 environment (or similar) using Gymnasium.
2. Implement a minimal DQN with replay buffer, target network, and ε-greedy exploration.
3. Train the agent for several episodes while logging loss, rewards, and ε schedule.
4. Experiment with hyperparameters (learning rate, batch size, target update frequency) and document findings.

## Deliverables
- Notebook detailing the DQN implementation, training curves, and tuning experiments.
- Reflection on stabilization tricks (replay, target networks) and remaining challenges.

## Resources
- [`old content/optimal.png`](../../old%20content/optimal.png) to connect deep RL performance with optimal policy intuition.
- Links to PyTorch/TensorFlow quickstarts for neural network setup.

### Deep Q-Network Scaffold
Start from this lightweight DQN loop—align it with the DQN discussion in `old content/DQN_vs_Q.png` and related slides.

In [None]:
import collections
import random

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import gymnasium as gym
except ImportError as exc:
    raise ImportError("Install torch and gymnasium to run the DQN example.") from exc

class DQN(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )

    def forward(self, x):
        return self.net(x)

class ReplayBuffer:
    def __init__(self, capacity=50000):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(torch.tensor, zip(*batch))
        return states.float(), actions.long(), rewards.float(), next_states.float(), dones.float()

    def __len__(self):
        return len(self.buffer)

def dqn_train(env_name='CartPole-v1', episodes=200):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    policy_net = DQN(obs_dim, act_dim)
    target_net = DQN(obs_dim, act_dim)
    target_net.load_state_dict(policy_net.state_dict())
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)

    buffer = ReplayBuffer()
    gamma = 0.99
    batch_size = 64
    epsilon = 1.0
    epsilon_min = 0.05
    epsilon_decay = 0.995

    for episode in range(episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32)
        done = False
        total_reward = 0
        while not done:
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_values = policy_net(state)
                    action = int(torch.argmax(q_values))
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
            buffer.push((state.numpy(), action, reward, next_state, float(done)))
            state = next_state_tensor
            total_reward += reward

            if len(buffer) >= batch_size:
                states, actions, rewards, next_states, dones = buffer.sample(batch_size)
                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                next_q_values = target_net(next_states).max(1)[0]
                targets = rewards + gamma * next_q_values * (1 - dones)
                loss = nn.functional.mse_loss(q_values, targets.detach())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        if episode % 10 == 0:
            target_net.load_state_dict(policy_net.state_dict())
            print(f"Episode {episode}, reward {total_reward:.1f}, epsilon {epsilon:.3f}")

    env.close()
    return policy_net

# trained_policy = dqn_train(episodes=50)
