In [None]:
# -*- coding: utf-8 -*-
from __future__ import annotations

import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm

# Optional GPU tabular libs
try:
    import xgboost as xgb  # type: ignore
except Exception:  # pragma: no cover
    xgb = None
try:
    from catboost import CatBoostRegressor  # type: ignore
except Exception:  # pragma: no cover
    CatBoostRegressor = None

import joblib

# ---------------------------------------------------------------------------
# Logging & constants
# ---------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="[%(levelname)s] %(message)s",
)
log = logging.getLogger("BDB2026")

YARDS_TO_METERS = 0.9144
FPS = 10.0
FIELD_LENGTH, FIELD_WIDTH = 120.0, 53.3

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass
class Config:
    # paths
    DATA_DIR: Path = Path("/kaggle/input/nfl-big-data-bowl-2026-prediction")  #按照你们各自的数据路径调整下，这个路径是kaggle数据集的root路径
    OUTPUT_DIR: Path = Path("/kaggle/working/")
    MODELS_DIR: Path = Path("/kaggle/input/model-cv5/ref_models")

    # run mode
    TRAIN: bool = False
    SUB: bool = True

    # model choice: GRU_RES | DYN | XGB | CAT
    MODEL_NAME: str = "GRU_RES"

    # feature groups (lean default)
    FEATURE_GROUPS: List[str] = field(
        default_factory=lambda: [
            "distance_rate",
            "target_alignment",
            "multi_window_rolling",
            "extended_lags",
            "velocity_changes",
            "field_position",
            "time_features",
            # NOTE: removed by default: role_specific, jerk_features, curvature_land_features
        ]
    )

    # learning
    SEED: int = 42
    N_FOLDS: int = 5
    BATCH_SIZE: int = 256
    EPOCHS: int = 125
    PATIENCE: int = 30
    LEARNING_RATE: float = 1e-3

    # sequence
    WINDOW_SIZE: int = 10
    HIDDEN_DIM: int = 128
    MAX_FUTURE_HORIZON: int = 94  # do not change

    # torch
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __post_init__(self):
        self.OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
        self.MODELS_DIR.mkdir(exist_ok=True, parents=True)


# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------

def set_seed(seed: int = 42) -> None:
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def wrap_angle_deg(s: pd.Series | np.ndarray) -> pd.Series | np.ndarray:
    return ((s + 180.0) % 360.0) - 180.0


def unify_left_direction(df: pd.DataFrame) -> pd.DataFrame:
    """Mirror rightward plays so all samples are 'left' oriented (x,y, dir, o, ball_land)."""
    if "play_direction" not in df.columns:
        return df
    out = df.copy()
    right = out["play_direction"].eq("right")
    if "x" in out.columns:
        out.loc[right, "x"] = FIELD_LENGTH - out.loc[right, "x"]
    if "y" in out.columns:
        out.loc[right, "y"] = FIELD_WIDTH - out.loc[right, "y"]
    for col in ("dir", "o"):
        if col in out.columns:
            out.loc[right, col] = (out.loc[right, col] + 180.0) % 360.0
    if "ball_land_x" in out.columns:
        out.loc[right, "ball_land_x"] = FIELD_LENGTH - out.loc[right, "ball_land_x"]
    if "ball_land_y" in out.columns:
        out.loc[right, "ball_land_y"] = FIELD_WIDTH - out.loc[right, "ball_land_y"]
    return out


def invert_to_original_direction(x_u: float, y_u: float, play_dir_right: bool) -> Tuple[float, float]:
    if not play_dir_right:
        return float(x_u), float(y_u)
    return float(FIELD_LENGTH - x_u), float(FIELD_WIDTH - y_u)


