## Note
The code is mostly based on [NFL Big Data Bowl 2026 - Geometry GNN [LB: .586]](https://www.kaggle.com/code/atl132/nfl-big-data-bowl-2026-geometry-gnn-lb-586), some functions are copied from [NFL2026 Prediction|Openmind on Hsiaosuan Exp‚ù•(^_-)](https://www.kaggle.com/code/atl132/nfl2026-prediction-openmind-on-hsiaosuan-exp)

In [None]:
import random
import os,sys
import pickle
from pathlib import Path
import warnings
import math
from typing import List

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LRScheduler
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold,StratifiedGroupKFold
from sklearn.cluster import KMeans

# warnings.filterwarnings('ignore')

In [None]:
# ============================================================================
# CONFIG
# ============================================================================

class Config:
    DATA_DIR = Path("/kaggle/input/nfl-big-data-bowl-2026-prediction/")
    OUTPUT_DIR = Path("./outputs")
    OUTPUT_DIR.mkdir(exist_ok=True)
    DEBUG=True

    SEED = 1
    N_FOLDS = 5
    BATCH_SIZE = 256
    EPOCHS = 2000
    PATIENCE = 30
    LEARNING_RATE = 2e-4
    AUG=False
    Split="sgkf"
    scheduler="ReduceLROnPlateau"

    Model_Name="STTransformer" #STSeqModel
    MAX_PLAYERS=17
    MAX_PLAYERS2Predict=9
    WINDOW_SIZE = 10
    DMODEL = 256
    MAX_FUTURE_HORIZON = 60
      
    FIELD_X_MIN, FIELD_X_MAX = 0.0, 120.0
    FIELD_Y_MIN, FIELD_Y_MAX = 0.0, 53.3
    Target_Scale_x, Target_Scale_y = 1.,1.
    
    K_NEIGH = 6
    RADIUS = 30.0
    TAU = 8.0
    N_ROUTE_CLUSTERS = 7
    
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.enable_nested_tensor = False
# ============================================================================
# GEOMETRIC BASELINE - THE BREAKTHROUGH
# ============================================================================

   
def compute_geometric_endpoint(df):
    """
    Compute where each player SHOULD end up based on geometry.
    This is the deterministic part - no learning needed.
    """
    df = df.copy()
    
    # Time to play end
    if 'num_frames_output' in df.columns:
        t_total = df['num_frames_output'] / 10.0
    else:
        t_total = 3.0
    
    df['time_to_endpoint'] = t_total
    
    # Initialize with momentum (default rule)
    df['geo_endpoint_x'] = df['x'] + df['velocity_x'] * t_total
    df['geo_endpoint_y'] = df['y'] + df['velocity_y'] * t_total
    
    # Rule 1: Targeted Receivers converge to ball
    if 'ball_land_x' in df.columns:
        receiver_mask = df['player_role'] == 'Targeted Receiver'
        df.loc[receiver_mask, 'geo_endpoint_x'] = df.loc[receiver_mask, 'ball_land_x']
        df.loc[receiver_mask, 'geo_endpoint_y'] = df.loc[receiver_mask, 'ball_land_y']
        
        # Rule 2: Defenders mirror receivers (maintain offset)
        defender_mask = df['player_role'] == 'Defensive Coverage'
        has_mirror = df.get('mirror_offset_x', 0).notna() & (df.get('mirror_wr_dist', 50) < 15)
        coverage_mask = defender_mask & has_mirror
        
        df.loc[coverage_mask, 'geo_endpoint_x'] = (
            df.loc[coverage_mask, 'ball_land_x'] + 
            df.loc[coverage_mask, 'mirror_offset_x'].fillna(0)
        )
        df.loc[coverage_mask, 'geo_endpoint_y'] = (
            df.loc[coverage_mask, 'ball_land_y'] + 
            df.loc[coverage_mask, 'mirror_offset_y'].fillna(0)
        )
    
    # Clip to field
    df['geo_endpoint_x'] = df['geo_endpoint_x'].clip(Config.FIELD_X_MIN, Config.FIELD_X_MAX)
    df['geo_endpoint_y'] = df['geo_endpoint_y'].clip(Config.FIELD_Y_MIN, Config.FIELD_Y_MAX)
    
    return df

def add_geometric_features(df):
    """Add features that describe the geometric solution"""
    df = compute_geometric_endpoint(df)
    
    # Vector to geometric endpoint
    df['geo_vector_x'] = df['geo_endpoint_x'] - df['x']
    df['geo_vector_y'] = df['geo_endpoint_y'] - df['y']
    df['geo_distance'] = np.sqrt(df['geo_vector_x']**2 + df['geo_vector_y']**2)
    
    # Required velocity to reach geometric endpoint
    t = df['time_to_endpoint'] + 0.1
    df['geo_required_vx'] = df['geo_vector_x'] / t
    df['geo_required_vy'] = df['geo_vector_y'] / t
    
    # Current velocity vs required
    df['geo_velocity_error_x'] = df['geo_required_vx'] - df['velocity_x']
    df['geo_velocity_error_y'] = df['geo_required_vy'] - df['velocity_y']
    df['geo_velocity_error'] = np.sqrt(
        df['geo_velocity_error_x']**2 + df['geo_velocity_error_y']**2
    )
    
    # Required constant acceleration (a = 2*Œîx/t¬≤)
    t_sq = t * t
    df['geo_required_ax'] = 2 * df['geo_vector_x'] / t_sq
    df['geo_required_ay'] = 2 * df['geo_vector_y'] / t_sq
    df['geo_required_ax'] = df['geo_required_ax'].clip(-10, 10)
    df['geo_required_ay'] = df['geo_required_ay'].clip(-10, 10)
    
    # Alignment with geometric path
    velocity_mag = np.sqrt(df['velocity_x']**2 + df['velocity_y']**2)
    geo_unit_x = df['geo_vector_x'] / (df['geo_distance'] + 0.1)
    geo_unit_y = df['geo_vector_y'] / (df['geo_distance'] + 0.1)
    df['geo_alignment'] = (
        df['velocity_x'] * geo_unit_x + df['velocity_y'] * geo_unit_y
    ) / (velocity_mag + 0.1)
    
    # Role-specific geometric quality
    df['geo_receiver_urgency'] = df['is_receiver'] * df['geo_distance'] / (t + 0.1)
    df['geo_defender_coupling'] = df['is_coverage'] * (1.0 / (df.get('mirror_wr_dist', 50) + 1.0))
    
    return df

# ============================================================================
# PROVEN FEATURE ENGINEERING 
# ============================================================================

def get_velocity(speed, direction_deg):
    theta = np.deg2rad(direction_deg)
    return speed * np.sin(theta), speed * np.cos(theta)

def height_to_feet(height_str):
    try:
        ft, inches = map(int, str(height_str).split('-'))
        return ft + inches/12
    except:
        return 6.0

def get_opponent_features(input_df):
    """Enhanced opponent interaction with MIRROR WR tracking"""
    features = []
    
    for (gid, pid), group in tqdm(input_df.groupby(['game_id', 'play_id']), 
                                   desc="üèà Opponents", leave=False):
        last = group.sort_values('frame_id').groupby('nfl_id').last()
        
        if len(last) < 2:
            continue
            
        positions = last[['x', 'y']].values
        sides = last['player_side'].values
        speeds = last['s'].values
        directions = last['dir'].values
        roles = last['player_role'].values
        
        receiver_mask = np.isin(roles, ['Targeted Receiver', 'Other Route Runner'])
        
        for i, (nid, side, role) in enumerate(zip(last.index, sides, roles)):
            opp_mask = sides != side
            
            feat = {
                'game_id': gid, 'play_id': pid, 'nfl_id': nid,
                'nearest_opp_dist': 50.0, 'closing_speed': 0.0,
                'num_nearby_opp_3': 0, 'num_nearby_opp_5': 0,
                'mirror_wr_vx': 0.0, 'mirror_wr_vy': 0.0,
                'mirror_offset_x': 0.0, 'mirror_offset_y': 0.0,
                'mirror_wr_dist': 50.0,
            }
            
            if not opp_mask.any():
                features.append(feat)
                continue
            
            opp_positions = positions[opp_mask]
            distances = np.sqrt(((positions[i] - opp_positions)**2).sum(axis=1))
            
            if len(distances) == 0:
                features.append(feat)
                continue
                
            nearest_idx = distances.argmin()
            feat['nearest_opp_dist'] = distances[nearest_idx]
            feat['num_nearby_opp_3'] = (distances < 3.0).sum()
            feat['num_nearby_opp_5'] = (distances < 5.0).sum()
            if role == 'Defensive Coverage' and receiver_mask.any():
                rec_positions = positions[receiver_mask]
                rec_distances = np.sqrt(((positions[i] - rec_positions)**2).sum(axis=1))
                
                if len(rec_distances) > 0:
                    closest_rec_idx = rec_distances.argmin()
                    rec_indices = np.where(receiver_mask)[0]
                    actual_rec_idx = rec_indices[closest_rec_idx]
                    
                    rec_vx, rec_vy = get_velocity(speeds[actual_rec_idx], directions[actual_rec_idx])
                    
                    feat['mirror_wr_vx'] = rec_vx
                    feat['mirror_wr_vy'] = rec_vy
                    feat['mirror_wr_dist'] = rec_distances[closest_rec_idx]
                    feat['mirror_offset_x'] = positions[i][0] - rec_positions[closest_rec_idx][0]
                    feat['mirror_offset_y'] = positions[i][1] - rec_positions[closest_rec_idx][1]
            
            features.append(feat)
    
    return pd.DataFrame(features)

def extract_route_patterns(input_df, kmeans=None, scaler=None, fit=True):
    """Route clustering"""
    route_features = []
    
    for (gid, pid, nid), group in tqdm(input_df.groupby(['game_id', 'play_id', 'nfl_id']), 
                                        desc="üõ£Ô∏è  Routes", leave=False):
        traj = group.sort_values('frame_id').tail(5)
        
        if len(traj) < 3:
            continue
        
        positions = traj[['x', 'y']].values
        speeds = traj['s'].values
        
        total_dist = np.sum(np.sqrt(np.diff(positions[:, 0])**2 + np.diff(positions[:, 1])**2))
        displacement = np.sqrt((positions[-1, 0] - positions[0, 0])**2 + 
                              (positions[-1, 1] - positions[0, 1])**2)
        straightness = displacement / (total_dist + 0.1)
        
        angles = np.arctan2(np.diff(positions[:, 1]), np.diff(positions[:, 0]))
        if len(angles) > 1:
            angle_changes = np.abs(np.diff(angles))
            max_turn = np.max(angle_changes)
            mean_turn = np.mean(angle_changes)
        else:
            max_turn = mean_turn = 0
        
        speed_mean = speeds.mean()
        speed_change = speeds[-1] - speeds[0] if len(speeds) > 1 else 0
        dx = positions[-1, 0] - positions[0, 0]
        dy = positions[-1, 1] - positions[0, 1]
        
        route_features.append({
            'game_id': gid, 'play_id': pid, 'nfl_id': nid,
            'traj_straightness': straightness,
            'traj_max_turn': max_turn,
            'traj_mean_turn': mean_turn,
            'traj_depth': abs(dx),
            'traj_width': abs(dy),
            'speed_mean': speed_mean,
            'speed_change': speed_change,
        })
    
    route_df = pd.DataFrame(route_features)
    feat_cols = ['traj_straightness', 'traj_max_turn', 'traj_mean_turn',
                 'traj_depth', 'traj_width', 'speed_mean', 'speed_change']
    X = route_df[feat_cols].fillna(0)
    
    if fit:
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        kmeans = KMeans(n_clusters=Config.N_ROUTE_CLUSTERS, random_state=Config.SEED, n_init=10)
        route_df['route_pattern'] = kmeans.fit_predict(X_scaled)
        return route_df, kmeans, scaler
    else:
        X_scaled = scaler.transform(X)
        route_df['route_pattern'] = kmeans.predict(X_scaled)
        return route_df

def compute_neighbor_embeddings(input_df, k_neigh=Config.K_NEIGH, 
                                radius=Config.RADIUS, tau=Config.TAU):
    """graph features"""
    print("üï∏Ô∏è  graph features...")
    
    cols_needed = ["game_id", "play_id", "nfl_id", "frame_id", "x", "y", 
                   "velocity_x", "velocity_y", "player_side"]
    src = input_df[cols_needed].copy()
    
    last = (src.sort_values(["game_id", "play_id", "nfl_id", "frame_id"])
               .groupby(["game_id", "play_id", "nfl_id"], as_index=False)
               .tail(1)
               .rename(columns={"frame_id": "last_frame_id"})
               .reset_index(drop=True))
    
    tmp = last.merge(
        src.rename(columns={
            "frame_id": "nb_frame_id", "nfl_id": "nfl_id_nb",
            "x": "x_nb", "y": "y_nb", 
            "velocity_x": "vx_nb", "velocity_y": "vy_nb", 
            "player_side": "player_side_nb"
        }),
        left_on=["game_id", "play_id", "last_frame_id"],
        right_on=["game_id", "play_id", "nb_frame_id"],
        how="left"
    )
    
    tmp = tmp[tmp["nfl_id_nb"] != tmp["nfl_id"]]
    tmp["dx"] = tmp["x_nb"] - tmp["x"]
    tmp["dy"] = tmp["y_nb"] - tmp["y"]
    tmp["dvx"] = tmp["vx_nb"] - tmp["velocity_x"]
    tmp["dvy"] = tmp["vy_nb"] - tmp["velocity_y"]
    tmp["dist"] = np.sqrt(tmp["dx"]**2 + tmp["dy"]**2)
    
    tmp = tmp[np.isfinite(tmp["dist"]) & (tmp["dist"] > 1e-6)]
    if radius is not None:
        tmp = tmp[tmp["dist"] <= radius]
    
    tmp["is_ally"] = (tmp["player_side_nb"] == tmp["player_side"]).astype(np.float32)
    
    keys = ["game_id", "play_id", "nfl_id"]
    tmp["rnk"] = tmp.groupby(keys)["dist"].rank(method="first")
    if k_neigh is not None:
        tmp = tmp[tmp["rnk"] <= float(k_neigh)]
    
    tmp["w"] = np.exp(-tmp["dist"] / float(tau))
    sum_w = tmp.groupby(keys)["w"].transform("sum")
    tmp["wn"] = np.where(sum_w > 0, tmp["w"] / sum_w, 0.0)
    
    tmp["wn_ally"] = tmp["wn"] * tmp["is_ally"]
    tmp["wn_opp"] = tmp["wn"] * (1.0 - tmp["is_ally"])
    
    for col in ["dx", "dy", "dvx", "dvy"]:
        tmp[f"{col}_ally_w"] = tmp[col] * tmp["wn_ally"]
        tmp[f"{col}_opp_w"] = tmp[col] * tmp["wn_opp"]
    
    tmp["dist_ally"] = np.where(tmp["is_ally"] > 0.5, tmp["dist"], np.nan)
    tmp["dist_opp"] = np.where(tmp["is_ally"] < 0.5, tmp["dist"], np.nan)
    
    ag = tmp.groupby(keys).agg(
        gnn_ally_dx_mean=("dx_ally_w", "sum"),
        gnn_ally_dy_mean=("dy_ally_w", "sum"),
        gnn_ally_dvx_mean=("dvx_ally_w", "sum"),
        gnn_ally_dvy_mean=("dvy_ally_w", "sum"),
        gnn_opp_dx_mean=("dx_opp_w", "sum"),
        # gnn_opp_dy_mean=("dy_opp_w", "sum"),
        gnn_opp_dvx_mean=("dvx_opp_w", "sum"),
        gnn_opp_dvy_mean=("dvy_opp_w", "sum"),
        gnn_ally_cnt=("is_ally", "sum"),
        gnn_opp_cnt=("is_ally", lambda s: float(len(s) - s.sum())),
        gnn_ally_dmin=("dist_ally", "min"),
        gnn_ally_dmean=("dist_ally", "mean"),
        gnn_opp_dmin=("dist_opp", "min"),
        gnn_opp_dmean=("dist_opp", "mean"),
    ).reset_index()
    
    near = tmp.loc[tmp["rnk"] <= 3, keys + ["rnk", "dist"]].copy()
    if len(near) > 0:
        near["rnk"] = near["rnk"].astype(int)
        dwide = near.pivot_table(index=keys, columns="rnk", values="dist", aggfunc="first")
        dwide = dwide.rename(columns={1: "gnn_d1", 2: "gnn_d2", 3: "gnn_d3"}).reset_index()
        ag = ag.merge(dwide, on=keys, how="left")
    
    for c in ["gnn_ally_dx_mean", "gnn_ally_dy_mean", "gnn_ally_dvx_mean", "gnn_ally_dvy_mean",
              "gnn_opp_dx_mean",  "gnn_opp_dvx_mean", "gnn_opp_dvy_mean"]:
        ag[c] = ag[c].fillna(0.0)
    for c in ["gnn_ally_cnt", "gnn_opp_cnt"]:
        ag[c] = ag[c].fillna(0.0)
    for c in ["gnn_ally_dmin", "gnn_opp_dmin", "gnn_ally_dmean", "gnn_opp_dmean", 
              "gnn_d1", "gnn_d2", "gnn_d3"]:
        ag[c] = ag[c].fillna(radius if radius is not None else 30.0)
    
    return ag

# ============================================================================
# SEQUENCE PREPARATION WITH GEOMETRIC FEATURES
# ============================================================================
    
def prepare_sequences_geometric(input_df, output_df=None, test_template=None, 
                                is_training=True, window_size=10,
                                route_kmeans=None, route_scaler=None):
    """YOUR 154 features + 13 geometric features = 167 total"""
    
    print(f"\n{'='*80}")
    print(f"PREPARING GEOMETRIC SEQUENCES")
    print(f"{'='*80}")
    
    input_df = input_df.copy()
    input_df = input_df.sort_values(['game_id', 'play_id', 'nfl_id', 'frame_id'])
    input_df["frames_bin"] = pd.qcut(input_df["num_frames_output"], q=3, labels=False)
    input_df["stratify_label"] = (
        input_df["player_role"].astype(str) + "_" +
        input_df["player_side"].astype(str) + "_" +
        input_df["play_direction"].astype(str) + "_" +
        input_df["frames_bin"].astype(str)
    )
    print("Step 1: Base features...")
    
    dir_rad = np.deg2rad(input_df['dir'].fillna(0))
    input_df['velocity_x'] = input_df['s'] * np.sin(dir_rad)
    input_df['velocity_y'] = input_df['s'] * np.cos(dir_rad)
    input_df['acceleration_x'] = input_df['a'] * np.cos(dir_rad)
    input_df['acceleration_y'] = input_df['a'] * np.sin(dir_rad)
    input_df['speed_squared'] = input_df['s'] ** 2

    input_df['momentum_x'] = input_df['velocity_x'] * input_df['player_weight']
    input_df['momentum_y'] = input_df['velocity_y'] * input_df['player_weight']
    input_df['kinetic_energy'] = 0.5 * input_df['player_weight'] * (input_df['s'] ** 2)
    
    input_df['orientation_diff'] = np.abs(input_df['o'] - input_df['dir'])
    input_df['orientation_diff'] = np.minimum(input_df['orientation_diff'], 360 - input_df['orientation_diff'])
    
    input_df['is_offense'] = (input_df['player_side'] == 'Offense').astype(int)
    input_df['is_defense'] = (input_df['player_side'] == 'Defense').astype(int)
    input_df['is_receiver'] = (input_df['player_role'] == 'Targeted Receiver').astype(int)
    input_df['is_coverage'] = (input_df['player_role'] == 'Defensive Coverage').astype(int)
    input_df['is_passer'] = (input_df['player_role'] == 'Passer').astype(int)
    input_df['player_to_predict'] = input_df['player_to_predict'].astype(int)
    
    if 'ball_land_x' in input_df.columns:
        ball_dx = input_df['ball_land_x'] - input_df['x']
        ball_dy = input_df['ball_land_y'] - input_df['y']
        input_df['dist_to_ball'] = np.sqrt(ball_dx**2 + ball_dy**2)
        input_df['angle_to_ball'] = np.arctan2(ball_dy, ball_dx)
        input_df['ball_direction_x'] = ball_dx / (input_df['dist_to_ball'] + 1e-6)
        input_df['ball_direction_y'] = ball_dy / (input_df['dist_to_ball'] + 1e-6)
        input_df['closing_speed_ball'] = (
            input_df['velocity_x'] * input_df['ball_direction_x'] +
            input_df['velocity_y'] * input_df['ball_direction_y']
        )
        input_df['velocity_toward_ball'] = (
            input_df['velocity_x'] * np.cos(input_df['angle_to_ball']) + 
            input_df['velocity_y'] * np.sin(input_df['angle_to_ball'])
        )
        input_df['velocity_alignment'] = np.cos(input_df['angle_to_ball'] - dir_rad)
        input_df['angle_diff'] = np.abs(input_df['o'] - np.degrees(input_df['angle_to_ball']))
        input_df['angle_diff'] = np.minimum(input_df['angle_diff'], 360 - input_df['angle_diff'])
        input_df['dist_squared'] = input_df['dist_to_ball'] ** 2
        input_df['is_out'] = ((input_df['ball_land_x'] < 0) | (input_df['ball_land_x'] > 120) |
                      (input_df['ball_land_y'] < 0) | (input_df['ball_land_y'] > 53.3)).astype(int)

    
    print("Step 2: Advanced features...")  
    opp_features = get_opponent_features(input_df)
    input_df = input_df.merge(opp_features, on=['game_id', 'play_id', 'nfl_id'], how='left')
    
    if is_training:
        route_features, route_kmeans, route_scaler = extract_route_patterns(input_df)
    else:
        route_features = extract_route_patterns(input_df, route_kmeans, route_scaler, fit=False)
    input_df = input_df.merge(route_features, on=['game_id', 'play_id', 'nfl_id'], how='left')
    
    gnn_features = compute_neighbor_embeddings(input_df)
    input_df = input_df.merge(gnn_features, on=['game_id', 'play_id', 'nfl_id'], how='left')
    
    if 'nearest_opp_dist' in input_df.columns:
        input_df['pressure'] = 1 / np.maximum(input_df['nearest_opp_dist'], 0.5)
        # input_df['under_pressure'] = (input_df['nearest_opp_dist'] < 3).astype(int)
        input_df['pressure_x_speed'] = input_df['pressure'] * input_df['s']
    
    if 'mirror_wr_vx' in input_df.columns:
        input_df['mirror_offset_dist'] = np.sqrt(
            input_df['mirror_offset_x']**2 + input_df['mirror_offset_y']**2
        )
    
    print("Step 3: Temporal features...")   
    gcols = ['game_id', 'play_id', 'nfl_id']
    
    for lag in [1, 2, 3, 4, 5]:
        for col in ['x', 'y', 'velocity_x', 'velocity_y', 's', 'a']:
            if col in input_df.columns:
                input_df[f'{col}_lag{lag}'] = input_df.groupby(gcols)[col].shift(lag)
    
    for window in [3, 5]:
        for col in ['x', 'y', 'velocity_x', 'velocity_y', 's']:
            if col in input_df.columns:
                input_df[f'{col}_rolling_mean_{window}'] = (
                    input_df.groupby(gcols)[col]
                      .rolling(window, min_periods=1).mean()
                      .reset_index(level=[0,1,2], drop=True)
                )
                input_df[f'{col}_rolling_std_{window}'] = (
                    input_df.groupby(gcols)[col]
                      .rolling(window, min_periods=1).std()
                      .reset_index(level=[0,1,2], drop=True)
                )
    
    for col in ['velocity_x', 'velocity_y']:
        if col in input_df.columns:
            input_df[f'{col}_delta'] = input_df.groupby(gcols)[col].diff()
    
    input_df['velocity_x_ema'] = input_df.groupby(gcols)['velocity_x'].transform(
        lambda x: x.ewm(alpha=0.3, adjust=False).mean()
    )
    input_df['velocity_y_ema'] = input_df.groupby(gcols)['velocity_y'].transform(
        lambda x: x.ewm(alpha=0.3, adjust=False).mean()
    )

    
    print("Step 4: Time features...")
    
    if 'num_frames_output' in input_df.columns:
        max_frames = input_df['num_frames_output']
        
        input_df['max_play_duration'] = max_frames / 10.0
        input_df['frame_time'] = input_df['frame_id'] / 10.0
        input_df['progress_ratio'] = input_df['frame_id'] / np.maximum(max_frames, 1)
        input_df['time_remaining'] = (max_frames - input_df['frame_id']) / 10.0
        
        input_df['expected_x_at_ball'] = input_df['x'] + input_df['velocity_x'] * input_df['frame_time']
        input_df['expected_y_at_ball'] = input_df['y'] + input_df['velocity_y'] * input_df['frame_time']
        
        if 'ball_land_x' in input_df.columns:
            input_df['error_from_ball_x'] = input_df['expected_x_at_ball'] - input_df['ball_land_x']
            input_df['error_from_ball_y'] = input_df['expected_y_at_ball'] - input_df['ball_land_y']
            input_df['error_from_ball'] = np.sqrt(
                input_df['error_from_ball_x']**2 + input_df['error_from_ball_y']**2
            )
            
            input_df['weighted_dist_by_time'] = input_df['dist_to_ball'] / (input_df['frame_time'] + 0.1)
            input_df['dist_scaled_by_progress'] = input_df['dist_to_ball'] * (1 - input_df['progress_ratio'])
        

        input_df['velocity_x_progress'] = input_df['velocity_x'] * input_df['progress_ratio']
        input_df['velocity_y_progress'] = input_df['velocity_y'] * input_df['progress_ratio']
        input_df['speed_scaled_by_time_left'] = input_df['s'] * input_df['time_remaining']
        

        input_df['length_ratio'] = max_frames / 30.0
    
    # üéØ THE BREAKTHROUGH: Add geometric features
    print("Step 5: üéØ Geometric endpoint features...")
    input_df = add_geometric_features(input_df)
    
    print("Step 6: Building feature list...")
    
    # Basic features
    feature_cols = [
        'x', 'y', 's', 'a', 'o', 'dir', 'frame_id', 'ball_land_x', 'ball_land_y',
        'player_height_feet', 'bmi',
        'velocity_x', 'velocity_y', 'acceleration_x', 'acceleration_y',
        'momentum_x', 'momentum_y', 'kinetic_energy',
         'orientation_diff',
        'is_offense', 'is_defense', 'is_receiver', 'is_coverage', 'is_passer',
        'player_to_predict','is_out',
        'speed_squared','dist_squared',
        'dist_to_ball',  'angle_to_ball', 
        'ball_direction_x', 'ball_direction_y', 'closing_speed_ball',
        'velocity_toward_ball', 'velocity_alignment', 'angle_diff',
        'nearest_opp_dist',  'num_nearby_opp_3', 'num_nearby_opp_5',
        'mirror_wr_vx', 'mirror_wr_vy', 'mirror_offset_x', 'mirror_offset_y',
        'pressure',  'pressure_x_speed', 
        'mirror_offset_dist', 'mirror_alignment',
        'route_pattern', 'traj_straightness', 'traj_max_turn', 'traj_mean_turn',
        'traj_depth', 'traj_width', 'speed_mean', 'speed_change',
        'gnn_ally_dx_mean', 'gnn_ally_dy_mean', 'gnn_ally_dvx_mean', 'gnn_ally_dvy_mean',
        'gnn_opp_dx_mean',  'gnn_opp_dvx_mean', 'gnn_opp_dvy_mean',
        'gnn_ally_cnt', 'gnn_opp_cnt',
        'gnn_ally_dmin', 'gnn_ally_dmean', 'gnn_opp_dmin', 'gnn_opp_dmean',
        'gnn_d1', 'gnn_d2', 'gnn_d3',
    ]
    
    for lag in [1, 2, 3, 4, 5]:
        for col in ['x', 'y', 'velocity_x', 'velocity_y', 's', 'a']:
            feature_cols.append(f'{col}_lag{lag}')
    
    for window in [3, 5]:
        for col in ['x', 'y', 'velocity_x', 'velocity_y', 's']:
            feature_cols.append(f'{col}_rolling_mean_{window}')
            feature_cols.append(f'{col}_rolling_std_{window}')
    
    feature_cols.extend(['velocity_x_delta', 'velocity_y_delta'])
    feature_cols.extend(['velocity_x_ema', 'velocity_y_ema'])    
    feature_cols.extend([
        'max_play_duration', 'frame_time', 
        'expected_x_at_ball', 'expected_y_at_ball', 
        'error_from_ball_x', 'error_from_ball_y', 'error_from_ball',
         'weighted_dist_by_time', 
        'velocity_x_progress', 'velocity_y_progress', 'dist_scaled_by_progress',
        'speed_scaled_by_time_left', 'actual_play_length', 'length_ratio',
    ])
    
    # üéØ Add 13 geometric features
    feature_cols.extend([
        'geo_endpoint_x', 'geo_endpoint_y',
        'geo_vector_x', 'geo_vector_y', 
        'geo_required_vx', 'geo_required_vy',
        'geo_velocity_error_x', 'geo_velocity_error_y', 'geo_velocity_error',
        'geo_required_ax', 'geo_required_ay',
        'geo_alignment',
    ])
    
    feature_cols = [c for c in feature_cols if c in input_df.columns]
    print(f"‚úì Using {len(feature_cols)} features ÔºÅÔºÅ")
    
    print("Step 7: Creating sequences...")

    input_df.set_index(['game_id', 'play_id'], inplace=True)

    # group plays
    grouped = input_df.groupby(level=['game_id', 'play_id'])
    target_rows = output_df if is_training else test_template
    target_groups = target_rows[['game_id', 'play_id']].drop_duplicates()
    sequences, targets_dx, targets_dy, targets_frame_ids, sequence_ids, stratify_labels = [], [], [], [], [], []
    masks=[]
    for _, row in tqdm(target_groups.iterrows(), total=len(target_groups), desc="Creating sequences",dynamic_ncols=True):
        key = (row['game_id'], row['play_id'])
        group_df = grouped.get_group(key)
        nfls=list(group_df[group_df.player_to_predict ==1].nfl_id.unique())
        nfls_bg=list(group_df[group_df.player_to_predict !=1].nfl_id.unique())
        nfls_all=nfls+nfls_bg
        num_players=len(nfls_all)
        num = group_df['frame_id'].max()
        start_id = max(1, num - window_size + 1)
        input_window = group_df[group_df['frame_id'] >= start_id]
        num_frames = len(input_window) // num_players
        assert len(input_window) == num_frames * num_players, f"{len(input_window)}!={num_frames}*{num_players}"      
        if num_frames < window_size:
            padded_dfs=[]
            pad_len = window_size - num_frames
            for nfl in nfls_all:
                group=group_df[group_df.nfl_id==nfl]
                first_row =group[group['frame_id'] == 1]
                pad_block = pd.concat([first_row] * pad_len, ignore_index=True)
                padded_group = pd.concat([pad_block, group], ignore_index=True)
                padded_dfs.append(padded_group)
            input_window=pd.concat(padded_dfs, ignore_index=True)
        else:
            pass
        
        #input_window = input_window.fillna(group_df.mean(numeric_only=True)).infer_objects(copy=False)
        seq = input_window[feature_cols].values
        
        seq_3d = seq.reshape(window_size, num_players, -1)
        if num_players < Config.MAX_PLAYERS:
            pad_players = Config.MAX_PLAYERS - num_players
            last_player = seq_3d[:, -1:, :]                 # (T, 1, F)
            padding = np.tile(last_player, (1, pad_players, 1))  # (T, pad, F)
            seq_3d = np.concatenate([seq_3d, padding], axis=1)   # (T, MAX, F)
        mask = np.zeros(Config.MAX_PLAYERS, dtype=np.int32)
        mask[:num_players] = 1
        stratify_label=input_window["stratify_label"].values[0]
        
        if np.isnan(seq).any():
            if is_training:
                nan_cols = input_window[feature_cols].columns[input_window[feature_cols].isna().any()].tolist()
                #print("features that contain NaN:", nan_cols)
                #continue
            seq_3d = np.nan_to_num(seq_3d, nan=0.0)        
        sequences.append(seq_3d)
        masks.append(mask)
        stratify_labels.append(stratify_label)
        
        if is_training:
            dxs,dys=[],[]
            for nfl in nfls:
                input_window=group_df[group_df["nfl_id"]==nfl]
                out_grp = output_df[
                    (output_df['game_id']==row['game_id']) &
                    (output_df['play_id']==row['play_id']) &
                    (output_df['nfl_id']==nfl)
                ].sort_values('frame_id')
            
                last_x = input_window.iloc[-1]['x']
                last_y = input_window.iloc[-1]['y']
                
                dx = out_grp['x'].values - last_x
                dy = out_grp['y'].values - last_y
                dxs.append(dx);dys.append(dy)
            dxs=np.stack(dxs,axis=0)
            dys=np.stack(dys,axis=0)
            targets_dx.append(dxs)
            targets_dy.append(dys)
            targets_frame_ids.append(out_grp['frame_id'].values)
        
        sequence_ids.append({
            'game_id': key[0],
            'play_id': key[1],
        })
    
    print(f"‚úì Created {len(sequences)} sequences")
    
    if is_training:
        return (sequences, targets_dx, targets_dy, targets_frame_ids, sequence_ids, 
                 route_kmeans, route_scaler,feature_cols,stratify_labels,masks)
    return sequences, sequence_ids,masks

In [None]:
# ============================================================================
# MODEL ARCHITECTURE 
# ============================================================================

class STSeqModel(nn.Module):
    """
    input:  
        x -> (B, T, N, input_dim)
        mask -> (B, N)   # True-> effective playersÔºåFalse -> filled players
    output:  
        out -> (B, N, H, 2)
    structure:  
        [Frame-Plalyer Mixture] -> TransformerEncoder -> Pool  -> predict x&y
    """
    def __init__(
        self,
        input_dim: int,
        horizon: int,
        d_model: int = 128,
        nhead: int = 4,
        num_layers: int = 2,
        ff_multiplier: int = 4,
        dropout: float = 0.1,
        max_len: int = 10,       # input window size
        max_players: int = 17,    # input maxium players
    ):
        super().__init__()
        self.horizon = horizon
        self.d_model = d_model
        self.max_players = max_players

        # input projection: (input_dim ‚Üí d_model)
        self.in_proj = nn.Linear(input_dim, d_model)

        # learnable position embeddings
        self.time_emb = nn.Parameter(torch.zeros(1, max_len, 1, d_model))
        self.player_emb = nn.Parameter(torch.zeros(1, 1, max_players, d_model))
        nn.init.trunc_normal_(self.time_emb, std=0.02)
        nn.init.trunc_normal_(self.player_emb, std=0.02)

        # Transformer encoder (Token ->T√óN )
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=ff_multiplier * d_model,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.post_ln = nn.LayerNorm(d_model)

        # predict x&y for players
        def make_head():
            return nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(d_model, horizon),
            )
        self.head_dx = make_head()
        self.head_dy = make_head()

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        """
        x: (B, T, N, input_dim)
        mask: (B, N)  ‚Äî True -> effect players
        return: out -> (B, M, horizon, 2)
        """
        B, T, N, _ = x.shape

        # Step 1. projection & position
        h = self.in_proj(x)  # (B, T, N, d_model)
        h=h+self.time_emb+self.player_emb

        # Step 2. Flatten as Transformer input: (B, T*N, d_model)
        h = h.reshape(B, T * N, self.d_model)

        # Step 3. build padding mask (mask -> ~mask)
        # expand mask to cover all frames
        # mask_flat shape = (B, T*N)
        mask_flat = (1-mask).unsqueeze(1).expand(-1, T, -1).reshape(B, T * N)

        # Step 4. encoder
        h = self.encoder(h, src_key_padding_mask=mask_flat)
        h = self.post_ln(h)

        # Step 5. reshape (B, T, N, d_model)
        h = h.view(B, T, N, self.d_model)

        # Step 6. pooling
        # use last frame as the context
        ctx = h[:, -1, :, :]   # (B, N, d_model)

        # Step 7. prediction
        dx = self.head_dx(ctx)              # (B, N, H)
        dy = self.head_dy(ctx)              # (B, N, H)
        dx = torch.cumsum(dx, dim=2)
        dy = torch.cumsum(dy, dim=2)

        # Step 8. to (B, N, H, 2)
        out = torch.stack([dx, dy], dim=-1)
        return out

