# Baseball States GPT-2 Training on Google Colab

This notebook trains a GPT-2 model to predict baseball game states.

## Setup
1. Mount Google Drive
2. Clone the repository
3. Load config and override paths for Colab environment
4. Train model

In [3]:
# Configuration
GITHUB_USERNAME = 'angerami'
REPO_NAME = 'baseball-states'
CONFIG_FILE = 'configs/colab_config.yaml'  # Relative to repo root
RUN_NUMBER = 'run-001'
GDRIVE_PATH = '/Users/angerami/Google Drive/My Drive'

In [4]:
import os
import sys

def get_colab_environment():
    # 1. Check if we are even in a Colab-managed runtime
    if 'google.colab' not in sys.modules:
        return "Local/Other Jupyter"

    # 2. Check for VS Code specific environment variables
    if os.environ.get('TERM_PROGRAM') == 'vscode' or 'VSCODE_PID' in os.environ:
        return "VS Code (Colab Extension)"

    # 3. Check for the browser-exclusive 'COLAB_RELEASE_TAG'
    if 'COLAB_RELEASE_TAG' in os.environ:
        return "Browser (colab.research.google.com)"

    return "Unknown Colab Environment"

env = get_colab_environment()
print(f"Running in: {env}")


drive_available = (env == "Browser (colab.research.google.com)")
print(f"Google Drive mounting supported: {drive_available}")
if drive_available:
    from google.colab import drive
    # Mount Google Drive
    drive.mount('/content/drive')
    GDRIVE_PATH = '/content/drive/MyDrive'

SECRETS_PATH = f'{GDRIVE_PATH}/secrets/colab-github'
DATA_PATH = f'{GDRIVE_PATH}/Data/{REPO_NAME}'
    # Load GitHub token
with open(SECRETS_PATH) as f:
    github_token = f.read().strip()

Running in: Browser (colab.research.google.com)
Google Drive mounting supported: True
Mounted at /content/drive


In [6]:
# Clone repository (or pull if it exists)
if not os.path.exists(REPO_NAME):
    !git clone https://{github_token}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git
else:
    print(f"Repository {REPO_NAME} already exists, pulling latest changes...")
    !cd {REPO_NAME} && git pull

# Add repo to Python path
import sys
sys.path.insert(0, f'/content/{REPO_NAME}')

Repository baseball-states already exists, pulling latest changes...
Already up to date.


In [8]:
# Install dependencies if needed
!pip install -e $REPO_NAME

Obtaining file:///content/baseball-states
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pybaseball (from baseball-states==0.1.0)
  Downloading pybaseball-2.2.7-py3-none-any.whl.metadata (11 kB)
Collecting streamlit (from baseball-states==0.1.0)
  Downloading streamlit-1.52.2-py3-none-any.whl.metadata (9.8 kB)
Collecting nox (from baseball-states==0.1.0)
  Downloading nox-2025.11.12-py3-none-any.whl.metadata (5.1 kB)
Collecting argcomplete<4,>=1.9.4 (from nox->baseball-states==0.1.0)
  Downloading argcomplete-3.6.3-py3-none-any.whl.metadata (16 kB)
Collecting colorlog<7,>=2.6.1 (from nox->baseball-states==0.1.0)
  Downloading colorlog-6.10.1-py3-none-any.whl.metadata (11 kB)
Collecting dependency-groups>=1.1 (from nox->baseball-states==0.1.0)
  Downloading dependency_groups-1.3.1-py3-none-any.whl.metadata (2.3 kB)
Collecting virtualenv>=20.15 (from nox->baseball-states==0.1.0)
  Downloading virtualenv-20.35.4-py3-none-any.whl.metadata (4.6 kB)
Collecting pygithub>=1.51

In [10]:
from baseball_states.training import ModelConfig, train_model

# Load config from file
config = ModelConfig.from_file(f'{REPO_NAME}/{CONFIG_FILE}')

