In [None]:
# ======================== IMPORTS & CONFIG ========================
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import polars as pl
from lightgbm import LGBMRegressor
import xgboost as xgb
import joblib

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# Data directory
DATA_DIR = "/kaggle/input/nfl-big-data-bowl-2026-prediction"

# Inference server setup
sys.path.append(DATA_DIR)
try:
    from kaggle_evaluation.nfl_inference_server import NFLInferenceServer
    HAS_EVAL_SERVER = True
except ModuleNotFoundError:
    NFLInferenceServer = None
    HAS_EVAL_SERVER = False
    print(
        "WARNING: kaggle_evaluation not found. "
  
    )

# Random seeds
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("=" * 70)
print("NFL BIG DATA BOWL 2026 - LGBM + XGB + GNN (rich features, Eval API)")
print("=" * 70)
print("DEVICE:", DEVICE)

# ======================== FEATURE LISTS ===========================
FEATURES = [
     # geometry of the current state
    "x_last", "y_last",
    "s", "a", "o", "dir",

    # components of speed and acceleration along the field axes

    "vx", "vy",
    "ax_comp", "ay_comp",
    "dir_sin", "dir_cos",
    "o_sin", "o_cos",

    # temporal features

    "frame_offset", "time_offset",
    "num_frames_output",
    "frac_of_flight",
    "frames_left",
    "time_to_land",
    "remaining_flight_frac",

    # relation to the ball landing point

    "dist_to_ball_land",
    "angle_to_ball_land",
    "dist_to_ball_land_per_frame",
    "cos_dir_to_ball",
    "cos_orient_to_ball",

    # relative coordinates and velocity projections toward the ball

    "x_rel_ball",
    "y_rel_ball",
    "v_toward_ball",
    "v_across_ball",

    # standardized by play direction along x

    "x_std",
    "ball_land_x_std",
    "dx_to_land_std",
    "dy_to_land",

    # position by field width and length

    "dist_to_sideline",
    "dist_to_center",
    "yardline_100",
    "yardline_norm",
    "dist_to_endzone",

    # target receiver and velocity projections toward him
    "dist_to_target_last",
    "dx_to_target_last",
    "dy_to_target_last",
    "angle_to_target",
    "cos_dir_to_target",
    "cos_orient_to_target",
    "v_toward_target",
    "v_across_target",
    "is_target",

    # play / player context

    "absolute_yardline_number",
    "player_height", "player_weight",
    "bmi",

    # pairwise context (approximation of attention/GNN in tabular form)

    "min_dist_teammate",
    "mean_dist_teammate",
    "min_dist_opponent",
    "mean_dist_opponent",
]

CAT_FEATS = ["player_role", "player_side", "play_direction"]

## For GNN we use the same features

GNN_NUM_FEATS = FEATURES
GNN_CAT_CODE_COLS = [f"{c}_cat" for c in CAT_FEATS]

## Which columns we pull from "the last observation before the pass"

BASE_COLS = [
    "game_id", "play_id", "nfl_id",
    "x_last", "y_last",
    "s", "a", "o", "dir",
    "player_role", "player_side",
    "num_frames_output",
    "ball_land_x", "ball_land_y",
    "target_last_x", "target_last_y", "target_nfl_id",
    "play_direction",
    "absolute_yardline_number",
    "player_height", "player_weight",
    "player_to_predict",
    # pairwise features from last_obs
    "min_dist_teammate",
    "mean_dist_teammate",
    "min_dist_opponent",
    "mean_dist_opponent",
]

# ======================== GLOBAL STATE ============================
# LGBM ensemble
LGBM_MODELS_DX = []
LGBM_MODELS_DY = []

# XGBoost
XGB_MODELS_DX = []
XGB_MODELS_DY = []

# GNN 
GNN_MODEL = None

# Categorical dictionaries for embeddings/encoding

CAT_CATEGORY_MAPS = {}
CAT_CARD_SIZES = {}

# Normalization of numerical features for GNN
GNN_NUM_MEAN = None
GNN_NUM_STD = None

# ВTarget-player weights in training

TARGET_WEIGHT_TREE = 1.0   # LGBM/XGB weight for target receiver

TARGET_WEIGHT_GNN = 1.0    # GNN weight for target receiver

# LGBM ensemble size
LGBM_N_MODELS = 3

# GNN hyperparameters

TRAIN_GNN = True           # can set False to speed up training since GNN has not prooved to help our approach
GNN_EPOCHS = 5
GNN_BATCH_SIZE = 512
GNN_HIDDEN_DIM = 256
GNN_EMB_DIM = 8
GNN_GAT_HEADS = 2
GNN_LR = 1e-3
GNN_DROPOUT = 0.2

# Weights for ensemble (LGBM / XGB / GNN): These were used for final predictions. In otherwords this is only LightGBM
ENSEMBLE_WEIGHTS = {"lgbm": 1.0, "xgb": 0.0, "gnn": 0.0}

# ======================== DATA LOADING ===========================
def load_train(data_dir: str):
    """train input/output."""
    train_dir = os.path.join(data_dir, "train")
    df_in_list = []
    df_out_list = []

    print("\n[1/4] Loading training inputs/outputs...")
    for w in range(1, 19):
        ip = os.path.join(train_dir, f"input_2023_w{w:02d}.csv")
        op = os.path.join(train_dir, f"output_2023_w{w:02d}.csv")
        if os.path.exists(ip) and os.path.exists(op):
            df_i = pd.read_csv(ip)
            df_o = pd.read_csv(op)
            df_in_list.append(df_i)
            df_out_list.append(df_o)
            print(f" Week {w:02d}: input {df_i.shape}, output {df_o.shape}")
        else:
            print(f" Week {w:02d}: files not found, skipping")

    if not df_in_list or not df_out_list:
        raise FileNotFoundError(
            f"No train CSV files found in {train_dir}. "
        )

    df_in = pd.concat(df_in_list, ignore_index=True)
    df_out = pd.concat(df_out_list, ignore_index=True)
    print("Train inputs:", df_in.shape, "train outputs:", df_out.shape)
    return df_in, df_out


# ======================== FEATURE ENGINEERING HELPERS =============
def height_to_inches(ht):
    """Convert height from the format '6-2' into inches (6*12 + 2)."""
    if isinstance(ht, str) and "-" in ht:
        try:
            feet, inches = ht.split("-")
            return int(feet) * 12 + int(inches)
        except Exception:
            return np.nan
    return np.nan


