In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from tqdm.auto import tqdm
import json
import time

# ================================================================================
# CONFIGURATION SEARCH SYSTEM
# ================================================================================

class ConfigSearch:
    """System to quickly test multiple configurations"""
    
    CONFIGS_TO_TEST = [
        # Baseline
        {
            'name': 'baseline',
            'HIDDEN_DIM': 128,
            'NUM_GNN_LAYERS': 2,
            'NUM_ATTENTION_HEADS': 4,
            'DROPOUT': 0.2,
            'LEARNING_RATE': 1e-4,
            'FEATURE_FUSION_DIM': 96
        },
        # Larger model
        {
            'name': 'large_model',
            'HIDDEN_DIM': 192,
            'NUM_GNN_LAYERS': 3,
            'NUM_ATTENTION_HEADS': 8,
            'DROPOUT': 0.2,
            'LEARNING_RATE': 1e-4,
            'FEATURE_FUSION_DIM': 128
        },
        # Higher dropout
        {
            'name': 'high_dropout',
            'HIDDEN_DIM': 128,
            'NUM_GNN_LAYERS': 2,
            'NUM_ATTENTION_HEADS': 4,
            'DROPOUT': 0.3,
            'LEARNING_RATE': 1e-4,
            'FEATURE_FUSION_DIM': 96
        },
        # Deeper GNN
        {
            'name': 'deep_gnn',
            'HIDDEN_DIM': 128,
            'NUM_GNN_LAYERS': 4,
            'NUM_ATTENTION_HEADS': 4,
            'DROPOUT': 0.2,
            'LEARNING_RATE': 1e-4,
            'FEATURE_FUSION_DIM': 96
        },
    ]
    
    SEARCH_EPOCHS = 5
    SEARCH_DATA_FRACTION = 0.3
    
    @staticmethod
    def get_results_summary():
        """Print summary of all tested configs"""
        try:
            with open('config_search_results.json', 'r') as f:
                results = json.load(f)
            
            print("\n" + "="*80)
            print("CONFIG SEARCH RESULTS SUMMARY")
            print("="*80)
            
            sorted_results = sorted(results, key=lambda x: x['best_val_loss'])
            
            print(f"\n{'Rank':<6} {'Config':<15} {'Val Loss':<12} {'Train Loss':<12} {'Params':<12} {'Time(s)':<10}")
            print("-"*80)
            
            for i, result in enumerate(sorted_results, 1):
                print(f"{i:<6} {result['config_name']:<15} {result['best_val_loss']:<12.4f} "
                      f"{result['final_train_loss']:<12.4f} {result['num_parameters']:<12,} "
                      f"{result['training_time']:<10.1f}")
            
            print("\n" + "="*80)
            print(f"BEST CONFIG: {sorted_results[0]['config_name']}")
            print(f"Best Validation Loss: {sorted_results[0]['best_val_loss']:.4f}")
            print("="*80 + "\n")
            
            return sorted_results[0]
            
        except FileNotFoundError:
            print("No search results found yet.")
            return None

# ================================================================================
# DYNAMIC CONFIGURATION
# ================================================================================

class Config:
    PLAYER_EMBED_DIM = 32
    SPATIAL_EMBED_DIM = 16
    HIDDEN_DIM = 128
    NUM_ATTENTION_HEADS = 4
    NUM_TRANSFORMER_LAYERS = 2
    NUM_GNN_LAYERS = 2
    DROPOUT = 0.2
    
    ENGINEERED_FEATURES_DIM = 108
    FEATURE_FUSION_DIM = 96
    
    MAX_FUTURE_FRAMES = 94
    PREDICTION_HORIZON = 25
    MULTI_HORIZON_TARGETS = [5, 10, 15, 20, 25]
    
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    EPOCHS = 25
    PATIENCE = 8
    GRADIENT_ACCUMULATION_STEPS = 4
    USE_AMP = True
    USE_COMPILE = False
    
    WINDOW_SIZE = 10
    ADAPTIVE_FRAME_SAMPLING = True
    
    DATA_DIR = Path("/kaggle/input/nfl-big-data-bowl-2026-prediction/")
    FIELD_X_MIN, FIELD_X_MAX = 0.0, 120.0
    FIELD_Y_MIN, FIELD_Y_MAX = 0.0, 53.3
    FIELD_HASH_LEFT = 17.5
    FIELD_HASH_RIGHT = 35.8
    
    SEARCH_MODE = False
    
    @classmethod
    def update_from_dict(cls, config_dict):
        """Update config from dictionary"""
        for key, value in config_dict.items():
            if hasattr(cls, key):
                setattr(cls, key, value)

# ================================================================================
# EFFICIENT COMPONENTS
# ================================================================================

