In [None]:
# -------------------------------
# Global imports + cuDF accelerator
# -------------------------------
import os
USE_CUDF = False
try:
    # zero/low-code GPU acceleration for DataFrame ops
    os.environ["CUDF_PANDAS_BACKEND"] = "cudf"
    import pandas as pd
    import numpy as np
    import cupy as cp  # optional (not strictly required below)
    USE_CUDF = True
    print("using cuda_backend pandas for faster parallel data processing")
except Exception:
    print("cuda df not used")
    import pandas as pd
    import numpy as np

import torch
import torch.nn as nn
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import GroupKFold
import warnings
warnings.filterwarnings("ignore")
from glob import glob
import json, pickle, re
import torch.nn.functional as F

# ===============================
# RUN MODE FLAGS
# ===============================
TRAIN = int(os.environ.get("TRAIN", "0"))   # 1=train, 0=off
SUB   = int(os.environ.get("SUB",   "1"))   # 1=submission(infer), 0=off
assert (TRAIN + SUB) == 1, "Set exactly one of TRAIN=1 or SUB=1"




# -------------------------------
# Constants & helpers
# -------------------------------
YARDS_TO_METERS = 0.9144
FPS = 10.0 
FIELD_LENGTH, FIELD_WIDTH = 120.0, 53.3

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
print("environment set up!")
def wrap_angle_deg(s):
    # map to (-180, 180]
    return ((s + 180.0) % 360.0) - 180.0

def unify_left_direction(df: pd.DataFrame) -> pd.DataFrame:
    """Mirror rightward plays so all samples are 'left' oriented (x,y, dir, o, ball_land)."""
    if 'play_direction' not in df.columns:
        return df
    df = df.copy()
    right = df['play_direction'].eq('right')
    # positions
    if 'x' in df.columns: df.loc[right, 'x'] = FIELD_LENGTH - df.loc[right, 'x']
    if 'y' in df.columns: df.loc[right, 'y'] = FIELD_WIDTH  - df.loc[right, 'y']
    # angles in degrees
    for col in ('dir','o'):
        if col in df.columns:
            df.loc[right, col] = (df.loc[right, col] + 180.0) % 360.0
    # ball landing
    if 'ball_land_x' in df.columns:
        df.loc[right, 'ball_land_x'] = FIELD_LENGTH - df.loc[right, 'ball_land_x']
    if 'ball_land_y' in df.columns:
        df.loc[right, 'ball_land_y'] = FIELD_WIDTH  - df.loc[right, 'ball_land_y']
    return df

def invert_to_original_direction(x_u, y_u, play_dir_right: bool):
    """Invert unified (left) coordinates back to original play direction."""
    if not play_dir_right:
        return float(x_u), float(y_u)
    return float(FIELD_LENGTH - x_u), float(FIELD_WIDTH - y_u)

# -------------------------------
# Config
# -------------------------------
class Config:
    DATA_DIR = Path("/kaggle/input/nfl-big-data-bowl-2026-prediction/")
    OUTPUT_DIR = Path("./outputs"); OUTPUT_DIR.mkdir(exist_ok=True)

    MODEL_BUNDLE_DIR_TRAIN = OUTPUT_DIR / "bundle"           # 训练期写出

    env_bundle = os.environ.get("MODEL_BUNDLE_DIR")
    MODEL_BUNDLE_DIR_SUB = Path(env_bundle) if env_bundle else Path("/kaggle/input/trans-my-test/outputs/bundle")

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    SEED = 42
    N_FOLDS = 5
    BATCH_SIZE = 256
    EPOCHS = 200
    PATIENCE = 30
    LEARNING_RATE = 1e-3

    WINDOW_SIZE = 10
    HIDDEN_DIM = 128
    MAX_FUTURE_HORIZON = 94  # 不要改动这个！！！

    #KNeighbours
    N_NEIGHBORS = 7  #每帧选K个最近邻(不含自己)->兜底用
    NEIGHBOR_RADIUS = 12.0     #近邻半径（码）
    ADAPTIVE_NEIGHBORS = True    #开启半径优先的自适应邻接

    #物理先验权重
    W_BOUNDARY = 1e-4     # 出界惩罚
    W_SPEED    = 1e-4     # 超速惩罚（软）
    W_JERK     = 1e-4     # 二阶差分平滑（jerk）

    MAX_SPEED_YPS = 12.0  # 每秒最大位移（码/秒）

set_seed(Config.SEED)

