In [None]:
# =============================================================================
# Kaggle NFL 2026 - API Submission Script for Model v4.9_2D (Joint Prediction)
#
# 基于 trans_v4_9_2D_copy.py 的训练脚本生成。
#
# 变更日志:
# - [架构] 从两个独立的 x, y 模型切换为单一的联合预测 (dx, dy) 模型。
# - [特征] 完全同步了 v4.9_2D 脚本中所有最新的特征工程逻辑。
# - [推理] 简化了 5 折集成流程，现在对 (dx, dy) 向量直接进行平均。
# - [配置] 更新了所有模型超参数以匹配新的训练配置。
# - [修正] 修复了 invert_to_original_direction 函数中的一个变量名错误。
# =============================================================================

# =============================================================================
# 1. 全局导入与环境设置
# =============================================================================
import os
import gc
import time
import pickle
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings("ignore")

import polars as pl
import kaggle_evaluation.nfl_inference_server

# =============================================================================
# 2. 核心逻辑定义 (与 trans_v4_9_2D_copy.py 保持一致)
# =============================================================================

# -------------------------------
# 配置类
# -------------------------------
class Config:
    # [重要] 确保这个路径指向您上传的包含模型和scaler的数据集
    MODEL_DIR = Path("/kaggle/input/trans-v4-9-2d-copy/outputs_v4_9_2D_copy") 
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    INFERENCE_BATCH_SIZE = 256
    N_FOLDS = 5
    WINDOW_SIZE = 25
    MAX_FUTURE_HORIZON = 94
    MAX_PLAYERS = 24
    D_MODEL = 192
    NHEAD = 6
    NUM_ENCODER_LAYERS = 3
    DIM_FEEDFORWARD = 768
    TRANSFORMER_DROPOUT = 0.15
    PLAYER_STATIC_HIDDEN_DIM = 64
    PLAYER_STATIC_DROPOUT = 0.2
    GNN_HIDDEN_DIM = 64
    GNN_DROPOUT = 0.2
    HEAD_DROPOUT = 0.2
    K_NEIGH = 8
    RADIUS = 30.0
    TAU = 8.0

# -------------------------------
# 辅助函数
# -------------------------------
YARDS_TO_METERS = 0.9144
FPS = 10.0
FIELD_LENGTH, FIELD_WIDTH = 120.0, 53.3

def wrap_angle_deg(s):
    return ((s + 180.0) % 360.0) - 180.0

def wrap_angle_rad(angle):
    return (angle + np.pi) % (2 * np.pi) - np.pi

def unify_left_direction(df: pd.DataFrame) -> pd.DataFrame:
    if 'play_direction' not in df.columns: return df
    df = df.copy()
    right = df['play_direction'].eq('right')
    if 'x' in df.columns: df.loc[right, 'x'] = FIELD_LENGTH - df.loc[right, 'x']
    if 'y' in df.columns: df.loc[right, 'y'] = FIELD_WIDTH  - df.loc[right, 'y']
    for col in ('dir','o'):
        if col in df.columns:
            df.loc[right, col] = (df.loc[right, col].astype(float) + 180.0) % 360.0
    if 'ball_land_x' in df.columns:
        df.loc[right, 'ball_land_x'] = FIELD_LENGTH - df.loc[right, 'ball_land_x']
    if 'ball_land_y' in df.columns:
        df.loc[right, 'ball_land_y'] = FIELD_WIDTH  - df.loc[right, 'ball_land_y']
    return df

def invert_to_original_direction(x_u, y_u, play_dir_right: bool):
    if not play_dir_right:
        return float(x_u), float(y_u)
    # [修正] 修复了训练代码中的 WIDTH -> FIELD_WIDTH 错误
    return float(FIELD_LENGTH - x_u), float(FIELD_WIDTH - y_u)