def add_team_distance_features(df_last: pd.DataFrame) -> pd.DataFrame:
    """
    Pairwise features for players within (game_id, play_id):
    - minimum/mean distance to teammates,
    - minimum/mean distance to opponents.

    """
    if "player_side" not in df_last.columns:
        df_last["min_dist_teammate"] = 0.0
        df_last["mean_dist_teammate"] = 0.0
        df_last["min_dist_opponent"] = 0.0
        df_last["mean_dist_opponent"] = 0.0
        return df_last

    groups = []
    for (_, _), g in df_last.groupby(["game_id", "play_id"], as_index=False):
        g = g.copy()
        xs = g["x_last"].to_numpy()
        ys = g["y_last"].to_numpy()
        sides = g["player_side"].astype("category").cat.codes.to_numpy()

        dx = xs[:, None] - xs[None, :]
        dy = ys[:, None] - ys[None, :]
        dist = np.sqrt(dx * dx + dy * dy)
        np.fill_diagonal(dist, np.inf)

        same = sides[:, None] == sides[None, :]
        opp = ~same

        dist_tm = np.where(same, dist, np.inf)
        min_dist_tm = dist_tm.min(axis=1)
        min_dist_tm[np.isinf(min_dist_tm)] = 0.0

        sum_tm = np.where(same, dist, 0.0).sum(axis=1)
        cnt_tm = same.sum(axis=1) - 1
        mean_tm = np.divide(
            sum_tm,
            np.maximum(cnt_tm, 1),
            out=np.zeros_like(sum_tm),
            where=cnt_tm > 0,
        )

        dist_op = np.where(opp, dist, np.inf)
        min_dist_op = dist_op.min(axis=1)
        min_dist_op[np.isinf(min_dist_op)] = 0.0

        sum_op = np.where(opp, dist, 0.0).sum(axis=1)
        cnt_op = opp.sum(axis=1)
        mean_op = np.divide(
            sum_op,
            np.maximum(cnt_op, 1),
            out=np.zeros_like(sum_op),
            where=cnt_op > 0,
        )

        g["min_dist_teammate"] = min_dist_tm
        g["mean_dist_teammate"] = mean_tm
        g["min_dist_opponent"] = min_dist_op
        g["mean_dist_opponent"] = mean_op

        groups.append(g)

    return pd.concat(groups, ignore_index=True)


def prepare_last_obs(df: pd.DataFrame) -> pd.DataFrame:
    """
    Take the last observation by (game_id, play_id, nfl_id)  
    and rename x,y -> x_last, y_last. Also convert height  
    and add pairwise features.
    """
    df_last = (
        df.sort_values(["game_id", "play_id", "nfl_id", "frame_id"])
          .groupby(["game_id", "play_id", "nfl_id"], as_index=False)
          .last()
    )
    df_last = df_last.rename(columns={"x": "x_last", "y": "y_last"})

    if "player_height" in df_last.columns:
        df_last["player_height"] = df_last["player_height"].apply(height_to_inches)
    else:
        df_last["player_height"] = np.nan

    df_last = add_team_distance_features(df_last)
    return df_last


def add_target_info(df_last: pd.DataFrame) -> pd.DataFrame:
    """For each play, add the coordinates of the target receiver."""
    mask_target = df_last.get("player_role", "") == "Targeted Receiver"
    targets = df_last.loc[
        mask_target,
        ["game_id", "play_id", "nfl_id", "x_last", "y_last"],
    ].copy()

    targets = targets.rename(
        columns={
            "nfl_id": "target_nfl_id",
            "x_last": "target_last_x",
            "y_last": "target_last_y",
        }
    )

    df_last = df_last.merge(
        targets[["game_id", "play_id", "target_last_x", "target_last_y", "target_nfl_id"]],
        on=["game_id", "play_id"],
        how="left",
    )
    return df_last


def mirror_raw(df_raw: pd.DataFrame) -> pd.DataFrame:
    """
    Mirror the play along the Y-axis

    """
    df = df_raw.copy()

    if "y_last" in df.columns:
        df["y_last"] = 53.3 - df["y_last"]
    if "y" in df.columns:
        df["y"] = 53.3 - df["y"]
    if "ball_land_y" in df.columns:
        df["ball_land_y"] = 53.3 - df["ball_land_y"]
    if "target_last_y" in df.columns:
        df["target_last_y"] = 53.3 - df["target_last_y"]

    for ang_col in ["dir", "o"]:
        if ang_col in df.columns:
            df[ang_col] = (-df[ang_col]) % 360.0

    return df


def create_features(df: pd.DataFrame, is_train: bool = True) -> pd.DataFrame:
    """
    Build features (as in your dataset), plus dx, dy if is_train=True.
    """
    df = df.copy()

    # -------- Basic quantities and angles --------
    s = df["s"].fillna(0.0)
    a = df["a"].fillna(0.0)
    dir_rad = np.deg2rad(df["dir"].fillna(0.0))
    o_rad = np.deg2rad(df["o"].fillna(0.0))

    # -------- Components of speed and acceleration --------
    df["vx"] = s * np.cos(dir_rad)
    df["vy"] = s * np.sin(dir_rad)
    df["ax_comp"] = a * np.cos(dir_rad)
    df["ay_comp"] = a * np.sin(dir_rad)

    df["dir_sin"] = np.sin(dir_rad)
    df["dir_cos"] = np.cos(dir_rad)
    df["o_sin"] = np.sin(o_rad)
    df["o_cos"] = np.cos(o_rad)

    # -------- Flight time / frame --------
    if "frame_id" in df.columns:
        df["frame_offset"] = df["frame_id"]
    else:
        df["frame_offset"] = 0

    df["time_offset"] = df["frame_offset"] / 10.0  # 10 frames per second

    if "num_frames_output" in df.columns:
        nfo = df["num_frames_output"].replace(0, np.nan)
        df["frac_of_flight"] = (df["frame_offset"] / nfo).clip(lower=0, upper=1)
        df["frac_of_flight"] = df["frac_of_flight"].fillna(0.0)
        df["frames_left"] = (nfo - df["frame_offset"]).clip(lower=0).fillna(0.0)
    else:
        df["frac_of_flight"] = 0.0
        df["frames_left"] = 0.0

    df["time_to_land"] = df["frames_left"] / 10.0
    df["remaining_flight_frac"] = (1.0 - df["frac_of_flight"]).clip(lower=0.0, upper=1.0)

    # -------- Geometry relative to the ball landing point --------
    df["dist_to_ball_land"] = np.sqrt(
        (df["ball_land_x"] - df["x_last"]) ** 2 +
        (df["ball_land_y"] - df["y_last"]) ** 2
    )
    df["angle_to_ball_land"] = np.arctan2(
        df["ball_land_y"] - df["y_last"],
        df["ball_land_x"] - df["x_last"],
    )

    frames_left_safe = df["frames_left"].replace(0, np.nan)
    df["dist_to_ball_land_per_frame"] = df["dist_to_ball_land"] / frames_left_safe
    df["dist_to_ball_land_per_frame"] = (
        df["dist_to_ball_land_per_frame"]
        .replace([np.inf, -np.inf], np.nan)
        .fillna(0.0)
    )

    df["cos_dir_to_ball"] = np.cos(df["angle_to_ball_land"] - dir_rad)
    df["cos_orient_to_ball"] = np.cos(df["angle_to_ball_land"] - o_rad)

    # -------- Standardize coordinates by the direction of the play --------
    play_dir = df.get("play_direction", "right").fillna("right")
    is_left = (play_dir == "left").astype(int)

    df["x_std"] = np.where(is_left == 1, 120.0 - df["x_last"], df["x_last"])
    df["ball_land_x_std"] = np.where(
        is_left == 1, 120.0 - df["ball_land_x"], df["ball_land_x"]
    )

    df["dx_to_land_std"] = df["ball_land_x_std"] - df["x_std"]
    df["dy_to_land"] = df["ball_land_y"] - df["y_last"]

    # -------- Position by field width/length --------
    df["dist_to_sideline"] = np.minimum(df["y_last"], 53.3 - df["y_last"])
    df["dist_to_center"] = np.abs(df["y_last"] - 53.3 / 2.0)

    yard = df["absolute_yardline_number"].fillna(50.0)
    yard_100 = yard.clip(lower=0.0, upper=100.0)
    df["yardline_100"] = yard_100
    df["yardline_norm"] = yard_100 / 100.0
    df["dist_to_endzone"] = 100.0 - yard_100

    # -------- Target receiver --------
    df["dist_to_target_last"] = np.sqrt(
        (df["target_last_x"] - df["x_last"]) ** 2 +
        (df["target_last_y"] - df["y_last"]) ** 2
    )

    df["dx_to_target_last"] = df["target_last_x"] - df["x_last"]
    df["dy_to_target_last"] = df["target_last_y"] - df["y_last"]
    df["angle_to_target"] = np.arctan2(
        df["target_last_y"] - df["y_last"],
        df["target_last_x"] - df["x_last"],
    )

    df["cos_dir_to_target"] = np.cos(df["angle_to_target"] - dir_rad)
    df["cos_orient_to_target"] = np.cos(df["angle_to_target"] - o_rad)

    df["is_target"] = (df["nfl_id"] == df["target_nfl_id"]).astype(int)

    # -------- Relative coordinates and velocity projections --------
    df["x_rel_ball"] = df["x_last"] - df["ball_land_x"]
    df["y_rel_ball"] = df["y_last"] - df["ball_land_y"]

    vx = df["vx"]
    vy = df["vy"]

    ball_cos = np.cos(df["angle_to_ball_land"])
    ball_sin = np.sin(df["angle_to_ball_land"])
    df["v_toward_ball"] = vx * ball_cos + vy * ball_sin
    df["v_across_ball"] = vx * (-ball_sin) + vy * ball_cos

    tgt_cos = np.cos(df["angle_to_target"])
    tgt_sin = np.sin(df["angle_to_target"])
    df["v_toward_target"] = vx * tgt_cos + vy * tgt_sin
    df["v_across_target"] = vx * (-tgt_sin) + vy * tgt_cos

    df[["v_toward_ball", "v_across_ball", "v_toward_target", "v_across_target"]] = (
        df[["v_toward_ball", "v_across_ball", "v_toward_target", "v_across_target"]]
        .replace([np.inf, -np.inf], np.nan)
        .fillna(0.0)
    )

    # -------- Player physics --------
    h = df["player_height"].replace(0, np.nan)
    w = df["player_weight"].replace(0, np.nan)
    df["bmi"] = 703.0 * (w / (h ** 2))
    df["bmi"] = df["bmi"].replace([np.inf, -np.inf], np.nan).fillna(0.0)
    
    # -------- Targets only in train --------
    if is_train:
        df["dx"] = df["x"] - df["x_last"]
        df["dy"] = df["y"] - df["y_last"]

    return df


