# RL DQN Training

This notebook:
1. Loads offline RL tensors from `data/processed/rl_tensors_*.npz`
2. Loads DQN hyperparameters from `configs/model.yaml`
3. Trains a dueling DQN using `src/rl/dqn.py::train_dqn`
4. Saves the trained model to `models/`
5. Runs offline evaluation using `src/ope/offline_eval.py`:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value estimate of greedy policy
    - Action agreement with logged (behavior) policy


## 1. Imports & Paths

In [1]:
import os
import sys
from pathlib import Path
import torch

PROJECT_ROOT = Path(os.getcwd()).resolve().parent
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT:", PROJECT_ROOT)

PROJECT_ROOT: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy


In [None]:
from src.rl.dqn import (
    load_dqn_training_config,
    train_dqn,
    BullpenOfflineDataset,
    RLDatasetConfig,
)
from src.ope.offline_eval import (
    OfflineEvalConfig,
    load_model_and_dataset,
    evaluate_td_error_full_mse,
    direct_policy_value_estimate,
    compute_policy_behavior_stats,
    compute_q_distributions,
    summarize_policy_behavior_stats,
    summarize_q_distributions,
)

## 2. Configurations

In [3]:
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
CONFIG_DIR = PROJECT_ROOT / "configs"
MODELS_DIR = PROJECT_ROOT / "models"

MODELS_DIR.mkdir(parents=True, exist_ok=True)

YEAR_TAG = "2022_2023"  # must match 01_build_dataset YEARS range
RL_TENSORS_PATH = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
MODEL_CFG_PATH = CONFIG_DIR / "model.yaml"
MODEL_OUT_PATH = MODELS_DIR / f"dqn_bullpen_{YEAR_TAG}.pt"

print("RL tensors:", RL_TENSORS_PATH)
print("Model config:", MODEL_CFG_PATH)
print("Model output:", MODEL_OUT_PATH)

RL tensors: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz
Model config: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/configs/model.yaml
Model output: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/models/dqn_bullpen_2022_2023.pt


## 3. Load Dataset & Build Model

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

train_cfg = load_dqn_training_config(
    model_config_path=MODEL_CFG_PATH,
    data_path=RL_TENSORS_PATH,
    device=device,
)

train_cfg

Using device: cpu