# Override paths for this specific run (config serves as template)
config.data_path = f"{DATA_PATH}/data/tokens_inning"
config.output_dir = f"{DATA_PATH}/checkpoints/{RUN_NUMBER}"

# Optional: Override other parameters for this run
# config.num_epochs = 3
# config.batch_size = 32

# Display the config
print("Training configuration:")
print(config)
print()

# Train the model
model, tokenizer = train_model(config)

# Save the actual config used for this run
config.to_file(f"{config.output_dir}/config_used.yaml")

Training configuration:
ModelConfig(
  n_embd=64,
  n_layer=6,
  n_head=4,
  n_positions=32,
  batch_size=64,
  num_epochs=1,
  learning_rate=0.0005,
  warmup_steps=10,
  weight_decay=0.01,
  data_path='/content/drive/MyDrive/Data/baseball-states/data/tokens_inning',
  val_split=0.1,
  max_length=32,
  train_fraction=1.0,
  use_packing=False,
  output_dir='/content/drive/MyDrive/Data/baseball-states/checkpoints/run-001',
  save_steps=500,
  eval_steps=100,
  logging_steps=10,
  save_initial_checkpoint=True,
  resume_from_checkpoint='auto',
  seed=42,
)

Using device: cuda
Train samples: 343975
Eval samples: 38220
Model parameters: 303,936
Saving initial checkpoint to /content/drive/MyDrive/Data/baseball-states/checkpoints/run-001/checkpoint-initial
Initial checkpoint saved!
Run name: gpt2-train-2025-1223-221255
No checkpoint found, starting training from scratch
Starting training...

Training started at 2025-12-23 22:12:56



`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss,Validation Loss,Accuracy,Top3 Accuracy,Top5 Accuracy,Perplexity,Illegal Prob,Runs Scored
100,1.1488,1.138857,0.190224,0.696579,0.773959,14.069635,0.086288,27.856228
200,1.1164,1.077359,0.19018,0.690145,0.837054,14.642381,0.089302,26.915541
300,1.0598,1.067977,0.190219,0.695509,0.800366,15.869293,0.090578,26.910807
400,1.0644,1.057088,0.19018,0.687219,0.842189,14.377877,0.090248,27.905704
500,1.0531,1.054341,0.19018,0.687199,0.830043,13.40766,0.091158,26.948011
600,1.1322,1.052911,0.19018,0.690468,0.829505,12.931553,0.091988,27.905312
700,1.0553,1.046559,0.19018,0.681024,0.834785,14.454954,0.091228,27.110701
800,0.9795,1.053356,0.190185,0.673008,0.830281,15.054455,0.09082,27.591993
900,1.0521,1.047457,0.19018,0.708028,0.832904,14.704562,0.08995,27.905861
1000,1.0617,1.041401,0.19018,0.686925,0.835218,14.262701,0.090571,27.989195


















































































































































































































































































































































































































































































































































































































There were missing keys in the checkpoint model loaded: ['lm_head.weight'].




Training completed!
Total time: 0:28:49
End time: 2025-12-23 22:41:46

Saving model to /content/drive/MyDrive/Data/baseball-states/checkpoints/run-001
Running final evaluation...



Final eval loss: 1.0307
Final perplexity: 2.8030
Final accuracy: 0.1902
Config saved to /content/drive/MyDrive/Data/baseball-states/checkpoints/run-001/config_used.yaml


In [None]:
# Optional: View training in TensorBoard
%load_ext tensorboard
%tensorboard --logdir {config.output_dir}/logs

In [None]:
# Optional: Quick inference test
import torch

# Generate a sample prediction
model.eval()
with torch.no_grad():
    # Create a simple test input (modify as needed)
    test_input = tokenizer.encode("<BOS>", return_tensors="pt")
    output = model.generate(test_input, max_length=20, num_return_sequences=1)
    print("Sample generation:")
    print(tokenizer.decode(output[0]))