# Fit a global model on the data

- For any given year of data, the outcome of a game can be described as a bayesian model of the number of points

In [1]:
import os
os.environ["JAX_PLATFORMS"]="cpu"

In [2]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable, TypedDict, Optional

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# For better numerical stability in Poisson log-rates
jax.config.update("jax_enable_x64", True)
numpyro.set_host_device_count(1)

## Utils

In [3]:
class SeasonDF(TypedDict):
    gameID: int
    date: str
    home_team: str
    away_team: str
    home_score: int
    away_score: int
    finalMessage: str
    start_time: str
    url: str
    conference_home: str
    conference_away: str

@dataclass(frozen=True)
class EncodedSeason:
    home_idx: jnp.ndarray     # shape (n_games,), int32
    away_idx: jnp.ndarray     # shape (n_games,), int32
    y_home: jnp.ndarray       # shape (n_games,), int32
    y_away: jnp.ndarray       # shape (n_games,), int32
    n_teams: int
    id_to_team: List[str]
    team_to_id: Dict[str, int]


In [4]:
def _normalize_team_name(name: str) -> str:
    # light normalization to avoid accidental dupes
    return " ".join(name.strip().split())

def build_team_indexer(df: pd.DataFrame) -> Tuple[Dict[str, int], List[str]]:
    assert {"home_team", "away_team"}.issubset(df.columns), "Missing required columns."
    teams: List[str] = sorted({
        _normalize_team_name(t) for t in
        df["home_team"].astype(str).tolist() + df["away_team"].astype(str).tolist()
    })
    team_to_id = {t: i for i, t in enumerate(teams)}
    id_to_team = teams
    return team_to_id, id_to_team

def encode_season(df: pd.DataFrame, team_to_id: Dict[str, int]) -> EncodedSeason:

    req = {"home_team","away_team","home_score","away_score"}
    missing = req - set(df.columns)
    if missing:
        raise ValueError(f"DataFrame missing columns: {missing}")

    home_idx = df["home_team"].astype(str).map(lambda x: team_to_id[_normalize_team_name(x)]).astype(np.int32).to_numpy()
    away_idx = df["away_team"].astype(str).map(lambda x: team_to_id[_normalize_team_name(x)]).astype(np.int32).to_numpy()

    y_home = df["home_score"].astype(np.int32).to_numpy()
    y_away = df["away_score"].astype(np.int32).to_numpy()

    n_teams = len(team_to_id)
    id_to_team = [None] * n_teams  # type: ignore
    for t, i in team_to_id.items():
        id_to_team[i] = t  # fill by index

    return EncodedSeason(
        home_idx=jnp.array(home_idx),
        away_idx=jnp.array(away_idx),
        y_home=jnp.array(y_home),
        y_away=jnp.array(y_away),
        n_teams=n_teams,
        id_to_team=id_to_team,  # type: ignore
        team_to_id=team_to_id,
    )


In [5]:
def hierarchal_model(
    home_idx: jnp.ndarray,
    away_idx: jnp.ndarray,
    y_home: jnp.ndarray,
    y_away: jnp.ndarray,
    n_teams: int,
):
    # Intercept: weakly-informative
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 5.0))

    # Hierarchy scales (positive)
    sigma_off = numpyro.sample("sigma_off", dist.HalfNormal(1.0))
    sigma_def = numpyro.sample("sigma_def", dist.HalfNormal(1.0))
    tau_h     = numpyro.sample("tau_h",     dist.HalfNormal(1.0))

    # League-level home advantage mean
    h_mu = numpyro.sample("h_mu", dist.Normal(0.0, 1.0))

    with numpyro.plate("team", n_teams):
        # Team-level effects drawn from distributions (not values)
        h       = numpyro.sample("h",       dist.Normal(h_mu, tau_h))
        offense = numpyro.sample("offense", dist.Normal(0.0, sigma_off))
        defense = numpyro.sample("defense", dist.Normal(0.0, sigma_def))

    # Linear predictors
    eta_home = alpha + offense[home_idx] - defense[away_idx] + h[home_idx]
    eta_away = alpha + offense[away_idx] - defense[home_idx]

    # Likelihood
    numpyro.sample("y_home", dist.Poisson(jnp.exp(eta_home)), obs=y_home)
    numpyro.sample("y_away", dist.Poisson(jnp.exp(eta_away)), obs=y_away)

