# RL DQN Training

This notebook:
1. Loads offline RL tensors from `data/processed/rl_tensors_*.npz`
2. Loads DQN hyperparameters from `configs/model.yaml`
3. Trains a dueling DQN using `src/rl/dqn.py::train_dqn`
4. Saves the trained model to `models/`
5. Runs offline evaluation using `src/ope/offline_eval.py`:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value estimate of greedy policy
    - Action agreement with logged (behavior) policy


## 1. Imports & Paths

In [None]:
import os
import sys
from pathlib import Path
import torch

PROJECT_ROOT = Path(os.getcwd()).resolve().parent
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT:", PROJECT_ROOT)

In [None]:
from src.rl.dqn import (
    load_dqn_training_config,
    train_dqn,
    BullpenOfflineDataset,
    RLDatasetConfig,
)
from src.ope.offline_eval import (
    OfflineEvalConfig,
    load_model_and_dataset,
    evaluate_td_error_full_mse,
    direct_policy_value_estimate,
    compute_action_agreement,
)

## 2. Configurations

In [None]:
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
CONFIG_DIR = PROJECT_ROOT / "configs"
MODELS_DIR = PROJECT_ROOT / "models"

MODELS_DIR.mkdir(parents=True, exist_ok=True)

YEAR_TAG = "2022_2023"  # must match 01_build_dataset YEARS range
RL_TENSORS_PATH = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
MODEL_CFG_PATH = CONFIG_DIR / "model.yaml"
MODEL_OUT_PATH = MODELS_DIR / f"dqn_bullpen_{YEAR_TAG}.pt"

print("RL tensors:", RL_TENSORS_PATH)
print("Model config:", MODEL_CFG_PATH)
print("Model output:", MODEL_OUT_PATH)

## 3. Load Dataset & Build Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

train_cfg = load_dqn_training_config(
    model_config_path=MODEL_CFG_PATH,
    data_path=RL_TENSORS_PATH,
    device=device,
)

train_cfg

DQNTrainingConfig(data_path=PosixPath('../data/processed/rl_tensors_2022_2023.npz'), device='cpu', batch_size=512, lr=0.001, gamma=0.99, max_steps=50000, target_update_interval=1000, log_interval=1000, val_fraction=0.1, hidden_size=256, num_layers=3, dropout=0.05, yaml_num_actions=11)

In [None]:
ds = BullpenOfflineDataset(
    RLDatasetConfig(
        data_path=train_cfg.data_path,
        device=train_cfg.device,
    )
)

print("Dataset size:", len(ds))
print("State dim:", ds.state_dim)
print("Num actions:", ds.num_actions)
print("H (next hitters window):", ds.H)
print("R (max relievers per team):", ds.R)

Dataset size: 407660
State dim: 208
Num actions: 10


## 4. Create Dueling DQN Model + Trainer

This calls `train_dqn(train_cfg)`, which:
 - loads `BullpenOfflineDataset` from `train_cfg.data_path`
 - splits into train/val by `train_cfg.val_fraction`
 - trains a dueling DQN with a target network
 - logs TD-error periodically using `evaluate_td_error` in `dqn.py`

In [None]:
dqn_model = train_dqn(train_cfg)

## 5. Save trained model weights

In [None]:
torch.save(dqn_model.state_dict(), MODEL_OUT_PATH)
MODEL_OUT_PATH

## Offline Policy Evaluation (OPE)

Now we use `src/ope/offline_eval.py` to:
- load the saved model and dataset
- compute:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value of the greedy policy
    - Action agreement with the logged policy

In [None]:
ope_cfg = OfflineEvalConfig(
    model_config_path=MODEL_CFG_PATH,
    model_path=MODEL_OUT_PATH,
    tensors_path=RL_TENSORS_PATH,
    device=device,
    batch_size=2048,
    gamma=train_cfg.gamma,
)

eval_model, eval_ds, eval_loader = load_model_and_dataset(ope_cfg)

print("Eval dataset size:", len(eval_ds))
print("State dim:", eval_ds.state_dim)
print("Num actions:", eval_ds.num_actions)

## 6. Mean Squared TD Error (MSTE)

This is the mean squared Bellman residual over the full dataset.
It reuses `evaluate_td_error` from `dqn.py` under the hood,
passing `model` as both the online and target networks.


In [None]:
mste = evaluate_td_error_full_mse(
    model=eval_model,
    loader=eval_loader,
    gamma=ope_cfg.gamma,
    device=ope_cfg.device,
)

print(f"Mean Squared TD Error (MSTE): {mste:.6f}")

## 7. Direct Q-based value estimate (FQE-style Direct Method)

For each state `s`:
    - compute Q(s, a) for all actions
    - mask unavailable actions
    - take greedy action a* = argmax_a Q(s, a)
    - define V_hat(s) = Q(s, a*)

Then average V_hat(s) across the dataset as an estimate of V(pi_greedy).

In [None]:
dm_value = direct_policy_value_estimate(
    model=eval_model,
    loader=eval_loader,
    device=ope_cfg.device,
)

print(f"Direct Q-based value estimate (V(pi_greedy)): {dm_value:.6f}")

## 8. Action agreement with logged policy

How often does the greedy DQN action (respecting availability mask) match the logged (historical) action from the dataset?


In [None]:
agreement = compute_action_agreement(
    model=eval_model,
    loader=eval_loader,
    device=ope_cfg.device,
)

print(f"Action agreement (logged vs greedy): {agreement:.3%}")

## 9. Summary

In [None]:
print("========= FINAL DQN EVALUATION RESULTS =========")
print(f"TD Error (MSTE):              {mste:.6f}")
print(f"Direct Q-based V(pi_greedy):  {dm_value:.6f}")
print(f"Action agreement rate:        {agreement:.3%}")