# Stage 2 — DQN Variants Comparison

Comparative analysis of three DQN improvements on CartPole-v1:

| Variant | Key Idea |
|---------|----------|
| **Vanilla DQN** | Experience Replay + Target Network |
| **Double DQN** | Decouple selection & evaluation to reduce overestimation |
| **Dueling DQN** | Separate Value and Advantage streams |

All variants can optionally use **Prioritized Experience Replay**.

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt
import gymnasium as gym

from src.agents import DQNAgent, DDQNAgent, DuelingDQNAgent
from src.environments.wrappers import EpisodeStatsWrapper
from src.training.trainer import Trainer
from src.training.evaluator import Evaluator
from src.utils.config_loader import load_config, get_device, merge_configs
from src.utils.logger import RLLogger
from src.utils.plotting import plot_comparison

print(f"PyTorch {torch.__version__}")

## 1. Configuration

We use the DDQN config as a base and override the agent type for each variant.

In [None]:
base_config = load_config(PROJECT_ROOT / "config" / "cartpole_ddqn.yaml")
device = get_device(base_config)
print(f"Device: {device}")

# Quick training for comparison (reduce episodes for demo)
N_EPISODES = 300

variants = {
    "DQN": {"agent": {"type": "dqn", "buffer_type": "standard", "n_episodes": N_EPISODES}},
    "DDQN": {"agent": {"type": "ddqn", "buffer_type": "standard", "n_episodes": N_EPISODES}},
    "DDQN + PER": {"agent": {"type": "ddqn", "buffer_type": "prioritized", "n_episodes": N_EPISODES}},
    "Dueling DQN": {"agent": {"type": "dueling_dqn", "buffer_type": "standard", "n_episodes": N_EPISODES}},
    "Dueling + PER": {"agent": {"type": "dueling_dqn", "buffer_type": "prioritized", "n_episodes": N_EPISODES}},
}
print(f"Will compare {len(variants)} variants over {N_EPISODES} episodes each.")

## 2. Train All Variants

⚠️ This cell trains 5 agents sequentially — takes a few minutes.

In [None]:
from src.agents.dqn_agent import DQNAgent
from src.agents.ddqn_agent import DDQNAgent
from src.agents.dueling_dqn_agent import DuelingDQNAgent

AGENT_CLS = {"dqn": DQNAgent, "ddqn": DDQNAgent, "dueling_dqn": DuelingDQNAgent}
all_results = {}

for name, overrides in variants.items():
    print(f"\n{'='*50}")
    print(f"Training: {name}")
    print(f"{'='*50}")

    cfg = merge_configs(base_config, overrides)
    cfg["experiment"]["name"] = name.lower().replace(" ", "_").replace("+", "")

    np.random.seed(cfg["experiment"].get("seed", 42))
    torch.manual_seed(cfg["experiment"].get("seed", 42))

    train_env = EpisodeStatsWrapper(gym.make("CartPole-v1"))
    eval_env = EpisodeStatsWrapper(gym.make("CartPole-v1"))

    agent_cls = AGENT_CLS[cfg["agent"]["type"]]
    agent = agent_cls(cfg, device)
    logger = RLLogger(cfg["experiment"]["name"], use_tensorboard=False)
    evaluator = Evaluator(eval_env, cfg)
    trainer = Trainer(agent, train_env, cfg, logger, evaluator)
    history = trainer.train()

    all_results[name] = history["rewards"]
    train_env.close()
    eval_env.close()
    logger.close()

print("\nAll variants trained!")

## 3. Reward Comparison

In [None]:
fig = plot_comparison(
    results=all_results,
    title="DQN Variants — CartPole-v1",
    window=20,
)
plt.show()

## 4. Statistical Summary

In [None]:
print(f"{'Variant':<25} {'Mean(last50)':>12} {'Std':>8} {'Max':>8} {'Solved@':>10}")
print("─" * 65)

for name, rewards in all_results.items():
    last50 = rewards[-50:] if len(rewards) >= 50 else rewards
    mean_r = np.mean(last50)
    std_r = np.std(last50)
    max_r = np.max(rewards)

    # Find first episode where running avg(100) >= 475
    solved_ep = "N/A"
    running = []
    for i, r in enumerate(rewards):
        running.append(r)
        if len(running) > 100:
            running.pop(0)
        if len(running) == 100 and np.mean(running) >= 475:
            solved_ep = str(i + 1)
            break

    print(f"{name:<25} {mean_r:>12.1f} {std_r:>8.1f} {max_r:>8.0f} {solved_ep:>10}")

## 5. Learning Efficiency — Area Under the Curve

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

names = list(all_results.keys())
aucs = [np.trapz(all_results[n]) for n in names]
colours = plt.cm.tab10.colors[:len(names)]

ax.barh(names, aucs, color=colours)
ax.set_xlabel("Area Under Reward Curve (higher = faster learning)")
ax.set_title("Learning Efficiency Comparison")
fig.tight_layout()
plt.show()

## 6. Overestimation Analysis

Vanilla DQN tends to overestimate Q-values. Let's compare the max Q-values
predicted by DQN vs DDQN on a set of random states.

In [None]:
# Quick re-train small agents to compare Q-values
from src.agents.dqn_agent import DQNAgent
from src.agents.ddqn_agent import DDQNAgent

small_cfg = merge_configs(base_config, {"agent": {"n_episodes": 100, "buffer_type": "standard"}})

# Generate random states
n_states = 200
random_states = np.random.uniform(
    low=[-2.4, -3.0, -0.25, -3.0],
    high=[2.4, 3.0, 0.25, 3.0],
    size=(n_states, 4),
).astype(np.float32)
states_t = torch.tensor(random_states, device=device)

# Check Q-values from already-trained agents (if available)
dqn_cfg = merge_configs(small_cfg, {"agent": {"type": "dqn"}})
ddqn_cfg = merge_configs(small_cfg, {"agent": {"type": "ddqn"}})

dqn_agent = DQNAgent(dqn_cfg, device)
ddqn_agent = DDQNAgent(ddqn_cfg, device)

with torch.no_grad():
    dqn_q = dqn_agent.online_net(states_t).max(dim=1).values.cpu().numpy()
    ddqn_q = ddqn_agent.online_net(states_t).max(dim=1).values.cpu().numpy()

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(dqn_q, bins=30, alpha=0.6, label="DQN max Q", color="tomato")
ax.hist(ddqn_q, bins=30, alpha=0.6, label="DDQN max Q", color="steelblue")
ax.set_xlabel("Max Q-value")
ax.set_ylabel("Count")
ax.set_title("Q-value Distribution (random states, untrained — for structure demo)")
ax.legend()
plt.show()