class STTransformer(nn.Module):
    """
    input:  
        x -> (B, T, N, input_dim)
        mask -> (B, N)   # True-> effective playersÔºåFalse -> filled players
    output:  
        out -> (B, N, H, 2)
    structure:  
        [Frame-Plalyer Mixture] -> TransformerEncoder -> Pool  -> predict x&y
    """
    def __init__(
        self,
        input_dim: int,
        horizon: int,
        d_model: int = 128,
        nhead: int = 4,
        num_layers: int = 2,
        num_decoder_layers: int = 2,  
        ff_multiplier: int = 4,
        dropout: float = 0.5,
        max_len: int = 10,        # input window size
        max_players: int = 17,    # input maxium players
    ):
        super().__init__()
        self.horizon = horizon
        self.d_model = d_model
        self.out_players = Config.MAX_PLAYERS2Predict

        # projection: (input_dim ‚Üí d_model)
        self.in_proj = nn.Linear(input_dim, d_model)

        # learnable position embeddings
        self.time_emb = nn.Parameter(torch.zeros(1, max_len, 1, d_model))
        self.player_emb = nn.Parameter(torch.zeros(1, 1, max_players, d_model))
        nn.init.trunc_normal_(self.time_emb, std=0.02)
        nn.init.trunc_normal_(self.player_emb, std=0.02)

        # Transformer Encoder
        enc_layer1 = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=ff_multiplier * d_model,
            dropout=dropout, batch_first=True, activation="gelu", norm_first=True
        )
        self.spatial_encoder = nn.TransformerEncoder(enc_layer1, num_layers=num_layers)

        enc_layer2 = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=ff_multiplier * d_model,
            dropout=dropout, batch_first=True, activation="gelu", norm_first=True
        )
        self.temporal_encoder = nn.TransformerEncoder(enc_layer2, num_layers=num_layers)

        # LayerNorms
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # ============  Decoder ============
        # M-players H-horizon
        self.query_embed = nn.Parameter(torch.zeros(self.out_players, horizon, d_model))
        nn.init.trunc_normal_(self.query_embed, std=0.02)
        
        # fix dimension
        self.time_emb_dec = nn.Parameter(torch.zeros(1, 1, horizon, d_model))
        nn.init.trunc_normal_(self.time_emb_dec, std=0.02)

        dec_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=ff_multiplier * d_model,
            dropout=dropout, batch_first=True, activation="gelu", norm_first=True
        )
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_decoder_layers)

        self.out_proj = nn.Linear(d_model, 2)
        # ===================================================

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        """
        x: (B, T, N, input_dim)
        mask: (B, N)  ‚Äî True -> effect players
        return: out -> (B, M, horizon, 2)
        """
        B, T, N, _ = x.shape

        # 1. projection
        h = self.in_proj(x)  # (B, T, N, d_model)
        h = h + self.player_emb[:, :, :N, :] + self.time_emb[:, :T, :, :]

        # 2. spatial interaction between players
        h_spatial = h.reshape(B * T, N, self.d_model)
        spatial_mask_pad = ~mask.bool().unsqueeze(1).repeat(1, T, 1).reshape(B * T, N)
        
        h_spatial = self.spatial_encoder(h_spatial, src_key_padding_mask=spatial_mask_pad)
        h_spatial = self.ln1(h_spatial)
        h = h_spatial.reshape(B, T, N, self.d_model)

        # 3. predict partial players
        ctx = h[:, :, :self.out_players, :]  # (B, T, M, D)

        # 4. temporal attention
        ctx_temporal = ctx.permute(0, 2, 1, 3).reshape(B * self.out_players, T, self.d_model)
        ctx_temporal = self.temporal_encoder(ctx_temporal)
        ctx_temporal = self.ln2(ctx_temporal)
        
        # 5. prepare memory
        memory = ctx_temporal  # (B*M, T, D)

        # 6. prepare queries
        # (M, H, D) -> (B, M, H, D)
        queries = self.query_embed.unsqueeze(0).expand(B, -1, -1, -1)
        queries = queries + self.time_emb_dec
        
        # (B, M, H, D) -> (B*M, H, D)
        queries = queries.reshape(B * self.out_players, self.horizon, self.d_model)

        # 7. causal mask
        causal_mask = torch.triu(
            torch.ones(self.horizon, self.horizon, dtype=torch.bool, device=x.device), 
            diagonal=1
        )

        # 8. Transformer Decoder
        decoded = self.decoder(
            tgt=queries,  # (B*M, H, D)
            memory=memory,  # (B*M, T, D)
            tgt_mask=causal_mask,  # (H, H)
        )

        # 9. output projection
        # (B*M, H, D) -> (B, M, H, D)
        decoded = decoded.reshape(B, self.out_players, self.horizon, self.d_model)
        # (B, M, H, D) -> (B, M, H, 2)
        out = self.out_proj(decoded)

        # 10. cumsum
        dx = out[..., 0]
        dy = out[..., 1]
        dx = torch.cumsum(dx, dim=2)
        dy = torch.cumsum(dy, dim=2)
        out = torch.stack([dx, dy], dim=-1)
        return out
        