In [None]:
# ===============================
# I/O helpers for model bundle
# ===============================
def _mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _save_json(obj, path: Path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def _load_json(path: Path):
    with open(path, "r") as f:
        return json.load(f)

def _save_pickle(obj, path: Path):
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def _load_pickle(path: Path):
    with open(path, "rb") as f:
        return pickle.load(f)

def ensure_bundle_dirs(cfg):
    root = cfg.MODEL_BUNDLE_DIR_TRAIN
    _mkdir(root / "models")
    _mkdir(root / "scalers")
    _mkdir(root / "meta")
    return root

def save_fold_artifacts(cfg, seed, fold, model_x, model_y, scaler):
    root = ensure_bundle_dirs(cfg)
    # models
    torch.save(model_x.state_dict(), root / "models" / f"dx_seed{seed}_fold{fold}.pth")
    torch.save(model_y.state_dict(), root / "models" / f"dy_seed{seed}_fold{fold}.pth")
    # scaler
    _save_pickle(scaler, root / "scalers" / f"scaler_seed{seed}_fold{fold}.pkl")

def save_meta(cfg, feature_cols, dir_map, fold_rmse_list, fold_assign=None):
    root = ensure_bundle_dirs(cfg)
    (dir_map.reset_index()
            .to_parquet(root / "meta" / "dir_map.parquet", index=False))
    _save_json({
        "feature_cols": feature_cols,
        "MAX_FUTURE_HORIZON": cfg.MAX_FUTURE_HORIZON,
        "WINDOW_SIZE": cfg.WINDOW_SIZE,
        "SEEDS": getattr(cfg, "SEEDS", [cfg.SEED]),
        "N_FOLDS": cfg.N_FOLDS,
        "rmse_per_fold": fold_rmse_list,
        "model_type": "st",                # 新增
        "n_neighbors": getattr(cfg, "N_NEIGHBORS", 7),  # 新增
    }, root / "meta" / "meta.json")
    if fold_assign is not None:
        fold_assign.to_parquet(root / "meta" / "train_folds.parquet", index=False)


def _parse_sf(p):  # p 是 Path('.../dx_seed42_fold5.pth')
    m = re.search(r"seed(\d+)_fold(\d+)", p.name)
    return (int(m.group(1)), int(m.group(2))) if m else (-1, -1)

def discover_bundle_for_sub(cfg):
    root = cfg.MODEL_BUNDLE_DIR_SUB
    assert root and root.exists(), f"MODEL_BUNDLE_DIR_SUB 无效: {root}"
    meta = _load_json(root / "meta" / "meta.json")
    feat_cols = meta["feature_cols"]
    # 列出所有 (dx,dy,scaler) 三元组
    model_x_paths = sorted((root/"models").glob("dx_seed*_fold*.pth"), key=_parse_sf)
    model_y_paths = sorted((root/"models").glob("dy_seed*_fold*.pth"), key=_parse_sf)
    scaler_paths  = sorted((root/"scalers").glob("scaler_seed*_fold*.pkl"), key=_parse_sf)
    # 简单对齐：按同样的排序顺序配对（命名规范保证顺序一一对应）
    assert len(model_x_paths)==len(model_y_paths)==len(scaler_paths)>0, "bundle 内文件不完整"
    return root, meta, feat_cols, model_x_paths, model_y_paths, scaler_paths

In [None]:
def compute_val_rmse(pred_dx, pred_dy, y_dx_list, y_dy_list, max_h):
    """
    RMSE = sqrt( (1/(2N)) * sum( (x_true-x_pred)^2 + (y_true-y_pred)^2 ) )
    其中 N 为所有样本在有效时间步上的总数。
    """
    total_sse = 0.0
    total_n = 0

    for i in range(len(pred_dx)):
        # 目标按 horizon 填充，并拿到 mask（哪些时间步有效）
        tdx_full, m_dx = prepare_targets([y_dx_list[i]], max_h)
        tdy_full, m_dy = prepare_targets([y_dy_list[i]], max_h)

        # 都转成 numpy，mask 取交集更保险
        m = (m_dx[0].cpu().numpy().astype(bool) &
             m_dy[0].cpu().numpy().astype(bool))

        tx = tdx_full[0].cpu().numpy()[m]
        ty = tdy_full[0].cpu().numpy()[m]
        px = pred_dx[i][m]
        py = pred_dy[i][m]

        total_sse += np.sum((px - tx)**2 + (py - ty)**2)
        total_n   += int(m.sum())

    if total_n == 0:
        return float("nan")
    return float(np.sqrt(total_sse / (2.0 * total_n)))


In [None]:
# 仅保留你指定的 ST 特征组合（不做交互/对齐/滚动/曲率等）
ST_FEATURE_GROUPS = [
    'time_features',     # 只保留归一化时间
]

# 严格白名单：最终送入模型的列只允许这些
FEATURE_WHITELIST = {
    # 基础运动学
    'x','y','s','a','velocity_x','velocity_y','acceleration_x','acceleration_y',
    # 角色
    'is_offense','is_defense','is_receiver','is_coverage','is_passer',
    # 球落点几何
    'distance_to_ball','ball_direction_x','ball_direction_y',
    # 时间（二选一，这里选择归一化时间，跨回合更稳）
    'normalized_time',
}



# -------------------------------
# Feature Engineering
# -------------------------------
class FeatureEngineer:
    """
    Modular, ablation-friendly feature builder (pandas or cuDF pandas-API).
    """
    def __init__(self, feature_groups_to_create):
        self.gcols = ['game_id', 'play_id', 'nfl_id']
        self.active_groups = feature_groups_to_create
        self.feature_creators = {
            'distance_rate': self._create_distance_rate_features,
            'target_alignment': self._create_target_alignment_features,
            'multi_window_rolling': self._create_multi_window_rolling_features,
            'extended_lags': self._create_extended_lag_features,
            'velocity_changes': self._create_velocity_change_features,
            'field_position': self._create_field_position_features,
            'role_specific': self._create_role_specific_features,
            'time_features': self._create_time_features,
            'jerk_features': self._create_jerk_features,
            'curvature_land_features': self._create_curvature_land_features,
            'interaction': self._create_interaction_features,
        }
        self.created_feature_cols = []

    def _height_to_feet(self, height_str):
        try:
            ft, inches = map(int, str(height_str).split('-'))
            return ft + inches / 12
        except Exception:
            return 6.0

    def _create_basic_features(self, df):
        print("Step 1/3: Adding basic features (slim)...")
        df = df.copy()
    
        # 运动学：由 dir 推出 vx, vy, ax, ay
        dir_rad = np.deg2rad(df['dir'].fillna(0.0).astype('float32'))
        df['velocity_x']     = df['s'] * np.cos(dir_rad)
        df['velocity_y']     = df['s'] * np.sin(dir_rad)
        df['acceleration_x'] = df['a'] * np.cos(dir_rad)
        df['acceleration_y'] = df['a'] * np.sin(dir_rad)
    
        # 角色
        df['is_offense']  = (df['player_side'] == 'Offense').astype(np.int8)
        df['is_defense']  = (df['player_side'] == 'Defense').astype(np.int8)
        df['is_receiver'] = (df['player_role'] == 'Targeted Receiver').astype(np.int8)
        df['is_coverage'] = (df['player_role'] == 'Defensive Coverage').astype(np.int8)
        df['is_passer']   = (df['player_role'] == 'Passer').astype(np.int8)
    
        # 球落点几何（不再计算 momentum/energy/closing_speed）
        if {'ball_land_x','ball_land_y'}.issubset(df.columns):
            ball_dx = df['ball_land_x'] - df['x']
            ball_dy = df['ball_land_y'] - df['y']
            dist = np.hypot(ball_dx, ball_dy)
            df['distance_to_ball'] = dist
            inv = 1.0 / (dist + 1e-6)
            df['ball_direction_x'] = ball_dx * inv
            df['ball_direction_y'] = ball_dy * inv
    
        # 仅把可能作为输入的列加入 created_feature_cols（其后再做白名单过滤）
        base = [
            'x','y','s','a','dir','frame_id',
            'velocity_x','velocity_y','acceleration_x','acceleration_y',
            'is_offense','is_defense','is_receiver','is_coverage','is_passer',
            'distance_to_ball','ball_direction_x','ball_direction_y',
        ]
        self.created_feature_cols.extend([c for c in base if c in df.columns])
        return df


    # ---- feature groups ----
    def _create_distance_rate_features(self, df):
        new_cols = []
        if 'distance_to_ball' in df.columns:
            d = df.groupby(self.gcols)['distance_to_ball'].diff()
            df['d2ball_dt']  = d.fillna(0.0) * FPS
            df['d2ball_ddt'] = df.groupby(self.gcols)['d2ball_dt'].diff().fillna(0.0) * FPS
            df['time_to_intercept'] = (df['distance_to_ball'] /
                                       (df['d2ball_dt'].abs() + 1e-3)).clip(0, 10)
            new_cols = ['d2ball_dt','d2ball_ddt','time_to_intercept']
        return df, new_cols

    def _create_target_alignment_features(self, df):
        new_cols = []
        if {'ball_direction_x','ball_direction_y','velocity_x','velocity_y'}.issubset(df.columns):
            df['velocity_alignment'] = df['velocity_x']*df['ball_direction_x'] + df['velocity_y']*df['ball_direction_y']
            df['velocity_perpendicular'] = df['velocity_x']*(-df['ball_direction_y']) + df['velocity_y']*df['ball_direction_x']
            new_cols.extend(['velocity_alignment','velocity_perpendicular'])
            if {'acceleration_x','acceleration_y'}.issubset(df.columns):
                df['accel_alignment'] = df['acceleration_x']*df['ball_direction_x'] + df['acceleration_y']*df['ball_direction_y']
                new_cols.append('accel_alignment')
        return df, new_cols

    def _create_multi_window_rolling_features(self, df):
        # keep it simple & compatible (works with cuDF pandas-API); vectorized rolling per group
        new_cols = []
        for window in (3, 5, 10):
            for col in ('velocity_x','velocity_y','s','a'):
                if col in df.columns:
                    r_mean = df.groupby(self.gcols)[col].rolling(window, min_periods=1).mean()
                    r_std  = df.groupby(self.gcols)[col].rolling(window, min_periods=1).std()
                    # align indices
                    r_mean = r_mean.reset_index(level=list(range(len(self.gcols))), drop=True)
                    r_std  = r_std.reset_index(level=list(range(len(self.gcols))), drop=True)
                    df[f'{col}_roll{window}'] = r_mean
                    df[f'{col}_std{window}']  = r_std.fillna(0.0)
                    new_cols.extend([f'{col}_roll{window}', f'{col}_std{window}'])
        return df, new_cols

    def _create_extended_lag_features(self, df):
        new_cols = []
        for lag in (1,2,3,4,5):
            for col in ('x','y','velocity_x','velocity_y'):
                if col in df.columns:
                    g = df.groupby(self.gcols)[col]
                    lagv = g.shift(lag)
                    # safe fill for first frames (no "future" leakage)
                    df[f'{col}_lag{lag}'] = lagv.fillna(g.transform('first'))
                    new_cols.append(f'{col}_lag{lag}')
        return df, new_cols

    def _create_velocity_change_features(self, df):
        new_cols = []
        if 'velocity_x' in df.columns:
            df['velocity_x_change'] = df.groupby(self.gcols)['velocity_x'].diff().fillna(0.0)
            df['velocity_y_change'] = df.groupby(self.gcols)['velocity_y'].diff().fillna(0.0)
            df['speed_change']      = df.groupby(self.gcols)['s'].diff().fillna(0.0)
            d = df.groupby(self.gcols)['dir'].diff().fillna(0.0)
            df['direction_change']  = wrap_angle_deg(d)
            new_cols = ['velocity_x_change','velocity_y_change','speed_change','direction_change']
        return df, new_cols

    def _create_field_position_features(self, df):
        df['dist_from_left'] = df['y']
        df['dist_from_right'] = FIELD_WIDTH - df['y']
        df['dist_from_sideline'] = np.minimum(df['dist_from_left'], df['dist_from_right'])
        df['dist_from_endzone']  = np.minimum(df['x'], FIELD_LENGTH - df['x'])
        return df, ['dist_from_sideline','dist_from_endzone']

    def _create_role_specific_features(self, df):
        new_cols = []
        if {'is_receiver','velocity_alignment'}.issubset(df.columns):
            df['receiver_optimality'] = df['is_receiver'] * df['velocity_alignment']
            df['receiver_deviation']  = df['is_receiver'] * np.abs(df.get('velocity_perpendicular', 0.0))
            new_cols.extend(['receiver_optimality','receiver_deviation'])
        if {'is_coverage','closing_speed'}.issubset(df.columns):
            df['defender_closing_speed'] = df['is_coverage'] * df['closing_speed']
            new_cols.append('defender_closing_speed')
        return df, new_cols

    def _create_time_features(self, df):
        df['frames_elapsed']  = df.groupby(self.gcols).cumcount()
        df['normalized_time'] = df.groupby(self.gcols)['frames_elapsed'].transform(
            lambda x: x / (x.max() + 1e-9)
        )
        # 只保留 normalized_time 作为模型输入（frames_elapsed 仅用于计算）
        return df, ['normalized_time']


    def _create_jerk_features(self, df):
        new_cols = []
        if 'a' in df.columns:
            df['jerk'] = df.groupby(self.gcols)['a'].diff().fillna(0.0) * FPS
            new_cols.append('jerk')
        if {'acceleration_x','acceleration_y'}.issubset(df.columns):
            df['jerk_x'] = df.groupby(self.gcols)['acceleration_x'].diff().fillna(0.0) * FPS
            df['jerk_y'] = df.groupby(self.gcols)['acceleration_y'].diff().fillna(0.0) * FPS
            new_cols.extend(['jerk_x','jerk_y'])
        return df, new_cols
    def _create_curvature_land_features(self, df):
        """
        -落点侧向偏差（符号）：landing_point 相对“当前运动方向”的左右偏离
          lateral = cross(u_dir, vector_to_land)（>0 表示落点在运动方向左侧）
        -bearing_to_land_signed: 运动方向 vs 落点方位角
        -速度归一化曲率： wrap(Δdir)/ (s*Δt) ，窗口化(3/5) 的均值/绝对值
        """
        import numpy as np
        # 侧向偏差 & bearing_to_land
        if {'ball_land_x','ball_land_y'}.issubset(df.columns):
            dx = df['ball_land_x'] - df['x']
            dy = df['ball_land_y'] - df['y']
            bearing = np.arctan2(dy, dx)
            a_dir = np.deg2rad(df['dir'].fillna(0.0).values)
            # 有符号方位差
            df['bearing_to_land_signed'] = np.rad2deg(np.arctan2(np.sin(bearing - a_dir), np.cos(bearing - a_dir)))
            # 侧向偏差：d × u (2D cross, z 分量)
            ux, uy = np.cos(a_dir), np.sin(a_dir)
            df['land_lateral_offset'] = dy*ux - dx*uy  # >0 落点在左侧
    
        # 曲率（按序列）
        ddir = df.groupby(self.gcols)['dir'].diff().fillna(0.0)
        ddir = ((ddir + 180.0) % 360.0) - 180.0
        curvature = np.deg2rad(ddir).astype('float32') / (df['s'].replace(0, np.nan).astype('float32') * 0.1 + 1e-6)
        df['curvature_signed'] = curvature.fillna(0.0)
        df['curvature_abs'] = df['curvature_signed'].abs()
    
        # 窗口均值（3/5）
        for w in (3,5):
            r = df.groupby(self.gcols)['curvature_signed'].rolling(w, min_periods=1).mean().reset_index(level=[0,1,2], drop=True)
            df[f'curv_signed_roll{w}'] = r
            r2 = df.groupby(self.gcols)['curvature_abs'].rolling(w, min_periods=1).mean().reset_index(level=[0,1,2], drop=True)
            df[f'curv_abs_roll{w}'] = r2
    
        new_cols = ['bearing_to_land_signed','land_lateral_offset',
                    'curvature_signed','curvature_abs','curv_signed_roll3','curv_abs_roll3',
                    'curv_signed_roll5','curv_abs_roll5']
        return df, [c for c in new_cols if c in df.columns]

    def _create_interaction_features(self, df, speed_eps=0.5):
        """
        Receiver–Defender 轻量交互特征（K=1最近对手）：
          - opp_dmin       : 与最近对手距离（裁剪到[0,30]）
          - opp_close_rate : 相对速度在“对手→自身”方向的投影（裁剪到[-10,10]）
          - opp_leverage   : 杠杆位符号（{-1,0,1}），当自身速度过小则置0
        只对 player_to_predict==True 的行计算（若该列存在），其余留 NaN。
        """
        need = ['x','y','velocity_x','velocity_y','player_side','frame_id']
        if any(c not in df.columns for c in need):
            return df, []
        import numpy as np
    
        out_cols = ['opp_dmin','opp_close_rate','opp_leverage']
        for c in out_cols:
            if c not in df.columns:
                df[c] = np.nan
    
        key = ['game_id','play_id','frame_id']
        use_mask_global = ('player_to_predict' in df.columns)
    
        for _, g in df.groupby(key, sort=False):
            if len(g) <= 1: 
                continue
            idx = g.index.values
            pos = g[['x','y']].values.astype('float32')
            vel = g[['velocity_x','velocity_y']].values.astype('float32')
            side_off = (g['player_side'].values == 'Offense')
            side_def = ~side_off
            tgt_mask = g['player_to_predict'].astype(bool).values if use_mask_global else np.ones(len(g), bool)
    
            def _assign(A_mask, B_mask):
                A_mask = A_mask & tgt_mask
                A_idx = np.where(A_mask)[0]
                B_idx = np.where(B_mask)[0]
                if len(A_idx)==0 or len(B_idx)==0:
                    return
                Apos, Bpos = pos[A_idx], pos[B_idx]
                Avel, Bvel = vel[A_idx], vel[B_idx]
    
                dx = Apos[:,None,0] - Bpos[None,:,0]
                dy = Apos[:,None,1] - Bpos[None,:,1]
                D  = np.sqrt(dx*dx + dy*dy) + 1e-6
                j  = np.argmin(D, axis=1)
    
                dmin = np.clip(D[np.arange(len(A_idx)), j], 0.0, 30.0)
    
                r   = Apos - Bpos[j]                      # opp -> self
                u   = r / (np.linalg.norm(r, axis=1, keepdims=True) + 1e-6)
                v_rel = Bvel[j] - Avel
                close = np.clip(np.einsum('ij,ij->i', v_rel, u), -10.0, 10.0)
    
                speed   = np.linalg.norm(Avel, axis=1)
                to_opp  = Bpos[j] - Apos                  # self -> opp
                cross_z = to_opp[:,0]*Avel[:,1] - to_opp[:,1]*Avel[:,0]
                lever   = np.where(speed > speed_eps, np.sign(cross_z), 0).astype('int8')
    
                rows = idx[A_idx]
                df.loc[rows, 'opp_dmin']       = dmin
                df.loc[rows, 'opp_close_rate'] = close
                df.loc[rows, 'opp_leverage']   = lever
    
            _assign(side_off, side_def)   # Offense w.r.t Defense
            _assign(side_def, side_off)   # Defense w.r.t Offense
    
        return df, out_cols


    def transform(self, df):
        df = df.copy().sort_values(['game_id','play_id','nfl_id','frame_id'])
        df = self._create_basic_features(df)

        print("\nStep 2/3: Adding selected advanced features...")
        for group_name in self.active_groups:
            if group_name in self.feature_creators:
                creator = self.feature_creators[group_name]
                df, new_cols = creator(df)
                self.created_feature_cols.extend(new_cols)
                print(f"  [+] Added '{group_name}' ({len(new_cols)} cols)")
            else:
                print(f"  [!] Unknown feature group: {group_name}")

        final_cols = sorted(set(self.created_feature_cols))
        # 白名单过滤 —— 只让指定列进入模型
        final_cols = [c for c in final_cols if c in FEATURE_WHITELIST]
        print(f"\nTotal features used (whitelist): {len(final_cols)} -> {final_cols}")
        return df, final_cols


In [None]:
# -------------------------------
# Sequence builder (unified frame + safe targets)
# -------------------------------
def build_play_direction_map(df_in: pd.DataFrame) -> pd.Series:
    """
    Return a Series indexed by (game_id, play_id) with values 'left'/'right'.
    This keeps a clean MultiIndex that works for both pandas and cuDF pandas-API.
    """
    s = (
        df_in[['game_id','play_id','play_direction']]
        .drop_duplicates()
        .set_index(['game_id','play_id'])['play_direction']
    )
    return s  # MultiIndex Series


def apply_direction_to_df(df: pd.DataFrame, dir_map: pd.Series) -> pd.DataFrame:
    """
    Attach play_direction (if missing) and then unify to 'left'.
    dir_map must be the MultiIndex Series produced by build_play_direction_map.
    """
    if 'play_direction' not in df.columns:
        dir_df = dir_map.reset_index()  # -> columns: game_id, play_id, play_direction
        df = df.merge(dir_df, on=['game_id','play_id'], how='left', validate='many_to_one')
    return unify_left_direction(df)

def prepare_sequences_with_advanced_features(
        input_df, output_df=None, test_template=None, 
        is_training=True, window_size=10, feature_groups=None):

    print(f"\n{'='*80}")
    print(f"PREPARING SEQUENCES WITH ADVANCED FEATURES (UNIFIED FRAME)")
    print(f"{'='*80}")
    print(f"Window size: {window_size}")

    if feature_groups is None:
        feature_groups = ST_FEATURE_GROUPS   # 只保留 time_features
    

    # Direction map and unify
    # inside prepare_sequences_with_advanced_features(...)
    dir_map = build_play_direction_map(input_df)
    input_df_u = unify_left_direction(input_df)
    
    if is_training:
        out_u = apply_direction_to_df(output_df, dir_map)  # <-- 用新的函数
        target_rows = out_u
        target_groups = out_u[['game_id','play_id','nfl_id']].drop_duplicates()
    else:
        # ensure test_template has play_direction via safe merge
        if 'play_direction' not in test_template.columns:
            dir_df = dir_map.reset_index()
            test_template = test_template.merge(dir_df, on=['game_id','play_id'], how='left', validate='many_to_one')
        target_rows = test_template
        target_groups = target_rows[['game_id','play_id','nfl_id','play_direction']].drop_duplicates()
        
    #after merging play_direction into outputs / test_template:
    assert target_rows[['game_id','play_id','play_direction']].isna().sum().sum() == 0, \
        "play_direction merge failed; check (game_id, play_id) coverage"
    print("play_direction merge OK:", target_rows['play_direction'].value_counts(dropna=False).to_dict())
    # --- FE ---

    fe = FeatureEngineer(feature_groups)
    processed_df, feature_cols = fe.transform(input_df_u)

    # --- Build sequences ---
    print("\nStep 3/3: Creating sequences...")
    processed_df = processed_df.set_index(['game_id','play_id','nfl_id']).sort_index()
    grouped = processed_df.groupby(level=['game_id','play_id','nfl_id'])

    # helpful indices for last x,y in unified frame
    idx_x = feature_cols.index('x')
    idx_y = feature_cols.index('y')

    sequences, targets_dx, targets_dy, targets_fids, seq_meta = [], [], [], [], []

    it = target_groups.itertuples(index=False)
    it = tqdm(list(it), total=len(target_groups), desc="Creating sequences")

    for row in it:
        gid = row[0]; pid = row[1]; nid = row[2]
        play_dir = row[3] if (not is_training and len(row) >= 4) else None
        key = (gid, pid, nid)

        try:
            group_df = grouped.get_group(key)
        except KeyError:
            continue

        input_window = group_df.tail(window_size)
        if len(input_window) < window_size:
            if is_training:
                continue
            pad_len = window_size - len(input_window)
            pad_df = pd.DataFrame(np.nan, index=range(pad_len), columns=input_window.columns)
            input_window = pd.concat([pad_df, input_window], ignore_index=True)

        # simple impute with group means
        input_window = input_window.fillna(group_df.mean(numeric_only=True))
        seq = input_window[feature_cols].values

        if np.isnan(seq).any():
            if is_training:
                continue
            seq = np.nan_to_num(seq, nan=0.0)

        sequences.append(seq)

        # training targets from unified outputs (dx, dy from last unified x,y)
        if is_training:
            out_grp = target_rows[
                (target_rows['game_id']==gid) &
                (target_rows['play_id']==pid) &
                (target_rows['nfl_id']==nid)
            ].sort_values('frame_id')
            if len(out_grp)==0:
                continue

            last_x = seq[-1, idx_x]
            last_y = seq[-1, idx_y]
            dx = out_grp['x'].values - last_x
            dy = out_grp['y'].values - last_y

            targets_dx.append(dx.astype(np.float32))
            targets_dy.append(dy.astype(np.float32))
            targets_fids.append(out_grp['frame_id'].values.astype(np.int32))

        seq_meta.append({
            'game_id': gid,
            'play_id': pid,
            'nfl_id': nid,
            'frame_id': int(input_window.iloc[-1]['frame_id']),
            'play_direction': (None if is_training else play_dir),
        })

    print(f"Created {len(sequences)} sequences with {len(feature_cols)} features each")

    if is_training:
        return sequences, targets_dx, targets_dy, targets_fids, seq_meta, feature_cols, dir_map
    return sequences, seq_meta, feature_cols, dir_map

def prepare_sequences_spatiotemporal(
        input_df, output_df=None, test_template=None,
        is_training=True, window_size=10, feature_groups=None,
        K_neighbors=7, neighbor_radius=None, adaptive_neighbors=True):
    """
    返回：
      sequences : List[np.ndarray]，每个 (T, 1+K, D) ；token0 是目标球员，其余是最近邻
      spatial_masks : List[np.ndarray]，每个 (T, 1+K) ，真实=1，pad=0
      targets_dx, targets_dy, targets_fids, seq_meta, feature_cols, dir_map （训练时）
      或 sequences, spatial_masks, seq_meta, feature_cols, dir_map（推理时）
    """
    print(f"\n{'='*80}")
    print(f"PREPARING SPATIOTEMPORAL SEQUENCES (UNIFIED FRAME)")
    print(f"{'='*80}")
    print(f"Window size: {window_size}, K_neighbors: {K_neighbors}")

    if feature_groups is None:
        feature_groups = ST_FEATURE_GROUPS  # 只保留你定义的精简特征组

    # 方向统一 & 特征
    dir_map = build_play_direction_map(input_df)
    input_df_u = unify_left_direction(input_df)

    if is_training:
        out_u = apply_direction_to_df(output_df, dir_map)
        target_rows = out_u
        target_groups = out_u[['game_id','play_id','nfl_id']].drop_duplicates()
    else:
        if 'play_direction' not in test_template.columns:
            dir_df = dir_map.reset_index()
            test_template = test_template.merge(dir_df, on=['game_id','play_id'], how='left', validate='many_to_one')
        target_rows = test_template
        target_groups = target_rows[['game_id','play_id','nfl_id','play_direction']].drop_duplicates()

    assert target_rows[['game_id','play_id','play_direction']].isna().sum().sum() == 0, \
        "play_direction merge failed; check (game_id, play_id) coverage"
    print("play_direction merge OK:", target_rows['play_direction'].value_counts(dropna=False).to_dict())

    fe = FeatureEngineer(feature_groups)
    processed_df, feature_cols = fe.transform(input_df_u)

    # 为了按帧拿同回合的所有球员，建一个 (gid,pid,frame_id) -> 该帧 DataFrame 的简易索引
    # 注意：此处不设索引，直接布尔筛选，保证兼容性
    sequences, spatial_masks = [], []
    targets_dx, targets_dy, targets_fids, seq_meta = [], [], [], []

    # 需要 x,y 列用于最近邻选择
    assert 'x' in feature_cols and 'y' in feature_cols, "x,y 必须在特征列中用于最近邻选择"

    # helper: 从一帧里取目标 + K 最近邻，返回 (1+K, D) 和 mask (1+K,)
    def build_tokens_for_frame(play_df, frame_id, target_id, cols, K, radius, adaptive=True):
        fdf = play_df[play_df['frame_id'] == frame_id]
        if len(fdf) == 0:
            return np.zeros((1+K, len(cols)), np.float32), np.zeros((1+K,), np.float32)
    
        # 目标行
        tgt = fdf[fdf['nfl_id'] == target_id]
        if len(tgt) == 0:
            return np.zeros((1+K, len(cols)), np.float32), np.zeros((1+K,), np.float32)
        tgt_row = tgt.iloc[-1]
        tx, ty = float(tgt_row['x']), float(tgt_row['y'])
    
        # 候选邻居（去掉自己）
        others = fdf[fdf['nfl_id'] != target_id]
        sel = []
        if len(others) > 0:
            dx = others['x'].astype('float32').values - tx
            dy = others['y'].astype('float32').values - ty
            d  = np.sqrt(dx*dx + dy*dy)
    
            if adaptive and (radius is not None):
                within = np.where(d <= float(radius))[0]               # 半径内
                if len(within) >= K:
                    order = within[np.argsort(d[within])[:K]]          # 半径内再截K
                elif len(within) > 0:
                    rest  = np.argsort(d)                              # 半径外补足
                    rest  = [i for i in rest if i not in within][: (K - len(within))]
                    order = np.concatenate([within, np.array(rest, dtype=int)])
                else:
                    order = np.argsort(d)[:K]                          # 全靠最近的K
            else:
                order = np.argsort(d)[:K]
    
            sel = others.iloc[order]
    
        # 组装 tokens
        tokens = [tgt_row[cols].values.astype('float32')]
        if len(sel) > 0:
            tokens.append(sel[cols].values.astype('float32'))
        tokens = np.vstack(tokens) if len(tokens) > 1 else np.array(tokens, dtype='float32')

        # >>> 关键：在 pad 前保存有效 token 数
        n_valid = tokens.shape[0]

        # pad 到 1+K
        if tokens.shape[0] < 1+K:
            pad = np.zeros((1+K - tokens.shape[0], tokens.shape[1]), dtype='float32')
            tokens = np.vstack([tokens, pad])

        mask = np.zeros((1+K,), dtype='float32')
        mask[:n_valid] = 1.0
        return tokens, mask



    # 遍历目标
    it = target_groups.itertuples(index=False)
    it = tqdm(list(it), total=len(target_groups), desc="Creating ST sequences")

    # 为了高效：预先把 processed_df 分成回合粒度
    #（避免每次循环都全表筛选）
    # key: (gid,pid) -> 该回合 df（含所有球员、所有帧）
    play_cache = {}
    for row in it:
        gid = row[0]; pid = row[1]; nid = row[2]
        play_dir = row[3] if (not is_training and len(row) >= 4) else None

        key_play = (gid, pid)
        if key_play not in play_cache:
            pdf = processed_df[(processed_df['game_id']==gid) & (processed_df['play_id']==pid)].copy()
            # 缺失填充（按回合均值）
            pdf = pdf.fillna(pdf.mean(numeric_only=True))
            play_cache[key_play] = pdf
        play_df = play_cache[key_play]

        # 该目标的时间序列（为了拿最近 window_size 个帧号）
        try:
            tgt_series = play_df[play_df['nfl_id']==nid].sort_values('frame_id')
            if len(tgt_series) == 0:
                continue
        except KeyError:
            continue

        frames = tgt_series['frame_id'].values
        if len(frames) < window_size:
            if is_training:
                continue
            # 测试期允许 pad 开头
            pad_len = window_size - len(frames)
            frames = np.concatenate([np.full(pad_len, frames[0], dtype=frames.dtype), frames])

        frames = frames[-window_size:]  # 取最后 window

        # 构建 (T,1+K,D)
        T = window_size
        D = len(feature_cols)
        N = 1 + K_neighbors
        seq = np.zeros((T, N, D), dtype='float32')
        msk = np.zeros((T, N), dtype='float32')

        for t_idx, fid in enumerate(frames):
            tokens, m = build_tokens_for_frame(
                play_df, fid, nid, feature_cols, K_neighbors,
                radius=neighbor_radius, adaptive=adaptive_neighbors
            )
            seq[t_idx] = tokens
            msk[t_idx] = m

        sequences.append(seq)
        spatial_masks.append(msk)

        # 训练目标：从统一坐标的 out_u 里取真实未来坐标 → dx,dy
        if is_training:
            out_grp = target_rows[
                (target_rows['game_id']==gid) &
                (target_rows['play_id']==pid) &
                (target_rows['nfl_id']==nid)
            ].sort_values('frame_id')
            if len(out_grp)==0:
                # 没有可用的未来
                sequences.pop(); spatial_masks.pop()
                continue

            last_x = tgt_series['x'].values[-1]
            last_y = tgt_series['y'].values[-1]
            dx = out_grp['x'].values - last_x
            dy = out_grp['y'].values - last_y

            targets_dx.append(dx.astype(np.float32))
            targets_dy.append(dy.astype(np.float32))
            targets_fids.append(out_grp['frame_id'].values.astype(np.int32))

        seq_meta.append({
            'game_id': gid,
            'play_id': pid,
            'nfl_id': nid,
            'frame_id': int(frames[-1]),
            'play_direction': (None if is_training else play_dir),
        })

    print(f"Created {len(sequences)} ST sequences with shape (T={window_size}, N={1+K_neighbors}, D={len(feature_cols)})")

    if is_training:
        return sequences, spatial_masks, targets_dx, targets_dy, targets_fids, seq_meta, feature_cols, dir_map
    return sequences, spatial_masks, seq_meta, feature_cols, dir_map


In [None]:
# -------------------------------
# Model & training (same spirit as your version)
# -------------------------------
class TemporalHuber(nn.Module):
    def __init__(self, delta=0.5, time_decay=0.03):
        super().__init__()
        self.delta = delta
        self.time_decay = time_decay

    def forward(self, pred, target, mask,
                last_pos=None, lower=None, upper=None,
                step_cap=None, w_boundary=0.0, w_speed=0.0, w_jerk=0.0):
        # --- 主 Huber ---
        err = pred - target
        abs_err = torch.abs(err)
        huber = torch.where(abs_err <= self.delta,
                            0.5 * err * err,
                            self.delta * (abs_err - 0.5 * self.delta))
        if self.time_decay > 0:
            L = pred.size(1)
            t = torch.arange(L, device=pred.device, dtype=pred.dtype)
            w = torch.exp(-self.time_decay * t).view(1, L)
            huber = huber * w
            mask  = mask  * w
        main_loss = (huber * mask).sum() / (mask.sum() + 1e-8)

        # --- 速度上限（软惩罚） ---
        speed_pen = 0.0
        if w_speed > 0.0:
            step = torch.cat([pred[:, :1], pred[:, 1:] - pred[:, :-1]], dim=1)
            step_mask = torch.zeros_like(mask)
            step_mask[:, 0] = mask[:, 0]
            step_mask[:, 1:] = mask[:, 1:] * mask[:, :-1]
            if step_cap is not None:
                overflow = F.relu(torch.abs(step) - float(step_cap))
                speed_pen = (overflow.pow(2) * step_mask).sum() / (step_mask.sum() + 1e-8)
                speed_pen = w_speed * speed_pen

        # --- jerk 平滑（二阶差分） ---
        jerk_pen = 0.0
        if w_jerk > 0.0 and pred.size(1) > 2:
            step = torch.cat([pred[:, :1], pred[:, 1:] - pred[:, :-1]], dim=1)   # (B,L)
            jerk1 = step[:, 1:] - step[:, :-1]                                   # (B,L-1)
            jerk  = jerk1[:, 1:]                                                 # (B,L-2)
            jerk_mask = mask[:, 2:] * mask[:, 1:-1] * mask[:, :-2]               # (B,L-2)
            jerk_pen = (jerk.pow(2) * jerk_mask).sum() / (jerk_mask.sum() + 1e-8)
            jerk_pen = w_jerk * jerk_pen

        # --- 边界惩罚（对绝对位置） ---
        boundary_pen = 0.0
        if w_boundary > 0.0 and (last_pos is not None) and (lower is not None) and (upper is not None):
            abs_pos = last_pos.unsqueeze(1) + pred
            lower_vi = F.relu(float(lower) - abs_pos)
            upper_vi = F.relu(abs_pos - float(upper))
            boundary_pen = ((lower_vi.pow(2) + upper_vi.pow(2)) * mask).sum() / (mask.sum() + 1e-8)
            boundary_pen = w_boundary * boundary_pen

        return main_loss + speed_pen + jerk_pen + boundary_pen



class RelPosSpatialSelfAttn(nn.Module):
    """
    x: (B, N, H) 一帧内 N 个 token（目标+近邻）
    edge_attr: (B, N, N, E) 每对 (i,j) 的边特征
    key_padding_mask: (B, N)  1=valid, 0=pad
    """
    def __init__(self, H, heads=4, edge_dim=8, dropout=0.1):
        super().__init__()
        self.H, self.heads = H, heads
        self.dk = H // heads
        self.q = nn.Linear(H, H, bias=False)
        self.k = nn.Linear(H, H, bias=False)
        self.v = nn.Linear(H, H, bias=False)
        # 边特征 -> attention bias（标量），对每个 head 共享一个标量就足够
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_dim, H // 4),
            nn.GELU(),
            nn.Linear(H // 4, 1)   # 标量 bias
        )
        self.out = nn.Linear(H, H)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_attr, key_padding_mask=None):
        B, N, H = x.shape
        q = self.q(x).reshape(B, N, self.heads, self.dk).transpose(1,2)  # (B,heads,N,dk)
        k = self.k(x).reshape(B, N, self.heads, self.dk).transpose(1,2)  # (B,heads,N,dk)
        v = self.v(x).reshape(B, N, self.heads, self.dk).transpose(1,2)  # (B,heads,N,dk)

        attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.dk ** 0.5)  # (B,heads,N,N)

        # 计算边偏置：edge_attr -> (B,N,N,1) -> broadcast 到 heads
        bias = self.edge_mlp(edge_attr).squeeze(-1)  # (B,N,N)
        attn_logits = attn_logits + bias.unsqueeze(1)  # (B,heads,N,N)

        if key_padding_mask is not None:
            # 将 pad 的列（被作为 key）置为 -inf
            # key_padding_mask: 1=valid, 0=pad
            mask = (key_padding_mask == 0).unsqueeze(1).unsqueeze(2)  # (B,1,1,N)
            attn_logits = attn_logits.masked_fill(mask, float('-inf'))

        attn = F.softmax(attn_logits, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)                      # (B,heads,N,dk)
        out = out.transpose(1,2).reshape(B, N, H)        # (B,N,H)
        return self.out(out)                             # (B,N,H)