# -------------------------------
# 特征工程 (v4.9_2D)
# -------------------------------
class FeatureEngineer:
    def __init__(self, feature_groups_to_create):
        self.gcols = ['game_id', 'play_id', 'nfl_id']
        self.active_groups = feature_groups_to_create
        self.feature_creators = {
            'distance_rate': self._create_distance_rate_features,
            'target_alignment': self._create_target_alignment_features,
            'multi_window_rolling': self._create_multi_window_rolling_features,
            'extended_lags': self._create_extended_lag_features,
            'velocity_changes': self._create_velocity_change_features,
            'field_position': self._create_field_position_features,
            'role_specific': self._create_role_specific_features,
            'time_features': self._create_time_features,
            'jerk_features': self._create_jerk_features,
            'curvature_land_features': self._create_curvature_land_features,
        }
        self.created_feature_cols = []

    def _height_to_feet(self, height_str):
        try:
            ft, inches = map(int, str(height_str).split('-'))
            return ft + inches / 12
        except Exception:
            return 6.0

    def _create_basic_features(self, df):
        df = df.copy()
        df['player_height_feet'] = df['player_height'].apply(self._height_to_feet)
        dir_rad = np.deg2rad(df['dir'].fillna(0.0).astype('float32'))
        df['velocity_x'] = df['s'] * np.cos(dir_rad)
        df['velocity_y'] = df['s'] * np.sin(dir_rad)
        df['acceleration_x'] = df['a'] * np.cos(dir_rad)
        df['acceleration_y'] = df['a'] * np.sin(dir_rad)
        df['is_offense']  = (df['player_side'] == 'Offense').astype(np.int8)
        df['is_defense']  = (df['player_side'] == 'Defense').astype(np.int8)
        df['is_receiver'] = (df['player_role'] == 'Targeted Receiver').astype(np.int8)
        df['is_coverage'] = (df['player_role'] == 'Defensive Coverage').astype(np.int8)
        mass_kg = df['player_weight'].fillna(200.0) / 2.20462
        v_ms = df['s'] * YARDS_TO_METERS
        df['momentum_x'] = mass_kg * df['velocity_x'] * YARDS_TO_METERS
        df['momentum_y'] = mass_kg * df['velocity_y'] * YARDS_TO_METERS
        df['kinetic_energy'] = 0.5 * mass_kg * (v_ms ** 2)
        df['momentum_magnitude'] = np.hypot(df['momentum_x'], df['momentum_y'])
        df['momentum_direction'] = np.arctan2(df['momentum_y'], df['momentum_x'])
        height_m = df['player_height_feet'] * 0.3048
        df['player_bmi'] = mass_kg / (height_m ** 2 + 1e-6)
        df['power_weight_ratio'] = df['kinetic_energy'] / (mass_kg + 1e-6)

        if {'ball_land_x','ball_land_y'}.issubset(df.columns):
            ball_dx = df['ball_land_x'] - df['x']
            ball_dy = df['ball_land_y'] - df['y']
            dist = np.hypot(ball_dx, ball_dy)
            df['distance_to_ball'] = dist
            inv = 1.0 / (dist + 1e-6)
            df['ball_direction_x'] = ball_dx * inv
            df['ball_direction_y'] = ball_dy * inv
            df['closing_speed'] = (df['velocity_x'] * df['ball_direction_x'] + df['velocity_y'] * df['ball_direction_y'])
            df['ball_angle'] = np.arctan2(ball_dy, ball_dx)
            df['approach_angle'] = wrap_angle_rad(df['ball_angle'] - np.arctan2(df['velocity_y'], df['velocity_x']))
            df['distance_to_ball_squared'] = dist ** 2
            df['ball_distance_log'] = np.log1p(dist)

        base = ['x','s','a','dir','frame_id','ball_land_x','ball_land_y','player_height_feet','player_weight',
                'velocity_x','velocity_y','acceleration_x','acceleration_y','momentum_x','momentum_y','kinetic_energy',
                'momentum_magnitude','momentum_direction','player_bmi','power_weight_ratio','is_offense','is_defense',
                'is_receiver','is_coverage','distance_to_ball','ball_direction_x','ball_direction_y',
                'closing_speed','ball_angle','approach_angle','distance_to_ball_squared','ball_distance_log']
        self.created_feature_cols.extend([c for c in base if c in df.columns])
        return df

    def _create_distance_rate_features(self, df):
        new_cols = []
        if 'distance_to_ball' in df.columns:
            d = df.groupby(self.gcols)['distance_to_ball'].diff()
            df['d2ball_dt']  = d.fillna(0.0) * FPS
            df['d2ball_ddt'] = df.groupby(self.gcols)['d2ball_dt'].diff().fillna(0.0) * FPS
            df['time_to_intercept'] = (df['distance_to_ball'] / (df['d2ball_dt'].abs() + 1e-3)).clip(0, 10)
            df['d2ball_d3t'] = df.groupby(self.gcols)['d2ball_ddt'].diff().fillna(0.0) * FPS
            df['closing_efficiency'] = df['d2ball_dt'] / (df['s'] + 1e-3)
            df['intercept_urgency'] = 1.0 / (df['time_to_intercept'] + 0.1)
            df['distance_ema3'] = df.groupby(self.gcols)['distance_to_ball'].transform(lambda x: x.ewm(span=3, adjust=False).mean())
            new_cols = ['d2ball_dt','d2ball_ddt','time_to_intercept','d2ball_d3t','closing_efficiency','intercept_urgency','distance_ema3']
        return df, new_cols

    def _create_target_alignment_features(self, df):
        new_cols = []
        if {'ball_direction_x','ball_direction_y','velocity_x','velocity_y'}.issubset(df.columns):
            df['velocity_alignment'] = df['velocity_x']*df['ball_direction_x'] + df['velocity_y']*df['ball_direction_y']
            df['velocity_perpendicular'] = df['velocity_x']*(-df['ball_direction_y']) + df['velocity_y']*df['ball_direction_x']
            new_cols.extend(['velocity_alignment','velocity_perpendicular'])
            if {'acceleration_x','acceleration_y'}.issubset(df.columns):
                df['accel_alignment'] = df['acceleration_x']*df['ball_direction_x'] + df['acceleration_y']*df['ball_direction_y']
                df['accel_perpendicular'] = df['acceleration_x']*(-df['ball_direction_y']) + df['acceleration_y']*df['ball_direction_x']
                new_cols.extend(['accel_alignment','accel_perpendicular'])
        return df, new_cols

    def _create_multi_window_rolling_features(self, df):
        new_cols = []
        for window in (3, 5, 10, 20):
            for col in ('velocity_x','velocity_y','s'):
                if col in df.columns:
                    r_mean = df.groupby(self.gcols)[col].rolling(window, min_periods=1).mean().reset_index(level=list(range(len(self.gcols))), drop=True)
                    r_std  = df.groupby(self.gcols)[col].rolling(window, min_periods=1).std().reset_index(level=list(range(len(self.gcols))), drop=True)
                    df[f'{col}_roll{window}'] = r_mean
                    df[f'{col}_std{window}']  = r_std.fillna(0.0)
                    df[f'{col}_dev{window}'] = df[col] - r_mean
                    new_cols.extend([f'{col}_roll{window}', f'{col}_std{window}', f'{col}_dev{window}'])
        if 's_roll3' in df.columns and 's_roll20' in df.columns:
            df['speed_trend_ratio'] = df['s_roll3'] / (df['s_roll20'] + 1e-3)
            new_cols.append('speed_trend_ratio')
        return df, new_cols

    def _create_extended_lag_features(self, df):
        new_cols = []
        for lag in (1, 2, 3, 5, 10):
            for col in ('velocity_x','velocity_y','s'):
                if col in df.columns:
                    g = df.groupby(self.gcols)[col]
                    lagv = g.shift(lag)
                    df[f'{col}_lag{lag}'] = lagv.fillna(g.transform('first'))
                    new_cols.append(f'{col}_lag{lag}')
                    if lag <= 3:
                        df[f'{col}_diff_lag{lag}'] = df[col] - lagv.fillna(df[col])
                        new_cols.append(f'{col}_diff_lag{lag}')
        return df, new_cols

    def _create_velocity_change_features(self, df):
        new_cols = []
        if 'velocity_x' in df.columns:
            df['velocity_x_change'] = df.groupby(self.gcols)['velocity_x'].diff().fillna(0.0)
            df['velocity_y_change'] = df.groupby(self.gcols)['velocity_y'].diff().fillna(0.0)
            df['speed_change']      = df.groupby(self.gcols)['s'].diff().fillna(0.0)
            d = df.groupby(self.gcols)['dir'].diff().fillna(0.0)
            df['direction_change']  = wrap_angle_deg(d)
            df['o_change_rate'] = wrap_angle_deg(df.groupby(self.gcols)['o'].diff().fillna(0.0)) * FPS
            new_cols = ['velocity_x_change','velocity_y_change','speed_change','direction_change', 'o_change_rate']
        return df, new_cols

    def _create_field_position_features(self, df):
        df['dist_from_left'] = df['y']
        df['dist_from_right'] = FIELD_WIDTH - df['y']
        df['dist_from_sideline'] = np.minimum(df['dist_from_left'], df['dist_from_right'])
        df['dist_from_endzone']  = np.minimum(df['x'], FIELD_LENGTH - df['x'])
        df['field_zone_x'] = (df['x'] / FIELD_LENGTH * 5).astype(int).clip(0, 4)
        df['field_zone_y'] = (df['y'] / FIELD_WIDTH * 3).astype(int).clip(0, 2)
        df['in_red_zone'] = (df['dist_from_endzone'] < 20).astype(np.int8)
        df['near_sideline'] = (df['dist_from_sideline'] < 5).astype(np.int8)
        df['dist_from_center'] = np.hypot(df['x'] - FIELD_LENGTH / 2, df['y'] - FIELD_WIDTH / 2)
        return df, ['dist_from_sideline','dist_from_endzone','field_zone_x','field_zone_y','in_red_zone','near_sideline','dist_from_center']

    def _create_role_specific_features(self, df):
        new_cols = []
        if {'is_receiver','velocity_alignment'}.issubset(df.columns):
            df['receiver_optimality'] = df['is_receiver'] * df['velocity_alignment']
            df['receiver_deviation']  = df['is_receiver'] * np.abs(df.get('velocity_perpendicular', 0.0))
            df['receiver_speed_usage'] = df['is_receiver'] * df['s'] / (df['s'].max() + 1e-3)
            new_cols.extend(['receiver_optimality','receiver_deviation','receiver_speed_usage'])
        if {'is_coverage','closing_speed'}.issubset(df.columns):
            df['defender_closing_speed'] = df['is_coverage'] * df['closing_speed']
            df['defender_pressure'] = df['is_coverage'] / (df.get('distance_to_ball', 10.0) + 1e-3)
            new_cols.extend(['defender_closing_speed', 'defender_pressure'])
        return df, new_cols

    def _create_time_features(self, df):
        df['frames_elapsed']  = df.groupby(self.gcols).cumcount()
        return df, ['frames_elapsed']

    def _create_jerk_features(self, df):
        new_cols = []
        if 'a' in df.columns:
            df['jerk'] = df.groupby(self.gcols)['a'].diff().fillna(0.0) * FPS
            df['jerk_smoothed'] = df.groupby(self.gcols)['jerk'].rolling(3, min_periods=1).mean().reset_index(level=list(range(len(self.gcols))), drop=True)
            new_cols.extend(['jerk','jerk_smoothed'])
        if {'acceleration_x','acceleration_y'}.issubset(df.columns):
            df['jerk_y'] = df.groupby(self.gcols)['acceleration_y'].diff().fillna(0.0) * FPS
            _jerk_x_temp = df.groupby(self.gcols)['acceleration_x'].diff().fillna(0.0) * FPS
            df['jerk_magnitude'] = np.hypot(_jerk_x_temp, df['jerk_y'])
            df['jerk_direction'] = np.arctan2(df['jerk_y'], _jerk_x_temp)
            df['cumulative_jerk'] = df.groupby(self.gcols)['jerk_magnitude'].cumsum()
            new_cols.extend(['jerk_y','jerk_magnitude','jerk_direction','cumulative_jerk'])
        return df, new_cols

    def _create_curvature_land_features(self, df):
        if {'ball_land_x','ball_land_y'}.issubset(df.columns):
            dx, dy = df['ball_land_x'] - df['x'], df['ball_land_y'] - df['y']
            bearing = np.arctan2(dy, dx)
            a_dir = np.deg2rad(df['dir'].fillna(0.0).values)
            df['bearing_to_land_signed'] = np.rad2deg(np.arctan2(np.sin(bearing - a_dir), np.cos(bearing - a_dir)))
            ux, uy = np.cos(a_dir), np.sin(a_dir)
            df['land_lateral_offset'] = dy*ux - dx*uy

        ddir = df.groupby(self.gcols)['dir'].diff().fillna(0.0)
        ddir = ((ddir + 180.0) % 360.0) - 180.0
        curvature_val = np.deg2rad(ddir).astype('float32') / (df['s'].replace(0, np.nan).astype('float32') * 0.1 + 1e-6)
        df['curvature_abs'] = curvature_val.fillna(0.0).abs()
        r2 = df.groupby(self.gcols)['curvature_abs'].rolling(3, min_periods=1).mean().reset_index(level=[0,1,2], drop=True)
        df['curv_abs_roll3'] = r2

        accel_angle_rad = np.arctan2(df['acceleration_y'], df['acceleration_x'])
        dir_rad = np.deg2rad(df['dir'].fillna(0.0))
        delta_angle_rad = wrap_angle_rad(accel_angle_rad - dir_rad)
        df['a_tangential'] = df['a'] * np.cos(delta_angle_rad)

        new_cols = ['bearing_to_land_signed','land_lateral_offset', 'curvature_abs', 'curv_abs_roll3', 'a_tangential']
        return df, [c for c in new_cols if c in df.columns]

    def transform(self, df):
        df = df.copy().sort_values(['game_id','play_id','nfl_id','frame_id'])
        df = self._create_basic_features(df)
        for group_name in self.active_groups:
            if group_name in self.feature_creators:
                df, new_cols = self.feature_creators[group_name](df)
                self.created_feature_cols.extend(new_cols)
        return df, sorted(list(set(self.created_feature_cols)))