def fit_hierarchal_model(encoded: EncodedSeason, seed: int = 0, num_chains: int = 2, num_warmup: int = 100, num_samples: int = 1000):
    """
    Fit the model_hier_offdef_home model on an EncodedSeason.
    Returns the MCMC object and posterior samples.
    """
    home_idx = jnp.array(encoded.home_idx)
    away_idx = jnp.array(encoded.away_idx)
    y_home = jnp.array(encoded.y_home)
    y_away = jnp.array(encoded.y_away)
    n_teams = encoded.n_teams

    kernel = numpyro.infer.NUTS(hierarchal_model)
    mcmc = numpyro.infer.MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=True,
        chain_method="sequential"
    )
    rng_key = jax.random.PRNGKey(seed)
    mcmc.run(
        rng_key,
        home_idx=home_idx,
        away_idx=away_idx,
        y_home=y_home,
        y_away=y_away,
        n_teams=n_teams,
    )
    return mcmc, mcmc.get_samples()

## Global prior

In [6]:
df = pd.read_csv("/Users/arhamhabib/Projects/PostPick/data/ncaab_2021_men_d1.csv")
df.dropna(inplace=True, subset=['home_team', 'away_team', 'home_score', 'away_score'])
df

Unnamed: 0,gameID,date,home_team,away_team,home_score,away_score,finalMessage,start_time,url,conference_home,conference_away
0,1976847.0,2021-11-01,Florida,Embry-Riddle (FL),80.0,57.0,FINAL,12:00AM ET,5896325,SEC,Sunshine State
1,1976852.0,2021-11-01,Alcorn,Xavier (LA),67.0,45.0,FINAL,01:00AM ET,5896329,SWAC,NON-NCAA ORG
2,1976853.0,2021-11-01,Texas,Texas Lutheran,96.0,33.0,FINAL,01:00AM ET,5896330,Big 12,SCAC
3,1976854.0,2021-11-01,Oklahoma,Rogers St.,106.0,57.0,FINAL,01:00AM ET,5896331,Big 12,Mid-America Intercollegiate
4,1994143.0,2021-11-01,UAB,AUM,101.0,70.0,FINAL,01:00AM ET,5907144,C-USA,Gulf South
...,...,...,...,...,...,...,...,...,...,...,...
6114,2070074.0,2022-03-31,Texas A&M,Xavier,72.0,73.0,FINAL,07:00PM ET,5996538,SEC,Big East
6115,2076366.0,2022-04-01,Coastal Carolina,Fresno St.,74.0,85.0,FINAL,06:00PM ET,5996410,Sun Belt,Mountain West
6116,2070510.0,2022-04-02,Kansas,Villanova,81.0,65.0,FINAL,06:09PM ET,5958444,Big 12,Big East
6117,2070511.0,2022-04-02,Duke,North Carolina,77.0,81.0,FINAL,08:51PM ET,5958443,ACC,ACC


In [7]:
# 2) Build global index, encode, and fit EB hypers on the FULL SEASON (replace df with your full season df)
team_to_id, id_to_team = build_team_indexer(df)
encoded = encode_season(df, team_to_id)

In [9]:
mcmc, samples = fit_hierarchal_model(encoded)

sample: 100%|██████████| 1100/1100 [10:37<00:00,  1.73it/s, 1023 steps of size 4.05e-03. acc. prob=0.90]
sample: 100%|██████████| 1100/1100 [12:13<00:00,  1.50it/s, 1023 steps of size 2.27e-03. acc. prob=0.91]


In [11]:
for param, sample in samples.items():
    print(param, sample.mean())

alpha 4.242628477075259
defense -9.889852346326115e-05
h 0.04431974025561814
h_mu 0.04429140438909339
offense 0.00023130194294340034
sigma_def 0.14980605874463465
sigma_off 0.14204310031692552
tau_h 0.027227108447182707


In [None]:
mcmc