# ============================================================================
# LOSS (YOUR PROVEN TEMPORAL HUBER)
# ============================================================================
class TemporalHuber(nn.Module):
    def __init__(self, delta=0.5, time_decay=0.03):
        super().__init__()
        self.delta = delta
        self.time_decay = time_decay
    
    def forward(self, preds, target, masks):
        '''
        preds,targets-->(BS, num_players, num_frames, 2)
        masks-->(BS, num_players, num_frames, 2)
        '''
        n=target.size(1)  # bs,num_players,horizon,2 
        loss=0
        for i in range(n):
            mask=masks[:,i,:,:]
            if mask.sum()<1:
                continue
            pred=preds[:,i,:,:]
            err = pred - target[:,i,:,:]
            abs_err = torch.abs(err)
            huber = torch.where(abs_err <= self.delta, 0.5 * err * err, 
                               self.delta * (abs_err - 0.5 * self.delta))          
            if self.time_decay > 0:
                L = pred.size(1) # (B,i,Horizon,2)
                t = torch.arange(L, device=pred.device).float()
                weight = torch.exp(-self.time_decay * t).view(1, L, 1)
                huber = huber * weight
            loss+=(huber * mask).sum() / (mask.sum())
        return loss

In [None]:
# ============================================================================
# TRAINING
# ============================================================================
def random_time_mask(bx, p=0.1, max_width=3):
    if p <= 0 or max_width <= 0:
        return bx
    B, T, _ = bx.shape
    if T <= 1:
        return bx
    for i in range(B):
        if random.random() < p:
            w = random.randint(1, max_width)
            s = random.randint(0, max(0, T - 1 - w))
            if s > 0:
                bx[i, s:s+w] = bx[i, s-1].unsqueeze(0)
            else:
                bx[i, s:s+w] = bx[i, s+w].unsqueeze(0)
    return bx

