In [56]:
import pybaseball
import pandas as pd
import numpy as np
from datetime import date, timedelta, datetime, timezone
import pytz
import time
import logging
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple, Callable, Optional
from sklearn.preprocessing import StandardScaler # Or other scaler
from collections import defaultdict
from tqdm.notebook import tqdm

import importlib
import src.pybaseball_dataset as pybaseball_dataset  # Import the module with an alias
import src.pytorch_dataset as pytorch_dataset      # Import the module with an alias
import src.model as model                          # Import the module with an alias
import src.train_eval as train_eval

from src.pybaseball_dataset import *
from src.pytorch_dataset import *
from src.model import *
from src.train_eval import *

importlib.reload(pybaseball_dataset)
importlib.reload(pytorch_dataset)
importlib.reload(model)
importlib.reload(train_eval)

from src.pybaseball_dataset import *
from src.pytorch_dataset import *
from src.model import *
from src.train_eval import *


# Setup basic logging
log_file = f"results/{datetime.now().astimezone(pytz.timezone('America/Los_Angeles')).strftime('%Y-%m-%d_%H-%M-%S')}"
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s',
                   handlers=[
                        logging.FileHandler(log_file),
                        logging.StreamHandler()
                    ])

# Define cache directory (can be passed as argument later)
CACHE_DIR = "dataset"
os.makedirs(CACHE_DIR, exist_ok=True) # Create cache directory if it doesn't exist