# -------------------------------
# GNN-lite 特征函数 (v4.9_2D)
# -------------------------------
def compute_neighbor_embeddings(input_df, cfg):
    k_neigh, radius, tau = cfg.K_NEIGH, cfg.RADIUS, cfg.TAU
    cols_needed = ["game_id", "play_id", "nfl_id", "frame_id", "x", "y", "velocity_x", "velocity_y", "player_side", "player_role"]
    src = input_df[cols_needed].copy()
    last = (src.sort_values(["game_id", "play_id", "nfl_id", "frame_id"]).groupby(["game_id", "play_id", "nfl_id"], as_index=False).tail(1).rename(columns={"frame_id": "last_frame_id"}).reset_index(drop=True))
    nb_cols = {c: f"{c}_nb" for c in src.columns if c not in ["game_id", "play_id"]}; nb_cols["frame_id"] = "nb_frame_id"
    tmp = last.merge(src.rename(columns=nb_cols), left_on=["game_id", "play_id", "last_frame_id"], right_on=["game_id", "play_id", "nb_frame_id"], how="left")
    tmp = tmp[tmp["nfl_id_nb"] != tmp["nfl_id"]]
    tmp["dx"] = tmp["x_nb"] - tmp["x"]; tmp["dy"] = tmp["y_nb"] - tmp["y"]
    tmp["dvx"] = tmp["velocity_x_nb"] - tmp["velocity_x"]; tmp["dvy"] = tmp["velocity_y_nb"] - tmp["velocity_y"]
    tmp["dist"] = np.sqrt(tmp["dx"]**2 + tmp["dy"]**2)
    tmp = tmp[np.isfinite(tmp["dist"]) & (tmp["dist"] > 1e-6)]
    if radius is not None: tmp = tmp[tmp["dist"] <= radius]
    tmp["is_ally"] = (tmp["player_side_nb"] == tmp["player_side"]).astype(np.float32)
    tmp['is_dc_nb'] = (tmp['player_role_nb'] == 'Defensive Coverage').astype(np.float32)
    tmp['is_tr_nb'] = (tmp['player_role_nb'] == 'Targeted Receiver').astype(np.float32)
    keys = ["game_id", "play_id", "nfl_id"]
    tmp["rnk"] = tmp.groupby(keys)["dist"].rank(method="first")
    if k_neigh is not None: tmp = tmp[tmp["rnk"] <= float(k_neigh)]
    tmp["w"] = np.exp(-tmp["dist"] / float(tau)); sum_w = tmp.groupby(keys)["w"].transform("sum"); tmp["wn"] = np.where(sum_w > 0, tmp["w"] / sum_w, 0.0)
    tmp["wn_opp"] = tmp["wn"] * (1.0 - tmp["is_ally"]); tmp["wn_opp_dc"] = tmp["wn_opp"] * tmp['is_dc_nb']; tmp["wn_opp_tr"] = tmp["wn_opp"] * tmp['is_tr_nb']
    for col in ["dx", "dy", "dvx", "dvy"]:
        tmp[f"{col}_opp_w"] = tmp[col] * tmp["wn_opp"]; tmp[f"{col}_opp_dc_w"] = tmp[col] * tmp["wn_opp_dc"]; tmp[f"{col}_opp_tr_w"] = tmp[col] * tmp["wn_opp_tr"]
    tmp["dist_opp_dc"] = np.where((tmp["is_ally"] < 0.5) & (tmp['is_dc_nb'] > 0.5), tmp["dist"], np.nan)
    tmp["dist_opp_tr"] = np.where((tmp["is_ally"] < 0.5) & (tmp['is_tr_nb'] > 0.5), tmp["dist"], np.nan)
    agg_dict = {
        'gnn_opp_dx_mean': ('dx_opp_w', 'sum'), 'gnn_opp_dy_mean': ('dy_opp_w', 'sum'),
        'gnn_opp_dvx_mean': ('dvx_opp_w', 'sum'), 'gnn_opp_dvy_mean': ('dvy_opp_w', 'sum'),
        'gnn_dc_dx_mean': ('dx_opp_dc_w', 'sum'), 'gnn_dc_dy_mean': ('dy_opp_dc_w', 'sum'),
        'gnn_dc_dvx_mean': ('dvx_opp_dc_w', 'sum'), 'gnn_dc_dvy_mean': ('dvy_opp_dc_w', 'sum'),
        'gnn_dc_dmin': ('dist_opp_dc', 'min'), 'gnn_dc_cnt': ('is_dc_nb', lambda s: (s * (1 - tmp.loc[s.index, 'is_ally'])).sum()),
        'gnn_tr_dx_mean': ('dx_opp_tr_w', 'sum'), 'gnn_tr_dy_mean': ('dy_opp_tr_w', 'sum'),
        'gnn_tr_dmin': ('dist_opp_tr', 'min'), 'gnn_tr_cnt': ('is_tr_nb', lambda s: (s * (1 - tmp.loc[s.index, 'is_ally'])).sum())
    }
    ag = tmp.groupby(keys).agg(**agg_dict).reset_index()
    near = tmp.loc[tmp["rnk"] <= 3, keys + ["rnk", "dist"]].copy(); near["rnk"] = near["rnk"].astype(int)
    dwide = near.pivot_table(index=keys, columns="rnk", values="dist", aggfunc="first").rename(columns={1: "gnn_d1", 2: "gnn_d2", 3: "gnn_d3"}).reset_index()
    ag = ag.merge(dwide, on=keys, how="left")
    fill_val = radius if radius is not None else 30.0
    for c in ag.columns:
        if c.startswith("gnn_"):
            if "cnt" in c: ag[c] = ag[c].fillna(0.0)
            elif "dmin" in c or c in ["gnn_d1", "gnn_d2", "gnn_d3"]: ag[c] = ag[c].fillna(fill_val)
            else: ag[c] = ag[c].fillna(0.0)
    return ag