class SeqModel(nn.Module):
    def __init__(self, input_dim, horizon):
        super().__init__()
        # 投影到可被num_heads整除的维度
        self.hidden_dim = 128
        self.input_proj = nn.Linear(input_dim, self.hidden_dim)
        
        # 时序卷积提取局部特征
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm1d(self.hidden_dim)
        )
        
        # 位置编码
        self.pos_encoding = nn.Parameter(torch.randn(1, 10, self.hidden_dim) * 0.02)
        
        # Transformer Encoder（更深层）
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_dim,
            nhead=4,
            dim_feedforward=512,  # 更大的FFN
            dropout=0.1,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=3)  # 3层
        
        # 双重池化：全局平均 + 注意力
        self.pool_ln = nn.LayerNorm(self.hidden_dim)
        self.pool_attn = nn.MultiheadAttention(
            self.hidden_dim, 
            num_heads=4, 
            batch_first=True,
            dropout=0.1
        )
        self.pool_query = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
        
        # 融合层
        self.fusion = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        
        # 输出头（更深）
        self.head = nn.Sequential(
            nn.Linear(self.hidden_dim, 256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, horizon)
        )
    
    def forward(self, x):
        # x: (B, seq_len, input_dim)
        B, seq_len, _ = x.shape
        
        # 投影输入
        x = self.input_proj(x)  # (B, seq_len, hidden_dim)
        
        # 时序卷积（提取局部模式）
        x_conv = self.temporal_conv(x.transpose(1, 2)).transpose(1, 2)  # (B, seq_len, hidden_dim)
        
        # 残差连接 + 位置编码
        x = x + x_conv + self.pos_encoding[:, :seq_len, :]
        
        # Transformer编码
        h = self.transformer(x)  # (B, seq_len, hidden_dim)
        
        # 双重池化
        # 1. 全局平均池化
        global_pool = h.mean(dim=1)  # (B, hidden_dim)
        
        # 2. 注意力池化
        q = self.pool_query.expand(B, -1, -1)  # (B, 1, hidden_dim)
        h_norm = self.pool_ln(h)  # (B, seq_len, hidden_dim)
        attn_pool, _ = self.pool_attn(q, h_norm, h_norm)  # (B, 1, hidden_dim)
        attn_pool = attn_pool.squeeze(1)  # (B, hidden_dim)
        
        # 融合两种池化结果
        ctx = self.fusion(torch.cat([global_pool, attn_pool], dim=1))  # (B, hidden_dim)
        
        # 预测
        out = self.head(ctx)  # (B, horizon)
        
        # 累积和
        return torch.cumsum(out, dim=1)  # (B, horizon)

