# RL CQL Training

This notebook:
1. Loads offline RL tensors from `data/processed/rl_tensors_*.npz`
2. Loads CQL hyperparameters from `configs/model.yaml` and `configs/training.yaml`
3. Trains a Conservative Q-Learning (CQL) model using `src/rl/cql.py::train_cql`
4. Saves the trained model to `models/`
5. Runs offline evaluation using `src/ope/offline_eval.py`:
    - Mean Squared TD Error (MSTE)
    - Direct Q-based value estimate of greedy policy
    - Action agreement with logged (behavior) policy

In [1]:
pip install "numpy<2.0"

Note: you may need to restart the kernel to use updated packages.


## 1. Imports & Paths

In [2]:
import os
import sys
from pathlib import Path
import torch
import yaml

PROJECT_ROOT = Path(os.getcwd()).resolve().parent
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT:", PROJECT_ROOT)

PROJECT_ROOT: /Users/matthewsaccone/Desktop/MLB-Bullpen-Strategy


In [10]:
from src.rl.cql import build_config_from_yamls, train_cql, BullpenOfflineDataset

from src.ope.offline_eval_cql import (
    OfflineEvalConfig,
    load_model_and_dataset,
    evaluate_td_error_full_mse,
    direct_policy_value_estimate,
    compute_action_agreement,
)

## 2. Configurations

In [11]:
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
CONFIG_DIR = PROJECT_ROOT / "configs"
MODELS_DIR = PROJECT_ROOT / "models"

MODELS_DIR.mkdir(parents=True, exist_ok=True)

YEAR_TAG = "2022_2023"  # must match 01_build_dataset YEARS range
RL_TENSORS_PATH = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
MODEL_CFG_PATH = CONFIG_DIR / "model.yaml"
DATA_CFG_PATH = PROJECT_ROOT / "configs/data.yaml"
TRAIN_CFG_PATH = CONFIG_DIR / "training.yaml"
MODEL_OUT_PATH = MODELS_DIR / f"cql_bullpen_{YEAR_TAG}.pt"

print("RL tensors:", RL_TENSORS_PATH)
print("Model config:", MODEL_CFG_PATH)
print("Training config:", TRAIN_CFG_PATH)
print("Model output:", MODEL_OUT_PATH)

RL tensors: /Users/matthewsaccone/Desktop/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz
Model config: /Users/matthewsaccone/Desktop/MLB-Bullpen-Strategy/configs/model.yaml
Training config: /Users/matthewsaccone/Desktop/MLB-Bullpen-Strategy/configs/training.yaml
Model output: /Users/matthewsaccone/Desktop/MLB-Bullpen-Strategy/models/cql_bullpen_2022_2023.pt


## 3. Load Dataset & Build Config

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

cfg = build_config_from_yamls(
    model_yaml=str(MODEL_CFG_PATH),
    training_yaml=str(TRAIN_CFG_PATH),
    data_yaml=str(DATA_CFG_PATH),     # real YAML only
    env_yaml=None,
    inference_yaml=None
)

cfg

Using device: cpu


CQLConfig(device='cpu', seed=42, input_dim=None, hidden_size=256, num_layers=3, dropout=0.05, num_actions=11, gamma=0.99, lr=0.0005, batch_size=512, max_steps=30, log_interval=200, target_update_interval=5000, tau=0.005, cql_alpha=1.0, cql_min_q_weight=None, cql_temp=1.0, l2_reg=1e-06, checkpoint_dir='checkpoints/', checkpoint_name='cql_model.pth', use_wandb=False)

In [None]:
ds = BullpenOfflineDataset(
    RLDatasetConfig(
        data_path=cfg.data.data_path,
        device=device,
    )
)

print("Dataset size:", len(ds))
print("State dim:", ds.state_dim)
print("Num actions:", ds.num_actions)
print("H (next hitters window):", ds.H)
print("R (max relievers per team):", ds.R)

AttributeError: 'CQLConfig' object has no attribute 'data'

## 4. Train CQL Model

This calls `train_cql(cfg)`, which:
 - loads `BullpenOfflineDataset`
 - trains Conservative Q-Learning
 - logs TD-error periodically

In [None]:
cql_model = train_cql(cfg)

## 5. Save trained model weights

In [None]:
torch.save(cql_model.state_dict(), MODEL_OUT_PATH)
MODEL_OUT_PATH

## 6. Offline Policy Evaluation (OPE)

In [None]:
ope_cfg = OfflineEvalConfig(
    model_config_path=MODEL_CFG_PATH,
    model_path=MODEL_OUT_PATH,
    tensors_path=RL_TENSORS_PATH,
    device=device,
    batch_size=2048,
    gamma=cfg.training.gamma,
)

eval_model, eval_ds, eval_loader = load_model_and_dataset(ope_cfg)

print("Eval dataset size:", len(eval_ds))
print("State dim:", eval_ds.state_dim)
print("Num actions:", eval_ds.num_actions)

In [None]:
# 7. Mean Squared TD Error (MSTE)
mste = evaluate_td_error_full_mse(
    model=eval_model,
    loader=eval_loader,
    gamma=ope_cfg.gamma,
    device=ope_cfg.device,
)
print(f"Mean Squared TD Error (MSTE): {mste:.6f}")

In [None]:
# 8. Direct Q-based value estimate
dm_value = direct_policy_value_estimate(
    model=eval_model,
    loader=eval_loader,
    device=ope_cfg.device,
)
print(f"Direct Q-based value estimate (V(pi_greedy)): {dm_value:.6f}")

In [None]:
# 9. Action agreement with logged policy
agreement = compute_action_agreement(
    model=eval_model,
    loader=eval_loader,
    device=ope_cfg.device,
)
print(f"Action agreement (logged vs greedy): {agreement:.3%}")

In [None]:
# 10. Summary
print("========= FINAL CQL EVALUATION RESULTS =========")
print(f"TD Error (MSTE):              {mste:.6f}")
print(f"Direct Q-based V(pi_greedy):  {dm_value:.6f}")
print(f"Action agreement rate:        {agreement:.3%}")