# -------------------------------
# 序列构建与特征划分 (v4.9_2D)
# -------------------------------
def build_play_direction_map(df_in: pd.DataFrame) -> pd.Series:
    return df_in[['game_id','play_id','play_direction']].drop_duplicates().set_index(['game_id','play_id'])['play_direction']

def split_features_v5(feature_cols: list) -> (list, list, list):
    base_static_features = ['ball_land_x', 'ball_land_y', 'player_height_feet', 'player_weight', 'player_bmi', 'power_weight_ratio',
                            'is_offense', 'is_defense', 'is_receiver', 'is_coverage']
    gnn_cols = [c for c in feature_cols if c.startswith('gnn_')]
    player_static_cols = [col for col in feature_cols if col in base_static_features]
    dynamic_cols = [c for c in feature_cols if c not in player_static_cols and c not in gnn_cols]
    return dynamic_cols, player_static_cols, gnn_cols

def split_dynamic_features(dynamic_cols: list) -> (list, list):
    CONTEXTUAL_BASE_FEATURES = ['distance_to_ball','ball_direction','closing_speed','time_to_intercept','velocity_alignment',
                                'velocity_perpendicular','accel_alignment', 'accel_perpendicular', 'dist_from_','receiver_',
                                'defender_','frames_elapsed','bearing_to_land','land_lateral_offset',
                                'ball_angle', 'approach_angle', 'd2ball_d3t', 'closing_efficiency', 'intercept_urgency', 'distance_ema3',
                                'field_zone', 'in_red_zone', 'near_sideline', 'dist_from_center']
    kinematic_cols, contextual_cols = [], []
    for col in dynamic_cols:
        is_contextual = any(col.startswith(base) for base in CONTEXTUAL_BASE_FEATURES)
        if is_contextual:
            contextual_cols.append(col)
        else:
            kinematic_cols.append(col)
    if not contextual_cols: warnings.warn("警告：上下文特征列表为空！")
    return kinematic_cols, contextual_cols