class STModel(nn.Module):
    def __init__(self, input_dim, horizon, hidden_dim=128, nheads=4, dropout=0.1, T_max=200,idx_offense=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.in_proj = nn.Linear(input_dim, hidden_dim)

        # === 用 RelPosSpatialSelfAttn 替换原“空间 TransformerEncoder” ===
        # 我们仍然“逐帧”做空间注意力（跟你原来一样），只是加了边偏置
        self.spatial_attn = RelPosSpatialSelfAttn(hidden_dim, heads=nheads, edge_dim=6, dropout=dropout)

        # 时间位置编码 + 时间 Transformer 保持不变
        self.time_pos = nn.Parameter(torch.randn(1, T_max, hidden_dim) * 0.02)
        enc = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nheads,
                                         dim_feedforward=hidden_dim*4, dropout=dropout,
                                         activation='gelu', batch_first=True, norm_first=True)
        self.temporal_encoder = nn.TransformerEncoder(enc, num_layers=2)

        self.pool_ln = nn.LayerNorm(hidden_dim)
        self.pool_attn = nn.MultiheadAttention(hidden_dim, num_heads=nheads, batch_first=True, dropout=dropout)
        self.pool_query = nn.Parameter(torch.randn(1, 1, hidden_dim))

        self.head = nn.Sequential(
            nn.Linear(hidden_dim, 256), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(256, 128), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(128, horizon)
        )
        self.idx_offense = idx_offense

    def _edge_features(self, tokens):
        """
        tokens: (B, N, D) —— 单帧的 N 个 token（已经 in_proj 之后）
        需要同时访问原始几何构造边特征。简化做法：
          - 你当前的 X 序列里 feature_cols 包含 ['x','y','velocity_x','velocity_y','is_offense','is_defense',...]
          - 下面演示只用 (Δx, Δy, d, bearing_sin, bearing_cos, same_side, role_pair)
          - 为了最小化改动：我们从 tokens 旁路拿“原始几何”，在调用处把 raw_xy 传进来更干净。
        """
        raise NotImplementedError  # 我们在 forward 里直接构造，见下
        # 返回 (B, N, N, E)

    def forward(self, x, spatial_mask=None):
        """
        x: (B, T, N, D)  你现有的数据形状
        spatial_mask: (B, T, N)  1=valid, 0=pad
        """
        B, T, N, D = x.shape
        h = self.in_proj(x)  # (B,T,N,H)

        # === 逐帧做带边偏置的空间注意力 ===
        out_frames = []
        for t in range(T):
            h_t = h[:, t]  # (B,N,H)
            m_t = spatial_mask[:, t] if spatial_mask is not None else None  # (B,N)

            # —— 构造边特征（E=8）——
            # 从原始输入 x 里取几何（注意：这里假设 feature_cols 前两个是 x,y；如果不是，请用索引找到）
            xy = x[:, t, :, -2:]                # (B,N,2) -> x,y
            dx = xy[:, :, None, 0] - xy[:, None, :, 0]    # (B,N,N)
            dy = xy[:, :, None, 1] - xy[:, None, :, 1]    # (B,N,N)
            dist = torch.sqrt(dx*dx + dy*dy + 1e-6)
            ux = dx / (dist + 1e-6); uy = dy / (dist + 1e-6)
            # bearing 用 sin/cos 更稳定
            edge_bearing = torch.stack([ux, uy], dim=-1)  # (B,N,N,2)

            # 角色侧别（建议：把 is_offense 放在 feature_cols，下面通过索引拿）
            # 这里假设第 idx_off 是 is_offense
            # 如果不好拿索引，也可以在构建序列时额外输出一份 side 标记
            # 为演示，尝试从 x 里抓：is_offense 在第 k 个维度
            # 简洁起见先容错：取不到就全 1（都当进攻）
            if (self.idx_offense is not None) and (0 <= self.idx_offense < x.shape[-1]):
                is_off = x[:, t, :, self.idx_offense]  # (B,N)
                same_side = (is_off[:, :, None] == is_off[:, None, :]).float().unsqueeze(-1)  # (B,N,N,1)
            else:
                same_side = torch.ones(B, N, N, 1, device=x.device)

            # 距离裁剪/标准化（更稳）
            d_feat = torch.clamp(dist, 0.0, 30.0) / 30.0
            edge_attr = torch.cat([dx.unsqueeze(-1)/30.0, dy.unsqueeze(-1)/30.0,
                                   d_feat.unsqueeze(-1), edge_bearing, same_side], dim=-1)  # (B,N,N,1+1+1+2+1=6~8)

            # 做注意力
            h_sp = self.spatial_attn(h_t, edge_attr, key_padding_mask=m_t)  # (B,N,H)
            out_frames.append(h_sp)

        z = torch.stack(out_frames, dim=1)  # (B,T,N,H)
        tgt_seq = z[:, :, 0, :]             # 仍然只取目标 token（index=0）的时间序列 (B,T,H)

        # === 时间编码保持不变 ===
        pos = self.time_pos[:, :T, :]
        zt = tgt_seq + pos
        zt = self.temporal_encoder(zt)

        # 池化（平均 + 注意力）
        global_pool = zt.mean(dim=1)
        q = self.pool_query.expand(B, -1, -1)
        z_norm = self.pool_ln(zt)
        attn_pool, _ = self.pool_attn(q, z_norm, z_norm)
        attn_pool = attn_pool.squeeze(1)

        ctx = 0.5 * (global_pool + attn_pool)
        out = self.head(ctx)
        return torch.cumsum(out, dim=1)