DQNTrainingConfig(data_path=PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz'), device='cpu', batch_size=512, lr=0.001, gamma=0.99, max_steps=50000, target_update_interval=1000, log_interval=1000, val_fraction=0.2, early_stopping_patience=10, early_stopping_min_delta=0.01, hidden_size=258, num_layers=3, dropout=0.1, yaml_num_actions=11)

In [5]:
ds = BullpenOfflineDataset(
    RLDatasetConfig(
        data_path=train_cfg.data_path,
        device=train_cfg.device,
    )
)

print("Dataset size:", len(ds))
print("State dim:", ds.state_dim)
print("Num actions:", ds.num_actions)
print("H (next hitters window):", ds.H)
print("R (max relievers per team):", ds.R)

Dataset size: 407660
State dim: 208
Num actions: 11
H (next hitters window): 5
R (max relievers per team): 10


## 4. Create Dueling DQN Model + Trainer

This calls `train_dqn(train_cfg)`, which:
 - loads `BullpenOfflineDataset` from `train_cfg.data_path`
 - splits into train/val by `train_cfg.val_fraction`
 - trains a dueling DQN with a target network
 - logs TD-error periodically using `evaluate_td_error` in `dqn.py`

In [6]:
dqn_model = train_dqn(train_cfg)

[DQN] step=0 loss=557.81396
      val_td_error=929.83026
      (new best val TD: 929.83026)
[DQN] step=1000 loss=875.70972
      val_td_error=595.31433
      (new best val TD: 595.31433)
[DQN] step=2000 loss=634.22729
      val_td_error=398.42147
      (new best val TD: 398.42147)
[DQN] step=3000 loss=358.03110
      val_td_error=236.46870
      (new best val TD: 236.46870)
[DQN] step=4000 loss=122.72629
      val_td_error=137.71283
      (new best val TD: 137.71283)
[DQN] step=5000 loss=156.93430
      val_td_error=81.05525
      (new best val TD: 81.05525)
[DQN] step=6000 loss=74.21658
      val_td_error=66.47167
      (new best val TD: 66.47167)
[DQN] step=7000 loss=80.28516
      val_td_error=34.00092
      (new best val TD: 34.00092)
[DQN] step=8000 loss=41.26678
      val_td_error=23.20855
      (new best val TD: 23.20855)
[DQN] step=9000 loss=33.77260
      val_td_error=16.76977
      (new best val TD: 16.76977)
[DQN] step=10000 loss=23.40019
      val_td_error=11.11499
      (n

## 5. Save trained model weights

In [7]:
torch.save(dqn_model.state_dict(), MODEL_OUT_PATH)
MODEL_OUT_PATH

PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/models/dqn_bullpen_2022_2023.pt')

## Offline Policy Evaluation (OPE)

Now we use `src/ope/offline_eval.py` to:
- load the saved model and dataset
- compute:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value of the greedy policy
    - Action agreement with the logged policy

In [6]:
ope_cfg = OfflineEvalConfig(
    model_config_path=MODEL_CFG_PATH,
    model_path=MODEL_OUT_PATH,
    tensors_path=RL_TENSORS_PATH,
    device=device,
    batch_size=2048,
    gamma=train_cfg.gamma,
)

eval_model, eval_ds, eval_loader = load_model_and_dataset(ope_cfg)

print("Eval dataset size:", len(eval_ds))
print("State dim:", eval_ds.state_dim)
print("Num actions:", eval_ds.num_actions)

Eval dataset size: 407660
State dim: 208
Num actions: 11


## 6. Mean Squared TD Error (MSTE)

This is the mean squared Bellman residual over the full dataset.
It reuses `evaluate_td_error` from `dqn.py` under the hood,
passing `model` as both the online and target networks.


In [7]:
mste = evaluate_td_error_full_mse(
    model=eval_model,
    loader=eval_loader,
    gamma=ope_cfg.gamma,
    device=ope_cfg.device,
)

print(f"Mean Squared TD Error (MSTE): {mste:.6f}")

Mean Squared TD Error (MSTE): 0.961571


## 7. Direct Q-based value estimate (FQE-style Direct Method)

For each state `s`:
    - compute Q(s, a) for all actions
    - mask unavailable actions
    - take greedy action a* = argmax_a Q(s, a)
    - define V_hat(s) = Q(s, a*)

Then average V_hat(s) across the dataset as an estimate of V(pi_greedy).

In [8]:
dm_value = direct_policy_value_estimate(
    model=eval_model,
    loader=eval_loader,
    device=ope_cfg.device,
)

print(f"Direct Q-based value estimate (V(pi_greedy)): {dm_value:.6f}")

Direct Q-based value estimate (V(pi_greedy)): 5.507400


## 8. Policy Behavior Stats and Q distributions

How often does the greedy DQN action (respecting availability mask) match the logged (historical) action from the dataset?


In [10]:
# New distributional metrics
policy_stats = compute_policy_behavior_stats(eval_model, eval_loader, device=ope_cfg.device)

q_stats = compute_q_distributions(eval_model, eval_loader, device=ope_cfg.device)

## 9. Summary

In [12]:
print("========= FINAL DQN EVALUATION RESULTS =========")
print(f"TD Error (MSTE):              {mste:.6f}")
print(f"Direct Q-based V(pi_greedy):  {dm_value:.6f}")
print(f"Action agreement rate:        {agreement:.3%}")
summarize_policy_behavior_stats(policy_stats)
summarize_q_distributions(q_stats)

TD Error (MSTE):              0.961571
Direct Q-based V(pi_greedy):  5.507400
Action agreement rate:        8.160%
=== Policy vs Behavior Stats ===
Num samples:   407660
Num actions:   11

Behavior pull rate: 5.90%
Policy pull rate:   92.10%
Action agreement:   8.14%

Behavior action counts (per action index):
[383628   4016   3546   3018   2587   2369   2204   1923   1662   1423
   1284]
Policy action counts (per action index):
[ 32206  17818  18402   3646   1047  10901  24804 213138      0   6137
  79561]
Valid action counts (per action index):
[407660 368480 357208 365623 361313 371378 361377 375523 364592 364655
 373403]
=== Q Distribution Stats ===
q_all_valid: n=4071212, mean=5.367, std=1.123, min=-0.868, max=122.772
q_stay: n=407660, mean=5.232, std=1.360, min=-0.868, max=81.848
q_best_pull: n=407660, mean=5.501, std=1.233, min=3.093, max=122.772
q_stay_minus_best_pull: n=407660, mean=-0.269, std=0.365, min=-40.924, max=2.015


In [None]:
import numpy as np
from pathlib import Path

npz = np.load(Path("../data/processed/rl_tensors_2022_2023.npz"))

for key in ["reward_folded"]:
    x = npz[key]
    print(key, "shape:", x.shape)
    print(
        key,
        "mean:", float(x.mean()),
        "std:", float(x.std()),
        "min:", float(x.min()),
        "max:", float(x.max()),
    )

reward_folded shape: (407660,)
reward_folded mean: -0.010023725219070911 std: 0.6859701871871948 min: -7.711379528045654 max: 1.1493159532546997


After tuning the model (dropout = 0.10, hidden size = 256, 3 layers, validation fraction = 0.2, no weight decay), we obtained:


TD Error (MSTE):              0.961571
Direct Q-based V(pi_greedy):  5.507400
Action agreement rate:        8.160%

1. TD Error (MSTE ≈ 0.96)

A TD error of 0.96 corresponds to an RMSE of about 0.98. Since the reward distribution in our offline data has mean around –0.01, standard deviation around 0.69, and values ranging from –7.7 to +1.15, this TD error indicates that the Q-network is fitting the Bellman targets reasonably well. It is not perfect, but it is significantly better than earlier runs where TD errors were much larger.

2. Direct Q-based Value Estimate (≈ 5.51)

The estimated value of the greedy policy induced by the DQN is around 5.5. This number is higher than what is realistically achievable given that the maximum folded reward per decision is about 1.15 and the SMDP horizon is short (no more than 3 plate appearances). This indicates overestimation, a common behavior in offline Q-learning caused by extrapolation error and the model evaluating actions that are rare or unseen in the logged data. However, this value is much more reasonable than prior extreme values (e.g., above 16).

3. Action Agreement Rate (≈ 8.16%)

This is the frequency with which the DQN chooses the same action as actual MLB managers in the dataset. An 8% agreement rate is low enough to show that the model is not simply copying historical decisions (which would be 60–90%), but high enough that the model’s actions are not completely random or pathological (1–2%). The model is learning a policy that differs from managers but is still grounded in the data.

Overall Assessment

These results indicate that:
- The Q-function is relatively stable and fits the offline data well.
- The greedy policy remains optimistic due to known limitations of offline DQN.
- The model is producing meaningful, nontrivial decisions instead of behavior cloning.
- The behavior aligns with well-known challenges of offline Q-learning and sets a strong baseline for comparison with more conservative methods like CQL.