def prepare_sequences_for_multistream(input_df, cfg):
    if len(input_df) == 0: return [], [], [], [], [], [], [], [], [], []

    feature_groups = ['distance_rate','target_alignment','multi_window_rolling','extended_lags','velocity_changes','field_position','role_specific','time_features','jerk_features',"curvature_land_features"]
    
    dir_map = build_play_direction_map(input_df)
    input_df_u = unify_left_direction(input_df)

    # 在推理时，'ball_land_x/y' 可能不存在，需要创建
    if 'ball_land_x' not in input_df_u.columns:
        last_pos_self = input_df_u.groupby(['game_id', 'play_id', 'nfl_id'])[['x', 'y']].last().reset_index()
        last_pos_self.rename(columns={'x': 'ball_land_x', 'y': 'ball_land_y'}, inplace=True)
        input_df_u = input_df_u.merge(last_pos_self, on=['game_id', 'play_id', 'nfl_id'], how='left')
        
        # 确保填充所有可能的NaN
        input_df_u['ball_land_x'].fillna(method='ffill', inplace=True)
        input_df_u['ball_land_y'].fillna(method='ffill', inplace=True)
        input_df_u['ball_land_x'].fillna(method='bfill', inplace=True)
        input_df_u['ball_land_y'].fillna(method='bfill', inplace=True)
        input_df_u.fillna({'ball_land_x': 0, 'ball_land_y': 0}, inplace=True)

    fe = FeatureEngineer(feature_groups)
    processed_df, feature_cols = fe.transform(input_df_u)

    gnn_features_df = compute_neighbor_embeddings(processed_df, cfg)
    processed_df = processed_df.merge(gnn_features_df, on=['game_id', 'play_id', 'nfl_id'], how='left')
    gnn_cols = list(gnn_features_df.columns.drop(['game_id', 'play_id', 'nfl_id']))
    feature_cols.extend(gnn_cols)
    
    dynamic_cols, player_static_cols, gnn_static_cols = split_features_v5(feature_cols)
    kinematic_cols, contextual_cols = split_dynamic_features(dynamic_cols)

    plays_grouped = processed_df.groupby(['game_id', 'play_id'])
    
    all_player_kin_seqs, all_player_ctx_seqs = [], []
    target_p_sta_feats, target_g_sta_feats, play_masks = [], [], []
    seq_meta = []

    players_to_predict = processed_df[['game_id', 'play_id', 'nfl_id']].drop_duplicates()

    for _, row in players_to_predict.iterrows():
        gid, pid, target_nfl_id = row['game_id'], row['play_id'], row['nfl_id']
        try:
            play_df = plays_grouped.get_group((gid, pid))
        except KeyError:
            continue
        
        players_in_play = play_df['nfl_id'].unique()
        
        play_kin_features, play_ctx_features, play_p_sta_features, play_g_sta_features = {}, {}, {}, {}

        for nfl_id in players_in_play:
            player_df = play_df[play_df['nfl_id'] == nfl_id].sort_values('frame_id')
            input_window = player_df.tail(cfg.WINDOW_SIZE)
            seq_len = len(input_window)
            pad_len = cfg.WINDOW_SIZE - seq_len

            kin_window_np = input_window[kinematic_cols].values
            ctx_window_np = input_window[contextual_cols].values
            padded_kin_seq_np = np.vstack([np.zeros((pad_len, len(kinematic_cols))), kin_window_np]) if pad_len > 0 else kin_window_np
            padded_ctx_seq_np = np.vstack([np.zeros((pad_len, len(contextual_cols))), ctx_window_np]) if pad_len > 0 else ctx_window_np

            play_kin_features[nfl_id] = np.nan_to_num(padded_kin_seq_np, nan=0.0)
            play_ctx_features[nfl_id] = np.nan_to_num(padded_ctx_seq_np, nan=0.0)
            
            p_static_vector_np = input_window[player_static_cols].iloc[-1].values
            g_static_vector_np = input_window[gnn_static_cols].iloc[-1].values
            play_p_sta_features[nfl_id] = np.nan_to_num(p_static_vector_np, nan=0.0)
            play_g_sta_features[nfl_id] = np.nan_to_num(g_static_vector_np, nan=0.0)

        all_kin_seq_np = np.zeros((cfg.MAX_PLAYERS, cfg.WINDOW_SIZE, len(kinematic_cols)), dtype=np.float32)
        all_ctx_seq_np = np.zeros((cfg.MAX_PLAYERS, cfg.WINDOW_SIZE, len(contextual_cols)), dtype=np.float32)
        play_mask_np = np.zeros(cfg.MAX_PLAYERS, dtype=np.float32)

        all_kin_seq_np[0, :, :] = play_kin_features[target_nfl_id]
        all_ctx_seq_np[0, :, :] = play_ctx_features[target_nfl_id]
        play_mask_np[0] = 1.0
        
        context_players = [p for p in players_in_play if p != target_nfl_id]
        for j, context_nfl_id in enumerate(context_players):
            if (j + 1) < cfg.MAX_PLAYERS:
                all_kin_seq_np[j + 1, :, :] = play_kin_features[context_nfl_id]
                all_ctx_seq_np[j + 1, :, :] = play_ctx_features[context_nfl_id]
                play_mask_np[j + 1] = 1.0
        
        all_player_kin_seqs.append(all_kin_seq_np)
        all_player_ctx_seqs.append(all_ctx_seq_np)
        target_p_sta_feats.append(play_p_sta_features[target_nfl_id])
        target_g_sta_feats.append(play_g_sta_features[target_nfl_id])
        play_masks.append(play_mask_np)

        play_dir_val = dir_map.loc[(gid, pid)]
        seq_meta.append({'game_id': gid, 'play_id': pid, 'nfl_id': target_nfl_id, 'play_direction': play_dir_val})
            
    return (all_player_kin_seqs, all_player_ctx_seqs, target_p_sta_feats, target_g_sta_feats, play_masks, 
            seq_meta, kinematic_cols, contextual_cols, player_static_cols, gnn_static_cols)

