# Tabular Q-Learning Training Notebook
This notebook trains and evaluates the tabular Q-learning model.

## 1. Imports & Paths

In [6]:
import os
import sys
from pathlib import Path
import torch

# Set project root
PROJECT_ROOT = Path(os.getcwd()).resolve().parent
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT:", PROJECT_ROOT)


PROJECT_ROOT: /Users/scottyang/MLB-Bullpen-Strategy


In [18]:
from src.rl.tabular_q_agent import (
    load_tabular_q_config,
    train_tabular_q_agent,
    TabularOfflineDataset,
)

ImportError: cannot import name 'load_tabular_q_config' from 'src.rl.tabular_q_agent' (/Users/scottyang/MLB-Bullpen-Strategy/src/rl/tabular_q_agent.py)

## 2. Configurations

In [8]:
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"
CONFIG_DIR = PROJECT_ROOT / "configs"
MODELS_DIR = PROJECT_ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Must match your dataset name
YEAR_TAG = "2022_2023"

RL_TENSORS_PATH = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
MODEL_CFG_PATH = CONFIG_DIR / "model.yaml"
MODEL_OUT_PATH = MODELS_DIR / f"tabular_q_{YEAR_TAG}.npy"

print("RL tensors:", RL_TENSORS_PATH)
print("Model config:", MODEL_CFG_PATH)
print("Model output:", MODEL_OUT_PATH)


RL tensors: /Users/scottyang/MLB-Bullpen-Strategy/data/processed/rl_tensors_2022_2023.npz
Model config: /Users/scottyang/MLB-Bullpen-Strategy/configs/model.yaml
Model output: /Users/scottyang/MLB-Bullpen-Strategy/models/tabular_q_2022_2023.npy


In [9]:
## 3. Load Dataset & Build Model

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

train_cfg = load_tabular_q_config(
    model_config_path=MODEL_CFG_PATH,
    data_path=RL_TENSORS_PATH,
    device=device,
)

train_cfg

Using device: cpu


NameError: name 'load_tabular_q_config' is not defined

In [11]:
ds = TabularOfflineDataset(
    data_path=train_cfg.data_path,
    device=train_cfg.device,
)

print("Dataset size:", len(ds))
print("Num actions:", ds.num_actions)

NameError: name 'TabularOfflineDataset' is not defined

## 5. Save trained model weights

In [12]:
tabular_agent.save(MODEL_OUT_PATH)
MODEL_OUT_PATH

NameError: name 'tabular_agent' is not defined

## 6. Offline TD error (Bellman Residual)



In [13]:
import numpy as np

def tabular_td_error(agent, ds, gamma):
    total = 0
    n = 0
    for i in range(len(ds)):
        s, a, r, ns, done, mask = ds[i]

        s_key = agent._s(s)
        ns_key = agent._s(ns)

        q_sa = agent.Q[s_key][a]
        target = r if done else (r + gamma * np.max(agent.Q[ns_key]))

        total += (q_sa - target)**2
        n += 1
    return total / n

mste = tabular_td_error(tabular_agent, ds, gamma=train_cfg.gamma)
print("Mean Squared TD Error:", mste)


NameError: name 'tabular_agent' is not defined

## 7. Direct Q-Based Estimate of Greedy Policy



In [14]:
def tabular_direct_value(agent, ds):
    vals = []
    for i in range(len(ds)):
        s, a, r, ns, d, mask = ds[i]

        s_key = agent._s(s)
        q = agent.Q[s_key].copy()
        q[~mask] = -1e9
        v = np.max(q)  # V_pi(s) = max_a Q(s,a)
        vals.append(v)
    return np.mean(vals)

dm_value = tabular_direct_value(tabular_agent, ds)
print("Direct Q-based greedy value:", dm_value)


NameError: name 'tabular_agent' is not defined

## 8. Action agreement with logged policy

How often does the greedy Tabular Q-Learning action (respecting availability mask) match the logged (historical) action from the dataset?


In [15]:
def tabular_action_agreement(agent, ds):
    matches = 0
    for i in range(len(ds)):
        s, a_logged, _, _, _, mask = ds[i]
        greedy_a = agent.act(s, mask)

        matches += int(greedy_a == a_logged)
    return matches / len(ds)

agreement = tabular_action_agreement(tabular_agent, ds)
print(f"Action agreement with MLB policy: {agreement:.3%}")


NameError: name 'tabular_agent' is not defined

In [None]:
## 9. Summary

In [16]:
print("=========== FINAL TABULAR Q EVALUATION ===========")
print(f"TD Error (MSTE):              {mste:.6f}")
print(f"Direct value V(pi_greedy):    {dm_value:.6f}")
print(f"Action agreement rate:        {agreement:.3%}")




NameError: name 'mste' is not defined