class EfficientPositionalEncoding(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.proj = nn.Linear(2, d_model)
        
    def forward(self, positions: torch.Tensor) -> torch.Tensor:
        return self.proj(positions)

class EfficientPlayerEncoder(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.role_embed = nn.Embedding(5, embed_dim // 4)
        self.side_embed = nn.Embedding(2, embed_dim // 4)
        self.continuous_proj = nn.Linear(8, embed_dim // 2)
        self.output_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, player_features: Dict[str, torch.Tensor]) -> torch.Tensor:
        role_emb = self.role_embed(player_features['role'])
        side_emb = self.side_embed(player_features['side'])
        cont_emb = self.continuous_proj(player_features['continuous'])
        combined = torch.cat([role_emb, side_emb, cont_emb], dim=-1)
        return self.output_proj(combined)

# ================================================================================
# FEATURE INJECTION MODULE
# ================================================================================

class FeatureInjectionModule(nn.Module):
    def __init__(self, feature_dim: int, hidden_dim: int, fusion_dim: int):
        super().__init__()
        self.feature_dim = feature_dim
        
        if feature_dim > 0:
            self.feature_encoder = nn.Sequential(
                nn.Linear(feature_dim, fusion_dim),
                nn.LayerNorm(fusion_dim),
                nn.ReLU(),
                nn.Dropout(Config.DROPOUT),
                nn.Linear(fusion_dim, fusion_dim)
            )
            
            self.fusion = nn.Sequential(
                nn.Linear(hidden_dim + fusion_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            )
        else:
            self.feature_encoder = None
            self.fusion = nn.Identity()
    
    def forward(self, learned_features: torch.Tensor, 
                engineered_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        if self.feature_dim > 0 and engineered_features is not None:
            encoded_features = self.feature_encoder(engineered_features)
            combined = torch.cat([learned_features, encoded_features], dim=-1)
            return self.fusion(combined)
        return learned_features

# ================================================================================
# EFFICIENT GRAPH ATTENTION - FULLY ROBUST FOR AMP
# ================================================================================

class EfficientGraphAttention(nn.Module):
    def __init__(self, in_features: int, out_features: int, num_heads: int = 2):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = out_features // num_heads
        
        self.W = nn.Linear(in_features, out_features)
        self.att = nn.Linear(out_features, num_heads)
        
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, 
                edge_attr: torch.Tensor = None) -> torch.Tensor:
        num_nodes = x.shape[0]
        
        target_dtype = x.dtype
        
        h = self.W(x)
        
        src_idx, dst_idx = edge_index[0], edge_index[1]
        h_src = h[src_idx]
        h_dst = h[dst_idx]
        
        att_logits = self.att(h_src + h_dst)
        
        att_max = torch.zeros(num_nodes, self.num_heads, device=x.device, dtype=target_dtype)
        
        att_logits = att_logits.to(target_dtype)
        att_max.index_reduce_(0, dst_idx, att_logits, 'amax', include_self=False)
        
        att_exp = torch.exp(att_logits - att_max[dst_idx])
        
        att_sum = torch.zeros(num_nodes, self.num_heads, device=x.device, dtype=target_dtype)
        att_sum.index_add_(0, dst_idx, att_exp)
        
        alpha = att_exp / (att_sum[dst_idx] + 1e-8)
        
        h_src_weighted = (h_src.unsqueeze(1) * alpha.unsqueeze(-1)).to(target_dtype)
        
        out = torch.zeros(num_nodes, self.num_heads, h.shape[1], device=x.device, dtype=target_dtype)
        out.index_add_(0, dst_idx, h_src_weighted)
        
        return out.mean(dim=1)

class EfficientGNN(nn.Module):
    def __init__(self, hidden_dim: int, num_layers: int = 2):
        super().__init__()
        self.layers = nn.ModuleList([
            EfficientGraphAttention(hidden_dim, hidden_dim, num_heads=2)
            for _ in range(num_layers)
        ])
        self.norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        
    def forward(self, x, edge_index, edge_attr=None):
        for layer, norm in zip(self.layers, self.norms):
            residual = x
            x_new = layer(x, edge_index, edge_attr)
            x = norm(residual.to(x_new.dtype) + x_new)
            x = F.relu(x)
        return x

# ================================================================================
# PARALLEL PLAYER ATTENTION POOLING
# ================================================================================

class ParallelPlayerAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, temporal_features: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, num_players, hidden_dim = temporal_features.shape
        
        x = temporal_features.permute(0, 2, 1, 3).reshape(batch_size * num_players, seq_len, hidden_dim)
        
        q = self.query_proj(x).view(batch_size * num_players, seq_len, self.num_heads, self.head_dim)
        k = self.key_proj(x).view(batch_size * num_players, seq_len, self.num_heads, self.head_dim)
        v = self.value_proj(x).view(batch_size * num_players, seq_len, self.num_heads, self.head_dim)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size * num_players, seq_len, hidden_dim)
        out = self.out_proj(out)
        
        out = out[:, -1, :]
        out = out.view(batch_size, num_players, hidden_dim)
        
        return self.norm(out)

# ================================================================================
# MULTI-HORIZON DECODER
# ================================================================================

class MultiHorizonDecoder(nn.Module):
    def __init__(self, hidden_dim: int, horizons: List[int]):
        super().__init__()
        self.horizons = horizons
        
        self.shared = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(Config.DROPOUT)
        )
        
        self.horizon_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(hidden_dim // 2, horizon * 2)
            )
            for horizon in horizons
        ])
        
    def forward(self, player_states: torch.Tensor, target_horizon: int = None):
        batch_size, num_players, _ = player_states.shape
        
        shared_repr = self.shared(player_states)
        
        predictions = {}
        
        if target_horizon is not None:
            idx = self.horizons.index(target_horizon)
            pred = self.horizon_heads[idx](shared_repr)
            pred = pred.view(batch_size, num_players, target_horizon, 2)
            predictions[target_horizon] = pred
        else:
            for horizon, head in zip(self.horizons, self.horizon_heads):
                pred = head(shared_repr)
                pred = pred.view(batch_size, num_players, horizon, 2)
                predictions[horizon] = pred
        
        return predictions

# ================================================================================
# ENHANCED MAIN MODEL
# ================================================================================

class EnhancedSpatioTemporalPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.player_encoder = EfficientPlayerEncoder(Config.PLAYER_EMBED_DIM)
        self.spatial_encoder = EfficientPositionalEncoding(Config.SPATIAL_EMBED_DIM)
        
        self.feature_fusion = nn.Sequential(
            nn.Linear(Config.PLAYER_EMBED_DIM + Config.SPATIAL_EMBED_DIM + 3, Config.HIDDEN_DIM),
            nn.ReLU()
        )
        
        self.feature_injection = FeatureInjectionModule(
            Config.ENGINEERED_FEATURES_DIM,
            Config.HIDDEN_DIM,
            Config.FEATURE_FUSION_DIM
        )
        
        self.gnn = EfficientGNN(Config.HIDDEN_DIM, num_layers=Config.NUM_GNN_LAYERS)
        
        self.parallel_attention = ParallelPlayerAttention(
            Config.HIDDEN_DIM,
            Config.NUM_ATTENTION_HEADS
        )
        
        self.decoder = MultiHorizonDecoder(
            Config.HIDDEN_DIM,
            Config.MULTI_HORIZON_TARGETS
        )
        
    def build_graph(self, positions, batch_indices, k=8):
        device = positions.device
        num_players = positions.shape[0]
        
        dist_matrix = torch.cdist(positions, positions)
        batch_mask = (batch_indices.unsqueeze(1) == batch_indices.unsqueeze(0))
        dist_matrix = dist_matrix.masked_fill(~batch_mask, float('inf'))
        
        k_actual = min(k + 1, num_players)
        _, indices = torch.topk(dist_matrix, k=k_actual, dim=1, largest=False)
        indices = indices[:, 1:]
        
        k_actual = indices.shape[1]
        src = torch.arange(num_players, device=device).unsqueeze(1).expand(-1, k_actual)
        edge_index = torch.stack([src.flatten(), indices.flatten()], dim=0)
        
        src_pos = positions[edge_index[0]]
        dst_pos = positions[edge_index[1]]
        edge_attr = torch.cat([
            dst_pos - src_pos,
            torch.norm(dst_pos - src_pos, dim=1, keepdim=True)
        ], dim=1)
        
        return edge_index, edge_attr
    
    def select_frames(self, seq_len: int) -> List[int]:
        if not Config.ADAPTIVE_FRAME_SAMPLING or seq_len <= 2:
            return [0, seq_len - 1] if seq_len > 1 else [0]
        
        if seq_len <= 4:
            return list(range(seq_len))
        
        return [0, seq_len // 3, 2 * seq_len // 3, seq_len - 1]
        
    def forward(self, batch_data, target_horizon=None, engineered_features=None):
        batch_size, seq_len, num_players, _ = batch_data['positions'].shape
        device = batch_data['positions'].device
        
        player_emb = self.player_encoder(batch_data['player_features'])
        
        frames_to_process = self.select_frames(seq_len)
        
        temporal_features = []
        
        for t in frames_to_process:
            pos_t = batch_data['positions'][:, t].reshape(-1, 2)
            vel_t = batch_data['velocities'][:, t].reshape(-1, 2)
            acc_t = batch_data['accelerations'][:, t].reshape(-1, 1)
            
            spatial_emb = self.spatial_encoder(pos_t)
            motion_features = torch.cat([vel_t, acc_t], dim=-1)
            combined = torch.cat([player_emb, spatial_emb, motion_features], dim=-1)
            node_features = self.feature_fusion(combined)
            
            if engineered_features is not None:
                eng_feat_t = engineered_features[:, t].reshape(-1, Config.ENGINEERED_FEATURES_DIM)
                node_features = self.feature_injection(node_features, eng_feat_t)
            
            batch_indices = torch.arange(batch_size, device=device).repeat_interleave(num_players)
            edge_index, edge_attr = self.build_graph(pos_t, batch_indices)
            
            node_features = self.gnn(node_features, edge_index, edge_attr)
            node_features = node_features.reshape(batch_size, num_players, -1)
            temporal_features.append(node_features)
        
        temporal_features = torch.stack(temporal_features, dim=1)
        
        player_states = self.parallel_attention(temporal_features)
        
        predictions = self.decoder(player_states, target_horizon)
        
        return predictions

# ================================================================================
# FEATURE ENGINEERING (keeping all original functions)
# ================================================================================

def compute_rolling_statistics(positions_history, velocities_history, accelerations_history):
    features = np.zeros(8)
    if len(positions_history) < 2:
        return features
    
    positions = np.array(positions_history)
    velocities = np.array(velocities_history)
    accelerations = np.array(accelerations_history)
    
    vel_magnitudes = np.sqrt(velocities[:, 0]**2 + velocities[:, 1]**2)
    features[0] = np.mean(vel_magnitudes) / 10.0
    features[1] = np.std(vel_magnitudes) / 5.0
    
    features[2] = np.mean(accelerations) / 5.0
    features[3] = np.std(accelerations) / 3.0
    
    pos_variance = np.var(positions, axis=0)
    features[4] = np.sqrt(pos_variance[0]) / 5.0
    features[5] = np.sqrt(pos_variance[1]) / 5.0
    
    if len(velocities) > 1:
        vel_angles = np.arctan2(velocities[:, 1], velocities[:, 0])
        angle_diffs = np.diff(vel_angles)
        angle_diffs = np.abs(np.arctan2(np.sin(angle_diffs), np.cos(angle_diffs)))
        features[6] = np.mean(angle_diffs) / np.pi
        features[7] = np.std(angle_diffs) / np.pi
    
    return features

def classify_route(player_pos, ball_land_pos, player_vel, player_role, player_side):
    features = np.zeros(6)
    if player_side != 0 or player_role not in [0, 1]:
        return features
    
    depth = ball_land_pos[0] - player_pos[0]
    lateral = ball_land_pos[1] - player_pos[1]
    
    features[0] = 1.0 if depth > 20 else 0.0
    features[1] = 1.0 if 10 <= depth <= 20 else 0.0
    features[2] = 1.0 if depth < 10 else 0.0
    
    field_center = Config.FIELD_Y_MAX / 2
    if player_pos[1] < field_center and lateral > 0:
        features[3] = 1.0
    elif player_pos[1] > field_center and lateral < 0:
        features[3] = 1.0
    else:
        features[4] = 1.0 if abs(lateral) > 3 else 0.0
    
    vel_magnitude = np.sqrt(player_vel[0]**2 + player_vel[1]**2)
    if vel_magnitude > 0:
        vel_angle = np.arctan2(player_vel[1], player_vel[0])
        route_angle = np.arctan2(lateral, depth)
        angle_diff = abs(vel_angle - route_angle)
        if angle_diff > np.pi:
            angle_diff = 2 * np.pi - angle_diff
        features[5] = 1.0 - (angle_diff / np.pi)
    
    return features

def compute_pursuit_angles(player_pos, player_vel, player_side, player_role, 
                          ball_land_pos, targeted_receiver_pos):
    features = np.zeros(6)
    if player_side != 1:
        return features
    
    target = targeted_receiver_pos if targeted_receiver_pos is not None else ball_land_pos
    
    dx = target[0] - player_pos[0]
    dy = target[1] - player_pos[1]
    optimal_angle = np.arctan2(dy, dx)
    
    vel_magnitude = np.sqrt(player_vel[0]**2 + player_vel[1]**2)
    if vel_magnitude > 0.1:
        current_angle = np.arctan2(player_vel[1], player_vel[0])
        angle_diff = abs(current_angle - optimal_angle)
        if angle_diff > np.pi:
            angle_diff = 2 * np.pi - angle_diff
        features[0] = angle_diff / np.pi
        
        pursuit_efficiency = np.cos(angle_diff)
        features[1] = max(0, pursuit_efficiency)
    else:
        features[0] = 0.5
        features[1] = 0.0
    
    distance_to_target = np.sqrt(dx**2 + dy**2)
    if distance_to_target > 0.1:
        closing_velocity = (player_vel[0] * dx + player_vel[1] * dy) / distance_to_target
        features[2] = np.clip(closing_velocity / 10.0, -1.0, 1.0)
        
        if vel_magnitude > 0.1:
            time_to_intercept = distance_to_target / vel_magnitude
            features[3] = min(time_to_intercept / 5.0, 1.0)
        else:
            features[3] = 1.0
    
    features[4] = distance_to_target / 30.0
    
    ball_dist = np.sqrt((ball_land_pos[0] - player_pos[0])**2 + 
                       (ball_land_pos[1] - player_pos[1])**2)
    features[5] = ball_dist / 40.0
    
    return features

def compute_formation_features(all_positions, all_sides, all_roles, ball_land_pos):
    num_players = len(all_positions)
    features = np.zeros((num_players, 8))
    
    offense_positions = all_positions[all_sides == 0]
    offense_roles = all_roles[all_sides == 0]
    
    if len(offense_positions) == 0:
        return features
    
    field_center = Config.FIELD_Y_MAX / 2
    left_receivers = np.sum((offense_positions[:, 1] < field_center) & 
                           ((offense_roles == 0) | (offense_roles == 1)))
    right_receivers = np.sum((offense_positions[:, 1] > field_center) & 
                            ((offense_roles == 0) | (offense_roles == 1)))
    
    receiver_positions = offense_positions[(offense_roles == 0) | (offense_roles == 1)]
    if len(receiver_positions) > 0:
        depths = receiver_positions[:, 0]
        deep_receivers = np.sum(depths > np.mean(depths) + 5)
        shallow_receivers = np.sum(depths < np.mean(depths) - 5)
    else:
        deep_receivers = 0
        shallow_receivers = 0
    
    te_present = np.any(offense_roles == 3)
    
    passer_positions = offense_positions[offense_roles == 3]
    if len(passer_positions) > 0:
        passer_depth = passer_positions[0, 0]
    else:
        passer_depth = np.mean(offense_positions[:, 0]) if len(offense_positions) > 0 else 50
    
    for p_idx in range(num_players):
        features[p_idx, 0] = left_receivers / 5.0
        features[p_idx, 1] = right_receivers / 5.0
        features[p_idx, 2] = deep_receivers / 3.0
        features[p_idx, 3] = shallow_receivers / 3.0
        features[p_idx, 4] = 1.0 if te_present else 0.0
        features[p_idx, 5] = passer_depth / Config.FIELD_X_MAX
        
        if all_sides[p_idx] == 0:
            features[p_idx, 6] = 1.0 if all_positions[p_idx, 1] < field_center else 0.0
            features[p_idx, 7] = abs(all_positions[p_idx, 1] - field_center) / (Config.FIELD_Y_MAX / 2)
    
    return features

def compute_coverage_features(all_positions, all_sides, all_roles, ball_land_pos):
    num_players = len(all_positions)
    features = np.zeros((num_players, 6))
    
    defense_positions = all_positions[all_sides == 1]
    defense_roles = all_roles[all_sides == 1]
    
    if len(defense_positions) == 0:
        return features
    
    offense_positions = all_positions[all_sides == 0]
    
    db_positions = defense_positions[defense_roles == 2]
    if len(db_positions) > 0:
        avg_db_depth = np.mean(db_positions[:, 0])
        deep_coverage = np.sum(db_positions[:, 0] > avg_db_depth + 5)
    else:
        avg_db_depth = ball_land_pos[0]
        deep_coverage = 0
    
    for p_idx in range(num_players):
        if all_sides[p_idx] == 1:
            features[p_idx, 0] = avg_db_depth / Config.FIELD_X_MAX
            features[p_idx, 1] = deep_coverage / 4.0
            
            if len(offense_positions) > 0:
                distances_to_offense = np.sqrt(np.sum((offense_positions - all_positions[p_idx])**2, axis=1))
                min_distance = np.min(distances_to_offense)
                features[p_idx, 2] = 1.0 if min_distance < 3.0 else 0.0
                features[p_idx, 3] = 1.0 if min_distance > 8.0 else 0.0
            
            features[p_idx, 4] = 1.0 if all_positions[p_idx, 0] < avg_db_depth - 3 else 0.0
            
            if len(defense_positions) > 1:
                distances_to_teammates = np.sqrt(np.sum((defense_positions - all_positions[p_idx])**2, axis=1))
                distances_to_teammates = distances_to_teammates[distances_to_teammates > 0]
                if len(distances_to_teammates) > 0:
                    avg_teammate_distance = np.mean(distances_to_teammates)
                    features[p_idx, 5] = avg_teammate_distance / 15.0
    
    return features

def compute_player_interaction_features(all_positions, all_velocities, all_sides, all_roles, player_idx):
    features = np.zeros(15)
    player_pos = all_positions[player_idx]
    player_vel = all_velocities[player_idx]
    player_side = all_sides[player_idx]
    
    opponents = all_positions[all_sides != player_side]
    teammates = all_positions[all_sides == player_side]
    
    if len(opponents) > 0:
        opponent_distances = np.linalg.norm(opponents - player_pos, axis=1)
        nearest_opponent_dist = np.min(opponent_distances)
        features[0] = nearest_opponent_dist / 10.0
        
        features[1] = np.sum(opponent_distances < 3.0) / 5.0
        features[2] = np.sum(opponent_distances < 5.0) / 7.0
        features[3] = np.sum(opponent_distances < 10.0) / 11.0
        
        nearest_idx = np.argmin(opponent_distances)
        nearest_opp_vel = all_velocities[np.where(all_sides != player_side)[0][nearest_idx]]
        direction_to_player = (player_pos - opponents[nearest_idx]) / (opponent_distances[nearest_idx] + 1e-6)
        closing_velocity = np.dot(nearest_opp_vel, direction_to_player)
        features[4] = np.clip(closing_velocity / 10.0, -1, 1)
        
        pressure_weights = np.exp(-opponent_distances / 5.0)
        features[5] = np.sum(pressure_weights) / 10.0
    
    if len(teammates) > 1:
        teammate_distances = np.linalg.norm(teammates - player_pos, axis=1)
        teammate_distances = teammate_distances[teammate_distances > 0.1]
        if len(teammate_distances) > 0:
            features[6] = np.min(teammate_distances) / 10.0
            features[7] = np.mean(teammate_distances) / 15.0
            
            features[8] = np.std(teammates[:, 0]) / 10.0
            features[9] = np.std(teammates[:, 1]) / 8.0
    
    if len(opponents) > 0:
        sorted_dists = np.sort(opponent_distances)
        features[10] = np.mean(sorted_dists[:min(3, len(sorted_dists))]) / 10.0
    
    if len(opponents) > 0:
        nearest_idx = np.argmin(opponent_distances)
        nearest_opp_vel = all_velocities[np.where(all_sides != player_side)[0][nearest_idx]]
        separation_velocity = np.linalg.norm(player_vel - nearest_opp_vel)
        features[11] = separation_velocity / 10.0
    
    if player_side == 0:
        db_positions = all_positions[(all_sides == 1) & (all_roles == 2)]
        if len(db_positions) > 0:
            db_distances = np.linalg.norm(db_positions - player_pos, axis=1)
            features[12] = np.min(db_distances) / 10.0
    else:
        receiver_positions = all_positions[(all_sides == 0) & ((all_roles == 0) | (all_roles == 1))]
        if len(receiver_positions) > 0:
            receiver_distances = np.linalg.norm(receiver_positions - player_pos, axis=1)
            features[13] = np.min(receiver_distances) / 10.0
    
    all_distances = np.linalg.norm(all_positions - player_pos, axis=1)
    features[14] = (np.sum(all_distances < 10.0) - 1) / 15.0
    
    return features

def compute_trajectory_features(positions_history, velocities_history, accelerations_history, 
                                ball_land_pos, time_to_ball):
    features = np.zeros(12)
    
    if len(positions_history) < 2:
        return features
    
    current_pos = positions_history[-1]
    current_vel = velocities_history[-1]
    current_acc = accelerations_history[-1]
    
    straight_pred = current_pos + current_vel * time_to_ball
    features[0] = np.linalg.norm(straight_pred - ball_land_pos) / 50.0
    
    if len(accelerations_history) > 0:
        acc_pred = current_pos + current_vel * time_to_ball + 0.5 * current_acc * (time_to_ball ** 2)
        features[1] = np.linalg.norm(acc_pred - ball_land_pos) / 50.0
    
    if time_to_ball > 0:
        optimal_velocity = (ball_land_pos - current_pos) / time_to_ball
        features[2] = np.linalg.norm(optimal_velocity) / 15.0
        
        velocity_deficit = np.linalg.norm(current_vel - optimal_velocity)
        features[3] = velocity_deficit / 10.0
        
        if np.linalg.norm(current_vel) > 0.1 and np.linalg.norm(optimal_velocity) > 0.1:
            cos_angle = np.dot(current_vel, optimal_velocity) / (np.linalg.norm(current_vel) * np.linalg.norm(optimal_velocity))
            features[4] = (1 - cos_angle) / 2
    
    if len(positions_history) >= 3:
        recent_positions = np.array(positions_history[-3:])
        vec1 = recent_positions[1] - recent_positions[0]
        vec2 = recent_positions[2] - recent_positions[1]
        if np.linalg.norm(vec1) > 0.1 and np.linalg.norm(vec2) > 0.1:
            cos_curve = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
            features[5] = (1 - cos_curve) / 2
    
    if len(velocities_history) >= 3:
        recent_speeds = [np.linalg.norm(v) for v in velocities_history[-3:]]
        features[6] = (recent_speeds[-1] - recent_speeds[0]) / 10.0
        features[7] = np.std(recent_speeds) / 3.0
    
    if len(positions_history) >= 3:
        start_pos = positions_history[-3]
        end_pos = positions_history[-1]
        mid_expected = (start_pos + end_pos) / 2
        mid_actual = positions_history[-2]
        features[8] = np.linalg.norm(mid_actual - mid_expected) / 5.0
    
    if time_to_ball > 0:
        displacement_needed = ball_land_pos - current_pos
        required_acc = 2 * (displacement_needed - current_vel * time_to_ball) / (time_to_ball ** 2)
        features[9] = np.linalg.norm(required_acc) / 10.0
    
    if len(accelerations_history) >= 2:
        acc_change = accelerations_history[-1] - accelerations_history[-2]
        features[10] = np.linalg.norm(acc_change) / 5.0
    
    if time_to_ball > 0:
        target_direction = (ball_land_pos - current_pos) / (np.linalg.norm(ball_land_pos - current_pos) + 1e-6)
        momentum_toward_target = np.dot(current_vel, target_direction)
        features[11] = momentum_toward_target / 10.0
    
    return features

def compute_zone_and_route_features(player_pos, ball_land_pos, all_positions, all_sides, 
                                    all_roles, player_idx):
    features = np.zeros(14)
    
    player_side = all_sides[player_idx]
    player_role = all_roles[player_idx]
    
    if player_pos[1] < Config.FIELD_HASH_LEFT:
        features[0] = 1.0
    elif player_pos[1] > Config.FIELD_HASH_RIGHT:
        features[1] = 1.0
    else:
        features[2] = 1.0
    
    if ball_land_pos[1] < Config.FIELD_HASH_LEFT:
        features[3] = 1.0
    elif ball_land_pos[1] > Config.FIELD_HASH_RIGHT:
        features[4] = 1.0
    else:
        features[5] = 1.0
    
    features[6] = 1.0 if (features[0] and (features[4] or features[5])) or \
                         (features[1] and (features[3] or features[5])) else 0.0
    
    features[7] = 1.0 if ball_land_pos[0] > 100 or ball_land_pos[0] < 20 else 0.0
    
    depth = ball_land_pos[0] - player_pos[0]
    features[8] = 1.0 if depth < 5 else 0.0
    features[9] = 1.0 if 5 <= depth < 15 else 0.0
    features[10] = 1.0 if depth >= 15 else 0.0
    
    features[11] = min(player_pos[1], Config.FIELD_Y_MAX - player_pos[1]) / Config.FIELD_Y_MAX
    
    features[12] = min(ball_land_pos[1], Config.FIELD_Y_MAX - ball_land_pos[1]) / Config.FIELD_Y_MAX
    
    if player_side == 0:
        receiver_positions = all_positions[(all_sides == 0) & ((all_roles == 0) | (all_roles == 1))]
        if len(receiver_positions) > 0:
            same_zone_receivers = 0
            for recv_pos in receiver_positions:
                if abs(recv_pos[1] - player_pos[1]) < 10:
                    same_zone_receivers += 1
            features[13] = same_zone_receivers / 5.0
    
    return features

def compute_physics_features(player_pos, player_vel, player_acc, player_weight, ball_land_pos, time_to_ball):
    features = np.zeros(13)
    
    vel_magnitude = np.linalg.norm(player_vel)
    features[0] = vel_magnitude / 10.0
    
    features[1] = vel_magnitude ** 2 / 100.0
    
    acc_magnitude = np.linalg.norm(player_acc)
    features[2] = acc_magnitude / 5.0
    
    features[3] = player_weight * player_vel[0] / 2000.0
    features[4] = player_weight * player_vel[1] / 2000.0
    
    features[5] = 0.5 * player_weight * (vel_magnitude ** 2) / 10000.0
    
    if time_to_ball > 0:
        expected_pos = player_pos + player_vel * time_to_ball
        features[6] = expected_pos[0] / Config.FIELD_X_MAX
        features[7] = expected_pos[1] / Config.FIELD_Y_MAX
        
        error = np.linalg.norm(expected_pos - ball_land_pos)
        features[8] = error / 50.0
        
        distance_to_ball = np.linalg.norm(ball_land_pos - player_pos)
        features[9] = distance_to_ball / (time_to_ball + 0.1) / 10.0
    
    distance_to_ball = np.linalg.norm(ball_land_pos - player_pos)
    if distance_to_ball > 0.1:
        direction_to_ball = (ball_land_pos - player_pos) / distance_to_ball
        velocity_toward_ball = np.dot(player_vel, direction_to_ball)
        features[10] = velocity_toward_ball / 10.0
    
    features[11] = time_to_ball ** 2 / 25.0
    features[12] = distance_to_ball ** 2 / 2500.0
    
    return features

def compute_comprehensive_features(player_idx, positions_window, velocities_window, 
                                  accelerations_window, player_sides, player_roles,
                                  ball_land_pos, all_positions_current, all_velocities_current,
                                  player_weight, time_to_ball):
    features = np.zeros(108)
    
    player_pos = positions_window[-1][player_idx]
    player_vel = velocities_window[-1][player_idx]
    player_acc = accelerations_window[-1][player_idx]
    player_side = player_sides[player_idx]
    player_role = player_roles[player_idx]
    
    positions_history = [pos[player_idx] for pos in positions_window]
    velocities_history = [vel[player_idx] for vel in velocities_window]
    accelerations_history = [acc[player_idx] for acc in accelerations_window]
    
    rolling_stats = compute_rolling_statistics(positions_history, velocities_history, 
                                               accelerations_history)
    features[0:8] = rolling_stats
    
    route_features = classify_route(player_pos, ball_land_pos, player_vel, 
                                    player_role, player_side)
    features[8:14] = route_features
    
    targeted_receiver_positions = all_positions_current[player_roles == 0]
    targeted_receiver_pos = targeted_receiver_positions[0] if len(targeted_receiver_positions) > 0 else None
    
    pursuit_features = compute_pursuit_angles(player_pos, player_vel, player_side, 
                                             player_role, ball_land_pos, 
                                             targeted_receiver_pos)
    features[14:20] = pursuit_features
    
    formation_features = compute_formation_features(all_positions_current, player_sides, 
                                                   player_roles, ball_land_pos)
    features[20:28] = formation_features[player_idx]
    
    coverage_features = compute_coverage_features(all_positions_current, player_sides, 
                                                  player_roles, ball_land_pos)
    features[28:34] = coverage_features[player_idx]
    
    dx_ball = ball_land_pos[0] - player_pos[0]
    dy_ball = ball_land_pos[1] - player_pos[1]
    distance_to_ball = np.sqrt(dx_ball**2 + dy_ball**2)
    features[34] = distance_to_ball / 50.0
    
    angle_to_ball = np.arctan2(dy_ball, dx_ball)
    features[35] = angle_to_ball / np.pi
    
    opponent_side = 1 - player_side
    opponent_positions = all_positions_current[player_sides == opponent_side]
    if len(opponent_positions) > 0:
        distances_to_opponents = np.sqrt(np.sum((opponent_positions - player_pos)**2, axis=1))
        features[36] = np.min(distances_to_opponents) / 20.0
    else:
        features[36] = 1.0
    
    all_distances = np.sqrt(np.sum((all_positions_current - player_pos)**2, axis=1))
    features[37] = (np.sum(all_distances < 5.0) - 1) / 10.0
    
    features[38] = min(player_pos[1], Config.FIELD_Y_MAX - player_pos[1]) / Config.FIELD_Y_MAX
    features[39] = min(player_pos[0], Config.FIELD_X_MAX - player_pos[0]) / Config.FIELD_X_MAX
    features[40] = 1.0 if player_role == 0 else 0.0
    features[41] = player_pos[0] / Config.FIELD_X_MAX if player_side == 0 else 1.0 - player_pos[0] / Config.FIELD_X_MAX
    
    interaction_features = compute_player_interaction_features(
        all_positions_current, all_velocities_current, player_sides, player_roles, player_idx
    )
    features[42:57] = interaction_features
    
    trajectory_features = compute_trajectory_features(
        positions_history, velocities_history, accelerations_history, ball_land_pos, time_to_ball
    )
    features[57:69] = trajectory_features
    
    zone_features = compute_zone_and_route_features(
        player_pos, ball_land_pos, all_positions_current, player_sides, player_roles, player_idx
    )
    features[69:83] = zone_features
    
    physics_features = compute_physics_features(
        player_pos, player_vel, player_acc, player_weight, ball_land_pos, time_to_ball
    )
    features[83:96] = physics_features
    
    features[96] = player_pos[0] / Config.FIELD_X_MAX
    features[97] = player_pos[1] / Config.FIELD_Y_MAX
    features[98] = (Config.FIELD_X_MAX - player_pos[0]) / Config.FIELD_X_MAX
    features[99] = (Config.FIELD_Y_MAX - player_pos[1]) / Config.FIELD_Y_MAX
    
    features[100] = dx_ball / Config.FIELD_X_MAX
    features[101] = dy_ball / Config.FIELD_Y_MAX
    
    vel_magnitude = np.linalg.norm(player_vel)
    if vel_magnitude > 0.1 and distance_to_ball > 0.1:
        direction_to_ball = np.array([dx_ball, dy_ball]) / distance_to_ball
        velocity_alignment = np.dot(player_vel / vel_magnitude, direction_to_ball)
        features[102] = velocity_alignment
    
    features[103] = np.log1p(distance_to_ball) / 5.0
    features[104] = np.sqrt(distance_to_ball) / 10.0
    
    if len(opponent_positions) > 0:
        avg_opponent_distance = np.mean(distances_to_opponents)
        features[105] = avg_opponent_distance / 20.0
        
        close_opponents = np.sum(distances_to_opponents < 5.0)
        features[106] = close_opponents / 5.0
    
    features[107] = 1.0 / (time_to_ball + 1.0)
    
    return features

# ================================================================================
# TRAINING
# ================================================================================

def train_model_enhanced(model, train_data, val_data, epochs=30, config_name="default"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nTraining {config_name}")
    print(f"Using device: {device}")
    
    model = model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    
    scaler = GradScaler('cuda') if Config.USE_AMP and torch.cuda.is_available() else None
    
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 3 if Config.SEARCH_MODE else Config.PATIENCE
    
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        train_losses = []
        optimizer.zero_grad()
        
        pbar = tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        for batch_idx, batch in enumerate(pbar):
            batch_dict = {
                'positions': torch.FloatTensor(batch['positions']).unsqueeze(0).to(device),
                'velocities': torch.FloatTensor(batch['velocities']).unsqueeze(0).to(device),
                'accelerations': torch.FloatTensor(batch['accelerations']).unsqueeze(0).to(device),
                'player_features': {
                    'role': torch.LongTensor(batch['player_roles']).to(device),
                    'side': torch.LongTensor(batch['player_sides']).to(device),
                    'continuous': torch.FloatTensor(batch['player_continuous']).to(device)
                }
            }
            
            targets = torch.FloatTensor(batch['targets']).unsqueeze(0).to(device)
            engineered_features = torch.FloatTensor(batch['engineered_features']).unsqueeze(0).to(device)
            
            if Config.USE_AMP and scaler is not None:
                with autocast('cuda'):
                    predictions = model(batch_dict, engineered_features=engineered_features)
                    
                    loss = 0
                    for horizon, pred in predictions.items():
                        target_slice = targets[:, :, :horizon, :]
                        horizon_loss = F.mse_loss(pred, target_slice)
                        weight = horizon / max(predictions.keys())
                        loss += weight * horizon_loss
                    
                    loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                
                scaler.scale(loss).backward()
                
                if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                predictions = model(batch_dict, engineered_features=engineered_features)
                
                loss = 0
                for horizon, pred in predictions.items():
                    target_slice = targets[:, :, :horizon, :]
                    horizon_loss = F.mse_loss(pred, target_slice)
                    weight = horizon / max(predictions.keys())
                    loss += weight * horizon_loss
                
                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                loss.backward()
                
                if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    optimizer.zero_grad()
            
            train_losses.append(loss.item() * Config.GRADIENT_ACCUMULATION_STEPS)
            pbar.set_postfix({'loss': f"{np.mean(train_losses):.4f}"})
        
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in val_data:
                batch_dict = {
                    'positions': torch.FloatTensor(batch['positions']).unsqueeze(0).to(device),
                    'velocities': torch.FloatTensor(batch['velocities']).unsqueeze(0).to(device),
                    'accelerations': torch.FloatTensor(batch['accelerations']).unsqueeze(0).to(device),
                    'player_features': {
                        'role': torch.LongTensor(batch['player_roles']).to(device),
                        'side': torch.LongTensor(batch['player_sides']).to(device),
                        'continuous': torch.FloatTensor(batch['player_continuous']).to(device)
                    }
                }
                
                targets = torch.FloatTensor(batch['targets']).unsqueeze(0).to(device)
                engineered_features = torch.FloatTensor(batch['engineered_features']).unsqueeze(0).to(device)
                
                if Config.USE_AMP and scaler is not None:
                    with autocast('cuda'):
                        predictions = model(batch_dict, engineered_features=engineered_features)
                        loss = 0
                        for horizon, pred in predictions.items():
                            target_slice = targets[:, :, :horizon, :]
                            horizon_loss = F.mse_loss(pred, target_slice)
                            weight = horizon / max(predictions.keys())
                            loss += weight * horizon_loss
                else:
                    predictions = model(batch_dict, engineered_features=engineered_features)
                    loss = 0
                    for horizon, pred in predictions.items():
                        target_slice = targets[:, :, :horizon, :]
                        horizon_loss = F.mse_loss(pred, target_slice)
                        weight = horizon / max(predictions.keys())
                        loss += weight * horizon_loss
                
                val_losses.append(loss.item())
        
        mean_train_loss = np.mean(train_losses)
        mean_val_loss = np.mean(val_losses)
        
        print(f"Epoch {epoch+1}/{epochs}: Train={mean_train_loss:.4f}, Val={mean_val_loss:.4f}")
        
        scheduler.step(mean_val_loss)
        
        if mean_val_loss < best_val_loss:
            best_val_loss = mean_val_loss
            patience_counter = 0
            if not Config.SEARCH_MODE:
                torch.save(model.state_dict(), f'best_{config_name}_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    training_time = time.time() - start_time
    
    return model, best_val_loss, mean_train_loss, training_time

# ================================================================================
# DATA LOADING
# ================================================================================

def load_data():
    print("Loading data...")
    train_input_files = [Config.DATA_DIR / f"train/input_2023_w{w:02d}.csv" for w in range(1, 19)]
    train_output_files = [Config.DATA_DIR / f"train/output_2023_w{w:02d}.csv" for w in range(1, 19)]
    
    train_input_files = [f for f in train_input_files if f.exists()]
    train_output_files = [f for f in train_output_files if f.exists()]
    
    train_input = pd.concat([pd.read_csv(f) for f in train_input_files], ignore_index=True)
    train_output = pd.concat([pd.read_csv(f) for f in train_output_files], ignore_index=True)
    
    test_input = pd.read_csv(Config.DATA_DIR / "test_input.csv")
    test_template = pd.read_csv(Config.DATA_DIR / "test.csv")
    
    return train_input, train_output, test_input, test_template

def prepare_graph_data(input_df, output_df=None, window_size=8, is_training=True):
    data_list = []
    
    for (game_id, play_id), play_group in tqdm(input_df.groupby(['game_id', 'play_id']), desc="Processing"):
        players = play_group['nfl_id'].unique()
        num_players = len(players)
        
        if num_players == 0:
            continue
        
        ball_land_x = play_group['ball_land_x'].iloc[0]
        ball_land_y = play_group['ball_land_y'].iloc[0]
        
        if pd.isna(ball_land_x) or pd.isna(ball_land_y):
            ball_land_x = play_group['x'].mean()
            ball_land_y = play_group['y'].mean()
        
        positions = np.zeros((window_size, num_players, 2))
        velocities = np.zeros((window_size, num_players, 2))
        accelerations = np.zeros((window_size, num_players, 1))
        engineered_features = np.zeros((window_size, num_players, Config.ENGINEERED_FEATURES_DIM))
        player_roles = np.zeros(num_players, dtype=np.int64)
        player_sides = np.zeros(num_players, dtype=np.int64)
        player_continuous = np.zeros((num_players, 8))
        player_weights = np.zeros(num_players)
        
        for p_idx, nfl_id in enumerate(players):
            player_data = play_group[play_group['nfl_id'] == nfl_id].sort_values('frame_id').tail(window_size)
            
            if len(player_data) < window_size:
                pad_df = pd.DataFrame(np.nan, index=range(window_size - len(player_data)), columns=player_data.columns)
                player_data = pd.concat([pad_df, player_data], ignore_index=True)
            
            positions[:, p_idx, 0] = player_data['x'].bfill().fillna(0).values
            positions[:, p_idx, 1] = player_data['y'].bfill().fillna(0).values
            
            dir_rad = np.deg2rad(player_data['dir'].fillna(0))
            velocities[:, p_idx, 0] = (player_data['s'] * np.cos(dir_rad)).fillna(0).values
            velocities[:, p_idx, 1] = (player_data['s'] * np.sin(dir_rad)).fillna(0).values
            accelerations[:, p_idx, 0] = player_data['a'].fillna(0).values
            
            last_frame = player_data.iloc[-1]
            role_map = {'Targeted Receiver': 0, 'Other Route Runner': 1, 'Defensive Coverage': 2, 'Passer': 3}
            player_roles[p_idx] = role_map.get(last_frame['player_role'], 4)
            player_sides[p_idx] = 0 if last_frame['player_side'] == 'Offense' else 1
            
            try:
                h = last_frame['player_height'].split('-')
                height = float(h[0]) * 12 + float(h[1])
            except:
                height = 70
            
            weight = last_frame['player_weight'] if pd.notna(last_frame['player_weight']) else 200
            player_weights[p_idx] = weight
            
            player_continuous[p_idx] = [
                height,
                weight,
                last_frame['s'] if pd.notna(last_frame['s']) else 0,
                last_frame['a'] if pd.notna(last_frame['a']) else 0,
                last_frame['o'] if pd.notna(last_frame['o']) else 0,
                last_frame['dir'] if pd.notna(last_frame['dir']) else 0,
                last_frame['absolute_yardline_number'] if pd.notna(last_frame['absolute_yardline_number']) else 50,
                0
            ]
        
        num_frames_output = play_group['num_frames_output'].iloc[0] if 'num_frames_output' in play_group.columns else 25
        time_to_ball = num_frames_output / 10.0
        
        for t in range(window_size):
            positions_window = positions[:t+1]
            velocities_window = velocities[:t+1]
            accelerations_window = accelerations[:t+1, :, 0]
            
            ball_land_pos = np.array([ball_land_x, ball_land_y])
            all_positions_current = positions[t]
            all_velocities_current = velocities[t]
            
            for p_idx in range(num_players):
                engineered_features[t, p_idx] = compute_comprehensive_features(
                    p_idx, positions_window, velocities_window, accelerations_window,
                    player_sides, player_roles, ball_land_pos, all_positions_current,
                    all_velocities_current, player_weights[p_idx], time_to_ball
                )
        
        targets = None
        if is_training and output_df is not None:
            play_output = output_df[(output_df['game_id'] == game_id) & (output_df['play_id'] == play_id)]
            targets = np.zeros((num_players, Config.PREDICTION_HORIZON, 2))
            
            for p_idx, nfl_id in enumerate(players):
                player_output = play_output[play_output['nfl_id'] == nfl_id].sort_values('frame_id')
                if len(player_output) > 0:
                    horizon = min(len(player_output), Config.PREDICTION_HORIZON)
                    targets[p_idx, :horizon, 0] = player_output['x'].values[:horizon] - positions[-1, p_idx, 0]
                    targets[p_idx, :horizon, 1] = player_output['y'].values[:horizon] - positions[-1, p_idx, 1]
        
        data_list.append({
            'game_id': game_id,
            'play_id': play_id,
            'player_ids': players,
            'positions': positions,
            'velocities': velocities,
            'accelerations': accelerations,
            'engineered_features': engineered_features,
            'player_roles': player_roles,
            'player_sides': player_sides,
            'player_continuous': player_continuous,
            'targets': targets,
            'ball_land_x': ball_land_x,
            'ball_land_y': ball_land_y
        })
    
    return data_list

# ================================================================================
# FIXED PREDICTION FUNCTION
# ================================================================================

def make_predictions(model, test_data, test_template):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    all_predictions = []
    
    with torch.no_grad():
        for batch in tqdm(test_data, desc="Inference"):
            batch_dict = {
                'positions': torch.FloatTensor(batch['positions']).unsqueeze(0).to(device),
                'velocities': torch.FloatTensor(batch['velocities']).unsqueeze(0).to(device),
                'accelerations': torch.FloatTensor(batch['accelerations']).unsqueeze(0).to(device),
                'player_features': {
                    'role': torch.LongTensor(batch['player_roles']).to(device),
                    'side': torch.LongTensor(batch['player_sides']).to(device),
                    'continuous': torch.FloatTensor(batch['player_continuous']).to(device)
                }
            }
            
            engineered_features = torch.FloatTensor(batch['engineered_features']).unsqueeze(0).to(device)
            
            if Config.USE_AMP and torch.cuda.is_available():
                with autocast('cuda'):
                    predictions = model(batch_dict, target_horizon=Config.PREDICTION_HORIZON, 
                                      engineered_features=engineered_features)
            else:
                predictions = model(batch_dict, target_horizon=Config.PREDICTION_HORIZON,
                                  engineered_features=engineered_features)
            
            pred_array = predictions[Config.PREDICTION_HORIZON].cpu().numpy()[0]
            last_positions = batch['positions'][-1]
            last_velocities = batch['velocities'][-1]
            
            for p_idx, nfl_id in enumerate(batch['player_ids']):
                player_frames = test_template[
                    (test_template['game_id'] == batch['game_id']) &
                    (test_template['play_id'] == batch['play_id']) &
                    (test_template['nfl_id'] == nfl_id)
                ].sort_values('frame_id')
                
                num_frames_needed = len(player_frames)
                
                for idx, (_, row) in enumerate(player_frames.iterrows()):
                    if idx < Config.PREDICTION_HORIZON:
                        # Use model predictions
                        pred_x = last_positions[p_idx, 0] + pred_array[p_idx, idx, 0]
                        pred_y = last_positions[p_idx, 1] + pred_array[p_idx, idx, 1]
                    else:
                        # Extrapolate beyond prediction horizon using constant velocity
                        frames_beyond = idx - Config.PREDICTION_HORIZON + 1
                        pred_x = (last_positions[p_idx, 0] + pred_array[p_idx, -1, 0] + 
                                 last_velocities[p_idx, 0] * frames_beyond * 0.1)
                        pred_y = (last_positions[p_idx, 1] + pred_array[p_idx, -1, 1] + 
                                 last_velocities[p_idx, 1] * frames_beyond * 0.1)
                    
                    # Clip to field boundaries
                    pred_x = float(np.clip(pred_x, Config.FIELD_X_MIN, Config.FIELD_X_MAX))
                    pred_y = float(np.clip(pred_y, Config.FIELD_Y_MIN, Config.FIELD_Y_MAX))
                    
                    all_predictions.append({
                        'id': f"{int(batch['game_id'])}_{int(batch['play_id'])}_{int(nfl_id)}_{int(row['frame_id'])}",
                        'x': pred_x,
                        'y': pred_y
                    })
    
    return pd.DataFrame(all_predictions)

# ================================================================================
# CONFIG SEARCH
# ================================================================================

def run_config_search(train_graph_data, val_graph_data):
    Config.SEARCH_MODE = True
    
    from sklearn.model_selection import train_test_split
    search_train, _ = train_test_split(train_graph_data, 
                                       train_size=ConfigSearch.SEARCH_DATA_FRACTION, 
                                       random_state=42)
    search_val, _ = train_test_split(val_graph_data, 
                                     train_size=ConfigSearch.SEARCH_DATA_FRACTION, 
                                     random_state=42)
    
    results = []
    
    print("\n" + "="*80)
    print("STARTING CONFIGURATION SEARCH")
    print(f"Testing {len(ConfigSearch.CONFIGS_TO_TEST)} configurations")
    print(f"Using {len(search_train)} training samples, {len(search_val)} validation samples")
    print(f"Training for {ConfigSearch.SEARCH_EPOCHS} epochs each")
    print("="*80 + "\n")
    
    for i, config_dict in enumerate(ConfigSearch.CONFIGS_TO_TEST, 1):
        print(f"\n{'='*80}")
        print(f"CONFIG {i}/{len(ConfigSearch.CONFIGS_TO_TEST)}: {config_dict['name']}")
        print(f"{'='*80}")
        
        Config.update_from_dict(config_dict)
        
        model = EnhancedSpatioTemporalPredictor()
        num_params = sum(p.numel() for p in model.parameters())
        print(f"Parameters: {num_params:,}")
        
        trained_model, best_val_loss, final_train_loss, training_time = train_model_enhanced(
            model, search_train, search_val, 
            epochs=ConfigSearch.SEARCH_EPOCHS,
            config_name=config_dict['name']
        )
        
        result = {
            'config_name': config_dict['name'],
            'config': config_dict,
            'best_val_loss': best_val_loss,
            'final_train_loss': final_train_loss,
            'num_parameters': num_params,
            'training_time': training_time
        }
        results.append(result)
        
        print(f"\nResults: Val={best_val_loss:.4f}, Train={final_train_loss:.4f}, Time={training_time:.1f}s")
        
        del model
        del trained_model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    with open('config_search_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    best_config = ConfigSearch.get_results_summary()
    
    return best_config

# ================================================================================
# MAIN
# ================================================================================

if __name__ == "__main__":
    train_input, train_output, test_input, test_template = load_data()
    
    train_graph_data = prepare_graph_data(train_input, train_output, Config.WINDOW_SIZE, True)
    test_graph_data = prepare_graph_data(test_input, window_size=Config.WINDOW_SIZE, is_training=False)
    
    from sklearn.model_selection import train_test_split
    train_data, val_data = train_test_split(train_graph_data, test_size=0.2, random_state=42)
    
    print(f"Total samples: {len(train_graph_data)}")
    print(f"Training samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")
    
    best_config = run_config_search(train_data, val_data)
    
    print("\n" + "="*80)
    print("TRAINING FINAL MODEL WITH BEST CONFIGURATION")
    print("="*80)
    
    Config.SEARCH_MODE = False
    Config.update_from_dict(best_config['config'])
    
    final_model = EnhancedSpatioTemporalPredictor()
    print(f"Model parameters: {sum(p.numel() for p in final_model.parameters()):,}")
    
    trained_model, _, _, _ = train_model_enhanced(
        final_model, train_data, val_data, 
        epochs=Config.EPOCHS,
        config_name="final"
    )
    
    submission = make_predictions(trained_model, test_graph_data, test_template)
    
    # Verify submission matches test_template
    expected_rows = len(test_template)
    actual_rows = len(submission)
    print(f"\nExpected rows: {expected_rows}")
    print(f"Actual rows: {actual_rows}")
    print(f"Match: {expected_rows == actual_rows}")
    
    submission.to_csv('submission.csv', index=False)
    print(f"Submission saved: {len(submission)} predictions")
    
    print("\n" + "="*80)
    print("COMPLETE!")
    print("="*80)