# NFL Big Data Bowl 2026 — Spatio‑Temporal GNN (ST‑GNN) Notebook (Explained)

**What you’ll see:**  
- Data loading (train/test)
- Feature building & graph construction (KNN + role‑aware edges + ball node)
- ST‑Encoder (edge‑attention, GRU) and Temporal Decoder
- Training loop with physics‑aware regularization
- Validation, Inference, and `submission.csv` generation

> Tip: If you run this on Kaggle, the notebook will automatically use the Kaggle dataset path when detected.


## 1. Environment, Imports, and Global Config

**EN:** We import all libraries and define the global configuration (`CFG`) as a dataclass. The `_resolve_paths()` function tries multiple data locations (local, env, Kaggle). AMP and multi‑GPU are supported.

**JP:** ライブラリの読み込みと `CFG`（設定）を定義します。`_resolve_paths()` がデータの場所を自動検出します。AMP/複数GPU対応です。

In [None]:

# ============================================================
# NFL Big Data Bowl 2026 - ST-GNN 改訂版 (AMP安全/全体書き換え)
# - 直近Lフレームの時空間要約 + Edge Attention (FP32計算でAMP衝突回避)
# - Ballノード注入 / 物理残差補助 / 終端重み / 速度違反ペナルティ
# - 学習→検証→推論→submission まで一気通貫
# ============================================================

import os, gc, math, warnings, random
from pathlib import Path
from glob import glob
from dataclasses import dataclass
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore")

# ------------------
# Config
# ------------------
@dataclass
class CFG:
    SEED: int = 42
    USE_CUDA: bool = torch.cuda.is_available()
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    TRAIN_WEEKS: tuple = tuple(range(1, 19))
    L_HIST: int = 5
    K_NEIGHBORS: int = 6
    K_ROLE: int = 2
    BATCH_PLAYS: int = 8
    CAP_T: int = 60
    FPS: float = 10.0

    FIELD_X_MIN: float = 0.0
    FIELD_X_MAX: float = 120.0
    FIELD_Y_MIN: float = 0.0
    FIELD_Y_MAX: float = 53.3
    MAX_SPEED: float = 12.0

    NODE_HID: int = 128
    EDGE_HID: int = 64
    MP_LAYERS: int = 2
    DEC_HID: int = 128
    ROLE_EMB: int = 8
    SIDE_EMB: int = 4
    BALL_EDGE_WEIGHT: float = 1.0

    EPOCHS: int = 8
    LR: float = 7e-4
    WEIGHT_DECAY: float = 1e-4
    GRAD_CLIP: float|None = 0.5
    AMP: bool = True
    ACCUM_STEPS: int = 4

    END_WEIGHTED: bool = True
    PHYS_RESID_ALPHA: float = 0.2
    SPEED_PENALTY: float = 0.01

    SMOOTH_SIGMA: float = 0.6

    BASE_DIR: Path = Path.cwd()
    TRAIN_DIR: Path|None = Path(os.getenv("NFL_TRAIN_DIR", "")) if os.getenv("NFL_TRAIN_DIR") else None
    TEST_DIR: Path|None  = Path(os.getenv("NFL_TEST_DIR",  "")) if os.getenv("NFL_TEST_DIR")  else None
    DATA_DIR: Path|None  = Path(os.getenv("NFL_DATA_DIR",  "")) if os.getenv("NFL_DATA_DIR")  else None
    OUT_DIR: Path = Path(os.getenv("NFL_OUT_DIR", "")) if os.getenv("NFL_OUT_DIR") else (Path.cwd() / "runs")

CFG = CFG()

def _resolve_paths():
    CFG.OUT_DIR.mkdir(parents=True, exist_ok=True)
    if CFG.TRAIN_DIR and CFG.TEST_DIR and CFG.TRAIN_DIR.exists() and CFG.TEST_DIR.exists():
        print(f"[Path] Using TRAIN_DIR={CFG.TRAIN_DIR}, TEST_DIR={CFG.TEST_DIR}"); return
    lt, ls = CFG.BASE_DIR/"train", CFG.BASE_DIR/"test"
    if lt.exists() and ls.exists():
        CFG.TRAIN_DIR, CFG.TEST_DIR = lt, ls
        print("[Path] Using ./train and ./test"); return
    kag = Path("/kaggle/input/nfl-big-data-bowl-2026-prediction")
    if kag.exists():
        CFG.DATA_DIR = kag; CFG.TRAIN_DIR = kag/"train"; CFG.TEST_DIR = kag
        print(f"[Path] Using Kaggle dataset under {kag}"); return
    raise FileNotFoundError("データが見つかりません。env or ./train ./test or Kaggle path を用意してください。")

_resolve_paths()
random.seed(CFG.SEED); np.random.seed(CFG.SEED); torch.manual_seed(CFG.SEED)
if CFG.USE_CUDA: torch.cuda.manual_seed_all(CFG.SEED)


## 2. Small Utilities and Physics Helpers

**EN:** Utility helpers for unit conversion, optional field normalization, field boundary clipping, and simple temporal positional encoding.

**JP:** 単位変換・フィールド境界クリップ・時間PEなど、補助的な関数群です。

In [None]:

# ------------------
# Utils / Physics
# ------------------
def height_ftin_to_in(h):
    if isinstance(h, str) and "-" in h:
        try:
            ft, inch = h.split("-"); return int(ft)*12 + int(inch)
        except: return 70
    return 70

def maybe_flip_direction(df):  # 必要なら左右正規化を実装
    return df

def clip_to_field(xy):
    return torch.stack([
        xy[...,0].clamp(CFG.FIELD_X_MIN, CFG.FIELD_X_MAX),
        xy[...,1].clamp(CFG.FIELD_Y_MIN, CFG.FIELD_Y_MAX)
    ], dim=-1)

def time_pos_encoding(t_norm):
    t = t_norm.unsqueeze(-1)
    return torch.cat([
        torch.sin(2*math.pi*t), torch.cos(2*math.pi*t),
        torch.sin(4*math.pi*t), torch.cos(4*math.pi*t)
    ], dim=-1)


## 3. CSV Loading (Train/Test)