# -------------------------------
# 模型架构 (v4.9_2D)
# -------------------------------
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=8):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(nn.Linear(in_channels, in_channels // reduction_ratio), nn.GELU(), nn.Linear(in_channels // reduction_ratio, in_channels), nn.Sigmoid())
    def forward(self, x):
        weights = self.excitation(self.squeeze(x).squeeze(-1)).unsqueeze(-1)
        return x * weights

class ConvBNGELU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.main = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2, bias=False), nn.BatchNorm1d(out_channels), nn.GELU())
    def forward(self, x): return self.main(x)

class ResidualSECNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, pool_size=1, dropout=0.2):
        super().__init__()
        self.conv1 = ConvBNGELU(in_channels, out_channels, kernel_size)
        self.conv2 = ConvBNGELU(out_channels, out_channels, kernel_size)
        self.se = SEBlock(out_channels)
        self.shortcut = nn.Identity() if in_channels == out_channels else nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm1d(out_channels))
        self.pool = nn.MaxPool1d(pool_size) if pool_size > 1 else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.final_act = nn.GELU()
    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv1(x); x = self.conv2(x); x = self.se(x)
        x += residual; x = self.final_act(x); x = self.pool(x)
        return self.dropout(x)

class MultiStreamCrossAttentionModel(nn.Module):
    def __init__(self, kin_dim, ctx_dim, player_sta_dim, gnn_sta_dim, horizon, cfg):
        super().__init__()
        self.cfg = cfg
        kin_branch_dim = int(cfg.D_MODEL * 0.75)
        self.kin_proj = nn.Linear(kin_dim, kin_branch_dim)
        self.kin_cnn = nn.Sequential(
            ResidualSECNNBlock(kin_branch_dim, kin_branch_dim, kernel_size=3, dropout=0.1),
            ResidualSECNNBlock(kin_branch_dim, kin_branch_dim, kernel_size=5, dropout=0.1)
        )
        ctx_branch_dim = cfg.D_MODEL - kin_branch_dim
        self.ctx_proj = nn.Linear(ctx_dim, ctx_branch_dim)
        self.ctx_cnn = nn.Sequential(ConvBNGELU(ctx_branch_dim, ctx_branch_dim, kernel_size=3), nn.Dropout(0.1))
        self.pos_encoding = nn.Parameter(torch.randn(1, cfg.WINDOW_SIZE, cfg.D_MODEL) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(d_model=cfg.D_MODEL, nhead=cfg.NHEAD, dim_feedforward=cfg.DIM_FEEDFORWARD, dropout=cfg.TRANSFORMER_DROPOUT, activation='gelu', batch_first=True, norm_first=True)
        self.self_attention_encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.NUM_ENCODER_LAYERS)
        self.cross_attention = nn.MultiheadAttention(embed_dim=cfg.D_MODEL, num_heads=cfg.NHEAD, dropout=cfg.TRANSFORMER_DROPOUT, batch_first=True)
        self.cross_attn_norm = nn.LayerNorm(cfg.D_MODEL)
        self.cross_attn_ffn = nn.Sequential(nn.Linear(cfg.D_MODEL, cfg.D_MODEL * 2), nn.GELU(), nn.Dropout(cfg.TRANSFORMER_DROPOUT), nn.Linear(cfg.D_MODEL * 2, cfg.D_MODEL))
        self.cross_attn_ffn_norm = nn.LayerNorm(cfg.D_MODEL)
        self.player_static_encoder = nn.Sequential(nn.Linear(player_sta_dim, cfg.PLAYER_STATIC_HIDDEN_DIM), nn.LayerNorm(cfg.PLAYER_STATIC_HIDDEN_DIM), nn.GELU(), nn.Dropout(cfg.PLAYER_STATIC_DROPOUT))
        self.gnn_static_encoder = nn.Sequential(nn.Linear(gnn_sta_dim, cfg.GNN_HIDDEN_DIM), nn.LayerNorm(cfg.GNN_HIDDEN_DIM), nn.GELU(), nn.Dropout(cfg.GNN_DROPOUT))
        fusion_input_dim = cfg.D_MODEL + cfg.PLAYER_STATIC_HIDDEN_DIM + cfg.GNN_HIDDEN_DIM
        self.head = nn.Sequential(
            nn.Linear(fusion_input_dim, 256), nn.GELU(), nn.Dropout(cfg.HEAD_DROPOUT), 
            nn.Linear(256, 128), nn.GELU(), nn.Dropout(cfg.HEAD_DROPOUT), 
            nn.Linear(128, horizon * 2)
        )

    def forward(self, all_kin_seq, all_ctx_seq, player_sta_feat, gnn_sta_feat, play_mask):
        B, N_Players, Seq_Len, _ = all_kin_seq.shape
        kin_seq_flat = all_kin_seq.view(B * N_Players, Seq_Len, -1); ctx_seq_flat = all_ctx_seq.view(B * N_Players, Seq_Len, -1)
        kin_embed = self.kin_proj(kin_seq_flat).transpose(1, 2); kin_features = self.kin_cnn(kin_embed).transpose(1, 2)
        ctx_embed = self.ctx_proj(ctx_seq_flat).transpose(1, 2); ctx_features = self.ctx_cnn(ctx_embed).transpose(1, 2)
        dyn_embed = torch.cat([kin_features, ctx_features], dim=-1)
        dyn_embed += self.pos_encoding[:, :Seq_Len, :]
        self_attn_out_flat = self.self_attention_encoder(dyn_embed)
        player_repr_flat = self_attn_out_flat.mean(dim=1)
        all_player_repr = player_repr_flat.view(B, N_Players, -1)
        target_repr = all_player_repr[:, 0:1, :]
        cross_attn_mask = (play_mask == 0)
        cross_attn_out, _ = self.cross_attention(query=target_repr, key=all_player_repr, value=all_player_repr, key_padding_mask=cross_attn_mask)
        fused_dyn_repr = self.cross_attn_norm(target_repr + cross_attn_out)
        fused_dyn_repr = fused_dyn_repr + self.cross_attn_ffn(fused_dyn_repr)
        fused_dyn_repr = self.cross_attn_ffn_norm(fused_dyn_repr).squeeze(1)
        player_static_repr = self.player_static_encoder(player_sta_feat); gnn_static_repr = self.gnn_static_encoder(gnn_sta_feat)
        final_repr = torch.cat([fused_dyn_repr, player_static_repr, gnn_static_repr], dim=1)
        out_flat = self.head(final_repr)
        out_reshaped = out_flat.view(-1, self.cfg.MAX_FUTURE_HORIZON, 2)
        return torch.cumsum(out_reshaped, dim=1)