def flip_context_keep_last(bx, p=0.1):
    if p <= 0:
        return bx
    B, T, _ = bx.shape
    if T <= 1:
        return bx
    mask = torch.rand(B, device=bx.device) < p
    if mask.any():
        ctx = bx[mask, :-1].flip(1)
        bx[mask] = torch.cat([ctx, bx[mask, -1:].clone()], dim=1)
    return bx

def add_random_gaussian(bx, sigma_max=0.02):
    if sigma_max <= 0:
        return bx
    sigma = sigma_max * torch.rand(1, device=bx.device)
    return bx + torch.randn_like(bx) * sigma
    
def prepare_targets(batch_dx, batch_dy, max_h):
    B = len(batch_dx)
    padded_dx = torch.zeros(B, Config.MAX_PLAYERS2Predict, max_h)
    padded_dy = torch.zeros(B, Config.MAX_PLAYERS2Predict, max_h)
    mask = torch.zeros(B, Config.MAX_PLAYERS2Predict, max_h, dtype=torch.bool)

    for i, (dx, dy) in enumerate(zip(batch_dx, batch_dy)):
        dx = torch.tensor(dx)  # (num_p, num_f)
        dy = torch.tensor(dy)
        
        num_p, num_f = dx.shape       
        padded_dx[i, :num_p, :num_f] = dx
        padded_dy[i, :num_p, :num_f] = dy
        
        # ÁîüÊàê maskÔºöÁúüÂÆû‰ΩçÁΩÆ‰∏∫ 1
        mask[i, :num_p, :num_f] = 1
    targets = torch.stack([padded_dx, padded_dy],dim=-1)
    mask = torch.stack([mask, mask],dim=-1)
    return targets,mask