**EN:** We search for weekly `input_*.csv`/`output_*.csv` files; fall back to globs if the week‑pattern is not present. For test, we look for `test_input.csv`, `test.csv`, and `sample_submission.csv`.

**JP:** 週別の `input_*.csv` / `output_*.csv` を優先し、なければグロブで探索。テストは `test_input.csv`・`test.csv`・`sample_submission.csv` を読み込みます。

In [None]:

# ------------------
# Loaders
# ------------------
def _week_candidates(train_dir, w):
    return [
        train_dir / f"input_2023_w{w:02d}.csv",
        train_dir / f"input_w{w:02d}.csv",
        train_dir / f"input_week{w:02d}.csv",
    ], [
        train_dir / f"output_2023_w{w:02d}.csv",
        train_dir / f"output_w{w:02d}.csv",
        train_dir / f"output_week{w:02d}.csv",
    ]

def load_train_input_output():
    tdir = CFG.TRAIN_DIR
    ins, outs = [], []
    for w in CFG.TRAIN_WEEKS:
        a, b = _week_candidates(tdir, w); ins += [p for p in a if p.exists()]; outs += [p for p in b if p.exists()]
    if not ins:  ins  = [Path(p) for p in sorted(glob(str(tdir / "input_*.csv")))]
    if not outs: outs = [Path(p) for p in sorted(glob(str(tdir / "output_*.csv")))]
    if not ins or not outs: raise FileNotFoundError("[Load] input/output CSV not found")
    print(f"[Load] input={len(ins)} files, output={len(outs)} files")
    tin  = pd.concat([pd.read_csv(p) for p in tqdm(ins, desc="read input")], ignore_index=True)
    tout = pd.concat([pd.read_csv(p) for p in tqdm(outs, desc="read output")], ignore_index=True)
    return tin, tout

def load_test_input_and_templates():
    tdir = CFG.TEST_DIR
    def _first_exists(cands): 
        for c in cands:
            if c.exists(): return pd.read_csv(c)
        return None
    test_in = _first_exists([tdir/"test_input.csv", CFG.BASE_DIR/"test_input.csv"])
    test_tpl= _first_exists([tdir/"test.csv", CFG.BASE_DIR/"test.csv"])
    sub     = _first_exists([tdir/"sample_submission.csv", CFG.BASE_DIR/"sample_submission.csv"])
    if test_in is None or test_tpl is None or sub is None:
        raise FileNotFoundError("[Load] test_input/test.csv/sample_submission.csv not found")
    return test_in, test_tpl, sub


## 4. Graph Construction (KNN, Role‑aware Edges, Ball Edges)

**EN:** We form:  
- **KNN** edges among all players  
- **Role/side‑aware** cross‑team edges (top‑k per player)  
- **Ball node edges** to/from all players

Edge features include relative positions, velocity diffs, distances, and simple team indicators.

**JP:** KNNエッジ、役割/サイド跨ぎのエッジ、ボールとの全双方向エッジを構築します。エッジ特徴量は相対位置/速度/距離/チーム指標などを含みます。

In [None]:

# ------------------
# Graph utils
# ------------------
def knn_edges(xy, k):
    N = xy.size(0)
    if N <= 1:
        dev = xy.device
        return (torch.empty(2,0,dtype=torch.long,device=dev),
                torch.empty(0,dtype=xy.dtype,device=dev))
    k_eff = min(k, N-1)
    d = torch.cdist(xy, xy, p=2)
    idx = torch.topk(-d, k_eff+1, dim=1).indices[:,1:]
    rows = torch.arange(N, device=xy.device).unsqueeze(1).expand(N, k_eff)
    eidx = torch.stack([rows.reshape(-1), idx.reshape(-1)], dim=0)
    dist = d[rows, idx].reshape(-1)
    return eidx, dist

def role_knn_edges(xy, roles, sides, k_role):
    N = xy.size(0); dev=xy.device; dt=xy.dtype
    if N <= 1:
        return (torch.empty(2,0,dtype=torch.long,device=dev),
                torch.empty(0,dtype=dt,device=dev))
    e_src=[]; e_dst=[]; e_d=[]
    for s in [0,1]:
        src_idx = (sides==s).nonzero(as_tuple=False).squeeze(-1)
        dst_idx = (sides!=s).nonzero(as_tuple=False).squeeze(-1)
        if src_idx.numel()==0 or dst_idx.numel()==0: continue
        xy_src = xy[src_idx]; xy_dst=xy[dst_idx]
        d = torch.cdist(xy_src, xy_dst, p=2)
        k_eff = min(k_role, xy_dst.size(0))
        idx = torch.topk(-d, k_eff, dim=1).indices
        e_src.append(src_idx.unsqueeze(1).expand_as(idx).reshape(-1))
        e_dst.append(dst_idx[idx].reshape(-1))
        e_d.append(d.gather(1, idx).reshape(-1))
    if not e_src:
        return (torch.empty(2,0,dtype=torch.long,device=dev),
                torch.empty(0,dtype=dt,device=dev))
    src = torch.cat(e_src); dst=torch.cat(e_dst); dist=torch.cat(e_d)
    return torch.stack([src,dst],dim=0), dist

def build_ball_edges(N, ball_idx, weight=1.0, device="cpu", dtype=torch.float32):
    players = torch.arange(N, device=device)
    players = players[players != ball_idx]
    src1 = torch.full_like(players, ball_idx)
    dst1 = players.clone()
    src2 = players.clone()
    dst2 = torch.full_like(players, ball_idx)
    edge_index = torch.cat([torch.stack([src1,dst1],dim=0), torch.stack([src2,dst2],dim=0)], dim=1)
    edge_w = torch.full((edge_index.size(1),), float(weight), device=device, dtype=dtype)
    return edge_index, edge_w


## 5. Dataset and Collate

**EN:** `STPlayDataset` collects the last **L** frames per player and computes static per‑player features at the final observed frame. When training labels are present, it builds the (N × T × 2) future displacement tensor and mask.

**JP:** 最終観測フレームに対する特徴を作成し、学習時は将来位置差分 `(N, T, 2)` とマスクを組み立てます。

In [None]:

# ------------------
# Dataset
# ------------------
class STPlayDataset(Dataset):
    def __init__(self, train_input, train_output, L=5, for_train=True):
        self.L = L
        self.for_train = for_train
        df = maybe_flip_direction(train_input.copy())
        df["height_inches"] = df["player_height"].apply(height_ftin_to_in)
        self.keys = df.groupby(["game_id","play_id"]).size().reset_index()[["game_id","play_id"]].values.tolist()
        self.tin = df.sort_values(["game_id","play_id","nfl_id","frame_id"]).reset_index(drop=True)
        self.tout = train_output.sort_values(["game_id","play_id","nfl_id","frame_id"]).reset_index(drop=True) if for_train else None
        self.role2id = {k:i for i,k in enumerate(sorted(self.tin["player_role"].fillna("Unknown").unique()))}
        self.side2id = {k:i for i,k in enumerate(sorted(self.tin["player_side"].fillna("Unknown").unique()))}
        self.role_dim = len(self.role2id); self.side_dim = len(self.side2id)

    def __len__(self): return len(self.keys)

    def _stack_L(self, g):
        g = g.sort_values("frame_id")
        def get(col): return g[col].values if col in g.columns else np.zeros(len(g), np.float32)
        x,y = get("x"),get("y")
        s = get("s"); a=get("a")
        dir_deg = get("dir")
        vx,vy = s*np.cos(np.deg2rad(dir_deg)), s*np.sin(np.deg2rad(dir_deg))
        arr = np.stack([x[-self.L:], y[-self.L:], vx[-self.L:], vy[-self.L:], s[-self.L:], a[-self.L:]], 0)
        if arr.shape[1] < self.L:
            pad = np.zeros((arr.shape[0], self.L-arr.shape[1]), np.float32)
            arr = np.concatenate([pad, arr], 1)
        return arr.astype(np.float32)

    def __getitem__(self, idx):
        gid, pid = self.keys[idx]
        tin_gp = self.tin[(self.tin.game_id==gid)&(self.tin.play_id==pid)]
        last = tin_gp.groupby("nfl_id", as_index=False).last()
        nfl_ids = last["nfl_id"].values; N=len(nfl_ids)

        seq = [self._stack_L(tin_gp[tin_gp.nfl_id==nid]) for nid in nfl_ids]
        seq = np.stack(seq, 0)  # (N,6,L)

        final_xy = last[["x","y"]].values.astype(np.float32)
        s = last["s"].fillna(0).values.astype(np.float32) if "s" in last else np.zeros(N,np.float32)
        a = last["a"].fillna(0).values.astype(np.float32) if "a" in last else np.zeros(N,np.float32)
        dir_deg = last["dir"].fillna(0).values.astype(np.float32) if "dir" in last else np.zeros(N,np.float32)
        o_deg   = last["o"].fillna(0).values.astype(np.float32) if "o" in last else np.zeros(N,np.float32)
        ds, dc = np.sin(np.deg2rad(dir_deg)), np.cos(np.deg2rad(dir_deg))
        os_, oc= np.sin(np.deg2rad(o_deg)),   np.cos(np.deg2rad(o_deg))
        height = last["height_inches"].values.astype(np.float32)
        weight = last["player_weight"].fillna(200).values.astype(np.float32)
        bmi = (weight/(np.clip(height,1,300)**2))*703.0
        yardln = last["absolute_yardline_number"].fillna(50).values.astype(np.float32) if "absolute_yardline_number" in last else np.full(N,50.0,np.float32)
        ball_xy = last[["ball_land_x","ball_land_y"]].fillna(0.0).values.astype(np.float32) if {"ball_land_x","ball_land_y"}.issubset(last.columns) else np.tile(final_xy.mean(0,keepdims=True),(N,1)).astype(np.float32)

        role_id = np.array([self.role2id.get(v if isinstance(v,str) else "Unknown",0) for v in last.get("player_role", pd.Series(["Unknown"]*N))], np.int64)
        side_id = np.array([self.side2id.get(v if isinstance(v,str) else "Unknown",0) for v in last.get("player_side", pd.Series(["Unknown"]*N))], np.int64)

        out={}
        if self.for_train:
            tout = self.tout[(self.tout.game_id==gid)&(self.tout.play_id==pid)].sort_values(["nfl_id","frame_id"])
            groups = tout.groupby("nfl_id")
            per=[]; Tmax=0
            for nid in nfl_ids:
                g = groups.get_group(nid) if nid in groups.groups else None
                if g is None: per.append(np.zeros((0,2),np.float32)); continue
                arr = g[["x","y"]].values.astype(np.float32)
                dxy = arr - final_xy[nfl_ids==nid][0][None,:]
                per.append(dxy); Tmax=max(Tmax,dxy.shape[0])
            if CFG.CAP_T and Tmax>CFG.CAP_T: Tmax=CFG.CAP_T
            tgt=np.zeros((N,Tmax,2),np.float32); msk=np.zeros((N,Tmax),np.float32)
            for i,d in enumerate(per):
                t=min(len(d),Tmax)
                if t>0: tgt[i,:t,:]=d[:t]; msk[i,:t]=1.0
            out["target_dxy"]=torch.from_numpy(tgt); out["tgt_mask"]=torch.from_numpy(msk); out["T_max"]=Tmax

        batch = {
            "game_id": int(gid), "play_id": int(pid), "nfl_ids": nfl_ids,
            "seq": torch.from_numpy(seq),
            "final_xy": torch.from_numpy(final_xy),
            "s": torch.from_numpy(s), "a": torch.from_numpy(a),
            "dir_sin": torch.from_numpy(ds), "dir_cos": torch.from_numpy(dc),
            "o_sin": torch.from_numpy(os_), "o_cos": torch.from_numpy(oc),
            "height": torch.from_numpy(height), "weight": torch.from_numpy(weight), "bmi": torch.from_numpy(bmi),
            "yardln": torch.from_numpy(yardln),
            "ball_xy": torch.from_numpy(ball_xy),
            "role_id": torch.from_numpy(role_id), "side_id": torch.from_numpy(side_id),
        }
        batch.update(out); return batch