def prepare_train(df_in: pd.DataFrame, df_out: pd.DataFrame):
    """
    Form training datasets:
    - train_tree: with symmetric augmentation (for LGBM/XGB),
    - train_gnn: without augmentation (for GNN).
    """
    print("\n[2/4] Preparing training features...")

    df_out_local = df_out.copy()
    if "frame_id" not in df_out_local.columns:
        df_out_local["frame_id"] = (
            df_out_local.groupby(["game_id", "play_id", "nfl_id"]).cumcount()
        )

    last_obs = prepare_last_obs(df_in)
    last_obs = add_target_info(last_obs)

    cols_to_keep_existing = [c for c in BASE_COLS if c in last_obs.columns]

    train_raw = df_out_local.merge(
        last_obs[cols_to_keep_existing],
        on=["game_id", "play_id", "nfl_id"],
        how="left",
    )

    if "player_to_predict" in train_raw.columns:
        before = len(train_raw)
        train_raw = train_raw[train_raw["player_to_predict"].astype(bool)].copy()
        after = len(train_raw)
        print(f" Filtered to player_to_predict==True: {before} -> {after} rows")

    train_main = create_features(train_raw, is_train=True)

    # Symmetric augmentation along the Y-axis
    train_mirror_raw = mirror_raw(train_raw)
    train_mirror = create_features(train_mirror_raw, is_train=True)

    train_tree = pd.concat([train_main, train_mirror], ignore_index=True)
    print(f" After symmetry augmentation (tree): {len(train_main)} -> {len(train_tree)} rows")

    # For GNN
    train_gnn = train_main.copy()

    return train_tree, train_gnn


# ======================== GNN DATASET ============================
class GraphDataset(Dataset):
    def __init__(self, df: pd.DataFrame, num_cols, cat_code_cols):
        self.num_cols = list(num_cols)
        self.cat_code_cols = list(cat_code_cols)

        self.X_num = df[self.num_cols].to_numpy(np.float32)
        if self.cat_code_cols:
            self.X_cat = df[self.cat_code_cols].to_numpy(np.int64)
        else:
            self.X_cat = None
        self.y = df[["dx", "dy"]].to_numpy(np.float32)
        self.w = (1.0 + TARGET_WEIGHT_GNN * df["is_target"].to_numpy(np.float32)).astype(
            np.float32
        )

        if "frame_id" in df.columns:
            gkeys = (
                df["game_id"].astype(str)
                + "_"
                + df["play_id"].astype(str)
                + "_"
                + df["frame_id"].astype(str)
            )
        else:
            gkeys = (
                df["game_id"].astype(str)
                + "_"
                + df["play_id"].astype(str)
            )

        self.graph_ids, _ = pd.factorize(gkeys)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        if self.X_cat is not None:
            return (
                self.X_num[idx],
                self.X_cat[idx],
                self.y[idx],
                self.w[idx],
                int(self.graph_ids[idx]),
            )
        else:
            return (
                self.X_num[idx],
                None,
                self.y[idx],
                self.w[idx],
                int(self.graph_ids[idx]),
            )


def collate_graph(batch):
    X_num, X_cat, y, w, g_ids = zip(*batch)
    X_num = torch.from_numpy(np.stack(X_num))
    if X_cat[0] is not None:
        X_cat = torch.from_numpy(np.stack(X_cat))
    else:
        X_cat = None
    y = torch.from_numpy(np.stack(y))
    w = torch.from_numpy(np.stack(w))
    g_ids = torch.tensor(g_ids, dtype=torch.long)
    return X_num, X_cat, y, w, g_ids


# ======================== GNN LAYERS =============================
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, alpha=0.2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Linear(in_features, out_features, bias=False)
        self.a = nn.Linear(2 * out_features, 1, bias=False)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, h, adj):
        Wh = self.W(h)  # [N, F_out]
        N = Wh.size(0)

        Wh_i = Wh.unsqueeze(1).repeat(1, N, 1)
        Wh_j = Wh.unsqueeze(0).repeat(N, 1, 1)

        e_input = torch.cat([Wh_i, Wh_j], dim=2)
        e = self.leakyrelu(self.a(e_input).squeeze(2))

        e = e.masked_fill(adj == 0, float("-inf"))
        attention = torch.softmax(e, dim=1)

        h_prime = attention @ Wh
        return h_prime


class MultiHeadGAT(nn.Module):
    def __init__(self, in_features, out_features, num_heads=2, alpha=0.2):
        super().__init__()
        assert out_features % num_heads == 0
        head_dim = out_features // num_heads
        self.heads = nn.ModuleList(
            [GraphAttentionLayer(in_features, head_dim, alpha=alpha) for _ in range(num_heads)]
        )
        self.out_proj = nn.Linear(out_features, out_features)

    def forward(self, h, adj):
        head_outs = [head(h, adj) for head in self.heads]
        h_cat = torch.cat(head_outs, dim=1)
        out = self.out_proj(h_cat)
        return F.relu(out)


