In [53]:
# Add src to sys.path
import sys
import os
sys.path.append(os.path.abspath("../src"))

import importlib
import data_utils, model_utils, shap_utils, ensemble_utils
importlib.reload(data_utils)
importlib.reload(model_utils)
importlib.reload(shap_utils)
importlib.reload(ensemble_utils)

from config import TARGET_VAR, set_seed, device
from data_utils import load_batting_years, build_feature_dataset
from model_utils import PlayerMLP
from shap_utils import explain_shap, get_top_shap_features
from ensemble_utils import load_ensemble_and_predict

import torch
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error
import matplotlib.pyplot as plt

set_seed()
print("Device:", device)


Device: mps


In [None]:
years = load_batting_years()
dataset = build_feature_dataset(years, target_var=TARGET_VAR)

X_train = dataset['X_train']
y_train = dataset['y_train']
X_test = dataset['X_test']
y_test = dataset['y_test']
features = dataset['fe atures']
scaler = dataset['scaler']


✅ Using build_feature_dataset with extended feature engineering
Processing year 2000...
Processing year 2001...
Processing year 2002...
❌ Skipping 2002 due to error: Reindexing only valid with uniquely valued Index objects
Processing year 2003...
Processing year 2004...
❌ Skipping 2004 due to error: Reindexing only valid with uniquely valued Index objects
Processing year 2005...
Processing year 2006...
Processing year 2007...
Processing year 2008...
Processing year 2009...
Processing year 2010...
Processing year 2011...
Processing year 2012...
Processing year 2013...
Processing year 2014...
Processing year 2015...
Processing year 2016...
Processing year 2017...
Processing year 2018...
Processing year 2019...
Processing year 2020...
Processing year 2021...
Processing year 2022...
Processing year 2023...


In [55]:
hidden_dims = [128, 64]
dropout = 0.3
activation = 'relu'
lr = 0.001
epochs = 100
batch_size = 64

model = PlayerMLP(X_train.shape[1], hidden_dims, dropout=dropout, activation=activation).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

for epoch in range(epochs):
    model.train()
    for i in range(0, len(X_train), batch_size):
        xb = X_train[i:i+batch_size].to(device)
        yb = y_train[i:i+batch_size].to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(xb), yb)
        loss.backward()
        optimizer.step()

# Evaluate
model.eval()
with torch.no_grad():
    preds = model(X_test.to(device)).cpu().numpy().flatten()
    y_true = y_test.cpu().numpy().flatten()
    mae = mean_absolute_error(y_true, preds)
    rmse = mean_squared_error(y_true, preds) ** 0.5

print(f"MLP for Total Bases → MAE: {mae:.4f}, RMSE: {rmse:.4f}")


MLP for Total Bases → MAE: 116.9014, RMSE: 502.2217


In [57]:
model = model.to('cpu')
shap_values = explain_shap(model, X_test.cpu().numpy(), features)
top_features, shap_df = get_top_shap_features(shap_values, features, top_n=20)

shap_df.to_csv(f"../output/shap_summary_{TARGET_VAR}.csv", index=False)
print("Top SHAP Features:", top_features)


Permutation explainer: 340it [03:42,  1.46it/s]                         

Top SHAP Features: ['Pos_3.7', 'Pos_11.0', 'Age Rng_recency', 'Age_recency', 'wRC+_w2_recency', 'Pos_8.8', 'Pos_3.8', 'Pos_10.3', 'Pos_10.6', 'wOBA_y1', 'Pos_8.4', 'wRC+_y1', 'R_mean', 'SL-Z (sc)_recency', 'wRC+_w1_recency', 'Pos_', 'CH-Z (sc)_recency', 'SLG_std', 'Hard%_w1_recency', 'CU-Z (sc)_recency']





In [58]:
os.makedirs("../models/total_bases", exist_ok=True)
torch.save(model.state_dict(), f"../models/total_bases/mlp_model_0.pt")
joblib.dump(scaler, f"../models/total_bases/scaler.joblib")
joblib.dump(features, f"../models/total_bases/features.joblib")


['../models/total_bases/features.joblib']

In [59]:
X_input = X_test.cpu().numpy()
predictions = load_ensemble_and_predict(
    X_input_np=X_input,
    device=device,
    model_class=PlayerMLP,
    configs=[(hidden_dims, lr, epochs, batch_size, dropout, 0.2, activation, 'plateau')],
    model_dir="../models/total_bases",
    feature_count=X_input.shape[1]
)

print("Loaded predictions:", predictions[:5])


Loaded predictions: [199.14896   99.272995 260.773    288.02908  217.83435 ]
