# Stage 1 — DQN on CartPole-v1

This notebook analyses the Vanilla DQN agent trained on the classic CartPole-v1 environment.

**Key concepts:**
- Experience Replay for breaking temporal correlations
- Target Network for stable TD targets
- Epsilon-greedy exploration with multiplicative decay
- Huber loss (Smooth L1) for robustness to outliers

In [None]:
import sys
from pathlib import Path

# Ensure project root is on the path
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt
import gymnasium as gym

from src.agents.dqn_agent import DQNAgent
from src.environments.wrappers import EpisodeStatsWrapper
from src.training.evaluator import Evaluator
from src.utils.config_loader import load_config, get_device
from src.utils.plotting import plot_training_curves

print(f"PyTorch {torch.__version__}")
print(f"Device: {torch.device('mps' if torch.backends.mps.is_available() else 'cpu')}")

## 1. Load Configuration & Build Agent

In [None]:
config = load_config(PROJECT_ROOT / "config" / "cartpole_dqn.yaml")
device = get_device(config)
print(f"Experiment: {config['experiment']['name']}")
print(f"Device: {device}")
print(f"Agent config: hidden_dims={config['agent']['hidden_dims']}, "
      f"lr={config['agent']['learning_rate']}, gamma={config['agent']['gamma']}")

agent = DQNAgent(config, device)
print(f"\nOnline network:\n{agent.online_net}")
n_params = sum(p.numel() for p in agent.online_net.parameters())
print(f"\nTotal parameters: {n_params:,}")

## 2. Load Trained Checkpoint (if available)

In [None]:
checkpoint_path = PROJECT_ROOT / "outputs" / "models" / "cartpole" / "checkpoint_best.pt"
if checkpoint_path.exists():
    ckpt = agent.load(checkpoint_path)
    episode_rewards = ckpt.get("episode_rewards", [])
    print(f"Loaded checkpoint from episode {ckpt.get('episode', '?')}")
    print(f"Best eval reward: {ckpt.get('best_eval_reward', 'N/A')}")
    print(f"Training episodes recorded: {len(episode_rewards)}")
else:
    episode_rewards = []
    print("No checkpoint found. Run training first:")
    print("  python train.py --config config/cartpole_dqn.yaml")

## 3. Training Curves

In [None]:
if episode_rewards:
    fig = plot_training_curves(
        rewards=episode_rewards,
        title="DQN — CartPole-v1 Training Progress",
        window=20,
    )
    plt.show()
else:
    print("No training data available yet.")

## 4. Greedy Evaluation

In [None]:
eval_env = EpisodeStatsWrapper(gym.make("CartPole-v1"))
evaluator = Evaluator(eval_env, config)
agent.epsilon = 0.0  # greedy

result = evaluator.evaluate(agent)
print("Evaluation (greedy policy):")
for k, v in result.items():
    print(f"  {k:<15}: {v:.2f}")

eval_env.close()

## 5. Visualise a Single Episode

In [None]:
env = gym.make("CartPole-v1")
state, _ = env.reset(seed=42)

positions, velocities, angles, rewards_ep = [], [], [], []
total_reward = 0.0

for step in range(500):
    action = agent.select_action(state, eval_mode=True)
    next_state, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    positions.append(state[0])
    velocities.append(state[1])
    angles.append(np.degrees(state[2]))
    rewards_ep.append(total_reward)
    state = next_state
    if terminated or truncated:
        break

env.close()

fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
axes[0].plot(positions, color="steelblue")
axes[0].set_ylabel("Cart Position")
axes[0].set_title(f"DQN CartPole Episode (Total Reward: {total_reward:.0f})")

axes[1].plot(angles, color="tomato")
axes[1].set_ylabel("Pole Angle (°)")

axes[2].plot(rewards_ep, color="green")
axes[2].set_ylabel("Cumulative Reward")
axes[2].set_xlabel("Step")

fig.tight_layout()
plt.show()

## 6. Q-Value Landscape

Visualise Q-values as a function of pole angle and angular velocity
(with position and velocity fixed at zero).

In [None]:
angles_range = np.linspace(-0.25, 0.25, 50)
ang_vel_range = np.linspace(-2.0, 2.0, 50)
AA, VV = np.meshgrid(angles_range, ang_vel_range)

q_left = np.zeros_like(AA)
q_right = np.zeros_like(AA)

for i in range(AA.shape[0]):
    for j in range(AA.shape[1]):
        state = np.array([0.0, 0.0, AA[i, j], VV[i, j]], dtype=np.float32)
        state_t = torch.tensor(state, device=device).unsqueeze(0)
        with torch.no_grad():
            q = agent.online_net(state_t).cpu().numpy()[0]
        q_left[i, j] = q[0]
        q_right[i, j] = q[1]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for ax, data, title in [
    (axes[0], q_left, "Q(s, left)"),
    (axes[1], q_right, "Q(s, right)"),
    (axes[2], q_right - q_left, "Advantage (right − left)"),
]:
    im = ax.contourf(AA, VV, data, levels=30, cmap="RdBu_r")
    ax.set_xlabel("Pole Angle (rad)")
    ax.set_ylabel("Angular Velocity (rad/s)")
    ax.set_title(title)
    fig.colorbar(im, ax=ax)

fig.suptitle("Q-Value Landscape (pos=0, vel=0)", fontsize=14, y=1.02)
fig.tight_layout()
plt.show()