class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.W = nn.Linear(in_features, out_features, bias=False)

    def forward(self, h, adj):
        N = adj.size(0)
        I = torch.eye(N, device=adj.device)
        A_hat = adj + I
        deg = A_hat.sum(dim=1)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.0
        D_inv_sqrt = torch.diag(deg_inv_sqrt)
        A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt

        hW = self.W(h)
        out = A_norm @ hW
        return F.relu(out)


class GATGCNModel(nn.Module):
    def __init__(
        self,
        num_num_features: int,
        num_cat_features: int,
        cat_cardinalities,
        hidden_dim: int = 384,
        emb_dim: int = 16,
        gat_heads: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.num_num_features = num_num_features
        self.num_cat_features = num_cat_features

        self.emb_layers = nn.ModuleList()
        total_emb_dim = 0
        if num_cat_features > 0:
            for card in cat_cardinalities:
                dim = min(emb_dim, (card + 1) // 2)
                self.emb_layers.append(nn.Embedding(card, dim))
                total_emb_dim += dim

        self.num_linear = nn.Linear(num_num_features, hidden_dim)
        self.cat_linear = nn.Linear(total_emb_dim, hidden_dim) if total_emb_dim > 0 else None

        self.gat1 = MultiHeadGAT(hidden_dim, hidden_dim, num_heads=gat_heads)
        self.gat2 = MultiHeadGAT(hidden_dim, hidden_dim, num_heads=gat_heads)
        self.gcn1 = GCNLayer(hidden_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)

        self.out_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, x_num, x_cat, adj):
        h_num = self.num_linear(x_num)

        if self.emb_layers and x_cat is not None:
            embs = []
            for i, emb_layer in enumerate(self.emb_layers):
                embs.append(emb_layer(x_cat[:, i]))
            cat_emb = torch.cat(embs, dim=1)
            h_cat = self.cat_linear(cat_emb) if self.cat_linear is not None else 0.0
            h = F.relu(h_num + h_cat)
        else:
            h = F.relu(h_num)

        h = self.gat1(h, adj)
        h = self.dropout(h)
        h = self.gat2(h, adj)
        h = self.dropout(h)
        h = self.gcn1(h, adj)
        h = self.dropout(h)
        h = self.gcn2(h, adj)
        h = self.dropout(h)

        out = self.out_mlp(h)
        return out


def train_gnn_model(train_gnn: pd.DataFrame):
    global GNN_MODEL

    print(" Training GAT+GCN model (dx, dy)...")

    for col in GNN_NUM_FEATS:
        if col not in train_gnn.columns:
            train_gnn[col] = 0.0

    for c in CAT_FEATS:
        cats = CAT_CATEGORY_MAPS[c]
        codes = pd.Categorical(train_gnn[c], categories=cats).codes
        codes = np.where(codes < 0, len(cats), codes)
        train_gnn[f"{c}_cat"] = codes.astype("int64")

    train_gnn_scaled = train_gnn.copy()
    train_gnn_scaled[GNN_NUM_FEATS] = (
        (train_gnn_scaled[GNN_NUM_FEATS] - GNN_NUM_MEAN) / GNN_NUM_STD
    ).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    dataset = GraphDataset(train_gnn_scaled, GNN_NUM_FEATS, GNN_CAT_CODE_COLS)
    loader = DataLoader(
        dataset,
        batch_size=GNN_BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_graph,
        drop_last=False,
    )

    cat_cardinalities = [CAT_CARD_SIZES[c] for c in CAT_FEATS]

    model = GATGCNModel(
        num_num_features=len(GNN_NUM_FEATS),
        num_cat_features=len(GNN_CAT_CODE_COLS),
        cat_cardinalities=cat_cardinalities,
        hidden_dim=GNN_HIDDEN_DIM,
        emb_dim=GNN_EMB_DIM,
        gat_heads=GNN_GAT_HEADS,
        dropout=GNN_DROPOUT,
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=GNN_LR, weight_decay=1e-4)
    loss_fn = nn.MSELoss(reduction="none")

    model.train()
    for epoch in range(GNN_EPOCHS):
        running_loss = 0.0
        n_samples = 0
        for X_num, X_cat, y, w, g_ids in loader:
            X_num = X_num.to(DEVICE)
            X_cat = X_cat.to(DEVICE) if X_cat is not None else None
            y = y.to(DEVICE)
            w = w.to(DEVICE)
            g_ids = g_ids.to(DEVICE)

            adj = (g_ids.unsqueeze(1) == g_ids.unsqueeze(0)).float()

            optimizer.zero_grad()
            pred = model(X_num, X_cat, adj)
            loss_per_coord = loss_fn(pred, y).mean(dim=1)
            loss = (loss_per_coord * w).mean()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)
            n_samples += y.size(0)

        avg_loss = running_loss / max(1, n_samples)
        print(f"  GNN epoch {epoch + 1}/{GNN_EPOCHS}: loss = {avg_loss:.6f}")

    GNN_MODEL = model
    print("✓ GNN model trained")
    return model


# ======================== TUNING & SAVING HELPERS =================
def rmse_xy(x_true, y_true, x_pred, y_pred) -> float:
    """Euclidean RMSE in (x, y)."""
    err = np.sqrt((x_pred - x_true) ** 2 + (y_pred - y_true) ** 2)
    return float(err.mean())


def make_holdout_split(df, val_frac=0.2, random_state=RANDOM_STATE):
    """Simple random holdout split for tuning."""
    val = df.sample(frac=val_frac, random_state=random_state)
    train = df.drop(val.index)
    return train.reset_index(drop=True), val.reset_index(drop=True)


