In [11]:
import sys
from pathlib import Path

# Go from Scripts/ up to root/
ROOT = Path('NBA_Model').resolve().parents[1]

# Point to the folder that contains load_nba_data.py
DOWNLOAD_DIR = ROOT / "Data" / "nba_data" / "function_for_download"

# Add that folder to Python's import path
sys.path.append(str(DOWNLOAD_DIR))

from load_nba_data import load_nba_data

# Now you can use load_nba_data(...)
pbp_df = load_nba_data(
    seasons=2024,
    data="nbastats",
    seasontype="rg",
    in_memory=True,
    use_pandas=True,
)

In [18]:
import pandas as pd

def aggregate_by_period(
    pbp: pd.DataFrame,
    group_overtime: bool = True
) -> pd.DataFrame:
    """
    Aggregate play-by-play into:
        GAME_ID | PERIOD | TEAM_ID | TEAM_NAME | IS_VISITOR |
        FG2_MADE | FG2_MISSED | FG3_MADE | FG3_MISSED | FT_MADE | FT_MISSED

    Requirements in `pbp`:
        - GAME_ID
        - PERIOD
        - EVENTMSGTYPE (1=FG made, 2=FG missed, 3=FT)
        - HOMEDESCRIPTION
        - VISITORDESCRIPTION
        - PLAYER1_TEAM_ID
        - PLAYER1_TEAM_NICKNAME

    Parameters
    ----------
    pbp : pd.DataFrame
        Full play-by-play table.
    group_overtime : bool, default True
        If True, all periods >=5 are grouped into a single '5+' bucket.

    Returns
    -------
    pd.DataFrame
        Columns:
        - GAME_ID
        - PERIOD              (1,2,3,4 and '5+' if group_overtime=True)
        - TEAM_ID
        - TEAM_NAME
        - IS_VISITOR          (1 = visitor, 0 = home)
        - FG2_MADE
        - FG2_MISSED
        - FG3_MADE
        - FG3_MISSED
        - FT_MADE
        - FT_MISSED
    """

    pbp = pbp.copy()
    records = []

    # Process game by game
    for game_id, g in pbp.groupby("GAME_ID"):
        g = g.copy()

        # Only keep events that matter for shooting stats
        scoring = g[g["EVENTMSGTYPE"].isin([1, 2, 3])].copy()
        if scoring.empty:
            continue

        scoring["PERIOD"] = scoring["PERIOD"].astype(int)

        # Period grouping: 1–4 as-is, 5+ merged into '5+'
        if group_overtime:
            scoring["PERIOD_GROUP"] = scoring["PERIOD"].apply(
                lambda p: p if p <= 4 else "5+"
            )
        else:
            scoring["PERIOD_GROUP"] = scoring["PERIOD"]

        # Combined description (used for 3PT / FT miss detection)
        desc = (
            scoring["HOMEDESCRIPTION"].fillna("")
            + " "
            + scoring["VISITORDESCRIPTION"].fillna("")
        )

        # --- Field goals: split into 2s vs 3s ---

        is_fg = scoring["EVENTMSGTYPE"].isin([1, 2])
        is_fg3 = is_fg & desc.str.contains("3PT", case=False, na=False)
        is_fg2 = is_fg & ~is_fg3

        scoring["FG2_MADE_FLAG"] = ((scoring["EVENTMSGTYPE"] == 1) & is_fg2).astype(int)
        scoring["FG2_MISSED_FLAG"] = ((scoring["EVENTMSGTYPE"] == 2) & is_fg2).astype(int)
        scoring["FG3_MADE_FLAG"] = ((scoring["EVENTMSGTYPE"] == 1) & is_fg3).astype(int)
        scoring["FG3_MISSED_FLAG"] = ((scoring["EVENTMSGTYPE"] == 2) & is_fg3).astype(int)

        # --- Free throws: made vs missed ---

        is_ft = scoring["EVENTMSGTYPE"] == 3
        ft_miss_mask = is_ft & desc.str.contains("MISS", case=False, na=False)
        scoring["FT_MISSED_FLAG"] = ft_miss_mask.astype(int)
        scoring["FT_MADE_FLAG"] = (is_ft & ~ft_miss_mask).astype(int)

        # Team info for each scoring event
        scoring = scoring[scoring["PLAYER1_TEAM_ID"].notna()].copy()
        scoring["TEAM_ID"] = scoring["PLAYER1_TEAM_ID"].astype(int)

        # Map TEAM_ID -> TEAM_NAME (nickname) for this game
        name_map = (
            scoring
            .groupby("TEAM_ID")["PLAYER1_TEAM_NICKNAME"]
            .agg(lambda x: x.dropna().iloc[0] if x.dropna().size > 0 else None)
            .to_dict()
        )
        scoring["TEAM_NAME"] = scoring["TEAM_ID"].map(name_map)

        # ---- Infer home vs visitor ----
        home_ids = (
            scoring.loc[
                scoring["HOMEDESCRIPTION"].notna() &
                scoring["PLAYER1_TEAM_ID"].notna(),
                "PLAYER1_TEAM_ID"
            ]
            .astype(int)
            .unique()
        )
        visitor_ids = (
            scoring.loc[
                scoring["VISITORDESCRIPTION"].notna() &
                scoring["PLAYER1_TEAM_ID"].notna(),
                "PLAYER1_TEAM_ID"
            ]
            .astype(int)
            .unique()
        )

        home_id = int(home_ids[0]) if len(home_ids) > 0 else None
        visitor_id = int(visitor_ids[0]) if len(visitor_ids) > 0 else None

        team_ids = scoring["TEAM_ID"].unique()
        if home_id is not None and visitor_id is None and len(team_ids) == 2:
            visitor_id = int([t for t in team_ids if t != home_id][0])
        if visitor_id is not None and home_id is None and len(team_ids) == 2:
            home_id = int([t for t in team_ids if t != visitor_id][0])

        def mark_visitor(tid: int) -> int:
            if visitor_id is not None and tid == visitor_id:
                return 1
            if home_id is not None and tid == home_id:
                return 0
            return 0  # default if unsure

        scoring["IS_VISITOR"] = scoring["TEAM_ID"].apply(mark_visitor)

        # --- Aggregate by game, period, team ---

        agg_game = (
            scoring
            .groupby(
                ["GAME_ID", "PERIOD_GROUP", "TEAM_ID", "TEAM_NAME", "IS_VISITOR"],
                dropna=False
            )
            .agg(
                FG2_MADE=("FG2_MADE_FLAG", "sum"),
                FG2_MISSED=("FG2_MISSED_FLAG", "sum"),
                FG3_MADE=("FG3_MADE_FLAG", "sum"),
                FG3_MISSED=("FG3_MISSED_FLAG", "sum"),
                FT_MADE=("FT_MADE_FLAG", "sum"),
                FT_MISSED=("FT_MISSED_FLAG", "sum"),
            )
            .reset_index()
            .rename(columns={"PERIOD_GROUP": "PERIOD"})
        )

        records.append(agg_game)

    if not records:
        return pd.DataFrame(
            columns=[
                "GAME_ID", "PERIOD", "TEAM_ID", "TEAM_NAME", "IS_VISITOR",
                "FG2_MADE", "FG2_MISSED", "FG3_MADE", "FG3_MISSED",
                "FT_MADE", "FT_MISSED"
            ]
        )

    result = pd.concat(records, ignore_index=True)

    # Nice sorting: by GAME_ID, then PERIOD (1–4 then '5+'), then visitor/home
    def _period_sort_key(p):
        if isinstance(p, (int, float)) and p <= 4:
            return (0, int(p))
        if p == "5+":
            return (1, 5)
        return (2, 999)

    result["__key__"] = result["PERIOD"].map(_period_sort_key)
    result = (
        result
        .sort_values(["GAME_ID", "__key__", "IS_VISITOR"])
        .drop(columns="__key__")
    )

    return result


In [19]:
period_stats = aggregate_by_period(pbp_df)

In [20]:
period_stats_2024

Unnamed: 0,GAME_ID,PERIOD,TEAM_ID,TEAM_NAME,IS_VISITOR,FG2_MADE,FG2_MISSED,FG3_MADE,FG3_MISSED,FT_MADE,FT_MISSED
1,22400001,1,1610612738,Celtics,0,7,2,3,9,8,1
0,22400001,1,1610612737,Hawks,1,10,8,2,6,3,4
3,22400001,2,1610612738,Celtics,0,5,3,6,7,6,0
2,22400001,2,1610612737,Hawks,1,7,5,3,8,2,2
5,22400001,3,1610612738,Celtics,0,6,1,3,7,5,1
...,...,...,...,...,...,...,...,...,...,...,...
9954,22401230,2,1610612745,Rockets,1,8,9,2,8,0,1
9957,22401230,3,1610612760,Thunder,0,7,9,5,0,5,0
9956,22401230,3,1610612745,Rockets,1,4,3,5,7,4,2
9959,22401230,4,1610612760,Thunder,0,6,4,4,3,12,0
