# Baseball States GPT-2 Training on Google Colab

This notebook trains a GPT-2 model to predict baseball game states.

## Setup
1. Mount Google Drive
2. Clone the repository
3. Load config and override paths for Colab environment
4. Train model

In [None]:
# Configuration
GITHUB_USERNAME = 'angerami'
REPO_NAME = 'baseball-states'
DATA_PATH = f'/content/drive/MyDrive/Data/{REPO_NAME}'
CONFIG_PATH = 'configs/colab_config.yaml'  # Relative to repo root
RUN_NUMBER = 'run-001'

In [None]:
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Load GitHub token
with open('/content/drive/MyDrive/secrets/colab-github') as f:
    github_token = f.read().strip()

# Clone repository (or pull if it exists)
if not os.path.exists(REPO_NAME):
    !git clone https://{github_token}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git
else:
    print(f"Repository {REPO_NAME} already exists, pulling latest changes...")
    !cd {REPO_NAME} && git pull

# Add repo to Python path
import sys
sys.path.insert(0, f'/content/{REPO_NAME}')

In [None]:
# Install dependencies if needed
!pip install -q pyyaml transformers datasets tensorboard

In [None]:
from baseball_states.training import ModelConfig, train_model

# Load config from file
config = ModelConfig.from_file(f'{REPO_NAME}/{CONFIG_PATH}')

# Override paths for this specific run (config serves as template)
config.data_path = f"{DATA_PATH}/data/tokens_inning"
config.output_dir = f"{DATA_PATH}/checkpoints/{RUN_NUMBER}"

# Optional: Override other parameters for this run
# config.num_epochs = 3
# config.batch_size = 32

# Display the config
print("Training configuration:")
print(config)
print()

# Train the model
model, tokenizer = train_model(config)

# Save the actual config used for this run
config.to_file(f"{config.output_dir}/config_used.yaml")

In [None]:
# Optional: View training in TensorBoard
%load_ext tensorboard
%tensorboard --logdir {config.output_dir}/logs

In [None]:
# Optional: Quick inference test
import torch

# Generate a sample prediction
model.eval()
with torch.no_grad():
    # Create a simple test input (modify as needed)
    test_input = tokenizer.encode("<BOS>", return_tensors="pt")
    output = model.generate(test_input, max_length=20, num_return_sequences=1)
    print("Sample generation:")
    print(tokenizer.decode(output[0]))