def tune_lgbm_hyperparams(train_tree: pd.DataFrame, max_train_rows: int = 300_000):
    """
    Light but more systematic LGBM tuning on a random holdout of train_tree.
    - Uses the same random-holdout strategy (so comparable to your earlier 0.70 / 0.60 baselines).
    - Explores a slightly richer grid of (num_leaves, learning_rate, n_estimators).
    """
    print("\n[2.5] Light hyperparameter tuning for LGBM...")

    # Subsample for speed
    if len(train_tree) > max_train_rows:
        work_df = train_tree.sample(n=max_train_rows, random_state=RANDOM_STATE).reset_index(drop=True)
        print(f"  Subsampled train_tree: {len(train_tree)} -> {len(work_df)} rows for tuning")
    else:
        work_df = train_tree.reset_index(drop=True)

    # Holdout split (same idea as before, so we can compare RMSEs)
    train_df, val_df = make_holdout_split(work_df, val_frac=0.2, random_state=RANDOM_STATE)

    # Ensure feature / cat columns
    for col in FEATURES:
        if col not in train_df.columns:
            train_df[col] = 0.0
            val_df[col] = 0.0
    for c in CAT_FEATS:
        if c not in train_df.columns:
            train_df[c] = "unknown"
            val_df[c] = "unknown"
        train_df[c] = train_df[c].astype("category")
        val_df[c] = val_df[c].astype("category")

    X_tr = train_df[FEATURES + CAT_FEATS].copy()
    X_va = val_df[FEATURES + CAT_FEATS].copy()

    y_tr_dx = train_df["dx"].values
    y_tr_dy = train_df["dy"].values
    y_va_dx = val_df["dx"].values
    y_va_dy = val_df["dy"].values

    w_tr = (1.0 + TARGET_WEIGHT_TREE * train_df["is_target"].values.astype(np.float32)).astype(np.float32)

    base_params = dict(
        objective="regression",
        boosting_type="gbdt",
        min_data_in_leaf=50,
        feature_fraction=0.9,
        bagging_fraction=0.9,
        bagging_freq=1,
        verbosity=-1,
        random_state=RANDOM_STATE,
    )

    # Includes “base” parameters and some nearby configs
    search_space = [
        # Original-ish baseline-ish config
        {"num_leaves": 63,  "learning_rate": 0.05, "n_estimators": 800},
        # Your previous best
        {"num_leaves": 127, "learning_rate": 0.05, "n_estimators": 1000},
        # Slightly faster LR, fewer trees
        {"num_leaves": 127, "learning_rate": 0.07, "n_estimators": 800},
        # Deeper trees, more capacity
        {"num_leaves": 255, "learning_rate": 0.05, "n_estimators": 1200},
        # Deeper, smaller LR (more boosting steps)
        {"num_leaves": 255, "learning_rate": 0.03, "n_estimators": 1600},
        # Smaller leaves, smaller LR
        {"num_leaves": 63,  "learning_rate": 0.03, "n_estimators": 1400},
    ]

    best_rmse = float("inf")
    best_cfg = None

    x_last_va = val_df["x_last"].values
    y_last_va = val_df["y_last"].values
    x_true_va = val_df["x"].values
    y_true_va = val_df["y"].values

    for cfg in search_space:
        params = dict(base_params)
        params.update(cfg)
        print(f"  Trying LGBM params: {cfg}")

        mdl_dx = LGBMRegressor(**params)
        mdl_dx.fit(X_tr, y_tr_dx, categorical_feature=CAT_FEATS, sample_weight=w_tr)

        mdl_dy = LGBMRegressor(**params)
        mdl_dy.fit(X_tr, y_tr_dy, categorical_feature=CAT_FEATS, sample_weight=w_tr)

        pred_dx = mdl_dx.predict(X_va)
        pred_dy = mdl_dy.predict(X_va)

        x_pred = x_last_va + pred_dx
        y_pred = y_last_va + pred_dy
        rmse_val = rmse_xy(x_true_va, y_true_va, x_pred, y_pred)
        print(f"    -> val RMSE: {rmse_val:.4f}")

        if rmse_val < best_rmse:
            best_rmse = rmse_val
            best_cfg = cfg

    print(f"  Best LGBM params: {best_cfg}, val RMSE: {best_rmse:.4f}")
    return best_cfg



def tune_xgb_hyperparams(train_tree: pd.DataFrame, max_train_rows: int = 300_000):
    """
    Light but more systematic XGBoost tuning on a random holdout of train_tree.
    Uses the same holdout logic for comparability.
    """
    print("\n[2.6] Light hyperparameter tuning for XGBoost...")

    if len(train_tree) > max_train_rows:
        work_df = train_tree.sample(n=max_train_rows, random_state=RANDOM_STATE).reset_index(drop=True)
        print(f"  Subsampled train_tree: {len(train_tree)} -> {len(work_df)} rows for tuning")
    else:
        work_df = train_tree.reset_index(drop=True)

    train_df, val_df = make_holdout_split(work_df, val_frac=0.2, random_state=RANDOM_STATE)

    for col in FEATURES:
        if col not in train_df.columns:
            train_df[col] = 0.0
            val_df[col] = 0.0
    for c in CAT_FEATS:
        if c not in train_df.columns:
            train_df[c] = "unknown"
            val_df[c] = "unknown"
        train_df[c] = train_df[c].astype("category")
        val_df[c] = val_df[c].astype("category")

    num_cols = FEATURES

    X_tr = train_df[FEATURES + CAT_FEATS].copy()
    X_va = val_df[FEATURES + CAT_FEATS].copy()

    # numeric cleaning
    X_tr[num_cols] = X_tr[num_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)
    X_va[num_cols] = X_va[num_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)

    for c in CAT_FEATS:
        X_tr[c] = X_tr[c].astype("category")
        X_va[c] = X_va[c].astype("category")

    y_tr_dx = train_df["dx"].values
    y_tr_dy = train_df["dy"].values
    y_va_dx = val_df["dx"].values
    y_va_dy = val_df["dy"].values

    w_tr = (1.0 + TARGET_WEIGHT_TREE * train_df["is_target"].values.astype(np.float32)).astype(np.float32)

    base_params = dict(
        objective="reg:squarederror",
        tree_method="hist",
        enable_categorical=True,
        n_jobs=-1,
        reg_lambda=1.0,
        random_state=RANDOM_STATE,
    )

    # A slightly richer but still small grid
    search_space = [
        {"max_depth": 6,  "learning_rate": 0.05, "n_estimators": 400, "subsample": 0.9, "colsample_bytree": 0.9},
        {"max_depth": 8,  "learning_rate": 0.05, "n_estimators": 600, "subsample": 0.9, "colsample_bytree": 0.9},
        {"max_depth": 8,  "learning_rate": 0.07, "n_estimators": 500, "subsample": 0.9, "colsample_bytree": 0.9},
        {"max_depth": 10, "learning_rate": 0.05, "n_estimators": 700, "subsample": 0.85, "colsample_bytree": 0.8},
        {"max_depth": 6,  "learning_rate": 0.03, "n_estimators": 800, "subsample": 0.9, "colsample_bytree": 0.9},
    ]

    best_rmse = float("inf")
    best_cfg = None

    x_last_va = val_df["x_last"].values
    y_last_va = val_df["y_last"].values
    x_true_va = val_df["x"].values
    y_true_va = val_df["y"].values

    for cfg in search_space:
        params = dict(base_params)
        params.update(cfg)
        print(f"  Trying XGB params: {cfg}")

        mdl_dx = xgb.XGBRegressor(**params)
        mdl_dx.fit(X_tr, y_tr_dx, sample_weight=w_tr)

        mdl_dy = xgb.XGBRegressor(**params)
        mdl_dy.fit(X_tr, y_tr_dy, sample_weight=w_tr)

        pred_dx = mdl_dx.predict(X_va)
        pred_dy = mdl_dy.predict(X_va)

        x_pred = x_last_va + pred_dx
        y_pred = y_last_va + pred_dy
        rmse_val = rmse_xy(x_true_va, y_true_va, x_pred, y_pred)
        print(f"    -> val RMSE: {rmse_val:.4f}")

        if rmse_val < best_rmse:
            best_rmse = rmse_val
            best_cfg = cfg

    print(f"  Best XGB params: {best_cfg}, val RMSE: {best_rmse:.4f}")
    return best_cfg