def prepare_targets(batch_axis, max_h):
    tensors, masks = [], []
    for arr in batch_axis:
        L = len(arr)
        padded = np.pad(arr, (0, max_h - L), constant_values=0).astype(np.float32)
        mask = np.zeros(max_h, dtype=np.float32)
        mask[:L] = 1.0
        tensors.append(torch.tensor(padded))
        masks.append(torch.tensor(mask))
    return torch.stack(tensors), torch.stack(masks)

def train_model(X_train, y_train, X_val, y_val, input_dim, horizon, config):
    device = config.DEVICE
    model = SeqModel(input_dim, horizon).to(device)
    criterion = TemporalHuber(delta=0.5, time_decay=0.03)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, verbose=False)

    # build batches (keep numpy → torch)
    def build_batches(X, Y):
        batches = []
        B = config.BATCH_SIZE
        for i in range(0, len(X), B):
            end = min(i + B, len(X))
            xs = torch.tensor(np.stack(X[i:end]).astype(np.float32))
            ys, ms = prepare_targets([Y[j] for j in range(i, end)], horizon)
            batches.append((xs, ys, ms))
        return batches

    tr_batches = build_batches(X_train, y_train)
    va_batches = build_batches(X_val,   y_val)

    best_loss, best_state, bad = float('inf'), None, 0
    for epoch in range(1, config.EPOCHS + 1):
        model.train()
        train_losses = []
        for bx, by, bm in tr_batches:
            bx, by, bm = bx.to(device), by.to(device), bm.to(device)
            pred = model(bx)
            loss = criterion(pred, by, bm)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.item())

        model.eval()
        val_losses = []
        with torch.no_grad():
            for bx, by, bm in va_batches:
                bx, by, bm = bx.to(device), by.to(device), bm.to(device)
                pred = model(bx)
                val_losses.append(criterion(pred, by, bm).item())

        trl, val = float(np.mean(train_losses)), float(np.mean(val_losses))
        scheduler.step(val)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: train={trl:.4f}, val={val:.4f}")

        if val < best_loss:
            best_loss, bad = val, 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            bad += 1
            if bad >= config.PATIENCE:
                print(f"  Early stop at epoch {epoch}")
                break

    if best_state:
        model.load_state_dict(best_state)
    return model, best_loss