# =============================================================================
# 3. 全局资产加载
# =============================================================================
print("--- [API] 启动全局加载流程 ---")
start_time = time.time()
cfg = Config()

MODELS_JOINT, SCALERS = [], []
KIN_DIM, CTX_DIM, P_STA_DIM, G_STA_DIM = 0, 0, 0, 0

try:
    print(f"从 {cfg.MODEL_DIR} 加载 Scalers 和模型...")
    # 通过加载第一个 fold 的 scaler 来动态确定特征维度
    with open(cfg.MODEL_DIR / "scaler_kin_fold1.pkl", 'rb') as f: KIN_DIM = pickle.load(f).n_features_in_
    with open(cfg.MODEL_DIR / "scaler_ctx_fold1.pkl", 'rb') as f: CTX_DIM = pickle.load(f).n_features_in_
    with open(cfg.MODEL_DIR / "scaler_psta_fold1.pkl", 'rb') as f: P_STA_DIM = pickle.load(f).n_features_in_
    with open(cfg.MODEL_DIR / "scaler_gsta_fold1.pkl", 'rb') as f: G_STA_DIM = pickle.load(f).n_features_in_
    
    print(f"检测到特征维度: Kinematic={KIN_DIM}, Contextual={CTX_DIM}, PlayerStatic={P_STA_DIM}, GNN={G_STA_DIM}")

    for fold in range(1, cfg.N_FOLDS + 1):
        model = MultiStreamCrossAttentionModel(KIN_DIM, CTX_DIM, P_STA_DIM, G_STA_DIM, cfg.MAX_FUTURE_HORIZON, cfg)
        
        # 推荐做法: 先加载到CPU，再整体移动到目标设备
        model.load_state_dict(torch.load(cfg.MODEL_DIR / f"model_joint_fold{fold}.pth", map_location='cpu'))
        
        # 明确地将整个模型移动到GPU并设置为评估模式
        model.to(cfg.DEVICE).eval()
        
        MODELS_JOINT.append(model)

        with open(cfg.MODEL_DIR / f"scaler_kin_fold{fold}.pkl", 'rb') as f: sc_kin = pickle.load(f)
        with open(cfg.MODEL_DIR / f"scaler_ctx_fold{fold}.pkl", 'rb') as f: sc_ctx = pickle.load(f)
        with open(cfg.MODEL_DIR / f"scaler_psta_fold{fold}.pkl", 'rb') as f: sc_ps = pickle.load(f)
        with open(cfg.MODEL_DIR / f"scaler_gsta_fold{fold}.pkl", 'rb') as f: sc_gs = pickle.load(f)
        SCALERS.append((sc_kin, sc_ctx, sc_ps, sc_gs))
    
    print(f"✅ 成功加载 {len(MODELS_JOINT)} 折模型和 scalers。耗时: {time.time() - start_time:.2f} 秒。")

except Exception as e:
    print(f"!!!!!! [错误] 全局加载失败: {e} !!!!!!")
    MODELS_JOINT, SCALERS = [], []