def tune_ensemble_weights(train_tree: pd.DataFrame, max_rows: int = 50_000):
    """
    Systematic search for best (w_lgbm, w_xgb, w_gnn) on a subset of train_tree.
    - Computes individual model RMSE first.
    - Uses a grid on the simplex of weights instead of a few arbitrary guesses.
    """
    global ENSEMBLE_WEIGHTS

    print("\n[4/4] Tuning ensemble weights (LGBM / XGB / GNN)...")

    # Subsample for speed
    if len(train_tree) > max_rows:
        work_df = train_tree.sample(n=max_rows, random_state=RANDOM_STATE).reset_index(drop=True)
        print(f"  Subsampled train_tree: {len(train_tree)} -> {len(work_df)} rows for ensemble tuning")
    else:
        work_df = train_tree.reset_index(drop=True)

    # Ensure features + cats
    for col in FEATURES:
        if col not in work_df.columns:
            work_df[col] = 0.0
    for c in CAT_FEATS:
        if c not in work_df.columns:
            work_df[c] = "unknown"
        work_df[c] = work_df[c].astype("category")

    # ----- LGBM predictions -----
    X_tree = work_df[FEATURES + CAT_FEATS].copy()
    for c in CAT_FEATS:
        X_tree[c] = X_tree[c].astype("category")

    if LGBM_MODELS_DX:
        pred_dx_lgb_list = [m.predict(X_tree) for m in LGBM_MODELS_DX]
        pred_dy_lgb_list = [m.predict(X_tree) for m in LGBM_MODELS_DY]
        pred_dx_lgb = np.mean(pred_dx_lgb_list, axis=0)
        pred_dy_lgb = np.mean(pred_dy_lgb_list, axis=0)
    else:
        pred_dx_lgb = np.zeros(len(work_df), dtype=np.float32)
        pred_dy_lgb = np.zeros(len(work_df), dtype=np.float32)

    # ----- XGB predictions -----
    num_cols = FEATURES
    X_xgb = work_df[FEATURES + CAT_FEATS].copy()
    X_xgb[num_cols] = X_xgb[num_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)
    for c in CAT_FEATS:
        X_xgb[c] = X_xgb[c].astype("category")

    if XGB_MODELS_DX:
        pred_dx_xgb_list = [m.predict(X_xgb) for m in XGB_MODELS_DX]
        pred_dy_xgb_list = [m.predict(X_xgb) for m in XGB_MODELS_DY]
        pred_dx_xgb = np.mean(pred_dx_xgb_list, axis=0)
        pred_dy_xgb = np.mean(pred_dy_xgb_list, axis=0)
    else:
        pred_dx_xgb = np.zeros(len(work_df), dtype=np.float32)
        pred_dy_xgb = np.zeros(len(work_df), dtype=np.float32)

    # ----- GNN predictions -----
    if GNN_MODEL is not None and GNN_NUM_MEAN is not None and GNN_NUM_STD is not None:
        df_g = work_df.copy()
        for c in CAT_FEATS:
            cats = CAT_CATEGORY_MAPS[c]
            codes = pd.Categorical(df_g[c], categories=cats).codes
            codes = np.where(codes < 0, len(cats), codes)
            df_g[f"{c}_cat"] = codes.astype("int64")

        for col in GNN_NUM_FEATS:
            if col not in df_g.columns:
                df_g[col] = 0.0

        X_num_df = df_g[GNN_NUM_FEATS].copy()
        X_num_df = (
            (X_num_df - GNN_NUM_MEAN) / GNN_NUM_STD
        ).replace([np.inf, -np.inf], np.nan).fillna(0.0)

        X_num = X_num_df.to_numpy(np.float32)
        X_cat = df_g[GNN_CAT_CODE_COLS].to_numpy(np.int64)

        if "frame_id" in df_g.columns:
            gkeys = (
                df_g["game_id"].astype(str)
                + "_"
                + df_g["play_id"].astype(str)
                + "_"
                + df_g["frame_id"].astype(str)
            )
        else:
            gkeys = (
                df_g["game_id"].astype(str)
                + "_"
                + df_g["play_id"].astype(str)
            )
        g_ids, _ = pd.factorize(gkeys)

        unique_graphs = np.unique(g_ids)
        pred_dx_gnn = np.zeros(len(df_g), dtype=np.float32)
        pred_dy_gnn = np.zeros(len(df_g), dtype=np.float32)

        GNN_MODEL.eval()
        with torch.no_grad():
            for graph_id in unique_graphs:
                mask = g_ids == graph_id
                if mask.sum() == 0:
                    continue

                X_num_batch = torch.from_numpy(X_num[mask]).to(DEVICE)
                X_cat_batch = torch.from_numpy(X_cat[mask]).to(DEVICE)

                n_nodes = X_num_batch.size(0)
                adj_batch = torch.ones(n_nodes, n_nodes, device=DEVICE)  # fully connected

                pred_nn_batch = GNN_MODEL(X_num_batch, X_cat_batch, adj_batch).cpu().numpy()

                pred_dx_gnn[mask] = pred_nn_batch[:, 0]
                pred_dy_gnn[mask] = pred_nn_batch[:, 1]

        print(f"  Processed {len(unique_graphs)} graphs for GNN predictions")
    else:
        pred_dx_gnn = np.zeros(len(work_df), dtype=np.float32)
        pred_dy_gnn = np.zeros(len(work_df), dtype=np.float32)

    x_last = work_df["x_last"].values
    y_last = work_df["y_last"].values
    x_true = work_df["x"].values
    y_true = work_df["y"].values

    # ---- First: individual model RMSEs ----
    def rmse_for(dx, dy, label):
        x_pred = x_last + dx
        y_pred = y_last + dy
        rm = rmse_xy(x_true, y_true, x_pred, y_pred)
        print(f"  Single-model RMSE [{label}] = {rm:.4f}")
        return rm

    has_lgb = LGBM_MODELS_DX is not None and len(LGBM_MODELS_DX) > 0
    has_xgb = XGB_MODELS_DX is not None and len(XGB_MODELS_DX) > 0
    has_gnn = GNN_MODEL is not None

    rm_lgb = rm_xgb = rm_gnn = None

    if has_lgb:
        rm_lgb = rmse_for(pred_dx_lgb, pred_dy_lgb, "LGBM")
    if has_xgb:
        rm_xgb = rmse_for(pred_dx_xgb, pred_dy_xgb, "XGB")
    if has_gnn:
        rm_gnn = rmse_for(pred_dx_gnn, pred_dy_gnn, "GNN")

    # ---- Now: systematic grid search over weights ----
    best_rmse = float("inf")
    best_w = (1.0, 0.0, 0.0)

    if has_lgb and has_xgb and not has_gnn:
        # 2-model simplex: w_l in [0,1] with step 0.05, w_x = 1 - w_l
        print("  Searching weights for 2-model ensemble (LGBM+XGB) with step=0.05 ...")
        for w_l in np.linspace(0.0, 1.0, 21):
            w_x = 1.0 - w_l
            dx = w_l * pred_dx_lgb + w_x * pred_dx_xgb
            dy = w_l * pred_dy_lgb + w_x * pred_dy_xgb
            x_pred = x_last + dx
            y_pred = y_last + dy
            rm = rmse_xy(x_true, y_true, x_pred, y_pred)
            if rm < best_rmse:
                best_rmse = rm
                best_w = (w_l, w_x, 0.0)

    elif has_lgb and has_xgb and has_gnn:
        # 3-model simplex: weights in {0.0, 0.1, ..., 1.0} with w_l + w_x + w_g = 1
        print("  Searching weights for 3-model ensemble (LGBM+XGB+GNN) with step=0.1 ...")
        step = 0.1
        grid_vals = np.arange(0.0, 1.0 + 1e-9, step)
        for w_l in grid_vals:
            for w_x in grid_vals:
                w_g = 1.0 - w_l - w_x
                if w_g < -1e-9 or w_g > 1.0:
                    continue
                # if GNN is very bad, don't waste a lot of time experimenting
                if rm_gnn is not None and rm_lgb is not None and rm_gnn > rm_lgb + 0.5 and w_g > 0.3:
                    continue

                dx = w_l * pred_dx_lgb + w_x * pred_dx_xgb + w_g * pred_dx_gnn
                dy = w_l * pred_dy_lgb + w_x * pred_dy_xgb + w_g * pred_dy_gnn
                x_pred = x_last + dx
                y_pred = y_last + dy
                rm = rmse_xy(x_true, y_true, x_pred, y_pred)
                if rm < best_rmse:
                    best_rmse = rm
                    best_w = (w_l, w_x, w_g)

    elif has_lgb and not has_xgb and has_gnn:
        # LGBM+GNN only
        print("  Searching weights for 2-model ensemble (LGBM+GNN) with step=0.05 ...")
        for w_l in np.linspace(0.0, 1.0, 21):
            w_g = 1.0 - w_l
            dx = w_l * pred_dx_lgb + w_g * pred_dx_gnn
            dy = w_l * pred_dy_lgb + w_g * pred_dy_gnn
            x_pred = x_last + dx
            y_pred = y_last + dy
            rm = rmse_xy(x_true, y_true, x_pred, y_pred)
            if rm < best_rmse:
                best_rmse = rm
                best_w = (w_l, 0.0, w_g)

    else:
        # Fallback: just use LGBM if available, otherwise XGB, otherwise GNN
        print("  Not enough models for ensemble grid search; using best single model.")
        if rm_lgb is not None and (rm_xgb is None or rm_lgb <= rm_xgb) and (rm_gnn is None or rm_lgb <= rm_gnn):
            best_rmse = rm_lgb
            best_w = (1.0, 0.0, 0.0)
        elif rm_xgb is not None and (rm_gnn is None or rm_xgb <= rm_gnn):
            best_rmse = rm_xgb
            best_w = (0.0, 1.0, 0.0)
        elif rm_gnn is not None:
            best_rmse = rm_gnn
            best_w = (0.0, 0.0, 1.0)

    ENSEMBLE_WEIGHTS = {"lgbm": best_w[0], "xgb": best_w[1], "gnn": best_w[2]}
    print(
        f"  Best ensemble weights: LGBM={best_w[0]:.2f}, "
        f"XGB={best_w[1]:.2f}, GNN={best_w[2]:.2f}, RMSE={best_rmse:.4f}"
    )
    return ENSEMBLE_WEIGHTS