def train_model_st(X_train, M_train, y_train, lastpos_train,
                   X_val,   M_val,   y_val,   lastpos_val,
                   input_dim, horizon, config, axis='x', idx_offense=None):
    device = config.DEVICE
    model = STModel(input_dim, horizon, hidden_dim=config.HIDDEN_DIM, idx_offense=idx_offense).to(device)
    criterion = TemporalHuber(delta=0.5, time_decay=0.03)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, verbose=False)

    # 场地边界 & 速度上限（每帧）
    if axis == 'x':
        lower, upper = 0.0, FIELD_LENGTH
    else:
        lower, upper = 0.0, FIELD_WIDTH
    step_cap = config.MAX_SPEED_YPS / FPS  # ≈ 1.2 码/帧 (FPS=10)

    def build_batches(X, M, Y, L):
        batches = []
        B = config.BATCH_SIZE
        for i in range(0, len(X), B):
            end = min(i + B, len(X))
            xs = torch.tensor(np.stack(X[i:end]).astype(np.float32))      # (b,T,N,D)
            ms = torch.tensor(np.stack(M[i:end]).astype(np.float32))      # (b,T,N)
            ys, mm = prepare_targets([Y[j] for j in range(i, end)], horizon)  # (b,L),(b,L)
            lp = torch.tensor(L[i:end], dtype=torch.float32)              # (b,)
            batches.append((xs, ms, ys, mm, lp))
        return batches

    tr_batches = build_batches(X_train, M_train, y_train, lastpos_train)
    va_batches = build_batches(X_val,   M_val,   y_val,   lastpos_val)

    best_loss, best_state, bad = float('inf'), None, 0
    for epoch in range(1, config.EPOCHS + 1):
        model.train(); train_losses = []
        for bx, bm_sp, by, bm, blp in tr_batches:
            bx, bm_sp, by, bm, blp = bx.to(device), bm_sp.to(device), by.to(device), bm.to(device), blp.to(device)
            pred = model(bx, spatial_mask=bm_sp)
            loss = criterion(
                pred, by, bm,
                last_pos=blp, lower=lower, upper=upper, step_cap=step_cap,
                w_boundary=config.W_BOUNDARY, w_speed=config.W_SPEED, w_jerk=config.W_JERK
            )
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.item())

        model.eval(); val_losses = []
        with torch.no_grad():
            for bx, bm_sp, by, bm, blp in va_batches:
                bx, bm_sp, by, bm, blp = bx.to(device), bm_sp.to(device), by.to(device), bm.to(device), blp.to(device)
                pred = model(bx, spatial_mask=bm_sp)
                val_loss = criterion(
                    pred, by, bm,
                    last_pos=blp, lower=lower, upper=upper, step_cap=step_cap,
                    w_boundary=config.W_BOUNDARY, w_speed=config.W_SPEED, w_jerk=config.W_JERK
                )
                val_losses.append(val_loss.item())

        trl, val = float(np.mean(train_losses)), float(np.mean(val_losses))
        scheduler.step(val)
        if epoch % 10 == 0:
            print(f"  Epoch {epoch}: train={trl:.4f}, val={val:.4f}")

        if val < best_loss:
            best_loss, bad = val, 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            bad += 1
            if bad >= config.PATIENCE:
                print(f"  Early stop at epoch {epoch}")
                break

    if best_state:
        model.load_state_dict(best_state)
    return model, best_loss