class WarmupCosineScheduler(LRScheduler):
    def __init__(
        self,
        optimizer,
        warmup_steps: int,
        total_steps: int,
        min_lr: float = 0.0,
        last_epoch: int = -1,
    ):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self) -> List[float]:
        if self.last_epoch < self.warmup_steps:
            # Linear warmup
            factor = self.last_epoch / max(1, self.warmup_steps)
            return [base_lr * factor for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            progress = (self.last_epoch - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            progress = min(progress, 1.0)
            cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
            return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs]
        
def train_model(X_train, y_train_dx, y_train_dy,mask_tr, X_val, y_val_dx, y_val_dy, mask_va,
                input_dim, horizon, config):

    if Config.Model_Name=="STTransformer":
        model = STTransformer(input_dim, horizon,Config.DMODEL).to(DEVICE)
    elif Config.Model_Name=="STSeqModel":
        model = STSeqModel(input_dim, horizon,Config.DMODEL).to(DEVICE)
    else:
        raise NotImplementedError
    
    criterion = TemporalHuber(delta=0.5, time_decay=0.03)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=1e-5)
    if Config.scheduler=="ReduceLROnPlateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5,)
    elif Config.scheduler=="WarmupCosineScheduler":
        WARMUP_STEPS = 1000
        TOTAL_STEPS = 50000
        scheduler = WarmupCosineScheduler(
            optimizer,
            warmup_steps=WARMUP_STEPS,
            total_steps=TOTAL_STEPS,
            min_lr=1e-6 
        )
    else:
        raise NotImplementedError
    train_batches = []
    for i in range(0, len(X_train), config.BATCH_SIZE):
        end = min(i + config.BATCH_SIZE, len(X_train))
        bx = torch.tensor(np.stack(X_train[i:end]).astype(np.float32))
        bmask = torch.tensor(np.stack(mask_tr[i:end]).astype(np.float32))
        by, bm = prepare_targets(
            [y_train_dx[j] for j in range(i, end)],
            [y_train_dy[j] for j in range(i, end)],
            horizon
        )
        train_batches.append((bx,bmask, by, bm))
    
    val_batches = []
    for i in range(0, len(X_val), config.BATCH_SIZE):
        end = min(i + config.BATCH_SIZE, len(X_val))
        bx = torch.tensor(np.stack(X_val[i:end]).astype(np.float32))
        bmask = torch.tensor(np.stack(mask_va[i:end]).astype(np.float32))
        by, bm = prepare_targets(
            [y_val_dx[j] for j in range(i, end)],
            [y_val_dy[j] for j in range(i, end)],
            horizon
        )
        val_batches.append((bx,bmask, by, bm))
    
    best_loss, best_state, bad = float('inf'), None, 0   
    for epoch in range(1, config.EPOCHS + 1):
        model.train()
        train_losses = []
        for bx,bmask, by, bm in train_batches:
            bx,bmask, by, bm = bx.to(DEVICE),bmask.to(DEVICE), by.to(DEVICE), bm.to(DEVICE)
            if Config.AUG and random.random()<0.2:
                bx = add_random_gaussian(bx, sigma_max=0.01)  
                bx = random_time_mask(bx, p=0.10, max_width=3)   
                bx = flip_context_keep_last(bx, p=0.10)  
            pred = model(bx,bmask)
            loss = criterion(pred, by, bm)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.item())
        
        model.eval()
        val_losses = []
        with torch.no_grad():
            for bx,bmask, by, bm in val_batches:
                bx,bmask, by, bm = bx.to(DEVICE),bmask.to(DEVICE), by.to(DEVICE), bm.to(DEVICE)
                pred = model(bx,bmask)
                val_losses.append(criterion(pred, by, bm).item())
        
        train_loss, val_loss = np.mean(train_losses), np.mean(val_losses)
        scheduler.step(val_loss)
        
        if epoch % 10 == 0 :
            print(f"  Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad = 0           
        else:
            bad += 1
            if bad >= config.PATIENCE:
                print(f"  Early stop at epoch {epoch}")
                break
    
    if best_state:
        model.load_state_dict(best_state)
    
    return model, val_batches

# ============================================================================
# MAIN
# ============================================================================
def evaluate(model,val_batches):
    model.eval()
    with torch.inference_mode():
        # Iterate through validation batches
        err=0
        num=0

        for bx,bmask, by, bm in val_batches:
            num+=bm.sum()
            bx,bmask,by,bm = bx.to(DEVICE),bmask.to(DEVICE), by.to(DEVICE), bm.to(DEVICE)
            pred = model(bx,bmask)

            n=by.size(1)
            pred=pred[:,:n,:,:]
            err+=torch.sum((bm*(pred - by))**2)

        #num*=2
    return err,num

In [None]:
def main():
    config = Config()    
    set_seed(Config.SEED)
    # Load
    print("[1/4] Loading data...")
    
    if Config.DEBUG:
        train_input_files = [config.DATA_DIR / f"train/input_2023_w{w:02d}.csv" for w in range(1, 2)]
        train_output_files = [config.DATA_DIR / f"train/output_2023_w{w:02d}.csv" for w in range(1,2)] 
    else:
        train_input_files = [config.DATA_DIR / f"train/input_2023_w{w:02d}.csv" for w in range(1, 19)]
        train_output_files = [config.DATA_DIR / f"train/output_2023_w{w:02d}.csv" for w in range(1,19)]
    train_input = pd.concat([pd.read_csv(f) for f in train_input_files if f.exists()])
    train_output = pd.concat([pd.read_csv(f) for f in train_output_files if f.exists()])   
    
    
    train_input = train_input[(train_input.game_id!=2023091100)&(train_input.play_id!=3167)]
    train_output = train_output[(train_output.game_id!=2023091100)&(train_output.play_id!=3167)]
    print(f"‚úì Train input: {train_input.shape}, Train output: {train_output.shape}")
    
    # Prepare
    print("[2/4] Preparing geometric sequences...")
    
    result = prepare_sequences_geometric(
        train_input, train_output, is_training=True, window_size=config.WINDOW_SIZE
    )
    sequences, targets_dx, targets_dy, targets_frame_ids, sequence_ids, route_kmeans, route_scaler,feature_cols,stratify_labels,masks  = result
    
    sequences = list(sequences)
    targets_dx = list(targets_dx)
    targets_dy = list(targets_dy)
    
    # Train
    print("[3/4] Training geometric models...")
    groups = np.array([d['game_id'] for d in sequence_ids])
    if Config.Split=="gkf":
        gkf = GroupKFold(n_splits=config.N_FOLDS)
        splits=gkf.split(sequences, groups=groups)
    else:
        sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=12)
        splits=sgkf.split(sequences, y=stratify_labels, groups=groups)
    
    models, scalers = [], []
    cv_result=[]
    cv_err,cv_num=0,0
    for fold, (tr, va) in enumerate(splits, 1):
        print(f"{'='*60}")
        print(f"Fold {fold}/{config.N_FOLDS}")     
        X_tr = np.array([sequences[i] for i in tr])
        X_va = np.array([sequences[i] for i in va])
        y_tr_dx = [targets_dx[i] for i in tr]
        y_va_dx = [targets_dx[i] for i in va]
        y_tr_dy = [targets_dy[i] for i in tr]
        y_va_dy = [targets_dy[i] for i in va]
        mask_tr =[masks[i] for i in tr]
        mask_va =[masks[i] for i in va]
        scaler = StandardScaler()
        n_samples_tr = X_tr.shape[0] * X_tr.shape[1] * X_tr.shape[2]
        n_samples_va = X_va.shape[0] * X_va.shape[1] * X_va.shape[2]
        X_tr_reshaped = X_tr.reshape(n_samples_tr, -1)
        X_va_reshaped = X_va.reshape(n_samples_va, -1)
        scaler.fit(X_tr_reshaped)      
        X_tr_scaled_reshaped = scaler.transform(X_tr_reshaped)
        X_tr_sc = X_tr_scaled_reshaped.reshape(X_tr.shape)
        X_va_scaled_reshaped = scaler.transform(X_va_reshaped)
        X_va_sc = X_va_scaled_reshaped.reshape(X_va.shape)
        input_dim=X_tr[0].shape[-1]
        print("input_dim:",input_dim)
        model, val_batches = train_model(
            X_tr_sc, y_tr_dx, y_tr_dy,mask_tr,
            X_va_sc, y_va_dx, y_va_dy,mask_va,
            input_dim, config.MAX_FUTURE_HORIZON, config
        )
        err,num=evaluate(model,val_batches)
        rsme=torch.sqrt(err/num).item()
        cv_result.append(rsme)
        print(f"fold_{fold}:{rsme}")
        cv_err+=err
        cv_num+=num
        models.append(model)
        scalers.append(scaler)
        torch.save(model.state_dict(), f'{Config.OUTPUT_DIR}/model_{fold}.pth')
        torch.save(scaler, f'{Config.OUTPUT_DIR}/scaler_{fold}.pth')
    cv_rsme=torch.sqrt(cv_err/cv_num).item()
    cv_results={"cv_rsme":cv_rsme,"avg":np.mean(cv_result),"std":np.std(cv_result),
        "folds":cv_result}
    print("CV Results:"+"-+"*30)
    print(cv_results)

In [None]:
main()