In [34]:
# ============================================================
# NFL Game Predictor — Cross-Attention + Pairwise Interaction + SHAP
# ============================================================

import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from ipywidgets import widgets, Output
from IPython.display import display, clear_output

In [35]:
# ------------------------------------------------------------
# 1️⃣ Load Model and Cross-Attention Embeddings
# ------------------------------------------------------------
interaction_model = load_model("XAtt_interaction_model.keras")
crossattn_df = pd.read_csv("crossattn_gru_offseason_embeddings.csv")

attn_cols_A = [c for c in crossattn_df.columns if c.startswith("A_attn_")]
attn_cols_B = [c for c in crossattn_df.columns if c.startswith("B_attn_")]

team_list  = sorted(set(crossattn_df["away_team"]).union(set(crossattn_df["home_team"])))
week_list  = sorted(crossattn_df["week"].unique())

print(f"✅ Model input dim: {interaction_model.input_shape[1]}")
print(f"✅ Loaded {len(crossattn_df)} matchups from cross-attention embeddings")


✅ Model input dim: 129
✅ Loaded 272 matchups from cross-attention embeddings


In [36]:
# ------------------------------------------------------------
# 2️⃣ Build Input Vector with Interaction Features
# ------------------------------------------------------------
def build_input_vector(team_a, team_b, week):
    row = crossattn_df.query("away_team == @team_a & home_team == @team_b & week == @week")
    if row.empty:
        print(f"❌ No matchup found for {team_a} vs {team_b} in week {week}")
        return None, None

    A = row[attn_cols_A].values
    B = row[attn_cols_B].values

    diff = A - B
    prod = A * B
    cos  = np.array([[np.dot(A.flatten(), B.flatten()) /
                     (np.linalg.norm(A) * np.linalg.norm(B))]])

    vec = np.concatenate([A, B, diff, prod, cos], axis=None)
    expected_dim = interaction_model.input_shape[1]
    if vec.size != expected_dim:
        print(f"⚠️ Resizing from {vec.size} → {expected_dim}")
        if vec.size < expected_dim:
            vec = np.pad(vec, (0, expected_dim - vec.size))
        else:
            vec = vec[:expected_dim]
    # feature names
    fnames = (
        [f"A_{c}" for c in attn_cols_A] +
        [f"B_{c}" for c in attn_cols_B] +
        [f"DIFF_{c}" for c in attn_cols_A] +
        [f"PROD_{c}" for c in attn_cols_A] +
        ["COSINE_SIM"]
    )
    return vec.reshape(1, -1), fnames


In [37]:
# ------------------------------------------------------------
# 3️⃣ Predict Matchup
# ------------------------------------------------------------
def predict_matchup(team_a, team_b, week):
    x, fnames = build_input_vector(team_a, team_b, week)
    if x is None:
        return None, None, None
    prob = interaction_model.predict(x, verbose=0)[0][0]
    winner = team_a if prob >= 0.5 else team_b
    return winner, prob, (x, fnames)

In [46]:
# ------------------------------------------------------------
# 4️⃣ SHAP Waterfall Plot
# ------------------------------------------------------------
def plot_shap_waterfall(model, x, feature_names, title):
    # Ensure correct shape for single instance
    x = np.array(x).reshape(1, -1)

    # Use KernelExplainer for dense models
    explainer = shap.KernelExplainer(lambda z: model.predict(z, verbose=0), np.zeros((1, x.shape[1])))
    shap_values = explainer.shap_values(x, nsamples=100)
    base_value = explainer.expected_value

    # Handle list or ndarray from SHAP output
    if isinstance(shap_values, list):
        shap_vals = shap_values[0].flatten()
        base_val = base_value[0] if isinstance(base_value, (list, np.ndarray)) else base_value
    else:
        shap_vals = shap_values.flatten()
        base_val = base_value

    # Build explanation object
    exp = shap.Explanation(
        values=shap_vals,
        base_values=np.array(base_val),
        data=x.flatten(),
        feature_names=feature_names
    )

    # Plot waterfall cleanly
    plt.figure(figsize=(10, 7))
    shap.plots.waterfall(exp, max_display=15, show=False)
    plt.title(title, fontsize=12, pad=20)
    plt.tight_layout()
    plt.show()

In [47]:
# ------------------------------------------------------------
# 5️⃣ UI Elements
# ------------------------------------------------------------
team_a_widget = widgets.Dropdown(options=team_list, description="Away Team:")
team_b_widget = widgets.Dropdown(options=team_list, description="Home Team:")
week_widget = widgets.Dropdown(
    options=[(f"{w//100} Week {w%100}", w) for w in week_list],
    value=week_list[0],
    description="Week:"
)
predict_button = widgets.Button(description="Predict Game", button_style="success")
output = Output()

In [48]:
# ------------------------------------------------------------
# 6️⃣ Button Action
# ------------------------------------------------------------
def on_predict_click(b):
    with output:
        clear_output()
        away, home, week = team_a_widget.value, team_b_widget.value, week_widget.value
        print(f"🏈 Predicting {away} (Away) vs {home} (Home) — Week {week}")
        winner, prob, (x, fnames) = predict_matchup(away, home, week)

        if winner is None:
            print("❌ Could not make a prediction (data missing).")
            return

        print(f"✅ Predicted Winner: {winner}")
        print(f"📊 Away Win Probability: {prob:.3f}")

        print("\n🔍 SHAP Explanation (Feature Contributions):")
        try:
            plot_shap_waterfall(interaction_model, x, fnames,
                                f"{away} vs {home} — Week {week}")
        except Exception as e:
            print(f"SHAP explanation failed: {e}")

predict_button.on_click(on_predict_click)

In [49]:
# ------------------------------------------------------------
# 7️⃣ Display Widgets
# ------------------------------------------------------------
display(team_a_widget, team_b_widget, week_widget, predict_button, output)

Dropdown(description='Away Team:', options=('ATL', 'BUF', 'CAR', 'CHI', 'CIN', 'CLE', 'CLT', 'CRD', 'DAL', 'DE…

Dropdown(description='Home Team:', options=('ATL', 'BUF', 'CAR', 'CHI', 'CIN', 'CLE', 'CLT', 'CRD', 'DAL', 'DE…

Dropdown(description='Week:', options=(('2024 Week 1', np.int64(202401)), ('2024 Week 2', np.int64(202402)), (…

Button(button_style='success', description='Predict Game', style=ButtonStyle())

Output()