def save_trained_models(output_dir: str = "models"):
    """
    Save:
      - all LGBM dx/dy models,
      - all XGB dx/dy models,
      - GNN model state_dict (if any),
      - metadata (features, cats, category maps, GNN stats, ENSEMBLE_WEIGHTS).
    """
    os.makedirs(output_dir, exist_ok=True)
    out_dir = Path(output_dir)

    for i, m in enumerate(LGBM_MODELS_DX):
        joblib.dump(m, out_dir / f"lgbm_dx_{i}.pkl")
    for i, m in enumerate(LGBM_MODELS_DY):
        joblib.dump(m, out_dir / f"lgbm_dy_{i}.pkl")
    print(f"  Saved {len(LGBM_MODELS_DX)} LGBM dx and {len(LGBM_MODELS_DY)} LGBM dy models to {out_dir}")

    for i, m in enumerate(XGB_MODELS_DX):
        joblib.dump(m, out_dir / f"xgb_dx_{i}.pkl")
    for i, m in enumerate(XGB_MODELS_DY):
        joblib.dump(m, out_dir / f"xgb_dy_{i}.pkl")
    print(f"  Saved {len(XGB_MODELS_DX)} XGB dx and {len(XGB_MODELS_DY)} XGB dy models to {out_dir}")

    if GNN_MODEL is not None:
        torch.save(GNN_MODEL.state_dict(), out_dir / "gnn_model.pth")
        print("  Saved GNN model state_dict to gnn_model.pth")

    meta = dict(
        FEATURES=FEATURES,
        CAT_FEATS=CAT_FEATS,
        GNN_NUM_FEATS=GNN_NUM_FEATS,
        CAT_CATEGORY_MAPS=CAT_CATEGORY_MAPS,
        CAT_CARD_SIZES=CAT_CARD_SIZES,
        GNN_NUM_MEAN=GNN_NUM_MEAN.to_dict() if GNN_NUM_MEAN is not None else None,
        GNN_NUM_STD=GNN_NUM_STD.to_dict() if GNN_NUM_STD is not None else None,
        ENSEMBLE_WEIGHTS=ENSEMBLE_WEIGHTS,
        TARGET_WEIGHT_TREE=TARGET_WEIGHT_TREE,
        TARGET_WEIGHT_GNN=TARGET_WEIGHT_GNN,
    )
    joblib.dump(meta, out_dir / "meta.pkl")
    print("  Saved metadata to meta.pkl")