# ---------------------------------------------------------------------------
# Feature engineering (leaned down)
# ---------------------------------------------------------------------------
class FeatureEngineer:
    """Modular, ablation-friendly feature builder."""

    def __init__(self, feature_groups_to_create: Sequence[str]):
        self.gcols = ["game_id", "play_id", "nfl_id"]
        self.active_groups = list(feature_groups_to_create)
        self.feature_creators = {
            "distance_rate": self._create_distance_rate_features,
            "target_alignment": self._create_target_alignment_features,
            "multi_window_rolling": self._create_multi_window_rolling_features,
            "extended_lags": self._create_extended_lag_features,
            "velocity_changes": self._create_velocity_change_features,
            "field_position": self._create_field_position_features,
            "time_features": self._create_time_features,
        }
        self.created_feature_cols: List[str] = []

    @staticmethod
    def _height_to_feet(height_str) -> float:
        try:
            ft, inches = map(int, str(height_str).split("-"))
            return ft + inches / 12
        except Exception:
            return 6.0

    def _create_basic_features(self, df: pd.DataFrame) -> pd.DataFrame:
        log.info("[FE] Base kinematics & roles…")
        out = df.copy()
        out["player_height_feet"] = out["player_height"].apply(self._height_to_feet)

        # Direction: dir from +x CCW
        dir_rad = np.deg2rad(out["dir"].fillna(0.0).astype("float32"))
        out["velocity_x"] = out["s"] * np.cos(dir_rad)
        out["velocity_y"] = out["s"] * np.sin(dir_rad)
        out["acceleration_x"] = out["a"] * np.cos(dir_rad)
        out["acceleration_y"] = out["a"] * np.sin(dir_rad)

        # Roles (kept as raw flags; no engineered role-specific features by default)
        out["is_offense"] = (out["player_side"] == "Offense").astype(np.int8)
        out["is_defense"] = (out["player_side"] == "Defense").astype(np.int8)
        out["is_receiver"] = (out["player_role"] == "Targeted Receiver").astype(np.int8)
        out["is_coverage"] = (out["player_role"] == "Defensive Coverage").astype(np.int8)
        out["is_passer"] = (out["player_role"] == "Passer").astype(np.int8)

        # Ball landing geometry
        if {"ball_land_x", "ball_land_y"}.issubset(out.columns):
            dx = out["ball_land_x"] - out["x"]
            dy = out["ball_land_y"] - out["y"]
            dist = np.hypot(dx, dy)
            out["distance_to_ball"] = dist
            inv = 1.0 / (dist + 1e-6)
            out["ball_direction_x"] = dx * inv
            out["ball_direction_y"] = dy * inv
            out["closing_speed"] = (
                out["velocity_x"] * out["ball_direction_x"]
                + out["velocity_y"] * out["ball_direction_y"]
            )

        base = [
            "x",
            "y",
            "s",
            "a",
            "o",
            "dir",
            "frame_id",
            "ball_land_x",
            "ball_land_y",
            "player_height_feet",
            "player_weight",
            "velocity_x",
            "velocity_y",
            "acceleration_x",
            "acceleration_y",
            "is_offense",
            "is_defense",
            "is_receiver",
            "is_coverage",
            "is_passer",
            "distance_to_ball",
            "ball_direction_x",
            "ball_direction_y",
            "closing_speed",
        ]
        self.created_feature_cols.extend([c for c in base if c in out.columns])
        return out

    # ---- feature groups ----
    def _create_distance_rate_features(self, df: pd.DataFrame):
        new_cols: List[str] = []
        if "distance_to_ball" in df.columns:
            d = df.groupby(self.gcols)["distance_to_ball"].diff()
            df["d2ball_dt"] = d.fillna(0.0) * FPS
            df["d2ball_ddt"] = df.groupby(self.gcols)["d2ball_dt"].diff().fillna(0.0) * FPS
            df["time_to_intercept"] = (
                (df["distance_to_ball"] / (df["d2ball_dt"].abs() + 1e-3)).clip(0, 10)
            )
            new_cols = ["d2ball_dt", "d2ball_ddt", "time_to_intercept"]
        return df, new_cols

    def _create_target_alignment_features(self, df: pd.DataFrame):
        new_cols: List[str] = []
        need = {"ball_direction_x", "ball_direction_y", "velocity_x", "velocity_y"}
        if need.issubset(df.columns):
            df["velocity_alignment"] = (
                df["velocity_x"] * df["ball_direction_x"]
                + df["velocity_y"] * df["ball_direction_y"]
            )
            df["velocity_perpendicular"] = (
                df["velocity_x"] * (-df["ball_direction_y"]) + df["velocity_y"] * df["ball_direction_x"]
            )
            new_cols.extend(["velocity_alignment", "velocity_perpendicular"])
        return df, new_cols

    def _create_multi_window_rolling_features(self, df: pd.DataFrame):
        new_cols: List[str] = []
        for window in (3, 5, 10):
            for col in ("velocity_x", "velocity_y", "s", "a"):
                if col in df.columns:
                    r_mean = df.groupby(self.gcols)[col].rolling(window, min_periods=1).mean()
                    r_std = df.groupby(self.gcols)[col].rolling(window, min_periods=1).std()
                    r_mean = r_mean.reset_index(level=list(range(len(self.gcols))), drop=True)
                    r_std = r_std.reset_index(level=list(range(len(self.gcols))), drop=True)
                    df[f"{col}_roll{window}"] = r_mean
                    df[f"{col}_std{window}"] = r_std.fillna(0.0)
                    new_cols.extend([f"{col}_roll{window}", f"{col}_std{window}"])
        return df, new_cols

    def _create_extended_lag_features(self, df: pd.DataFrame):
        new_cols: List[str] = []
        for lag in (1, 2, 3, 4, 5):
            for col in ("x", "y", "velocity_x", "velocity_y"):
                if col in df.columns:
                    g = df.groupby(self.gcols)[col]
                    lagv = g.shift(lag)
                    df[f"{col}_lag{lag}"] = lagv.fillna(g.transform("first"))
                    new_cols.append(f"{col}_lag{lag}")
        return df, new_cols

    def _create_velocity_change_features(self, df: pd.DataFrame):
        new_cols: List[str] = []
        if "velocity_x" in df.columns:
            df["velocity_x_change"] = df.groupby(self.gcols)["velocity_x"].diff().fillna(0.0)
            df["velocity_y_change"] = df.groupby(self.gcols)["velocity_y"].diff().fillna(0.0)
            df["speed_change"] = df.groupby(self.gcols)["s"].diff().fillna(0.0)
            d = df.groupby(self.gcols)["dir"].diff().fillna(0.0)
            df["direction_change"] = wrap_angle_deg(d)
            new_cols = [
                "velocity_x_change",
                "velocity_y_change",
                "speed_change",
                "direction_change",
            ]
        return df, new_cols

    def _create_field_position_features(self, df: pd.DataFrame):
        df["dist_from_left"] = df["y"]
        df["dist_from_right"] = FIELD_WIDTH - df["y"]
        df["dist_from_sideline"] = np.minimum(df["dist_from_left"], df["dist_from_right"])
        df["dist_from_endzone"] = np.minimum(df["x"], FIELD_LENGTH - df["x"])
        return df, ["dist_from_sideline", "dist_from_endzone"]

    def _create_time_features(self, df: pd.DataFrame):
        df["frames_elapsed"] = df.groupby(self.gcols).cumcount()
        df["normalized_time"] = df.groupby(self.gcols)["frames_elapsed"].transform(lambda x: x / (x.max() + 1e-9))
        return df, ["frames_elapsed", "normalized_time"]

    def transform(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
        df = df.copy().sort_values(["game_id", "play_id", "nfl_id", "frame_id"])
        df = self._create_basic_features(df)

        log.info("[FE] Adding selected groups…")
        for group_name in self.active_groups:
            creator = self.feature_creators.get(group_name)
            if creator is None:
                log.warning("[FE] Unknown group: %s", group_name)
                continue
            df, new_cols = creator(df)
            self.created_feature_cols.extend(new_cols)
            log.info("  [+] %s (+%d cols)", group_name, len(new_cols))

        final_cols = sorted(set(self.created_feature_cols))
        log.info("[FE] Total features: %d", len(final_cols))
        return df, final_cols


# ---------------------------------------------------------------------------
# Sequence builder
# ---------------------------------------------------------------------------

def build_play_direction_map(df_in: pd.DataFrame) -> pd.Series:
    s = (
        df_in[["game_id", "play_id", "play_direction"]]
        .drop_duplicates()
        .set_index(["game_id", "play_id"])["play_direction"]
    )
    return s


def apply_direction_to_df(df: pd.DataFrame, dir_map: pd.Series) -> pd.DataFrame:
    if "play_direction" not in df.columns:
        dir_df = dir_map.reset_index()
        df = df.merge(dir_df, on=["game_id", "play_id"], how="left", validate="many_to_one")
    return unify_left_direction(df)


def prepare_sequences(
    input_df: pd.DataFrame,
    output_df: pd.DataFrame | None = None,
    test_template: pd.DataFrame | None = None,
    is_training: bool = True,
    window_size: int = 10,
    feature_groups: Sequence[str] | None = None,
) -> Tuple:
    log.info("=" * 80)
    log.info("PREPARING SEQUENCES (unified left frame)")
    log.info("=" * 80)
    log.info("window_size=%d", window_size)

    if feature_groups is None:
        feature_groups = [
            "distance_rate",
            "target_alignment",
            "multi_window_rolling",
            "extended_lags",
            "velocity_changes",
            "field_position",
            "time_features",
        ]

    dir_map = build_play_direction_map(input_df)
    input_df_u = unify_left_direction(input_df)

    if is_training:
        assert output_df is not None
        out_u = apply_direction_to_df(output_df, dir_map)
        target_rows = out_u
        target_groups = out_u[["game_id", "play_id", "nfl_id"]].drop_duplicates()
    else:
        assert test_template is not None
        if "play_direction" not in test_template.columns:
            dir_df = dir_map.reset_index()
            test_template = test_template.merge(dir_df, on=["game_id", "play_id"], how="left", validate="many_to_one")
        target_rows = test_template
        target_groups = target_rows[["game_id", "play_id", "nfl_id", "play_direction"]].drop_duplicates()

    assert (
        target_rows[["game_id", "play_id", "play_direction"]].isna().sum().sum() == 0
    ), "play_direction merge failed"
    log.info("play_direction merge OK: %s", target_rows["play_direction"].value_counts(dropna=False).to_dict())

    fe = FeatureEngineer(feature_groups)
    processed_df, feature_cols = fe.transform(input_df_u)

    # ---- build sliding window per (gid,pid,nid)
    processed_df = processed_df.set_index(["game_id", "play_id", "nfl_id"]).sort_index()
    grouped = processed_df.groupby(level=["game_id", "play_id", "nfl_id"])

    idx_x = feature_cols.index("x")
    idx_y = feature_cols.index("y")

    sequences: List[np.ndarray] = []
    targets_dx: List[np.ndarray] = []
    targets_dy: List[np.ndarray] = []
    targets_fids: List[np.ndarray] = []
    seq_meta: List[Dict] = []

    it = tqdm(list(target_groups.itertuples(index=False)), total=len(target_groups), desc="Create seqs")
    for row in it:
        gid = row[0]
        pid = row[1]
        nid = row[2]
        play_dir = row[3] if (not is_training and len(row) >= 4) else None
        key = (gid, pid, nid)
        try:
            group_df = grouped.get_group(key)
        except KeyError:
            continue

        input_window = group_df.tail(window_size)
        if len(input_window) < window_size:
            if is_training:
                continue
            pad_len = window_size - len(input_window)
            pad_df = pd.DataFrame(np.nan, index=range(pad_len), columns=input_window.columns)
            input_window = pd.concat([pad_df, input_window], ignore_index=True)

        input_window = input_window.fillna(group_df.mean(numeric_only=True))
        seq = input_window[feature_cols].values
        if np.isnan(seq).any():
            if is_training:
                continue
            seq = np.nan_to_num(seq, nan=0.0)

        sequences.append(seq)

        if is_training:
            out_grp = target_rows[
                (target_rows["game_id"] == gid)
                & (target_rows["play_id"] == pid)
                & (target_rows["nfl_id"] == nid)
            ].sort_values("frame_id")
            if len(out_grp) == 0:
                continue
            last_x = seq[-1, idx_x]
            last_y = seq[-1, idx_y]
            dx = out_grp["x"].values - last_x
            dy = out_grp["y"].values - last_y
            targets_dx.append(dx.astype(np.float32))
            targets_dy.append(dy.astype(np.float32))
            targets_fids.append(out_grp["frame_id"].values.astype(np.int32))

        seq_meta.append(
            {
                "game_id": gid,
                "play_id": pid,
                "nfl_id": nid,
                "frame_id": int(input_window.iloc[-1]["frame_id"]),
                "play_direction": (None if is_training else play_dir),
            }
        )

    log.info("Created %d sequences with %d features", len(sequences), len(feature_cols))

    if is_training:
        return sequences, targets_dx, targets_dy, targets_fids, seq_meta, feature_cols, dir_map
    return sequences, seq_meta, feature_cols, dir_map


# ---------------------------------------------------------------------------
# Torch models & training
# ---------------------------------------------------------------------------
class TemporalHuber(nn.Module):
    def __init__(self, delta: float = 0.5, time_decay: float = 0.03):
        super().__init__()
        self.delta = delta
        self.time_decay = time_decay

    def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        err = pred - target
        abs_err = torch.abs(err)
        huber = torch.where(
            abs_err <= self.delta, 0.5 * err * err, self.delta * (abs_err - 0.5 * self.delta)
        )
        if self.time_decay > 0:
            L = pred.size(1)
            t = torch.arange(L, device=pred.device).float()
            w = torch.exp(-self.time_decay * t).view(1, L)
            huber, mask = huber * w, mask * w
        return (huber * mask).sum() / (mask.sum() + 1e-8)


class Residual(nn.Module):
    def __init__(self, mod: nn.Module, dim_in: int, dim_out: int, drop_prob: float = 0.0):
        super().__init__()
        self.mod = mod
        self.proj = nn.Identity() if dim_in == dim_out else nn.Linear(dim_in, dim_out)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.mod(x)
        x_proj = self.proj(x)
        return self.dropout(y) + x_proj


class GRUResidualModel(nn.Module):
    """GRU + learned query attention pooling + residual head (predict cumulative Δ)."""

    def __init__(self, input_dim: int, hidden: int, horizon: int):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden, num_layers=2, batch_first=True, dropout=0.1)
        self.ln = nn.LayerNorm(hidden)
        self.attn = nn.MultiheadAttention(hidden, num_heads=4, batch_first=True)
        self.q = nn.Parameter(torch.randn(1, 1, hidden))
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden, horizon),
        )
        # residual shortcut from simple linear readout of last step features
        self.res_proj = nn.Linear(input_dim, horizon)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h, _ = self.gru(x)  # (B,T,H)
        B = h.size(0)
        q = self.q.expand(B, -1, -1)
        ctx, _ = self.attn(q, self.ln(h), self.ln(h))
        core = self.head(ctx.squeeze(1))  # (B,H)
        # residual: map last-step raw features to horizon and add
        skip = self.res_proj(x[:, -1, :])
        out = core + skip
        return torch.cumsum(out, dim=1)