def st_collate_fn(blist):
    keys = ["seq","final_xy","s","a","dir_sin","dir_cos","o_sin","o_cos","height","weight","bmi","yardln","ball_xy","role_id","side_id"]
    out={k:torch.cat([b[k] for b in blist],0) for k in keys}
    out["metas"]=[(b["game_id"],b["play_id"],b["nfl_ids"]) for b in blist]
    if "target_dxy" in blist[0]:
        Tm=max(b["T_max"] for b in blist); tgt=[]; msk=[]
        for b in blist:
            if b["target_dxy"].shape[1]<Tm:
                pad=(0,0,0,Tm-b["target_dxy"].shape[1]); tgt.append(F.pad(b["target_dxy"],pad)); msk.append(F.pad(b["tgt_mask"],(0,Tm-b["tgt_mask"].shape[1])))
            else: tgt.append(b["target_dxy"]); msk.append(b["tgt_mask"])
        out["target_dxy"]=torch.cat(tgt,0); out["tgt_mask"]=torch.cat(msk,0); out["T_max"]=Tm
    return out


## 6. Model: Edge‑Attention Encoder and Temporal Decoder

**EN:** The encoder:  
- Embeds role/side  
- Applies **EdgeAttentionLayer** (GATv2‑like) with FP32 inside to avoid AMP instabilities  
- Aggregates **L** steps via a GRU  

The decoder:  
- Conditions on encoder’s last state and predicts future `(x,y)` displacements

**JP:** エンコーダは役割/サイド埋め込み→エッジ注意→GRUで時系列要約。デコーダは将来の `(x,y)` 差分を生成します。

In [None]:

# ------------------
# Model
# ------------------
class EdgeAttentionLayer(nn.Module):
    """GATv2風：このレイヤ内部は常にFP32で計算してAMP衝突を回避"""
    def __init__(self, node_dim, edge_in, edge_hid):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_in, edge_hid), nn.ReLU(),
            nn.Linear(edge_hid, edge_hid), nn.ReLU()
        )
        self.msg_mlp = nn.Sequential(
            nn.Linear(node_dim*2 + edge_hid, node_dim), nn.ReLU(),
            nn.Linear(node_dim, node_dim)
        )
        self.attn = nn.Linear(node_dim, 1)

    def forward(self, h, edge_index, edge_feat):
        src, dst = edge_index
        # --- ここからはFP32固定 ---
        with torch.cuda.amp.autocast(False):
            h32 = h.float()
            e32 = edge_feat.float()
            e = self.edge_mlp(e32)                          # (E, Eh)
            pair = torch.cat([h32[src], h32[dst], e], -1)   # (E, 2H+Eh)
            m_ij = self.msg_mlp(pair)                       # (E, H)
            score = self.attn(torch.tanh(m_ij))             # (E, 1)

            N = h32.size(0)
            max_per_dst = torch.full((N,1), -float('inf'), dtype=torch.float32, device=h.device)
            max_per_dst.index_reduce_(0, dst, score, reduce="amax")

            norm = torch.exp(score - max_per_dst[dst])
            sum_per_dst = torch.zeros(N,1, dtype=torch.float32, device=h.device)
            sum_per_dst.index_add_(0, dst, norm)
            alpha = norm / (sum_per_dst[dst] + 1e-9)        # (E,1)

            agg32 = torch.zeros(N, h32.size(1), dtype=torch.float32, device=h.device)
            agg32.index_add_(0, dst, alpha * m_ij)
        # 戻りは元のdtypeに合わせる
        return agg32.to(h.dtype)

class STEncoder(nn.Module):
    def __init__(self, node_in, role_dim, side_dim, hid, edge_in, edge_hid, L):
        super().__init__()
        self.emb_role = nn.Embedding(role_dim, CFG.ROLE_EMB)
        self.emb_side = nn.Embedding(side_dim, CFG.SIDE_EMB)
        self.pre = nn.Sequential(
            nn.Linear(node_in + CFG.ROLE_EMB + CFG.SIDE_EMB, hid),
            nn.ReLU(), nn.Linear(hid, hid), nn.ReLU()
        )
        self.layers = nn.ModuleList([EdgeAttentionLayer(hid, edge_in, edge_hid) for _ in range(CFG.MP_LAYERS)])
        self.gru = nn.GRU(input_size=hid, hidden_size=hid, num_layers=1, batch_first=True)
        self.L = L

    def forward(self, node_feat_t_list, role_id, side_id, edges_t_list, efeat_t_list):
        e_role = self.emb_role(role_id); e_side = self.emb_side(side_id)
        hs=[]
        for t in range(self.L):
            x = torch.cat([node_feat_t_list[t], e_role, e_side], dim=-1)
            h = self.pre(x)
            for layer in self.layers:
                h = h + layer(h, edges_t_list[t], efeat_t_list[t])  # residual
            hs.append(h.unsqueeze(1))
        seq = torch.cat(hs, dim=1)         # (N, L, H)
        out, _ = self.gru(seq)             # (N, L, H)
        return out[:, -1, :], out

class TemporalDecoder(nn.Module):
    def __init__(self, hid, dec_hid):
        super().__init__()
        self.h0 = nn.Linear(hid, dec_hid)
        self.gru = nn.GRU(input_size=hid, hidden_size=dec_hid, num_layers=1, batch_first=True)
        self.head = nn.Sequential(nn.Linear(dec_hid, dec_hid), nn.ReLU(), nn.Linear(dec_hid, 2))
    def forward(self, cond_seq, H_last):
        h0 = self.h0(H_last).unsqueeze(0)
        out, _ = self.gru(cond_seq, h0)
        return self.head(out)

class STGNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 直近Lフレームの粗い運動埋め込み (6,L)->16
        self.seq_conv = nn.Sequential(
            nn.Conv1d(6, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        # Edge feature 次元（build_edges_one_tと合わせる）
        self.edge_in = 10
        self.encoder=None; self.decoder=None  # 遅延初期化

    def build_node_feat(self, batch):
        seq = batch["seq"].to(CFG.DEVICE)                 # (N,6,L)
        emb = self.seq_conv(seq).squeeze(-1)              # (N,16)
        final_xy = batch["final_xy"].to(CFG.DEVICE)
        s = batch["s"].to(CFG.DEVICE).unsqueeze(-1)
        a = batch["a"].to(CFG.DEVICE).unsqueeze(-1)
        ds = batch["dir_sin"].to(CFG.DEVICE).unsqueeze(-1)
        dc = batch["dir_cos"].to(CFG.DEVICE).unsqueeze(-1)
        os_ = batch["o_sin"].to(CFG.DEVICE).unsqueeze(-1)
        oc = batch["o_cos"].to(CFG.DEVICE).unsqueeze(-1)
        height = batch["height"].to(CFG.DEVICE).unsqueeze(-1)
        weight = batch["weight"].to(CFG.DEVICE).unsqueeze(-1)
        bmi    = batch["bmi"].to(CFG.DEVICE).unsqueeze(-1)
        yardln = batch["yardln"].to(CFG.DEVICE).unsqueeze(-1)
        ball_xy= batch["ball_xy"].to(CFG.DEVICE)

        ball_vec = ball_xy - final_xy
        ball_dist = ball_vec.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        ball_dir = ball_vec / ball_dist
        bdx, bdy = ball_vec[...,0:1], ball_vec[...,1:2]
        bds, bdc = ball_dir[...,0:1], ball_dir[...,1:2]

        # + 時間PE（末時刻tのPEを定数として結合）
        pe = time_pos_encoding(torch.tensor([1.0], device=final_xy.device)).repeat(final_xy.size(0),1)  # (N,8)

        node = torch.cat([
            final_xy, s, a, ds, dc, os_, oc, height, weight, bmi, yardln,
            bdx, bdy, ball_dist, bds, bdc, emb, pe
        ], dim=-1)
        return node  # (N, Din)

    def build_edges_one_t(self, final_xy, s, dir_sin, dir_cos, side_id, role_id):
        e_knn, d_knn = knn_edges(final_xy, CFG.K_NEIGHBORS)
        e_role, d_role = role_knn_edges(final_xy, role_id, side_id, CFG.K_ROLE)
        if e_role.numel()>0:
            eidx = torch.cat([e_knn, e_role], dim=1); dist = torch.cat([d_knn, d_role], dim=0)
        else:
            eidx, dist = e_knn, d_knn

        vx = s*dir_cos; vy = s*dir_sin
        src, dst = eidx
        dx = (final_xy[src,0] - final_xy[dst,0]).unsqueeze(-1)
        dy = (final_xy[src,1] - final_xy[dst,1]).unsqueeze(-1)
        dvx = (vx[src] - vx[dst]).unsqueeze(-1)
        dvy = (vy[src] - vy[dst]).unsqueeze(-1)
        distv = dist.unsqueeze(-1).clamp(min=1e-6)

        rel = torch.cat([dx,dy], -1)
        rel_n = rel / (rel.norm(dim=-1, keepdim=True).clamp(min=1e-6))
        vdst = torch.cat([vx[dst].unsqueeze(-1), vy[dst].unsqueeze(-1)], -1)
        vdst_n = vdst / (vdst.norm(dim=-1, keepdim=True).clamp(min=1e-6))
        cos_pos = (rel_n * vdst_n).sum(-1, keepdim=True)

        vsrc = torch.cat([vx[src].unsqueeze(-1), vy[src].unsqueeze(-1)], -1)
        vsrc_n = vsrc / (vsrc.norm(dim=-1, keepdim=True).clamp(min=1e-6))
        cos_vel = (vsrc_n * vdst_n).sum(-1, keepdim=True)

        same_team = (side_id[src]==side_id[dst]).float().unsqueeze(-1)
        off_to_def = ((side_id[src]==1)&(side_id[dst]==0)).float().unsqueeze(-1)
        def_to_off = ((side_id[src]==0)&(side_id[dst]==1)).float().unsqueeze(-1)

        efeat = torch.cat([dx,dy,dvx,dvy,distv,cos_pos,cos_vel,same_team,off_to_def,def_to_off], -1)
        return eidx.to(CFG.DEVICE), efeat.to(CFG.DEVICE)

    def add_ball_node(self, node, final_xy):
        ball_xy = final_xy.mean(dim=0, keepdim=True)
        ball_feat = torch.zeros_like(node[:1])
        ball_feat[:, :2] = ball_xy
        node2 = torch.cat([node, ball_feat], dim=0)
        ball_idx = node2.size(0)-1
        return node2, ball_idx

    def forward(self, batch, T):
        role_id = batch["role_id"].to(CFG.DEVICE)
        side_id = batch["side_id"].to(CFG.DEVICE)
        dir_sin = batch["dir_sin"].to(CFG.DEVICE)
        dir_cos = batch["dir_cos"].to(CFG.DEVICE)
        s = batch["s"].to(CFG.DEVICE)
        final_xy = batch["final_xy"].to(CFG.DEVICE)

        node = self.build_node_feat(batch)              # (N, Din)
        Din = node.size(-1)

        eidx, efeat = self.build_edges_one_t(final_xy, s, dir_sin, dir_cos, side_id, role_id)

        node_ball, ball_idx = self.add_ball_node(node, final_xy)   # (N+1, Din)
        e_ball, _ = build_ball_edges(node_ball.size(0), ball_idx, weight=CFG.BALL_EDGE_WEIGHT,
                                     device=node_ball.device, dtype=node_ball.dtype)
        efeat_ball = torch.zeros((e_ball.size(1), self.edge_in), device=node_ball.device, dtype=node_ball.dtype)

        eidx = torch.cat([eidx, e_ball], dim=1)
        efeat = torch.cat([efeat, efeat_ball], dim=0)

        # Lスライスに同一ノード（簡易）。必要なら各tで再構築に拡張可
        node_t_list = [node_ball for _ in range(CFG.L_HIST)]
        edges_t_list = [eidx for _ in range(CFG.L_HIST)]
        efeat_t_list = [efeat for _ in range(CFG.L_HIST)]

        # 遅延初期化（role/sideの語彙数は+1でballを含める）
        if self.encoder is None:
            node_in = Din
            self.encoder = STEncoder(node_in=node_in,
                                     role_dim=int(role_id.max().item())+2,
                                     side_dim=int(side_id.max().item())+2,
                                     hid=CFG.NODE_HID, edge_in=self.edge_in, edge_hid=CFG.EDGE_HID, L=CFG.L_HIST).to(CFG.DEVICE)
            self.decoder = TemporalDecoder(hid=CFG.NODE_HID, dec_hid=CFG.DEC_HID).to(CFG.DEVICE)

        # ballノードのIDをrole/sideに追加（0番をball想定）
        role_ext = torch.cat([role_id, torch.zeros(1, device=role_id.device, dtype=role_id.dtype)], dim=0)
        side_ext = torch.cat([side_id, torch.zeros(1, device=side_id.device, dtype=side_id.dtype)], dim=0)

        H_last, _ = self.encoder(node_t_list, role_ext, side_ext, edges_t_list, efeat_t_list)

        tt = torch.arange(1, T+1, device=H_last.device, dtype=H_last.dtype)
        pe = time_pos_encoding(tt/T).unsqueeze(0).repeat(H_last.size(0),1,1)   # (N+1,T,8)
        node_cond = H_last.unsqueeze(1).repeat(1, T, 1)                         # (N+1,T,H)
        cond_seq = node_cond  # 必要に応じて pe を結合: torch.cat([node_cond, pe], -1)

        dxy = self.decoder(cond_seq, H_last)  # (N+1,T,2)
        dxy = dxy[:-1]  # drop ball
        return dxy


## 7. Losses and Training / Evaluation Loops

**EN:**  
- Base loss: MSE on future displacements (optionally end‑weighted).  
- Physics residual (optional): compares against linear motion from current velocity.  
- Speed penalty (optional): penalizes speeds above `MAX_SPEED`.

**JP:**  
- 基本損失は将来差分のMSE（終端重み付け可）  
- 物理残差（線形速度との差）と速度違反ペナルティを加算可能  


In [None]:

# ------------------
# Loss & Train/Eval
# ------------------
def end_weights(T, dtype=torch.float32):
    w = torch.arange(1, T+1, device=CFG.DEVICE, dtype=dtype) / T
    return w / w.mean()

def velocity_clip(pred_dxy, fps, max_speed):
    v = pred_dxy.norm(dim=-1) * fps
    over = (v - max_speed).clamp(min=0.0)
    return over.mean()

def train_one_epoch(model, loader, optim, scaler):
    model.train(); total=0.0
    optim.zero_grad(set_to_none=True)
    for step, batch in enumerate(tqdm(loader, leave=False), start=1):
        for k in ["seq","final_xy","s","a","dir_sin","dir_cos","o_sin","o_cos",
                  "height","weight","bmi","yardln","ball_xy","role_id","side_id",
                  "target_dxy","tgt_mask"]:
            batch[k] = batch[k].to(CFG.DEVICE, non_blocking=True)
        T = batch["T_max"]
        if CFG.CAP_T and T>CFG.CAP_T:
            batch["target_dxy"]=batch["target_dxy"][:,:CFG.CAP_T,...]
            batch["tgt_mask"]=batch["tgt_mask"][:,:CFG.CAP_T,...]
            T = CFG.CAP_T

        with torch.cuda.amp.autocast(enabled=CFG.AMP):
            pred = model(batch, T)                    # (N,T,2)
            msk = batch["tgt_mask"].unsqueeze(-1)    # (N,T,1)

            diff = (pred - batch["target_dxy"]) * msk
            core = (diff**2).sum() / (msk.sum()*2 + 1e-6)
            if CFG.END_WEIGHTED:
                w = end_weights(T, dtype=pred.dtype).view(1,-1,1)
                core = ((diff**2)*w).sum() / (msk.sum()*2 + 1e-6)

            if CFG.PHYS_RESID_ALPHA>0:
                vx = batch["s"]*batch["dir_cos"]; vy=batch["s"]*batch["dir_sin"]
                vxy = torch.stack([vx,vy], dim=-1)           # (N,2)
                ts = torch.arange(1, T+1, device=pred.device, dtype=pred.dtype)
                base = vxy.unsqueeze(1) * (ts.view(1,-1,1)/CFG.FPS)
                resid = (((pred - base) - (batch["target_dxy"] - base))*msk)**2
                resid = resid.sum() / (msk.sum()*2 + 1e-6)
                core = core + CFG.PHYS_RESID_ALPHA*resid

            if CFG.SPEED_PENALTY>0:
                pen = velocity_clip(pred, fps=CFG.FPS, max_speed=CFG.MAX_SPEED)
                core = core + CFG.SPEED_PENALTY*pen

        scaler.scale(core/CFG.ACCUM_STEPS).backward()
        if step % CFG.ACCUM_STEPS == 0:
            if CFG.GRAD_CLIP:
                scaler.unscale_(optim)
                nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
            scaler.step(optim); scaler.update()
            optim.zero_grad(set_to_none=True)
        total += core.item()
    return total / max(len(loader),1)

@torch.no_grad()
def evaluate(model, loader, max_batches=30):
    model.eval(); rmse_sum=0.0; cnt=0
    ys=[]; ps=[]
    for bi, batch in enumerate(tqdm(loader, leave=False)):
        for k in ["seq","final_xy","s","a","dir_sin","dir_cos","o_sin","o_cos",
                  "height","weight","bmi","yardln","ball_xy","role_id","side_id",
                  "target_dxy","tgt_mask"]:
            batch[k] = batch[k].to(CFG.DEVICE, non_blocking=True)
        T = batch["T_max"]
        if CFG.CAP_T and T>CFG.CAP_T:
            batch["target_dxy"]=batch["target_dxy"][:,:CFG.CAP_T,...]
            batch["tgt_mask"]=batch["tgt_mask"][:,:CFG.CAP_T,...]
            T = CFG.CAP_T
        pred = model(batch, T)
        diff = (pred - batch["target_dxy"]) * batch["tgt_mask"].unsqueeze(-1)
        mse = (diff**2).sum() / (batch["tgt_mask"].sum()*2 + 1e-6)
        rmse = torch.sqrt(mse).item()
        rmse_sum += rmse; cnt += 1
        if bi < max_batches:
            ys.append(batch["target_dxy"][batch["tgt_mask"].bool()].view(-1,2).detach().cpu().numpy())
            ps.append(pred[batch["tgt_mask"].bool()].view(-1,2).detach().cpu().numpy())
    y = np.concatenate(ys,0) if ys else np.zeros((0,2))
    p = np.concatenate(ps,0) if ps else np.zeros((0,2))
    return rmse_sum/max(cnt,1), y, p


## 8. Train/Validation Split and Dataloaders

**EN:** Randomly split plays into train/validation and build PyTorch dataloaders with the custom collate.  

**JP:** プレイ単位でランダム分割し、`st_collate_fn` を用いた DataLoader を作成します。

In [None]:

# ------------------
# Split / Loader
# ------------------
def split_train_valid_keys(all_keys, valid_ratio=0.1):
    idx = np.arange(len(all_keys)); rng=np.random.RandomState(CFG.SEED); rng.shuffle(idx)
    n_valid=max(1,int(len(idx)*valid_ratio)); va=set(idx[:n_valid].tolist())
    tr,vl=[],[]
    for i,k in enumerate(all_keys): (vl if i in va else tr).append(k)
    return tr,vl

class SubsetPlay(Dataset):
    def __init__(self, base, keys): self.base=base; self.keys=keys; self.map={tuple(k):i for i,k in enumerate(base.keys)}
    def __len__(self): return len(self.keys)
    def __getitem__(self,i): return self.base[self.map[tuple(self.keys[i])]]

def make_loader(ds, bsz, shuffle):
    return DataLoader(ds, batch_size=bsz, shuffle=shuffle, collate_fn=st_collate_fn, num_workers=0, pin_memory=CFG.USE_CUDA)


## 9. Main: Training → Validation → Test Inference → Submission

**EN:**  
- Warm forward pass triggers lazy initialization (embeddings depend on observed role/side vocab).  
- Optional `DataParallel` when ≥2 GPUs are available.  
- Test phase follows the template to produce `submission.csv` (also saved under `runs/`).

**JP:**  
- ウォーム実行で遅延初期化（役割/サイド辞書サイズが観測に依存）  
- 2GPU以上なら自動DP化  
- テストはテンプレ構造に従って `submission.csv` を生成します。

In [None]:

# ------------------
# Train → Eval → Predict → Submission
# ------------------
def main():
    print("Loading train/test...")
    train_in, train_out = load_train_input_output()
    test_in, test_tpl, sample_sub = load_test_input_and_templates()

    full = STPlayDataset(train_in, train_out, L=CFG.L_HIST, for_train=True)
    tr_keys, va_keys = split_train_valid_keys(full.keys, valid_ratio=0.1)
    tr_ds, va_ds = SubsetPlay(full, tr_keys), SubsetPlay(full, va_keys)
    tr_loader = make_loader(tr_ds, CFG.BATCH_PLAYS, True)
    va_loader = make_loader(va_ds, CFG.BATCH_PLAYS, False)

    print(f"roles={full.role_dim}, sides={full.side_dim}")
    warm = next(iter(tr_loader))

    model = STGNNModel().to(CFG.DEVICE)
    # Warm forward（遅延初期化）
    with torch.cuda.amp.autocast(enabled=CFG.AMP):
        _ = model({k:(v.to(CFG.DEVICE) if torch.is_tensor(v) else v) for k,v in warm.items()}, T=warm["T_max"])

    if torch.cuda.device_count()>=2:
        print(f"Using {torch.cuda.device_count()} GPUs (DP)")
        model = nn.DataParallel(model)

    optim = torch.optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.AMP)

    print("Start training...")
    for ep in range(1, CFG.EPOCHS+1):
        tr_loss = train_one_epoch(model, tr_loader, optim, scaler)
        val_rmse, y_true, y_pred = evaluate(model, va_loader, max_batches=20)
        print(f"[Epoch {ep}] train={tr_loss:.5f}  valid_RMSE={val_rmse:.5f}")
        gc.collect(); 
        if CFG.USE_CUDA: torch.cuda.empty_cache()

    # ---- Test inference ----
    print("Preparing test...")
    tin = maybe_flip_direction(test_in.copy())
    tin["height_inches"] = tin["player_height"].apply(height_ftin_to_in) if "player_height" in tin.columns else 70
    pred_rows=[]; collect=[]

    plays = test_tpl.groupby(["game_id","play_id"])["frame_id"].max().reset_index().values.tolist()
    for gid, pid, Tmax in tqdm(plays, desc="test plays"):
        g = tin[(tin.game_id==gid)&(tin.play_id==pid)].sort_values(["nfl_id","frame_id"])
        if g.empty: continue
        last = g.groupby("nfl_id", as_index=False).last()
        nfl_ids = last["nfl_id"].values; N=len(nfl_ids)

        seq=[]
        for nid in nfl_ids:
            gg = g[g.nfl_id==nid].sort_values("frame_id")
            x,y = gg["x"].values, gg["y"].values
            s = gg["s"].values if "s" in gg else np.zeros(len(gg),np.float32)
            a = gg["a"].values if "a" in gg else np.zeros(len(gg),np.float32)
            dir_deg = gg["dir"].values if "dir" in gg else np.zeros(len(gg),np.float32)
            vx,vy = s*np.cos(np.deg2rad(dir_deg)), s*np.sin(np.deg2rad(dir_deg))
            arr = np.stack([x[-CFG.L_HIST:],y[-CFG.L_HIST:],vx[-CFG.L_HIST:],vy[-CFG.L_HIST:],s[-CFG.L_HIST:],a[-CFG.L_HIST:]],0)
            if arr.shape[1]<CFG.L_HIST:
                pad=np.zeros((arr.shape[0], CFG.L_HIST-arr.shape[1]), np.float32)
                arr=np.concatenate([pad,arr],1)
            seq.append(arr.astype(np.float32))
        seq=np.stack(seq,0)

        final_xy = last[["x","y"]].values.astype(np.float32)
        s = last["s"].fillna(0).values.astype(np.float32) if "s" in last else np.zeros(N,np.float32)
        a = last["a"].fillna(0).values.astype(np.float32) if "a" in last else np.zeros(N,np.float32)
        dir_deg = last["dir"].fillna(0).values.astype(np.float32) if "dir" in last else np.zeros(N,np.float32)
        o_deg   = last["o"].fillna(0).values.astype(np.float32) if "o" in last else np.zeros(N,np.float32)
        ds,dc = np.sin(np.deg2rad(dir_deg)), np.cos(np.deg2rad(dir_deg))
        os_,oc= np.sin(np.deg2rad(o_deg)),   np.cos(np.deg2rad(o_deg))
        height = last["height_inches"].values.astype(np.float32) if "height_inches" in last else np.full(N,70,np.float32)
        weight = last["player_weight"].fillna(200).values.astype(np.float32) if "player_weight" in last else np.full(N,200,np.float32)
        bmi = (weight/(np.clip(height,1,300)**2))*703.0
        yardln = last["absolute_yardline_number"].fillna(50).values.astype(np.float32) if "absolute_yardline_number" in last else np.full(N,50,np.float32)
        ball_xy = last[["ball_land_x","ball_land_y"]].fillna(0.0).values.astype(np.float32) if {"ball_land_x","ball_land_y"}.issubset(last.columns) else np.tile(final_xy.mean(0,keepdims=True),(N,1)).astype(np.float32)

        # role/side ID（test側は未知を0でOK）
        role_map = {k:i for i,k in enumerate(sorted(tin["player_role"].fillna("Unknown").unique()))} if "player_role" in tin else {"Unknown":0}
        side_map = {k:i for i,k in enumerate(sorted(tin["player_side"].fillna("Unknown").unique()))} if "player_side" in tin else {"Unknown":0}
        role_id = np.array([role_map.get(v if isinstance(v,str) else "Unknown",0) for v in last.get("player_role", pd.Series(["Unknown"]*N))], np.int64)
        side_id = np.array([side_map.get(v if isinstance(v,str) else "Unknown",0) for v in last.get("player_side", pd.Series(["Unknown"]*N))], np.int64)

        batch = {
            "seq": torch.from_numpy(seq).to(CFG.DEVICE),
            "final_xy": torch.from_numpy(final_xy).to(CFG.DEVICE),
            "s": torch.from_numpy(s).to(CFG.DEVICE),
            "a": torch.from_numpy(a).to(CFG.DEVICE),
            "dir_sin": torch.from_numpy(ds).to(CFG.DEVICE),
            "dir_cos": torch.from_numpy(dc).to(CFG.DEVICE),
            "o_sin": torch.from_numpy(os_).to(CFG.DEVICE),
            "o_cos": torch.from_numpy(oc).to(CFG.DEVICE),
            "height": torch.from_numpy(height).to(CFG.DEVICE),
            "weight": torch.from_numpy(weight).to(CFG.DEVICE),
            "bmi": torch.from_numpy(bmi).to(CFG.DEVICE),
            "yardln": torch.from_numpy(yardln).to(CFG.DEVICE),
            "ball_xy": torch.from_numpy(ball_xy).to(CFG.DEVICE),
            "role_id": torch.from_numpy(role_id).to(CFG.DEVICE),
            "side_id": torch.from_numpy(side_id).to(CFG.DEVICE),
        }
        model.eval()
        T = int(Tmax); 
        if CFG.CAP_T and T>CFG.CAP_T: T=CFG.CAP_T
        with torch.no_grad():
            dxy = model(batch, T)              # (N,T,2)
        abs_xy = clip_to_field(torch.from_numpy(final_xy).to(CFG.DEVICE).unsqueeze(1) + dxy).cpu().numpy()
        for i in range(abs_xy.shape[0]):
            abs_xy[i,:,0] = gaussian_filter1d(abs_xy[i,:,0], sigma=CFG.SMOOTH_SIGMA)
            abs_xy[i,:,1] = gaussian_filter1d(abs_xy[i,:,1], sigma=CFG.SMOOTH_SIGMA)

        collect.append(abs_xy.reshape(-1,2))
        gtpl = test_tpl[(test_tpl.game_id==gid)&(test_tpl.play_id==pid)].sort_values(["nfl_id","frame_id"])
        mapping = {nid:i for i,nid in enumerate(nfl_ids)}
        for _, r in gtpl.iterrows():
            i = mapping.get(r["nfl_id"], None)
            if i is None: x,y = final_xy[0,0], final_xy[0,1]
            else:
                f = int(r["frame_id"])-1; f = max(0, min(f, abs_xy.shape[1]-1))
                x,y = abs_xy[i,f,0], abs_xy[i,f,1]
            rid = f"{int(r['game_id'])}_{int(r['play_id'])}_{int(r['nfl_id'])}_{int(r['frame_id'])}"
            pred_rows.append((rid, x, y))

    pred_df = pd.DataFrame(pred_rows, columns=["id","x","y"])
    submission = sample_sub[["id"]].merge(pred_df, on="id", how="left")
    mean_x = float(submission["x"].mean(skipna=True)) if submission["x"].notna().any() else (CFG.FIELD_X_MIN+CFG.FIELD_X_MAX)/2
    mean_y = float(submission["y"].mean(skipna=True)) if submission["y"].notna().any() else (CFG.FIELD_Y_MIN+CFG.FIELD_Y_MAX)/2
    submission["x"] = submission["x"].fillna(mean_x).clip(CFG.FIELD_X_MIN, CFG.FIELD_X_MAX)
    submission["y"] = submission["y"].fillna(mean_y).clip(CFG.FIELD_Y_MIN, CFG.FIELD_Y_MAX)
    Path("/kaggle/working").mkdir(parents=True, exist_ok=True)
    kaggle_out = "/kaggle/working/submission.csv"
    submission[["id","x","y"]].to_csv(kaggle_out, index=False)
    local_out = CFG.OUT_DIR / "submission.csv"
    submission[["id","x","y"]].to_csv(local_out, index=False)
    print(f"✅ Saved: {kaggle_out} / {local_out}  shape={submission.shape}")

if __name__ == "__main__":
    main()


---

### 10. Notes / Tips

- **Reproducibility:** Seeds are set for Python, NumPy, and Torch.  
- **AMP Safety:** Edge‑attention runs in FP32 internally to avoid AMP over/underflow artifacts.  
- **Scalability:** You can increase `BATCH_PLAYS` and `ACCUM_STEPS` based on GPU memory.  
- **Physics priors:** `PHYS_RESID_ALPHA` and `SPEED_PENALTY` can be tuned per roster/play style.

**JP補足:** 物理正則化や速度ペナルティは汎用的な安全バイアスです。ドメインに合わせて係数を微調整してください。