In [57]:
def run_training(
    # Data params
    start_year: int,
    end_year: int,
    cache_dir: str = CACHE_DIR,
    # Feature params - Define based on preprocess output
    target_action_col: str = 'target_action',
    target_ev_la_cols: List[str] = ['target_ev', 'target_la'],
    bip_action_index: int = 3,
    seq_length: int = 10,
    # Model Hyperparameters
    model_params: Dict = { # Default example parameters
        'd_model': 128, 'nhead': 8, 'num_encoder_layers': 4,
        'dim_feedforward': 512, 'dropout': 0.1, 'num_mdn_components': 5
    },
    # Training params
    batch_size: int = 64,
    num_epochs: int = 20,
    learning_rate: float = 1e-4,
    val_split_size: float = 0.15,
    test_split_size: float = 0.10, # Set to 0 if no test set needed now
    device_str: str = 'cuda',
    early_stopping_patience: int = 3
    ) -> None:
    """
    Orchestrates the entire process: data fetching (with raw cache),
    preprocessing (with processed cache), dataset/loader creation,
    model initialization, training, and evaluation.
    """
    logging.info("Starting training orchestration with caching...")
    start_time = time.time()

    # --- Setup ---
    if device_str == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
        logging.info("Using CUDA device.")
    else:
        device = torch.device('cpu')
        logging.info("Using CPU device.")

    all_processed_dfs = []
    # --- Data Fetching & Preprocessing Loop (with Caching) ---
    for year in range(start_year, end_year + 1):
        logging.info(f"\n--- Processing Year: {year} ---")
        processed_cache_file = os.path.join(cache_dir, f"processed_statcast_{year}.parquet")

        # Check for processed data cache
        if os.path.exists(processed_cache_file):
            try:
                logging.info(f"Loading cached processed data for {year} from {processed_cache_file}")
                processed_df = pd.read_parquet(processed_cache_file)
                # Basic validation
                if 'target_action' in processed_df.columns and 'at_bat_id' in processed_df.columns:
                    logging.info(f"Loaded {len(processed_df)} processed pitches from cache for {year}.")
                    all_processed_dfs.append(processed_df)
                    continue # Skip to next year
                else:
                    logging.warning(f"Cached processed file {processed_cache_file} invalid. Re-processing.")
            except Exception as e:
                logging.error(f"Error loading processed cache {processed_cache_file}: {e}. Re-processing.")

        # If processed cache doesn't exist or is invalid, process raw data
        logging.info(f"Processing raw data for {year}...")
        # Fetch raw data (uses its own cache, except for current year)
        raw_df = fetch_statcast_data_for_year(year, cache_dir)

        if raw_df.empty:
            logging.warning(f"No raw data found or fetched for {year}. Skipping.")
            continue

        # Preprocess the raw data for the year
        processed_df = preprocess_and_feature_engineer(raw_df)

        if processed_df.empty:
            logging.warning(f"Preprocessing failed or resulted in empty data for {year}. Skipping.")
            continue

        # Save the processed data to cache
        try:
            logging.info(f"Saving processed data for {year} to {processed_cache_file}")
            processed_df.to_parquet(processed_cache_file, index=False)
        except Exception as e:
            logging.error(f"Error saving processed data cache file {processed_cache_file}: {e}")

        all_processed_dfs.append(processed_df)

    # --- Combine Data and Proceed ---
    if not all_processed_dfs:
        logging.error("No processed data available after checking all years. Exiting.")
        return

    logging.info("Concatenating processed data from all years...")
    combined_processed_data = pd.concat(all_processed_dfs, ignore_index=True)
    logging.info(f"Total processed pitches combined: {len(combined_processed_data)}")


    # --- Feature Definition (after loading/processing all data) ---
    # Re-define feature lists based on the final combined dataframe, just to be safe
    numerical_features_base = ['release_speed', 'release_pos_x', 'release_pos_z', 'release_extension', 'release_spin_rate', 'spin_axis', 'pfx_x', 'pfx_z', 'inning', 'outs_when_up', 'score_diff', 'balls', 'strikes']
    runner_features = ['on_1b_flag', 'on_2b_flag', 'on_3b_flag']
    categorical_features_base = ['pitch_type', 'stand', 'p_throws', 'inning_topbot_numeric']
    numerical_features = [f for f in numerical_features_base if f in combined_processed_data.columns] + runner_features
    categorical_features = [f for f in categorical_features_base if f in combined_processed_data.columns]
    logging.info(f"Final Numerical Features for Training: {numerical_features}")
    logging.info(f"Final Categorical Features for Training: {categorical_features}")
    num_actions = combined_processed_data[target_action_col].nunique() # Get num actions from data


    # --- Data Splitting (by Game ID on Combined Data) ---
    if 'game_pk' not in combined_processed_data.columns:
         raise ValueError("'game_pk' needed for splitting but not found.")
    game_ids = combined_processed_data['game_pk'].unique()
    logging.info(f"Splitting combined data based on {len(game_ids)} unique games.")
    train_val_splitter = GroupShuffleSplit(n_splits=1, test_size=val_split_size + test_split_size, random_state=42)
    train_idx, val_test_idx = next(train_val_splitter.split(game_ids, groups=game_ids))
    train_game_ids = game_ids[train_idx]; val_test_game_ids = game_ids[val_test_idx]
    if test_split_size > 0 and val_split_size > 0:
         relative_test_size = test_split_size / (val_split_size + test_split_size)
         val_test_splitter = GroupShuffleSplit(n_splits=1, test_size=relative_test_size, random_state=42)
         val_idx_rel, test_idx_rel = next(val_test_splitter.split(val_test_game_ids, groups=val_test_game_ids))
         val_game_ids = val_test_game_ids[val_idx_rel]; test_game_ids = val_test_game_ids[test_idx_rel]
    elif val_split_size > 0: # Only validation split needed
         val_game_ids = val_test_game_ids; test_game_ids = np.array([])
    else: # No val or test split (use all for training - not recommended)
         train_game_ids = game_ids; val_game_ids = np.array([]); test_game_ids = np.array([])

    train_data = combined_processed_data[combined_processed_data['game_pk'].isin(train_game_ids)].copy()
    val_data = combined_processed_data[combined_processed_data['game_pk'].isin(val_game_ids)].copy()
    test_data = combined_processed_data[combined_processed_data['game_pk'].isin(test_game_ids)].copy()
    logging.info(f"Data split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")


    # --- Dataset & DataLoader Creation ---
    if train_data.empty or val_data.empty:
         logging.error("Training or Validation data is empty after split. Cannot proceed.")
         return

    train_loader, val_loader, fitted_scaler, fitted_cat_mappings, fitted_cat_vocab_sizes = create_dataloaders(
        train_data=train_data, val_data=val_data, numerical_features=numerical_features,
        categorical_features=categorical_features, target_action_col=target_action_col,
        target_ev_la_cols=target_ev_la_cols, seq_length=seq_length, batch_size=batch_size, device=device )
    if not fitted_cat_vocab_sizes: logging.error("Cat vocab sizes missing."); return


    # --- Model Initialization ---
    num_numerical = len(numerical_features)
    cat_indices = list(range(num_numerical, num_numerical + len(categorical_features)))
    # Ensure num_actions derived from data is used if different from default
    model_params_updated = model_params.copy()
    model_params_updated['num_actions'] = num_actions

    model = BatterActionTransformer(
        num_numerical_features=num_numerical, cat_feature_indices=cat_indices,
        cat_vocab_sizes=fitted_cat_vocab_sizes, seq_length=seq_length,
        **model_params_updated # Pass hyperparameters dict
    ).to(device)
    logging.info(f"Model initialized ({sum(p.numel() for p in model.parameters() if p.requires_grad):,} params).")


    # --- Optimizer & Loss ---
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    loss_fn = mdn_loss_fn # Use the defined function

    # --- Training Loop ---
    best_val_loss = float('inf')
    epochs_no_improve = 0
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_mae_ev': [], 'val_mae_la': []}

    for epoch in range(1, num_epochs + 1):
        logging.info(f"\n--- Epoch {epoch}/{num_epochs} ---")
        train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device, bip_action_index)
        history['train_loss'].append(train_loss)
        logging.info(f"Epoch {epoch} Train Loss: {train_loss:.4f}")

        val_metrics = evaluate_model(model, val_loader, loss_fn, device, bip_action_index)
        history['val_loss'].append(val_metrics['loss'])
        history['val_accuracy'].append(val_metrics['accuracy'])
        history['val_mae_ev'].append(val_metrics['mae_ev_bip'])
        history['val_mae_la'].append(val_metrics['mae_la_bip'])
        logging.info(f"Epoch {epoch} Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, MAE_EV: {val_metrics['mae_ev_bip']:.2f}, MAE_LA: {val_metrics['mae_la_bip']:.2f}")

        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            epochs_no_improve = 0
            # torch.save(model.state_dict(), os.path.join(cache_dir, 'best_model_checkpoint.pth')) # Save in cache dir
            # logging.info("Val loss improved, saving model.")
        else:
            epochs_no_improve += 1
            logging.info(f"Val loss did not improve for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= early_stopping_patience:
            logging.info(f"Early stopping triggered after {epoch} epochs.")
            break

    # --- Final Steps ---
    total_time = time.time() - start_time
    logging.info(f"\nTraining finished in {total_time / 60:.2f} minutes.")
    logging.info(f"Best Validation Loss: {best_val_loss:.4f}")

    # --- (Optional) Test Set Evaluation ---
    if not test_data.empty:
         logging.info("\nEvaluating on Test Set...")
         # Load best model if saved
         # if os.path.exists(os.path.join(cache_dir, 'best_model_checkpoint.pth')):
         #     model.load_state_dict(torch.load(os.path.join(cache_dir, 'best_model_checkpoint.pth'), map_location=device))

         test_dataset = PitchSequenceDataset(
              test_data, numerical_features, categorical_features,
              target_action_col, target_ev_la_cols, seq_length,
              scaler=fitted_scaler, # Use scaler from train
              cat_mappings=fitted_cat_mappings, # Use mappings from train
              cat_vocab_sizes=fitted_cat_vocab_sizes,
              device=device
          )
         test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False)
         test_metrics = evaluate_model(model, test_loader, loss_fn, device, bip_action_index)
         logging.info(f"Test Loss: {test_metrics['loss']:.4f}, Acc: {test_metrics['accuracy']:.4f}, MAE_EV: {test_metrics['mae_ev_bip']:.2f}, MAE_LA: {test_metrics['mae_la_bip']:.2f}")


    logging.info("run_training completed.")