In [None]:
# ------------------------------_
# Main pipeline (MODIFICADO PARA ENSEMBLE DE SEMILLAS)
# ------------------------------_
class CFG(Config):
    # Añadimos la lista de semillas para el ensemble
    SEEDS = [42, 19, 89,64] # ¡Puedes cambiar o añadir más semillas aquí!

def train_mode(cfg):
    print("="*80)
    print(f"TRAIN MODE: seeds={getattr(cfg,'SEEDS',[cfg.SEED])}, folds={cfg.N_FOLDS}")
    print("="*80)

    # [1] 读训练数据（仅 train 模式才需要全量训练集）
    train_input_files  = [cfg.DATA_DIR / f"train/input_2023_w{w:02d}.csv"  for w in range(1, 19)]
    train_output_files = [cfg.DATA_DIR / f"train/output_2023_w{w:02d}.csv" for w in range(1, 19)]
    train_input  = pd.concat([pd.read_csv(f) for f in train_input_files  if f.exists()], ignore_index=True)
    train_output = pd.concat([pd.read_csv(f) for f in train_output_files if f.exists()], ignore_index=True)

    # [2] FE + 序列构建
    seqs, smasks, tdx, tdy, tfids, seq_meta, feature_cols, dir_map = prepare_sequences_spatiotemporal(
        train_input, output_df=train_output, is_training=True,
        window_size=cfg.WINDOW_SIZE, K_neighbors=cfg.N_NEIGHBORS,
        neighbor_radius=cfg.NEIGHBOR_RADIUS, adaptive_neighbors=cfg.ADAPTIVE_NEIGHBORS
    )
    sequences = list(seqs); targets_dx = list(tdx); targets_dy = list(tdy)

    # === last_pos（统一坐标系）用于物理先验 ===
    idx_x = feature_cols.index('x')
    idx_y = feature_cols.index('y')
    last_x_all = np.array([seqs[i][-1, 0, idx_x] for i in range(len(seqs))], dtype=np.float32)
    last_y_all = np.array([seqs[i][-1, 0, idx_y] for i in range(len(seqs))], dtype=np.float32)


    # 记录每个样本的 fold 分配（可复现实验）
    groups = np.array([d['game_id'] for d in seq_meta])
    fold_assign_rows = []

    fold_rmse_list = []
    for seed in getattr(cfg, "SEEDS", [cfg.SEED]):
        set_seed(seed)
        gkf = GroupKFold(n_splits=cfg.N_FOLDS)
        for fold, (tr, va) in enumerate(gkf.split(sequences, groups=groups), 1):
            print(f"\n=== Seed {seed} | Fold {fold}/{cfg.N_FOLDS} ===")
            # 记录样本→fold（只记录一次即可）
            if seed == getattr(cfg, "SEEDS", [cfg.SEED])[0]:
                for idx in va:
                    fold_assign_rows.append({
                        "game_id": seq_meta[idx]['game_id'],
                        "play_id": seq_meta[idx]['play_id'],
                        "nfl_id" : seq_meta[idx]['nfl_id'],
                        "fold": fold
                    })

            # X_tr = [sequences[i] for i in tr]; X_va = [sequences[i] for i in va]
            # X 是 List[(T,N,D)]，把 (T*N,D) 拼起来做 scaler
            def _stack_for_fit(X):
                mats = []
                for s in X:
                    TN, D = s.shape[0]*s.shape[1], s.shape[2]
                    mats.append(s.reshape(TN, D))
                return np.vstack(mats)
            
            X_tr, X_va = [seqs[i] for i in tr], [seqs[i] for i in va]
            M_tr, M_va = [smasks[i] for i in tr], [smasks[i] for i in va]
            
            scaler = StandardScaler().fit(_stack_for_fit(X_tr))
            def _transform(X):
                out = []
                for s in X:
                    T,N,D = s.shape
                    ss = scaler.transform(s.reshape(T*N, D)).reshape(T,N,D)
                    out.append(ss.astype(np.float32))
                return np.stack(out)
            
            X_tr_sc = _transform(X_tr)
            X_va_sc = _transform(X_va)
            # --- 追加未标准化的 is_offense / x / y 通道 ---
            idx_x = feature_cols.index('x')
            idx_y = feature_cols.index('y')
            idx_off = feature_cols.index('is_offense') if 'is_offense' in feature_cols else None
            
            def augment_with_raw(X_sc, X_raw, idx_x, idx_y, idx_off):
                out = []
                for s_sc, s_raw in zip(list(X_sc), X_raw):  # X_raw 是 list，X_sc 是 np.ndarray
                    raw_xy  = s_raw[..., [idx_x, idx_y]].astype(np.float32)          # (T,N,2)
                    parts = [s_sc]
                    if idx_off is not None:
                        raw_off = s_raw[..., [idx_off]].astype(np.float32)           # (T,N,1)
                        parts.append(raw_off)                                        # 先拼 raw_off
                        raw_off_idx = s_sc.shape[-1]                                 # 追加后的索引用它
                    else:
                        raw_off_idx = None
                    parts.append(raw_xy)                                             # 最后拼 raw_x, raw_y
                    out.append(np.concatenate(parts, axis=-1))
                X_aug = np.stack(out)
                return X_aug, raw_off_idx
            
            X_tr_aug, idx_off_raw = augment_with_raw(X_tr_sc, X_tr, idx_x, idx_y, idx_off)
            X_va_aug, _           = augment_with_raw(X_va_sc, X_va, idx_x, idx_y, idx_off)

            
            # idx_off = feature_cols.index('is_offense') if 'is_offense' in feature_cols else None
            
            mx, loss_x = train_model_st(
                X_tr_aug, M_tr, [tdx[i] for i in tr], last_x_all[tr],
                X_va_aug, M_va, [tdx[i] for i in va], last_x_all[va],
                X_tr_aug.shape[-1], cfg.MAX_FUTURE_HORIZON, cfg, axis='x',
                idx_offense=idx_off_raw  # <-- 用原始 is_offense 的通道索引
            )
            my, loss_y = train_model_st(
                X_tr_aug, M_tr, [tdy[i] for i in tr], last_y_all[tr],
                X_va_aug, M_va, [tdy[i] for i in va], last_y_all[va],
                X_tr_aug.shape[-1], cfg.MAX_FUTURE_HORIZON, cfg, axis='y',
                idx_offense=idx_off_raw
            )


            # 计算验证 RMSE（与你原来一样）
            mx.eval(); my.eval()
            with torch.no_grad():
                # X_va_t = torch.tensor(X_va_sc).to(cfg.DEVICE)
                X_va_t = torch.tensor(X_va_aug).to(cfg.DEVICE)
                M_va_t = torch.tensor(np.stack(M_va).astype(np.float32)).to(cfg.DEVICE)
                with torch.no_grad():
                    pred_dx = mx(X_va_t, spatial_mask=M_va_t).cpu().numpy()
                    pred_dy = my(X_va_t, spatial_mask=M_va_t).cpu().numpy()

            y_va_dx = [targets_dx[i] for i in va]
            y_va_dy = [targets_dy[i] for i in va]
            # sqe = []
            # for i in range(len(pred_dx)):
            #     tdx_full, m_dx = prepare_targets([y_va_dx[i]], cfg.MAX_FUTURE_HORIZON)
            #     tdy_full, m_dy = prepare_targets([y_va_dy[i]], cfg.MAX_FUTURE_HORIZON)
            #     m = m_dx[0].cpu().numpy().astype(bool)
            #     dx_err = (pred_dx[i][m] - tdx_full[0].cpu().numpy()[m])**2
            #     dy_err = (pred_dy[i][m] - tdy_full[0].cpu().numpy()[m])**2
            #     sqe.extend(dx_err + dy_err)
            rmse = compute_val_rmse(pred_dx, pred_dy, y_va_dx, y_va_dy, cfg.MAX_FUTURE_HORIZON)
            fold_rmse_list.append(rmse)
            print(f"[Fold {fold} | Seed {seed}] val_huber: dx={loss_x:.4f}, dy={loss_y:.4f} | RMSE={rmse:.5f}")



            # [保存本 fold/seed 的模型与 scaler]
            save_fold_artifacts(cfg, seed, fold, mx, my, scaler)
            
    print("\n" + "="*80)
    print("VALID RMSE by fold (all seeds × folds)")
    print("="*80)
    for i, r in enumerate(fold_rmse_list, 1):
        print(f"{i:02d}: RMSE = {r:.5f}")
    rmse_mean = float(np.mean(fold_rmse_list)) if len(fold_rmse_list) else float("nan")
    rmse_std  = float(np.std(fold_rmse_list))  if len(fold_rmse_list) else float("nan")
    print("-"*80)
    print(f"Mean RMSE: {rmse_mean:.5f} | Std: {rmse_std:.5f}")
    print("="*80)
    # 保存元信息
    fold_assign_df = pd.DataFrame(fold_assign_rows) if fold_assign_rows else None
    save_meta(cfg, feature_cols, dir_map, fold_rmse_list, fold_assign=fold_assign_df)

    print("\nBundle 已写出到:", cfg.MODEL_BUNDLE_DIR_TRAIN.resolve())