# --- Flexible dynamic-space NN (inspired by your baseline2) ---
class RNNBlock(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, rnn: str = "gru", num_layers: int = 1, dropout: float = 0.1, bidirectional: bool = False):
        super().__init__()
        rnn_cls = nn.GRU if rnn.lower() == "gru" else nn.LSTM
        self.rnn = rnn_cls(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
        )
        self.out_dim = hidden_dim * (2 if bidirectional else 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y, _ = self.rnn(x)
        return y


class Conv1DBlock(nn.Module):
    def __init__(self, dim: int, kernel_size: int = 3, dilation: int = 1, dropout: float = 0.1):
        super().__init__()
        pad = (kernel_size - 1) * dilation // 2
        self.pre_ln = nn.LayerNorm(dim)
        self.dw = nn.Conv1d(dim, dim, kernel_size, padding=pad, dilation=dilation, groups=dim)
        self.pw = nn.Conv1d(dim, dim, 1)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.pre_ln(x)
        y = y.transpose(1, 2)
        y = self.dw(y)
        y = self.act(y)
        y = self.pw(y)
        y = self.drop(y)
        return y.transpose(1, 2)


class TransformerBlock(nn.Module):
    def __init__(self, dim: int, nhead: int = 4, ff_mult: int = 4, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads=nhead, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_mult * dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_mult * dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.ln1(x)
        y, _ = self.attn(h, h, h)
        x = x + y
        h = self.ln2(x)
        h = x + self.ff(h)
        return h


class SEBlock(nn.Module):
    def __init__(self, dim: int, r: int = 4):
        super().__init__()
        hidden = max(1, dim // r)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden), nn.ReLU(), nn.Linear(hidden, dim), nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        s = x.mean(dim=1)
        g = self.net(s).unsqueeze(1)
        return x * g


class TransformerBlockWrapper(nn.Module):
    def __init__(self, block: nn.Module):
        super().__init__()
        self.block = block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class FlexibleSeqModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        horizon: int,
        block_specs: List[Dict],
        dropout: float = 0.2,
        pooling: str = "attn",
        predict_mode: str = "steps",
        attn_pool_heads: int = 4,
    ):
        super().__init__()
        self.horizon = horizon
        self.predict_mode = predict_mode
        self.pooling = pooling

        dim = input_dim
        blocks: List[nn.Module] = []
        for spec in block_specs:
            t = spec["type"].lower()
            if t == "rnn":
                blk = RNNBlock(
                    input_dim=dim,
                    hidden_dim=spec.get("hidden", 128),
                    rnn=spec.get("rnn", "gru"),
                    num_layers=spec.get("layers", 1),
                    dropout=spec.get("dropout", 0.1),
                    bidirectional=spec.get("bidirectional", False),
                )
                blocks.append(Residual(blk, dim, blk.out_dim, drop_prob=spec.get("res_dropout", 0.0)))
                dim = blk.out_dim
            elif t == "tcn":
                blk = Conv1DBlock(
                    dim, kernel_size=spec.get("kernel", 3), dilation=spec.get("dilation", 1), dropout=spec.get("dropout", 0.1)
                )
                blocks.append(Residual(blk, dim, dim, drop_prob=spec.get("res_dropout", 0.0)))
            elif t == "transformer":
                blk = TransformerBlock(
                    dim, nhead=spec.get("nhead", 4), ff_mult=spec.get("ff_mult", 4), dropout=spec.get("dropout", 0.1)
                )
                blocks.append(Residual(TransformerBlockWrapper(blk), dim, dim, drop_prob=spec.get("res_dropout", 0.0)))
            elif t == "se":
                blk = SEBlock(dim, r=spec.get("r", 4))
                blocks.append(Residual(blk, dim, dim, drop_prob=spec.get("res_dropout", 0.0)))
            else:
                raise ValueError(f"Unknown block type: {t}")
        self.blocks = nn.ModuleList(blocks)

        if pooling == "attn":
            self.pool_ln = nn.LayerNorm(dim)
            self.pool_attn = nn.MultiheadAttention(dim, num_heads=attn_pool_heads, batch_first=True)
            self.pool_vec = nn.Parameter(torch.randn(1, 1, dim))
        else:
            self.pool_ln = nn.LayerNorm(dim)

        self.head = nn.Sequential(nn.Linear(dim, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, horizon))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = x
        for blk in self.blocks:
            h = blk(h)
        if self.pooling == "attn":
            B, T, D = h.shape
            q = self.pool_vec.expand(B, -1, -1)
            k = v = self.pool_ln(h)
            ctx, _ = self.pool_attn(q, k, v)
            ctx = ctx.squeeze(1)
        elif self.pooling == "mean":
            ctx = self.pool_ln(h).mean(dim=1)
        else:
            ctx = self.pool_ln(h[:, -1, :])
        out = self.head(ctx)
        if self.predict_mode == "steps":
            out = torch.cumsum(out, dim=1)
        return out


# ---------------------------------------------------------------------------
# Training helpers
# ---------------------------------------------------------------------------
def _masked_mse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> float:
    # pred/target/mask: (B, L)
    se = (pred - target) ** 2
    se = se * mask
    denom = mask.sum().clamp_min(1e-8)
    return (se.sum() / denom).item()

def masked_rmse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> float:
    return float(np.sqrt(_masked_mse(pred, target, mask)))

def kaggle_combo_rmse(
    pred_dx: torch.Tensor, pred_dy: torch.Tensor,
    true_dx: torch.Tensor, true_dy: torch.Tensor,
    mask: torch.Tensor
) -> float:
    # 官方評分：sqrt( 0.5 * (MSE_x + MSE_y) )，注意是先合併 MSE 再開根號
    mse_x = _masked_mse(pred_dx, true_dx, mask)
    mse_y = _masked_mse(pred_dy, true_dy, mask)
    return float(np.sqrt(0.5 * (mse_x + mse_y)))


def prepare_targets(batch_axis: Sequence[np.ndarray], max_h: int) -> Tuple[torch.Tensor, torch.Tensor]:
    tensors, masks = [], []
    for arr in batch_axis:
        L = len(arr)
        padded = np.pad(arr, (0, max_h - L), constant_values=0).astype(np.float32)
        mask = np.zeros(max_h, dtype=np.float32)
        mask[:L] = 1.0
        tensors.append(torch.tensor(padded))
        masks.append(torch.tensor(mask))
    return torch.stack(tensors), torch.stack(masks)


def build_batches(X: List[np.ndarray], Y: List[np.ndarray], batch_size: int, horizon: int) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    batches = []
    for i in range(0, len(X), batch_size):
        end = min(i + batch_size, len(X))
        xs = torch.tensor(np.stack(X[i:end]).astype(np.float32))
        ys, ms = prepare_targets([Y[j] for j in range(i, end)], horizon)
        batches.append((xs, ys, ms))
    return batches


def train_torch_axis_model(
    X_tr: np.ndarray,
    y_tr: List[np.ndarray],
    X_va: np.ndarray,
    y_va: List[np.ndarray],
    input_dim: int,
    horizon: int,
    cfg: Config,
    model_name: str,
    axis_name: str = "dx",   # 用於日誌標識
) -> Tuple[nn.Module, float, float, float]:
    device = cfg.DEVICE
    if model_name == "GRU_RES":
        model = GRUResidualModel(input_dim=input_dim, hidden=cfg.HIDDEN_DIM, horizon=horizon).to(device)
    elif model_name == "DYN":
        block_specs = [
            {"type": "rnn", "rnn": "gru", "hidden": 128, "layers": 1},
            {"type": "tcn", "kernel": 3, "dilation": 1},
            {"type": "transformer", "nhead": 4, "ff_mult": 4},
            {"type": "se", "r": 4},
        ]
        model = FlexibleSeqModel(input_dim=input_dim, horizon=horizon, block_specs=block_specs, pooling="attn").to(device)
    else:
        raise ValueError("Torch training only for GRU_RES or DYN")

    criterion = TemporalHuber(delta=0.5, time_decay=0.03)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)

    tr_batches = build_batches(X_tr, y_tr, cfg.BATCH_SIZE, horizon)
    va_batches = build_batches(X_va, y_va, cfg.BATCH_SIZE, horizon)

    best_loss, best_state, bad = float("inf"), None, 0
    last_train_loss, last_lr = float("inf"), cfg.LEARNING_RATE

    for epoch in range(1, cfg.EPOCHS + 1):
        model.train()
        train_losses = []
        for step, (bx, by, bm) in enumerate(tr_batches):
            bx, by, bm = bx.to(device), by.to(device), bm.to(device)
            pred = model(bx)
            loss = criterion(pred, by, bm)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            # per-batch cosine
            scheduler.step(epoch - 1 + step / max(1, len(tr_batches)))
            train_losses.append(loss.item())

        # ---- validation: Huber + masked RMSE（該軸）
        model.eval()
        val_losses, val_rmses = [], []
        with torch.no_grad():
            for bx, by, bm in va_batches:
                bx, by, bm = bx.to(device), by.to(device), bm.to(device)
                pred = model(bx)
                val_losses.append(criterion(pred, by, bm).item())
                val_rmses.append(masked_rmse(pred, by, bm))

        trl = float(np.mean(train_losses))
        val = float(np.mean(val_losses))
        vrmse = float(np.mean(val_rmses))
        last_train_loss = trl
        last_lr = optimizer.param_groups[0]["lr"]

        # ✅ 每個 epoch 都打印
        log.info("  [%s] Epoch %03d | Train Loss=%.4f | Val Huber=%.4f | Val RMSE=%.4f | LR=%.2e",
                 axis_name, epoch, trl, val, vrmse, last_lr)

        # 早停
        if val < best_loss:
            best_loss, bad = val, 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            bad += 1
            if bad >= cfg.PATIENCE:
                log.info("  [%s] Early stop at epoch %d", axis_name, epoch)
                break

    if best_state:
        model.load_state_dict(best_state)
    return model, best_loss, last_train_loss, last_lr

# ---------------------------------------------------------------------------
# Tabular (XGB / CAT) training helpers – flatten sequences
# ---------------------------------------------------------------------------

def flatten_sequences(X: List[np.ndarray]) -> np.ndarray:
    Xf = np.stack([x.reshape(-1) for x in X]).astype(np.float32)
    return Xf


def train_xgb_axis_model(X_tr: np.ndarray, y_tr: List[np.ndarray], X_va: np.ndarray, y_va: List[np.ndarray], cfg: Config, horizon: int):
    if xgb is None:
        raise ImportError("xgboost is not available in this environment")
    from sklearn.multioutput import MultiOutputRegressor

    model = MultiOutputRegressor(
        xgb.XGBRegressor(
            max_depth=8,
            n_estimators=1200,
            subsample=0.8,
            colsample_bytree=0.8,
            learning_rate=0.03,
            tree_method="gpu_hist",
            predictor="gpu_predictor",
            reg_lambda=1.0,
            reg_alpha=0.0,
            random_state=cfg.SEED,
        )
    )
    Ytr = np.stack([y for y in y_tr])
    Yva = np.stack([y for y in y_va])
    model.fit(X_tr, Ytr)
    # simple validation RMSE
    pred = model.predict(X_va)
    rmse = float(np.sqrt(((pred - Yva) ** 2).mean()))
    return model, rmse


def train_cat_axis_model(X_tr: np.ndarray, y_tr: List[np.ndarray], X_va: np.ndarray, y_va: List[np.ndarray], cfg: Config, horizon: int):
    if CatBoostRegressor is None:
        raise ImportError("catboost is not available in this environment")
    Ytr = np.stack([y for y in y_tr])
    Yva = np.stack([y for y in y_va])
    model = CatBoostRegressor(
        loss_function="MultiRMSE",
        iterations=2500,
        learning_rate=0.03,
        depth=8,
        subsample=0.8,
        rsm=0.8,
        random_state=cfg.SEED,
        task_type="GPU",
        verbose=False,
    )
    model.fit(X_tr, Ytr, eval_set=(X_va, Yva), verbose=False)
    pred = model.predict(X_va)
    rmse = float(np.sqrt(((pred - Yva) ** 2).mean()))
    return model, rmse


# ---------------------------------------------------------------------------
# Model save/load helpers
# ---------------------------------------------------------------------------

def fold_dir(cfg: Config, fold: int) -> Path:
    d = cfg.MODELS_DIR / f"fold{fold}"
    d.mkdir(parents=True, exist_ok=True)
    return d


def save_fold_artifacts(
    fold: int,
    scaler: StandardScaler,
    model_dx,
    model_dy,
    cfg: Config,
):
    sdir = fold_dir(cfg, fold)
    joblib.dump(scaler, sdir / "scaler.pkl")
    if cfg.MODEL_NAME in {"GRU_RES", "DYN"}:
        torch.save(model_dx.state_dict(), sdir / "model_dx.pt")
        torch.save(model_dy.state_dict(), sdir / "model_dy.pt")
    else:
        joblib.dump(model_dx, sdir / "model_dx.pkl")
        joblib.dump(model_dy, sdir / "model_dy.pkl")


def write_meta(cfg: Config, feature_cols: List[str]):
    meta = {
        "model_name": cfg.MODEL_NAME,
        "n_folds": cfg.N_FOLDS,
        "feature_cols": feature_cols,
        "window_size": cfg.WINDOW_SIZE,
        "max_future_horizon": cfg.MAX_FUTURE_HORIZON,
        "feature_groups": cfg.FEATURE_GROUPS,
        "version": 2,
    }
    with open(cfg.MODELS_DIR / "meta.json", "w") as f:
        json.dump(meta, f)
    log.info("[META] wrote meta.json to %s", cfg.MODELS_DIR)


def load_cv(cfg: Config):
    meta_path = cfg.MODELS_DIR / "meta.json"
    assert meta_path.exists(), f"meta.json not found: {meta_path}"
    with open(meta_path, "r") as f:
        meta = json.load(f)
    feature_cols = meta["feature_cols"]
    horizon = int(meta["max_future_horizon"])
    n_folds = int(meta["n_folds"])
    model_name = meta.get("model_name", cfg.MODEL_NAME)

    models_x, models_y, scalers = [], [], []
    for fold in range(1, n_folds + 1):
        sdir = fold_dir(cfg, fold)
        scaler = joblib.load(sdir / "scaler.pkl")
        scalers.append(scaler)
        if model_name in {"GRU_RES", "DYN"}:
            if model_name == "GRU_RES":
                mx = GRUResidualModel(len(feature_cols), cfg.HIDDEN_DIM, horizon).to(cfg.DEVICE)
                my = GRUResidualModel(len(feature_cols), cfg.HIDDEN_DIM, horizon).to(cfg.DEVICE)
            else:
                block_specs = [
                    {"type": "rnn", "rnn": "gru", "hidden": 128, "layers": 1},
                    {"type": "tcn", "kernel": 3, "dilation": 1},
                    {"type": "transformer", "nhead": 4, "ff_mult": 4},
                    {"type": "se", "r": 4},
                ]
                mx = FlexibleSeqModel(len(feature_cols), horizon, block_specs, pooling="attn").to(cfg.DEVICE)
                my = FlexibleSeqModel(len(feature_cols), horizon, block_specs, pooling="attn").to(cfg.DEVICE)
            mx.load_state_dict(torch.load(sdir / "model_dx.pt", map_location=cfg.DEVICE))
            my.load_state_dict(torch.load(sdir / "model_dy.pt", map_location=cfg.DEVICE))
            mx.eval(); my.eval()
        else:
            mx = joblib.load(sdir / "model_dx.pkl")
            my = joblib.load(sdir / "model_dy.pkl")
        models_x.append(mx)
        models_y.append(my)

    log.info("[LOAD] loaded %d-fold models from %s", len(models_x), cfg.MODELS_DIR)
    return models_x, models_y, scalers, meta


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    cfg = Config()
    set_seed(cfg.SEED)

    log.info("=" * 80)
    log.info("RUN MODE: TRAIN=%s | SUB=%s | MODEL=%s", cfg.TRAIN, cfg.SUB, cfg.MODEL_NAME)
    log.info("=" * 80)

    if cfg.TRAIN:
        # 1) load
        log.info("[1/3] Loading train data…")
        train_input_files = [cfg.DATA_DIR / f"train/input_2023_w{w:02d}.csv" for w in range(1, 19)]
        train_output_files = [cfg.DATA_DIR / f"train/output_2023_w{w:02d}.csv" for w in range(1, 19)]
        train_input = pd.concat([pd.read_csv(f) for f in train_input_files if f.exists()], ignore_index=True)
        train_output = pd.concat([pd.read_csv(f) for f in train_output_files if f.exists()], ignore_index=True)

        # 2) sequences
        log.info("[2/3] Feature engineering & sequences…")
        seqs, tdx, tdy, tfids, seq_meta, feat_cols, dir_map = prepare_sequences(
            train_input,
            output_df=train_output,
            is_training=True,
            window_size=cfg.WINDOW_SIZE,
            feature_groups=cfg.FEATURE_GROUPS,
        )
        write_meta(cfg, feat_cols)

        # 3) 5-fold CV (single seed)
        log.info("[3/3] 5-fold CV training (%s)…", cfg.MODEL_NAME)
        groups = np.array([d["game_id"] for d in seq_meta])
        gkf = GroupKFold(n_splits=cfg.N_FOLDS)
        fold_metrics: List[Dict[str, float]] = []

        for fold, (tr, va) in enumerate(gkf.split(seqs, groups=groups), 1):
            log.info("-" * 60)
            log.info("Fold %d/%d", fold, cfg.N_FOLDS)
            X_tr_raw = [seqs[i] for i in tr]
            X_va_raw = [seqs[i] for i in va]

            scaler = StandardScaler()
            scaler.fit(np.vstack([s for s in X_tr_raw]))
            X_tr_sc = np.stack([scaler.transform(s) for s in X_tr_raw]).astype(np.float32)
            X_va_sc = np.stack([scaler.transform(s) for s in X_va_raw]).astype(np.float32)

            H = cfg.MAX_FUTURE_HORIZON
            if cfg.MODEL_NAME in {"GRU_RES", "DYN"}:
                # ΔX
                mx, val_huber_x, tr_huber_x, lr_x = train_torch_axis_model(
                    X_tr_sc, [tdx[i] for i in tr], X_va_sc, [tdx[i] for i in va],
                    X_tr_sc.shape[-1], H, cfg, cfg.MODEL_NAME, axis_name="dx"
                )
                # ΔY
                my, val_huber_y, tr_huber_y, lr_y = train_torch_axis_model(
                    X_tr_sc, [tdy[i] for i in tr], X_va_sc, [tdy[i] for i in va],
                    X_tr_sc.shape[-1], H, cfg, cfg.MODEL_NAME, axis_name="dy"
                )

                # ---- 準備驗證集目標與 mask
                dx_va_t, m_va = prepare_targets([tdx[i] for i in va], H)
                dy_va_t, _    = prepare_targets([tdy[i] for i in va], H)

                # ---- 模型在驗證集上的預測
                with torch.no_grad():
                    Xv = torch.tensor(X_va_sc, device=cfg.DEVICE)
                    pred_dx_va = mx(Xv).cpu()
                    pred_dy_va = my(Xv).cpu()

                # ---- 軸向 RMSE（檢查用）
                rmse_dx = masked_rmse(pred_dx_va, dx_va_t, m_va)
                rmse_dy = masked_rmse(pred_dy_va, dy_va_t, m_va)

                # ---- 官方 Kaggle RMSE（重點）
                rmse_kaggle = kaggle_combo_rmse(pred_dx_va, pred_dy_va, dx_va_t, dy_va_t, m_va)

                # ---- 單行總結輸出
                train_loss_avg = 0.5 * (tr_huber_x + tr_huber_y)  # 兩軸 Huber 平均，作為 Train Loss
                log.info("Train Loss=%.4f | Val RMSE (dx=%.4f, dy=%.4f) | Kaggle RMSE=%.4f | LR=%.2e",
                         train_loss_avg, rmse_dx, rmse_dy, rmse_kaggle, lr_y)
                fold_metrics.append({
                    "fold": float(fold),
                    "rmse_dx": float(rmse_dx),
                    "rmse_dy": float(rmse_dy),
                    "rmse_kaggle": float(rmse_kaggle),
                })


            elif cfg.MODEL_NAME == "XGB":
                X_tr_flat = flatten_sequences(X_tr_sc)
                X_va_flat = flatten_sequences(X_va_sc)

                mx, _ = train_xgb_axis_model(X_tr_flat, [tdx[i] for i in tr],
                                             X_va_flat, [tdx[i] for i in va], cfg, H)
                my, _ = train_xgb_axis_model(X_tr_flat, [tdy[i] for i in tr],
                                             X_va_flat, [tdy[i] for i in va], cfg, H)

                # 準備目標與 mask
                dx_va_t, m_va = prepare_targets([tdx[i] for i in va], H)
                dy_va_t, _    = prepare_targets([tdy[i] for i in va], H)

                # 驗證集預測（np -> torch）
                pred_dx = torch.tensor(mx.predict(X_va_flat), dtype=torch.float32)
                pred_dy = torch.tensor(my.predict(X_va_flat), dtype=torch.float32)

                # 軸向 RMSE（參考）
                rmse_dx = masked_rmse(pred_dx, dx_va_t, m_va)
                rmse_dy = masked_rmse(pred_dy, dy_va_t, m_va)

                # Kaggle RMSE（重點）
                rmse_kaggle = kaggle_combo_rmse(pred_dx, pred_dy, dx_va_t, dy_va_t, m_va)
                log.info("[VAL] fold %d (XGB): RMSE(dx=%.4f, dy=%.4f) | Kaggle RMSE=%.4f",
                         fold, rmse_dx, rmse_dy, rmse_kaggle)
                fold_metrics.append({
                    "fold": float(fold),
                    "rmse_dx": float(rmse_dx),
                    "rmse_dy": float(rmse_dy),
                    "rmse_kaggle": float(rmse_kaggle),
                })

            elif cfg.MODEL_NAME == "CAT":
                X_tr_flat = flatten_sequences(X_tr_sc)
                X_va_flat = flatten_sequences(X_va_sc)

                mx, _ = train_cat_axis_model(X_tr_flat, [tdx[i] for i in tr],
                                             X_va_flat, [tdx[i] for i in va], cfg, H)
                my, _ = train_cat_axis_model(X_tr_flat, [tdy[i] for i in tr],
                                             X_va_flat, [tdy[i] for i in va], cfg, H)

                dx_va_t, m_va = prepare_targets([tdx[i] for i in va], H)
                dy_va_t, _    = prepare_targets([tdy[i] for i in va], H)

                pred_dx = torch.tensor(mx.predict(X_va_flat), dtype=torch.float32)
                pred_dy = torch.tensor(my.predict(X_va_flat), dtype=torch.float32)

                rmse_dx = masked_rmse(pred_dx, dx_va_t, m_va)
                rmse_dy = masked_rmse(pred_dy, dy_va_t, m_va)
                rmse_kaggle = kaggle_combo_rmse(pred_dx, pred_dy, dx_va_t, dy_va_t, m_va)
                log.info("[VAL] fold %d (CAT): RMSE(dx=%.4f, dy=%.4f) | Kaggle RMSE=%.4f",
                         fold, rmse_dx, rmse_dy, rmse_kaggle)
                fold_metrics.append({
                    "fold": float(fold),
                    "rmse_dx": float(rmse_dx),
                    "rmse_dy": float(rmse_dy),
                    "rmse_kaggle": float(rmse_kaggle),
                })


            else:
                raise ValueError(f"Unknown MODEL_NAME={cfg.MODEL_NAME}")

            # 保存当前折的 scaler 与两个轴向模型
            save_fold_artifacts(fold=fold, scaler=scaler, model_dx=mx, model_dy=my, cfg=cfg)
        
        
        if fold_metrics:
            mdf = pd.DataFrame(fold_metrics).sort_values("fold")
            # 均值與樣本標準差（ddof=1），以及方差
            summary = {
                "model_name": cfg.MODEL_NAME,
                "n_folds": int(len(mdf)),
                "rmse_dx": {
                    "mean": float(mdf["rmse_dx"].mean()),
                    "std": float(mdf["rmse_dx"].std(ddof=1)),
                    "var": float(mdf["rmse_dx"].var(ddof=1)),
                },
                "rmse_dy": {
                    "mean": float(mdf["rmse_dy"].mean()),
                    "std": float(mdf["rmse_dy"].std(ddof=1)),
                    "var": float(mdf["rmse_dy"].var(ddof=1)),
                },
                "kaggle_rmse": {
                    "mean": float(mdf["rmse_kaggle"].mean()),
                    "std": float(mdf["rmse_kaggle"].std(ddof=1)),
                    "var": float(mdf["rmse_kaggle"].var(ddof=1)),
                },
            }
            # 打印總結
            log.info("[CV SUMMARY] Folds=%d | Kaggle RMSE: mean=%.4f, std=%.4f, var=%.6f | "
                     "dx mean=%.4f std=%.4f | dy mean=%.4f std=%.4f",
                     summary["n_folds"],
                     summary["kaggle_rmse"]["mean"], summary["kaggle_rmse"]["std"], summary["kaggle_rmse"]["var"],
                     summary["rmse_dx"]["mean"], summary["rmse_dx"]["std"],
                     summary["rmse_dy"]["mean"], summary["rmse_dy"]["std"])
            # 存檔（方便追溯）
            mdf.to_csv(cfg.MODELS_DIR / "fold_metrics.csv", index=False)
            with open(cfg.MODELS_DIR / "cv_metrics.json", "w") as f:
                json.dump(summary, f, indent=2)
        else:
            log.warning("[CV SUMMARY] No fold metrics collected — please check data splits or masks.")
        
        log.info("=" * 80)
        log.info("COMPLETE (TRAIN). Models saved under: %s", cfg.MODELS_DIR)
        return

    # ---------------------------  SUBMIT / INFERENCE  ---------------------------
    if cfg.SUB:
        log.info("[1/3] Loading test data…")
        test_input = pd.read_csv(cfg.DATA_DIR / "test_input.csv")
        test_template = pd.read_csv(cfg.DATA_DIR / "test.csv")

        log.info("[2/3] Loading CV models & meta…")
        models_x, models_y, scalers, meta = load_cv(cfg)
        saved_feature_cols = meta["feature_cols"]
        saved_window = int(meta["window_size"])
        model_name = meta.get("model_name", cfg.MODEL_NAME)

        log.info("[3/3] Building test sequences & predicting…")
        test_seqs, test_meta, feat_cols_t, _ = prepare_sequences(
            test_input,
            test_template=test_template,
            is_training=False,
            window_size=saved_window,
            feature_groups=meta.get("feature_groups", cfg.FEATURE_GROUPS),
        )
        assert feat_cols_t == saved_feature_cols, \
            f"Feature mismatch! train:{len(saved_feature_cols)} vs test:{len(feat_cols_t)}"

        idx_x = feat_cols_t.index("x")
        idx_y = feat_cols_t.index("y")
        X_test_raw = list(test_seqs)
        x_last_uni = np.array([s[-1, idx_x] for s in X_test_raw], dtype=np.float32)
        y_last_uni = np.array([s[-1, idx_y] for s in X_test_raw], dtype=np.float32)

        H = int(meta["max_future_horizon"])
        all_preds_dx, all_preds_dy = [], []

        for mx, my, sc in zip(models_x, models_y, scalers):
            X_sc = np.stack([sc.transform(s) for s in X_test_raw]).astype(np.float32)

            if model_name in {"GRU_RES", "DYN"}:
                X_t = torch.tensor(X_sc).to(cfg.DEVICE)
                with torch.no_grad():
                    pred_dx = mx(X_t).cpu().numpy()
                    pred_dy = my(X_t).cpu().numpy()
            elif model_name in {"XGB", "CAT"}:
                X_flat = flatten_sequences(X_sc)
                pred_dx = mx.predict(X_flat)
                pred_dy = my.predict(X_flat)
            else:
                raise ValueError(f"Unknown model_name {model_name}")

            all_preds_dx.append(pred_dx)
            all_preds_dy.append(pred_dy)

        ens_dx = np.mean(all_preds_dx, axis=0)
        ens_dy = np.mean(all_preds_dy, axis=0)

        # 组装提交（把统一坐标反变换回原始方向）
        rows = []
        tt_idx = test_template.set_index(["game_id", "play_id", "nfl_id"]).sort_index()

        for i, meta_row in enumerate(test_meta):
            gid = meta_row["game_id"]
            pid = meta_row["play_id"]
            nid = meta_row["nfl_id"]
            play_is_right = (meta_row["play_direction"] == "right")

            try:
                fids = tt_idx.loc[(gid, pid, nid), "frame_id"]
                if isinstance(fids, pd.Series):
                    fids = fids.sort_values().tolist()
                else:
                    fids = [int(fids)]
            except KeyError:
                continue

            for t, fid in enumerate(fids):
                tt = min(t, H - 1)
                x_uni = float(np.clip(x_last_uni[i] + ens_dx[i, tt], 0, FIELD_LENGTH))
                y_uni = float(np.clip(y_last_uni[i] + ens_dy[i, tt], 0, FIELD_WIDTH))
                x_out, y_out = invert_to_original_direction(x_uni, y_uni, play_is_right)
                rows.append({"id": f"{gid}_{pid}_{nid}_{int(fid)}", "x": x_out, "y": y_out})

        submission = pd.DataFrame(rows)
        submission.to_csv("submission.csv", index=False)
        log.info("=" * 80)
        log.info("COMPLETE (SUB). Saved submission.csv  |  Rows: %d", len(submission))
        return

    raise ValueError("Please set mode in Config: TRAIN=True/SUB=False or TRAIN=False/SUB=True")


if __name__ == "__main__":
    main()