# ======================== MODEL TRAINING ==========================
def train_models(train_tree: pd.DataFrame, train_gnn: pd.DataFrame):
    """
    1) Lightly tune LGBM hyperparams on a random holdout.
    2) Lightly tune XGBoost hyperparams on a random holdout.
    3) Train LGBM ensemble on full train_tree.
    4) Train XGBoost models on full train_tree.
    5) Train GNN on full train_gnn (if TRAIN_GNN=True).
    6) Tune ensemble weights for LGBM/XGB/GNN on a subset of train_tree.
    7) Save all models + metadata to disk.
    """
    global CAT_CATEGORY_MAPS, CAT_CARD_SIZES
    global GNN_NUM_MEAN, GNN_NUM_STD
    global LGBM_MODELS_DX, LGBM_MODELS_DY
    global XGB_MODELS_DX, XGB_MODELS_DY
    global ENSEMBLE_WEIGHTS

    print("\n[3/4] Training models...")

    # Ensure features present
    for col in FEATURES:
        if col not in train_tree.columns:
            train_tree[col] = 0.0
        if col not in train_gnn.columns:
            train_gnn[col] = 0.0
    for col in CAT_FEATS:
        if col not in train_tree.columns:
            train_tree[col] = "unknown"
        if col not in train_gnn.columns:
            train_gnn[col] = "unknown"

    # Categorical encodings for LGBM/XGB and dictionaries for GNN
    for c in CAT_FEATS:
        train_tree[c] = train_tree[c].astype("category")
        cats = list(train_tree[c].cat.categories)
        CAT_CATEGORY_MAPS[c] = cats
        CAT_CARD_SIZES[c] = len(cats) + 1  # +1 for unknown

    # ---- 1) Light LGBM tuning ----
    best_cfg_lgbm = tune_lgbm_hyperparams(train_tree)
    if best_cfg_lgbm is None:
        best_cfg_lgbm = {"num_leaves": 63, "learning_rate": 0.05, "n_estimators": 1200}
        print("  WARNING: LGBM tuning did not return best_cfg; falling back to default:", best_cfg_lgbm)

    # ---- 2) Light XGB tuning ----
    best_cfg_xgb = tune_xgb_hyperparams(train_tree)
    if best_cfg_xgb is None:
        best_cfg_xgb = {
            "max_depth": 8,
            "learning_rate": 0.05,
            "n_estimators": 600,
            "subsample": 0.9,
            "colsample_bytree": 0.9,
        }
        print("  WARNING: XGB tuning did not return best_cfg; falling back to default:", best_cfg_xgb)

    # ---- 3) Train LGBM ensemble on FULL train_tree ----
    X_tree = train_tree[FEATURES + CAT_FEATS].copy()
    for c in CAT_FEATS:
        X_tree[c] = X_tree[c].astype("category")

    y_dx = train_tree["dx"].values
    y_dy = train_tree["dy"].values
    sample_weight = (1.0 + TARGET_WEIGHT_TREE * train_tree["is_target"].values).astype(np.float32)

    base_params_lgbm = dict(
        objective="regression",
        boosting_type="gbdt",
        n_estimators=best_cfg_lgbm["n_estimators"],
        learning_rate=best_cfg_lgbm["learning_rate"],
        num_leaves=best_cfg_lgbm["num_leaves"],
        min_data_in_leaf=50,
        feature_fraction=0.9,
        bagging_fraction=0.9,
        bagging_freq=1,
        verbosity=-1,
    )

    LGBM_MODELS_DX = []
    LGBM_MODELS_DY = []

    print(f"\n  Training LGBM ensemble with tuned params: {base_params_lgbm}")
    for m in range(LGBM_N_MODELS):
        seed = RANDOM_STATE + m
        params = dict(base_params_lgbm)
        params["random_state"] = seed

        print(f"   -> LGBM model {m + 1}/{LGBM_N_MODELS} for dx...")
        model_dx = LGBMRegressor(**params)
        model_dx.fit(X_tree, y_dx, categorical_feature=CAT_FEATS, sample_weight=sample_weight)
        LGBM_MODELS_DX.append(model_dx)

        print(f"   -> LGBM model {m + 1}/{LGBM_N_MODELS} for dy...")
        model_dy = LGBMRegressor(**params)
        model_dy.fit(X_tree, y_dy, categorical_feature=CAT_FEATS, sample_weight=sample_weight)
        LGBM_MODELS_DY.append(model_dy)

    print("✓ LGBM ensemble trained on full dataset")

    # ---- 4) Train XGBoost models on FULL train_tree ----
    num_cols = FEATURES
    X_xgb_full = train_tree[FEATURES + CAT_FEATS].copy()
    X_xgb_full[num_cols] = X_xgb_full[num_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)
    for c in CAT_FEATS:
        X_xgb_full[c] = X_xgb_full[c].astype("category")

    base_params_xgb = dict(
        objective="reg:squarederror",
        tree_method="hist",
        enable_categorical=True,
        n_jobs=-1,
        reg_lambda=1.0,
        random_state=RANDOM_STATE,
    )
    base_params_xgb.update(best_cfg_xgb)

    print(f"\n  Training XGBoost models with tuned params: {base_params_xgb}")
    XGB_MODELS_DX = []
    XGB_MODELS_DY = []

    # one model per target
    print("   -> XGB model for dx...")
    xgb_dx = xgb.XGBRegressor(**base_params_xgb)
    xgb_dx.fit(X_xgb_full, y_dx, sample_weight=sample_weight)
    XGB_MODELS_DX.append(xgb_dx)

    print("   -> XGB model for dy...")
    xgb_dy = xgb.XGBRegressor(**base_params_xgb)
    xgb_dy.fit(X_xgb_full, y_dy, sample_weight=sample_weight)
    XGB_MODELS_DY.append(xgb_dy)

    print("✓ XGBoost models trained on full dataset")

    # ---- 5) GNN normalization stats + training ----
    GNN_NUM_MEAN = train_gnn[GNN_NUM_FEATS].mean()
    GNN_NUM_STD = train_gnn[GNN_NUM_FEATS].std().replace(0, 1.0).fillna(1.0)

    if TRAIN_GNN:
        train_gnn_model(train_gnn)
    else:
        print("  TRAIN_GNN=False -> skipping GNN training")

    # ---- 6) Tune ENSEMBLE_WEIGHTS ----
    tune_ensemble_weights(train_tree)

    # ---- 7) Save all models + metadata ----
    save_trained_models(output_dir="models")


# ======================== INFERENCE HELPERS ======================
def prepare_inference_batch(test_pd: pd.DataFrame, test_input_pd: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare rows for inference:
    - take the last observation from test_input by (game_id, play_id, nfl_id),
    - add the target receiver and pairwise features,
    - merge with the current batch table test (id, game_id, play_id, nfl_id, frame_id),
    - build features as in train.
    """
    last_obs = prepare_last_obs(test_input_pd)
    last_obs = add_target_info(last_obs)

    cols_to_keep_existing = [c for c in BASE_COLS if c in last_obs.columns]

    test_rows = test_pd.merge(
        last_obs[cols_to_keep_existing],
        on=["game_id", "play_id", "nfl_id"],
        how="left",
    )

    test_rows = create_features(test_rows, is_train=False)
    return test_rows


# ======================== PREDICT API (EVAL) =====================

# ======================== MAIN (TRAIN + SERVER) ==================
if __name__ == "__main__":
    # 1) Train on the public train dataset
    df_in, df_out = load_train(DATA_DIR)
    train_tree_df, train_gnn_df = prepare_train(df_in, df_out)
    train_models(train_tree_df, train_gnn_df)

    # 2) Start the inference server only if the module exists
    if HAS_EVAL_SERVER and NFLInferenceServer is not None:
        inference_server = NFLInferenceServer(predict)

        if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
            inference_server.serve()
        else:
            print("\n[LOCAL] Running local gateway to generate submission.csv on public mock test...")
            inference_server.run_local_gateway((DATA_DIR,))
            print("✓ submission.csv should now be created in the working directory")
    else:
        print(
            "\nNFLInferenceServer unavailable (kaggle_evaluation not found). "
            "Locally you can train the model and debug features, "
            "but to generate submission.csv you need to run the code in Kaggle "
            "with the dataset 'nfl-big-data-bowl-2026-prediction' connected in the Data tab."
        )


NFL BIG DATA BOWL 2026 - LGBM + XGB + GNN (rich features, Eval API)
DEVICE: cuda

[1/4] Loading training inputs/outputs...
 Week 01: input (285714, 23), output (32088, 6)
 Week 02: input (288586, 23), output (32180, 6)
 Week 03: input (297757, 23), output (36080, 6)
 Week 04: input (272475, 23), output (30147, 6)
 Week 05: input (254779, 23), output (29319, 6)
 Week 06: input (270676, 23), output (31162, 6)
 Week 07: input (233597, 23), output (27443, 6)
 Week 08: input (281011, 23), output (33017, 6)
 Week 09: input (252796, 23), output (28291, 6)
 Week 10: input (260372, 23), output (29008, 6)
 Week 11: input (243413, 23), output (27623, 6)
 Week 12: input (294940, 23), output (32156, 6)
 Week 13: input (233755, 23), output (29568, 6)
 Week 14: input (279972, 23), output (32873, 6)
 Week 15: input (281820, 23), output (32715, 6)
 Week 16: input (316417, 23), output (36508, 6)
 Week 17: input (277582, 23), output (33076, 6)
 Week 18: input (254917, 23), output (29682, 6)
Train inputs: