# RL DQN Training

This notebook:
1. Loads offline RL tensors from `data/processed/rl_tensors_*.npz`
2. Loads DQN hyperparameters from `configs/model.yaml`
3. Trains a dueling DQN using `src/rl/dqn.py::train_dqn`
4. Saves the trained model to `models/`
5. Runs offline evaluation using `src/ope/offline_eval.py`:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value estimate of greedy policy
    - Action agreement with logged (behavior) policy


## 1. Imports & Paths

In [1]:
import os
import sys
from pathlib import Path
import torch
from dataclasses import replace
from collections import OrderedDict

PROJECT_ROOT = Path(os.getcwd()).resolve().parent
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT:", PROJECT_ROOT)

PROJECT_ROOT: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy


In [2]:
from src.rl.dqn import (
    load_dqn_training_config,
    train_dqn,
    BullpenOfflineDataset,
    RLDatasetConfig,
)
from src.ope.offline_eval import (
    evaluate_td_error_full_mse,
    direct_policy_value_estimate,
    compute_policy_behavior_stats,
    compute_q_distributions,
    summarize_policy_behavior_stats,
    summarize_q_distributions,
)

## 2. Configurations

In [3]:
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
CONFIG_DIR = PROJECT_ROOT / "configs"
MODELS_DIR = PROJECT_ROOT / "models"

MODELS_DIR.mkdir(parents=True, exist_ok=True)

YEAR_TAG = "2022_2023"  # must match 01_build_dataset YEARS range
RL_TENSORS_PATH = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
MODEL_CFG_PATH = CONFIG_DIR / "model.yaml"
def outpath(model: str):
    MODEL_OUT_PATH = MODELS_DIR / f"{model}_dqn_bullpen_{YEAR_TAG}.pt"
    return MODEL_OUT_PATH

print("RL tensors:", RL_TENSORS_PATH)
print("Model config:", MODEL_CFG_PATH)

RL tensors: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz
Model config: /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/configs/model.yaml


## 3. Load Dataset & Build Model

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

train_cfg = load_dqn_training_config(
    model_config_path=MODEL_CFG_PATH,
    data_path=RL_TENSORS_PATH,
    device=device,
)

train_cfg

Using device: cpu