In [58]:
# Define number of actions based on mapping
num_actions = 4 # Take, S&M, Foul, BIP

# Define other model params based on choices above
model_constructor_params = {
    'd_model': 128,
    'nhead': 8,
    'num_encoder_layers': 4,
    'dim_feedforward': 512,
    'dropout': 0.1,
    'num_actions': num_actions,
    'num_mdn_components': 5,
    # seq_length will be passed separately
    # feature dims/indices/vocabs determined after data loading
}

run_training(
    start_year=2022, # Shortened range for faster testing
    cache_dir="dataset",
    end_year=2023,
    seq_length=10,
    model_params=model_constructor_params, # Pass dict here (will be updated internally)
    batch_size=128, # Increased batch size
    num_epochs=5, # Reduced epochs for faster testing
    learning_rate=5e-5, # Adjusted learning rate
    device_str='cuda' if torch.cuda.is_available() else 'cpu'
)

2025-04-11 23:05:01,151 - INFO - Starting training orchestration with caching...
2025-04-11 23:05:01,154 - INFO - Using CUDA device.
2025-04-11 23:05:01,156 - INFO - 
--- Processing Year: 2022 ---
2025-04-11 23:05:01,157 - INFO - Loading cached processed data for 2022 from dataset/processed_statcast_2022.parquet
2025-04-11 23:05:01,525 - INFO - Loaded 754158 processed pitches from cache for 2022.
2025-04-11 23:05:01,526 - INFO - 
--- Processing Year: 2023 ---
2025-04-11 23:05:01,528 - INFO - Loading cached processed data for 2023 from dataset/processed_statcast_2023.parquet
2025-04-11 23:05:01,872 - INFO - Loaded 751775 processed pitches from cache for 2023.
2025-04-11 23:05:01,873 - INFO - Concatenating processed data from all years...
2025-04-11 23:05:02,094 - INFO - Total processed pitches combined: 1505933
2025-04-11 23:05:02,096 - INFO - Final Numerical Features for Training: ['release_speed', 'release_pos_x', 'release_pos_z', 'release_extension', 'release_spin_rate', 'spin_axis',

Training Epoch:   0%|          | 0/8820 [00:00<?, ?it/s]

2025-04-11 23:13:11,130 - INFO - Epoch 1 Train Loss: 12.7456


Evaluating:   0%|          | 0/884 [00:00<?, ?it/s]

ValueError: Expected parameter logits (Tensor of shape (41, 5)) of distribution Categorical(logits: torch.Size([41, 5])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]], device='cuda:0')