def sub_mode(cfg):
    print("="*80)
    print("SUB MODE: 仅推理（从 bundle 读取模型与scaler）")
    print("="*80)

    # [1] 读测试数据
    test_input    = pd.read_csv(cfg.DATA_DIR / "test_input.csv")
    test_template = pd.read_csv(cfg.DATA_DIR / "test.csv")

    # [2] 读取 bundle 元信息 + 模型/Scaler 列表
    root, meta, feat_cols, model_x_paths, model_y_paths, scaler_paths = discover_bundle_for_sub(cfg)

    # [3] 用“时空”构建器拿 (T,N,D) 和 mask
    test_seqs, test_smasks, test_meta, feat_cols_t, dir_map_test = prepare_sequences_spatiotemporal(
        test_input, test_template=test_template, is_training=False,
        window_size=meta["WINDOW_SIZE"], K_neighbors=cfg.N_NEIGHBORS,
        neighbor_radius=cfg.NEIGHBOR_RADIUS, adaptive_neighbors=cfg.ADAPTIVE_NEIGHBORS
    )
    assert feat_cols_t == feat_cols, "Train/Test 特征列不一致，请检查 bundle 的 feature_cols"

    # [3.1] 取 idx_x/idx_y，并在 unified frame 下取“目标 token(0)”的最后一帧位置
    idx_x = feat_cols.index('x'); idx_y = feat_cols.index('y')
    X_test_raw = list(test_seqs)
    x_last_uni = np.array([s[-1, 0, idx_x] for s in X_test_raw], dtype=np.float32)  # 注意 N 维度取 0
    y_last_uni = np.array([s[-1, 0, idx_y] for s in X_test_raw], dtype=np.float32)

    # [4] 标准化函数（对 T*N 维一起做）
    def _transform_list(X_list, scaler):
        out = []
        for s in X_list:
            T,N,D = s.shape
            out.append(scaler.transform(s.reshape(T*N, D)).reshape(T,N,D).astype(np.float32))
        return np.stack(out)

    # [4.1] 逐个 (模型, scaler) 推理并集成
    all_preds_dx, all_preds_dy = [], []
    for mx_path, my_path, sc_path in zip(model_x_paths, model_y_paths, scaler_paths):
        scaler = _load_pickle(sc_path)
        X_sc = _transform_list(X_test_raw, scaler)
        # 追加未标准化通道（保持最后两维是 raw_x, raw_y；并返回 raw_is_offense 的索引）
        idx_x = feat_cols.index('x'); idx_y = feat_cols.index('y')
        idx_off = feat_cols.index('is_offense') if 'is_offense' in feat_cols else None
        
        def augment_with_raw_list(X_sc, X_raw_list, idx_x, idx_y, idx_off):
            out = []
            for s_sc, s_raw in zip(list(X_sc), X_raw_list):
                raw_xy = s_raw[..., [idx_x, idx_y]].astype(np.float32)
                parts = [s_sc]
                if idx_off is not None:
                    raw_off = s_raw[..., [idx_off]].astype(np.float32)
                    parts.append(raw_off)
                    raw_off_idx = s_sc.shape[-1]
                else:
                    raw_off_idx = None
                parts.append(raw_xy)
                out.append(np.concatenate(parts, axis=-1))
            X_aug = np.stack(out)
            return X_aug, raw_off_idx
        
        X_aug, idx_off_raw = augment_with_raw_list(X_sc, X_test_raw, idx_x, idx_y, idx_off)
        
        X_t  = torch.tensor(X_aug).to(cfg.DEVICE)
        M_t  = torch.tensor(np.stack(test_smasks).astype(np.float32)).to(cfg.DEVICE)
        
        input_dim = X_aug.shape[-1]; horizon = meta["MAX_FUTURE_HORIZON"]
        m_dx = STModel(input_dim, horizon, hidden_dim=cfg.HIDDEN_DIM, idx_offense=idx_off_raw).to(cfg.DEVICE)
        m_dy = STModel(input_dim, horizon, hidden_dim=cfg.HIDDEN_DIM, idx_offense=idx_off_raw).to(cfg.DEVICE)
        m_dx.load_state_dict(torch.load(mx_path, map_location=cfg.DEVICE))
        m_dy.load_state_dict(torch.load(my_path, map_location=cfg.DEVICE))
        m_dx.eval(); m_dy.eval()
        with torch.no_grad():
            all_preds_dx.append(m_dx(X_t, spatial_mask=M_t).cpu().numpy())
            all_preds_dy.append(m_dy(X_t, spatial_mask=M_t).cpu().numpy())

    ens_dx = np.mean(all_preds_dx, axis=0)
    ens_dy = np.mean(all_preds_dy, axis=0)
    H = ens_dx.shape[1]

    # [5] 组装 submission（把 unified 坐标反变换回原方向）
    rows = []
    tt_idx = test_template.set_index(['game_id','play_id','nfl_id']).sort_index()
    for i, meta_row in enumerate(test_meta):
        gid = meta_row['game_id']; pid = meta_row['play_id']; nid = meta_row['nfl_id']
        play_dir = meta_row['play_direction']; play_is_right = (play_dir == 'right')
        try:
            fids = tt_idx.loc[(gid,pid,nid),'frame_id']
            fids = fids.sort_values().tolist() if isinstance(fids, pd.Series) else [int(fids)]
        except KeyError:
            continue
        for t, fid in enumerate(fids):
            tt = min(t, H-1)
            x_uni = np.clip(x_last_uni[i] + ens_dx[i, tt], 0, FIELD_LENGTH)
            y_uni = np.clip(y_last_uni[i] + ens_dy[i, tt], 0, FIELD_WIDTH)
            x_out, y_out = invert_to_original_direction(x_uni, y_uni, play_is_right)
            rows.append({'id': f"{gid}_{pid}_{nid}_{int(fid)}", 'x': x_out, 'y': y_out})

    sub = pd.DataFrame(rows)
    sub.to_csv("submission.csv", index=False)
    print(f"submission.csv 生成完毕，行数={len(sub)}")




def main():
    cfg = CFG()
    if TRAIN == 1 and SUB == 0:
        cfg.MODEL_BUNDLE_DIR_TRAIN = cfg.OUTPUT_DIR / "bundle"
        train_mode(cfg)
    elif TRAIN == 0 and SUB == 1:
        # 在 Kaggle “Add Input” 里把你上传的 bundle 数据集挂上来，
        # 然后设置环境变量 MODEL_BUNDLE_DIR 指向那个目录（或这里手动写死）
        if not cfg.MODEL_BUNDLE_DIR_SUB:
            # 回退：尝试自动探测唯一的 /kaggle/input/*-bundle*
            candidates = [Path(p) for p in glob("/kaggle/input/*")]
            # 选择带 bundle 关键字的，或只有一个输入目录时就用它
            bundles = [c for c in candidates if "bundle" in c.name.lower()]
            cfg.MODEL_BUNDLE_DIR_SUB = bundles[0] if bundles else candidates[0]
            print("自动选择 MODEL_BUNDLE_DIR_SUB =", cfg.MODEL_BUNDLE_DIR_SUB)
        sub_mode(cfg)
    else:
        raise ValueError("Set exactly one of TRAIN=1 or SUB=1")

if __name__ == "__main__":
    main()