DQNTrainingConfig(data_path=PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz'), device='cpu', batch_size=512, weight_decay=0.0, lr=0.001, gamma=0.99, max_steps=50000, target_update_interval=1000, log_interval=1000, val_fraction=0.2, early_stopping_patience=10, early_stopping_min_delta=0.01, hidden_size=258, num_layers=3, dropout=0.1, grad_clip_max_norm=0.0, yaml_num_actions=11)

In [5]:
# We treat train_cfg as a base config coming from configs/model.yaml
base_cfg = train_cfg

# 1) Shallow model: smaller net, less layers, light regularization
cfg_shallow = replace(
    base_cfg,
    hidden_size=128,          # fewer hidden units
    num_layers=2,             # shallower
    dropout=0.05,             # small dropout
    weight_decay=0.0,         # no L2
    grad_clip_max_norm=0.0,   # no grad clipping
)

# 2) Deeper model: larger network, same regularization as base
cfg_deep = replace(
    base_cfg,
    hidden_size=256,
    num_layers=4,
    dropout=0.10,
    weight_decay=0.0,
    grad_clip_max_norm=0.0,
)

# 3) Constrained model: like deep, but with weight decay + grad clipping
cfg_constrained = replace(
    base_cfg,
    hidden_size=256,
    num_layers=4,
    dropout=0.10,
    weight_decay=1e-4,        # L2 regularization
    grad_clip_max_norm=5.0,   # clip gradients by global norm
)

cfg_shallow, cfg_deep, cfg_constrained

(DQNTrainingConfig(data_path=PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz'), device='cpu', batch_size=512, weight_decay=0.0, lr=0.001, gamma=0.99, max_steps=50000, target_update_interval=1000, log_interval=1000, val_fraction=0.2, early_stopping_patience=10, early_stopping_min_delta=0.01, hidden_size=128, num_layers=2, dropout=0.05, grad_clip_max_norm=0.0, yaml_num_actions=11),
 DQNTrainingConfig(data_path=PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz'), device='cpu', batch_size=512, weight_decay=0.0, lr=0.001, gamma=0.99, max_steps=50000, target_update_interval=1000, log_interval=1000, val_fraction=0.2, early_stopping_patience=10, early_stopping_min_delta=0.01, hidden_size=256, num_layers=4, dropout=0.1, grad_clip_max_norm=0.0, yaml_num_actions=11),
 DQNTrainingConfig(data_path=PosixPath('/Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz'),

In [6]:
ds = BullpenOfflineDataset(
    RLDatasetConfig(
        data_path=base_cfg.data_path,
        device=base_cfg.device,
    )
)

print("Dataset size:", len(ds))
print("State dim:", ds.state_dim)
print("Num actions:", ds.num_actions)
print("H (next hitters window):", ds.H)
print("R (max relievers per team):", ds.R)

Dataset size: 407660
State dim: 208
Num actions: 11
H (next hitters window): 5
R (max relievers per team): 10


## 4. Create Dueling DQN Model + Trainer

This calls `train_dqn(train_cfg)`, which:
 - loads `BullpenOfflineDataset` from `train_cfg.data_path`
 - splits into train/val by `train_cfg.val_fraction`
 - trains a dueling DQN with a target network
 - logs TD-error periodically using `evaluate_td_error` in `dqn.py`

This is done for 3 different models with differing complexity,
1. Shallow DQN Model
2. Deep DQN Model
3. Deep DQN Model with Gradient Clipping + Weight Decay

In [7]:
print("=== Training SHALLOW DQN model ===")
shallow_dqn = train_dqn(cfg_shallow)

=== Training SHALLOW DQN model ===
[DQN] step=0 loss=14947.88672
      val_td_error=10213.48000
      (new best val TD: 10213.48000)
[DQN] step=1000 loss=5245.73730
      val_td_error=4542.58110
      (new best val TD: 4542.58110)
[DQN] step=2000 loss=4347.56445
      val_td_error=3995.27182
      (new best val TD: 3995.27182)
[DQN] step=3000 loss=4682.87256
      val_td_error=2884.55904
      (new best val TD: 2884.55904)
[DQN] step=4000 loss=2148.10547
      val_td_error=2171.13028
      (new best val TD: 2171.13028)
[DQN] step=5000 loss=3191.11377
      val_td_error=1689.13913
      (new best val TD: 1689.13913)
[DQN] step=6000 loss=1882.16577
      val_td_error=1585.43341
      (new best val TD: 1585.43341)
[DQN] step=7000 loss=1254.97729
      val_td_error=1283.68585
      (new best val TD: 1283.68585)
[DQN] step=8000 loss=1300.65356
      val_td_error=1180.03556
      (new best val TD: 1180.03556)
[DQN] step=9000 loss=1202.85095
      val_td_error=1194.82069
      (no improvement

In [8]:
print("=== Training DEEP DQN model ===")
deep_dqn = train_dqn(cfg_deep)

=== Training DEEP DQN model ===
[DQN] step=0 loss=280.62946
      val_td_error=478.21731
      (new best val TD: 478.21731)
[DQN] step=1000 loss=298.60492
      val_td_error=255.84395
      (new best val TD: 255.84395)
[DQN] step=2000 loss=266.53717
      val_td_error=135.70832
      (new best val TD: 135.70832)
[DQN] step=3000 loss=177.20477
      val_td_error=79.65328
      (new best val TD: 79.65328)
[DQN] step=4000 loss=75.76599
      val_td_error=39.50763
      (new best val TD: 39.50763)
[DQN] step=5000 loss=34.35710
      val_td_error=22.11051
      (new best val TD: 22.11051)
[DQN] step=6000 loss=26.93233
      val_td_error=11.57459
      (new best val TD: 11.57459)
[DQN] step=7000 loss=13.69288
      val_td_error=7.12578
      (new best val TD: 7.12578)
[DQN] step=8000 loss=6.37473
      val_td_error=4.75054
      (new best val TD: 4.75054)
[DQN] step=9000 loss=4.63814
      val_td_error=3.13497
      (new best val TD: 3.13497)
[DQN] step=10000 loss=3.63013
      val_td_error=

In [9]:
print("=== Training CONSTRAINED DQN model (weight decay + grad clipping) ===")
constrained_dqn = train_dqn(cfg_constrained)

=== Training CONSTRAINED DQN model (weight decay + grad clipping) ===
[DQN] step=0 loss=107.75221
      val_td_error=454.41000
      (new best val TD: 454.41000)
[DQN] step=1000 loss=389.82013
      val_td_error=178.24252
      (new best val TD: 178.24252)
[DQN] step=2000 loss=238.07759
      val_td_error=109.96997
      (new best val TD: 109.96997)
[DQN] step=3000 loss=109.37428
      val_td_error=57.40383
      (new best val TD: 57.40383)
[DQN] step=4000 loss=49.18357
      val_td_error=34.61348
      (new best val TD: 34.61348)
[DQN] step=5000 loss=30.85622
      val_td_error=22.31415
      (new best val TD: 22.31415)
[DQN] step=6000 loss=21.81225
      val_td_error=9.68241
      (new best val TD: 9.68241)
[DQN] step=7000 loss=13.36385
      val_td_error=5.45047
      (new best val TD: 5.45047)
[DQN] step=8000 loss=5.71718
      val_td_error=3.24512
      (new best val TD: 3.24512)
[DQN] step=9000 loss=5.75646
      val_td_error=2.68200
      (new best val TD: 2.68200)
[DQN] step=10

## 5. Save trained model weights

In [10]:
shallow_outpath = outpath('shallow')
deep_outpath = outpath('deep')
constrained_outpath = outpath('constrained')


torch.save(shallow_dqn.state_dict(), shallow_outpath)
torch.save(deep_dqn.state_dict(), deep_outpath)
torch.save(constrained_dqn.state_dict(), constrained_outpath)

## Offline Policy Evaluation (OPE)

Now we use `src/ope/offline_eval.py` to:
- load the saved model and dataset
- compute:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value of the greedy policy
    - Action agreement with the logged policy

In [13]:
eval_ds = BullpenOfflineDataset(
    RLDatasetConfig(
        data_path=RL_TENSORS_PATH,
        device=device,
    )
)
eval_loader = torch.utils.data.DataLoader(eval_ds, batch_size=2048, shuffle=False)

print("Eval dataset size:", len(eval_ds))
print("State dim:", eval_ds.state_dim)
print("Num actions:", eval_ds.num_actions)

Eval dataset size: 407660
State dim: 208
Num actions: 11


In [14]:
# Ensure models are on correct device
dqn_shallow = shallow_dqn.to(device)
dqn_deep = deep_dqn.to(device)
dqn_constrained = constrained_dqn.to(device)

gamma = cfg_constrained.gamma

models_to_eval = OrderedDict(
    [
        ("shallow", dqn_shallow),
        ("deep", dqn_deep),
        ("constrained", dqn_constrained),
    ]
)

all_results = {}

for name, model in models_to_eval.items():
    print(f"\n==================== {name.upper()} MODEL ====================")
    model.eval()

    # 1) TD Error (MSTE)
    mste_i = evaluate_td_error_full_mse(
        model=model,
        loader=eval_loader,
        gamma=gamma,
        device=device,
    )

    # 2) Direct Q-based value estimate
    dm_value_i = direct_policy_value_estimate(
        model=model,
        loader=eval_loader,
        device=device,
    )

    # 3) Policy vs behavior stats
    policy_stats_i = compute_policy_behavior_stats(
        model=model,
        loader=eval_loader,
        device=device,
    )

    # 4) Q-value distribution stats
    q_stats_i = compute_q_distributions(
        model=model,
        loader=eval_loader,
        device=device,
    )

    all_results[name] = {
        "mste": mste_i,
        "dm_value": dm_value_i,
        "policy_stats": policy_stats_i,
        "q_stats": q_stats_i,
    }

    agreement_rate_i = policy_stats_i["agreement_rate"]
    behavior_pull_rate_i = policy_stats_i["behavior_pull_rate"]
    policy_pull_rate_i = policy_stats_i["policy_pull_rate"]

    print(f"TD Error (MSTE):              {mste_i:.6f}")
    print(f"Direct Q-based V(pi_greedy):  {dm_value_i:.6f}")
    print(f"Action agreement rate:        {agreement_rate_i*100:.3f}%")
    print(f"Behavior pull rate:           {behavior_pull_rate_i*100:.3f}%")
    print(f"Policy pull rate:             {policy_pull_rate_i*100:.3f}%")

    print("\n[Policy vs Behavior stats]")
    summarize_policy_behavior_stats(policy_stats_i)

    print("\n[Q-value distribution stats]")
    summarize_q_distributions(q_stats_i)
    print("======================================================")


TD Error (MSTE):              50.840302
Direct Q-based V(pi_greedy):  65.003840
Action agreement rate:        30.793%
Behavior pull rate:           5.895%
Policy pull rate:             67.621%

[Policy vs Behavior stats]
=== Policy vs Behavior Stats ===
Num samples:   407660
Num actions:   11

Behavior pull rate: 5.90%
Policy pull rate:   67.62%
Action agreement:   30.79%

Behavior action counts (per action index):
[383628   4016   3546   3018   2587   2369   2204   1923   1662   1423
   1284]
Policy action counts (per action index):
[131997  64027   2027   8960  70358    340  25648  52188   5460  29572
  17083]
Valid action counts (per action index):
[407660 368480 357208 365623 361313 371378 361377 375523 364592 364655
 373403]

[Q-value distribution stats]
=== Q Distribution Stats ===
q_all_valid: n=4071212, mean=63.769, std=47.255, min=-33.747, max=1537.623
q_stay: n=407660, mean=62.601, std=47.499, min=-33.747, max=1476.034
q_best_pull: n=407660, mean=64.708, std=49.733, min=1.13

In [15]:
import numpy as np
from pathlib import Path

npz = np.load(Path("../data/processed/rl_tensors_2022_2023.npz"))

for key in ["reward_folded"]:
    x = npz[key]
    print(key, "shape:", x.shape)
    print(
        key,
        "mean:", float(x.mean()),
        "std:", float(x.std()),
        "min:", float(x.min()),
        "max:", float(x.max()),
    )

reward_folded shape: (407660,)
reward_folded mean: -0.010023725219070911 std: 0.6859701871871948 min: -7.711379528045654 max: 1.1493159532546997


After tuning the model (dropout = 0.10, hidden size = 256, 3 layers, validation fraction = 0.2, no weight decay), we obtained:


TD Error (MSTE):              0.961571
Direct Q-based V(pi_greedy):  5.507400
Action agreement rate:        8.160%

1. TD Error (MSTE ≈ 0.96)

A TD error of 0.96 corresponds to an RMSE of about 0.98. Since the reward distribution in our offline data has mean around –0.01, standard deviation around 0.69, and values ranging from –7.7 to +1.15, this TD error indicates that the Q-network is fitting the Bellman targets reasonably well. It is not perfect, but it is significantly better than earlier runs where TD errors were much larger.

2. Direct Q-based Value Estimate (≈ 5.51)

The estimated value of the greedy policy induced by the DQN is around 5.5. This number is higher than what is realistically achievable given that the maximum folded reward per decision is about 1.15 and the SMDP horizon is short (no more than 3 plate appearances). This indicates overestimation, a common behavior in offline Q-learning caused by extrapolation error and the model evaluating actions that are rare or unseen in the logged data. However, this value is much more reasonable than prior extreme values (e.g., above 16).

3. Action Agreement Rate (≈ 8.16%)

This is the frequency with which the DQN chooses the same action as actual MLB managers in the dataset. An 8% agreement rate is low enough to show that the model is not simply copying historical decisions (which would be 60–90%), but high enough that the model’s actions are not completely random or pathological (1–2%). The model is learning a policy that differs from managers but is still grounded in the data.

Overall Assessment

These results indicate that:
- The Q-function is relatively stable and fits the offline data well.
- The greedy policy remains optimistic due to known limitations of offline DQN.
- The model is producing meaningful, nontrivial decisions instead of behavior cloning.
- The behavior aligns with well-known challenges of offline Q-learning and sets a strong baseline for comparison with more conservative methods like CQL.