# =============================================================================
# 4. 'predict' 函数
# =============================================================================
def predict(test_df_pl: pl.DataFrame, test_input_df_pl: pl.DataFrame) -> pd.DataFrame:
    if not all([MODELS_JOINT, SCALERS]):
        print("错误: predict 被调用，但全局模型/scalers未加载。返回零预测。")
        return pd.DataFrame({'x': [0.0] * len(test_df_pl), 'y': [0.0] * len(test_df_pl)})

    try:
        # --- 1. 数据准备 ---
        test_df = test_df_pl.to_pandas()
        test_input_df = test_input_df_pl.to_pandas()

        if test_input_df.empty or test_df.empty:
            return pd.DataFrame({'x': [0.0] * len(test_df), 'y': [0.0] * len(test_df)})

        # --- 2. 特征工程 & 序列构建 ---
        (all_kin, all_ctx, tar_p_sta, tar_g_sta, play_m, seq_meta, 
         _, _, _, _) = prepare_sequences_for_multistream(test_input_df, cfg=cfg)

        if not all_kin:
            return pd.DataFrame({'x': [0.0] * len(test_df), 'y': [0.0] * len(test_df)})
            
        # --- 3. 5折集成预测 ---
        all_folds_preds_xy = []
        num_samples = len(all_kin)

        for fold in range(cfg.N_FOLDS):
            model = MODELS_JOINT[fold]
            sc_kin, sc_ctx, sc_ps, sc_gs = SCALERS[fold]
            
            fold_xy_batches = []

            for i in range(0, num_samples, cfg.INFERENCE_BATCH_SIZE):
                end = min(i + cfg.INFERENCE_BATCH_SIZE, num_samples)
                batch_kin_raw = all_kin[i:end]
                batch_ctx_raw = all_ctx[i:end]
                batch_p_sta_raw = tar_p_sta[i:end]
                batch_g_sta_raw = tar_g_sta[i:end]
                batch_mask = play_m[i:end]

                kin_sc = [np.stack([sc_kin.transform(p_seq) for p_seq in seq]).astype(np.float32) for seq in batch_kin_raw]
                ctx_sc = [np.stack([sc_ctx.transform(p_seq) for p_seq in seq]).astype(np.float32) for seq in batch_ctx_raw]
                p_sta_sc = sc_ps.transform(np.array(batch_p_sta_raw)).astype(np.float32)
                g_sta_sc = sc_gs.transform(np.array(batch_g_sta_raw)).astype(np.float32)

                kin_tensor = torch.tensor(np.stack(kin_sc)).to(cfg.DEVICE)
                ctx_tensor = torch.tensor(np.stack(ctx_sc)).to(cfg.DEVICE)
                p_sta_tensor = torch.tensor(p_sta_sc).to(cfg.DEVICE)
                g_sta_tensor = torch.tensor(g_sta_sc).to(cfg.DEVICE)
                mask_tensor = torch.tensor(np.stack(batch_mask)).to(cfg.DEVICE)

                with torch.no_grad():
                    batch_pred_xy = model(kin_tensor, ctx_tensor, p_sta_tensor, g_sta_tensor, mask_tensor)
                
                fold_xy_batches.append(batch_pred_xy.cpu().numpy())

            all_folds_preds_xy.append(np.vstack(fold_xy_batches))

        avg_xy = np.mean(all_folds_preds_xy, axis=0)
        
        # --- 4. 后处理与格式化输出 ---
        unified_test_df = unify_left_direction(test_input_df.copy())
        last_known_positions = unified_test_df.groupby(['game_id', 'play_id', 'nfl_id'])[['x', 'y']].last().reset_index()

        meta_df = pd.DataFrame(seq_meta)
        meta_df = meta_df.merge(last_known_positions, on=['game_id', 'play_id', 'nfl_id'], how='left')

        pred_map = {}
        for i, row in meta_df.iterrows():
            key = (int(row['game_id']), int(row['play_id']), int(row['nfl_id']))
            pred_map[key] = {
                'dxy': avg_xy[i],
                'last_x': row['x'],
                'last_y': row['y'],
                'play_dir_right': row['play_direction'] == 'right'
            }
            
        results = []
        for _, row in test_df.iterrows():
            key = (int(row['game_id']), int(row['play_id']), int(row['nfl_id']))
            if key in pred_map:
                preds = pred_map[key]
                frame_idx = min(int(row['frame_id']) - 1, len(preds['dxy']) - 1)
                
                pred_dx, pred_dy = preds['dxy'][frame_idx]
                
                pred_abs_x_u = preds['last_x'] + pred_dx
                pred_abs_y_u = preds['last_y'] + pred_dy
                
                x_orig, y_orig = invert_to_original_direction(pred_abs_x_u, pred_abs_y_u, preds['play_dir_right'])
                results.append({'x': x_orig, 'y': y_orig})
            else:
                # 如果某个球员因为某种原因没有生成序列，则返回 (0,0)
                results.append({'x': 0.0, 'y': 0.0})

        predictions_df = pd.DataFrame(results)
        
        assert len(predictions_df) == len(test_df)
        return predictions_df[['x', 'y']]

    except Exception as e:
        print(f"!!!!!! [错误] 'predict' 函数执行时发生异常: {e} !!!!!!")
        import traceback
        traceback.print_exc()
        # 在任何错误情况下，返回一个符合格式要求的零预测DataFrame
        return pd.DataFrame({'x': [0.0] * len(test_df_pl), 'y': [0.0] * len(test_df_pl)})

# =============================================================================
# 5. API 服务器启动
# =============================================================================
inference_server = kaggle_evaluation.nfl_inference_server.NFLInferenceServer(predict)

# 判断是在Kaggle真实提交环境还是本地调试环境
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("--- [API] 启动 NFLInferenceServer (竞赛重跑模式) ---")
    inference_server.serve()
else:
    print("--- [API] 启动本地网关模拟 ---")
    # 此处路径指向Kaggle平台上的竞赛数据路径
    inference_server.run_local_gateway(('/kaggle/input/nfl-big-data-bowl-2026-prediction/',))
    print("--- [API] 本地网关模拟完成。")