# Football Reception Prediction Analysis

This notebook demonstrates the use of Graph Attention Networks (GAT) for analyzing football reception predictions. It includes data loading, model initialization, and interactive visualization capabilities.

## Overview
- **Data Source**: PFF tracking data for game analysis
- **Model**: Graph Attention Network for reception prediction
- **Visualization**: Interactive pitch visualization

## 1. Import Required Libraries

In [None]:
# Standard library imports
import bz2
import json
import os
import pickle
import gc

# Data processing imports
import pandas as pd
import numpy as np

# PyTorch imports
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, GENConv, SAGEConv, GATv2Conv
from torch_geometric.data import Data, DataLoader

# Scikit-learn imports
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MinMaxScaler

# Custom module imports
import convert_tracking as ct
import plot_functions as pf
import create_graph as cg
import scale_graph as sg
import visualisation
from GNNs.custom_GAT import myGATv2Conv
import GNNs.model_training as mt
import GNNs.convert_data as cd
from GNNs.GNN import ReceptionPredictionGNN
from GNNs.GAT import GATReceptionPredictor

## 2. Configuration and Data Paths

In [None]:
# Configuration parameters
GAME_ID = 13335
DATA_DIR = 'Data'
XT_GRID_PATH = 'xT_grid.csv'
MODEL_PATH = 'Data/graphs/gat_model2.pth'

# Load Expected Threat (xT) grid
xT_grid = pd.read_csv(XT_GRID_PATH, header=None)

# Change to data directory
os.chdir(DATA_DIR)

# Set up file paths
filepath = f'{GAME_ID}.jsonl.bz2'

print(f"Configuration loaded for Game ID: {GAME_ID}")
print(f"Data directory: {DATA_DIR}")
print(f"xT grid shape: {xT_grid.shape}")

## 3. Data Loading and Preprocessing

In [None]:
# Load game metadata
(
    home_team_id, away_team_id, home_team_name, away_team_name, 
    home_team_start_left, rosters_for_game_home, rosters_for_game_away,
    roster_game_home_name_dict, roster_game_home_team_name_dict, roster_game_home_pos_dict,
    roster_game_away_name_dict, roster_game_away_team_name_dict, roster_game_away_pos_dict,
    pitch_x_adjustment, pitch_y_adjustment
) = ct.get_metadata(GAME_ID)

print(f"Game: {home_team_name} vs {away_team_name}")
print(f"Home team starts left: {home_team_start_left}")

In [None]:
# Load game dataframes
data_path = f'Data/game_dataframes/{GAME_ID}/'

balls_df = pd.read_csv(f'{data_path}{GAME_ID}_balls_df.csv')
events_df = pd.read_csv(f'{data_path}{GAME_ID}_events_df.csv')
players_df = pd.read_csv(f'{data_path}{GAME_ID}_players_df.csv')

# Calculate ball velocities
balls_df = ct.calculate_ball_velocities(balls_df)

print(f"Data loaded successfully:")
print(f"- Ball tracking points: {len(balls_df)}")
print(f"- Events: {len(events_df)}")
print(f"- Player tracking points: {len(players_df)}")

## 4. Graph Creation and Processing

In [None]:
# Create graphs for each event frame
graphs = []
event_frames = events_df['frameNum'].values

print(f"Creating graphs for {len(event_frames)} event frames...")

for i, frameNum in enumerate(event_frames):
    if i % 200 == 0:  # Progress indicator
        print(f"Processing frame {i+1}/{len(event_frames)}")
    
    G = cg.create_normalized_graph_directed(players_df, balls_df, events_df, frameNum, home_team_name)
    if G is not None:
        graphs.append(G)

print(f"Successfully created {len(graphs)} graphs")

In [None]:
# Load pre-scaled graphs
scaled_graphs_path = f'Data/graphs_scaled_version/{GAME_ID}_graphs.pkl'

try:
    with open(scaled_graphs_path, 'rb') as f:
        loaded_graphs = pickle.load(f)
    print(f"Loaded {len(loaded_graphs)} pre-scaled graphs")
except FileNotFoundError:
    print(f"Pre-scaled graphs not found at {scaled_graphs_path}")
    print("Creating scaled graphs...")
    
    # Create graph scaler and fit to data
    graph_scaler = sg.GraphFeatureScaler()
    graph_scaler.fit(graphs)
    
    # Scale graphs
    loaded_graphs = [graph_scaler.transform(g) for g in graphs]
    print(f"Created {len(loaded_graphs)} scaled graphs")

## 5. Model Definition and Loading

In [None]:
class GATReceptionPredictor(torch.nn.Module):
    """
    Graph Attention Network for predicting football reception probabilities.
    
    This model uses two GAT layers with attention mechanisms to predict
    which players are most likely to receive the ball in a given game state.
    """
    
    def __init__(self, num_node_features, num_edge_features, hidden_channels, edge_hidden_channels, num_heads=16):
        super(GATReceptionPredictor, self).__init__()
        self.num_heads = num_heads
        self.hidden_channels = hidden_channels

        # Feature encoders
        self.node_encoder = torch.nn.Linear(num_node_features, hidden_channels)
        self.edge_encoder = torch.nn.Linear(num_edge_features, edge_hidden_channels)

        # GAT layers
        self.gat1 = myGATv2Conv(
            hidden_channels, hidden_channels // num_heads, heads=num_heads,
            edge_dim=edge_hidden_channels, add_self_loops=True
        )
        self.gat2 = myGATv2Conv(
            hidden_channels, hidden_channels // num_heads, heads=num_heads,
            edge_dim=edge_hidden_channels, add_self_loops=True
        )

        # Final prediction layer
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, x, edge_index, edge_attr, batch, test_mode=False, 
                mask_node_name=None, target_node_name=None, graph=None):
        """
        Forward pass through the GAT model.
        
        Args:
            x: Node features
            edge_index: Edge connectivity
            edge_attr: Edge features
            batch: Batch information
            test_mode: Whether to run in test mode (for defender masking)
            mask_node_name: Node to mask in test mode
            target_node_name: Target node for analysis
            graph: NetworkX graph for node mapping
        """
        # Encode features
        x = self.node_encoder(x)
        x_original = x
        edge_features = self.edge_encoder(edge_attr)

        # First GAT layer
        if not test_mode:
            x, attention_weights1, pre_sm1 = self.gat1(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True
            )
        else:
            # Handle test mode with node masking
            graph_names = [s[0] for s in graph.nodes(data=True)]
            mask_node = graph_names.index(mask_node_name)
            target_node = graph_names.index(target_node_name)
            
            x_temp, attention_weights1, pre_sm1 = self.gat1(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True
            )
            edge_index_1, attn_weights_1 = attention_weights1

            # Mask attention weights from defender to target
            target_mask = (edge_index_1[1] == target_node)
            source_mask = (edge_index_1[0] == mask_node)
            modify_mask = target_mask & source_mask
            attn_weights_1[modify_mask] = 0
            
            x, attention_weights1, pre_sm1 = self.gat1(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True, 
                test=True, alpha_updated=attn_weights_1
            )

        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)

        # Second GAT layer
        if not test_mode:
            x, attention_weights2, pre_sm2 = self.gat2(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True
            )
        else:
            # Similar masking for second layer
            x_temp, attention_weights2, pre_sm2 = self.gat2(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True
            )
            edge_index_2, attn_weights_2 = attention_weights2
            
            target_mask = (edge_index_1[1] == target_node)
            source_mask = (edge_index_1[0] == mask_node)
            modify_mask = target_mask & source_mask
            attn_weights_2[modify_mask] = 0
            
            x, attention_weights2, pre_sm2 = self.gat2(
                x, edge_index, edge_attr=edge_features, return_attention_weights=True, 
                test=True, alpha_updated=attn_weights_2
            )
            
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)

        # Final prediction
        x_combined = torch.cat([x_original, x], dim=1)
        reception_logits = self.mlp(x_combined)

        return torch.sigmoid(reception_logits).squeeze(-1), (attention_weights1, attention_weights2)

In [None]:
# Initialize and load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model parameters
MODEL_CONFIG = {
    'num_node_features': 15,
    'num_edge_features': 6,
    'hidden_channels': 32,
    'edge_hidden_channels': 16,
    'num_heads': 16
}

# Create and load model
loaded_model = GATReceptionPredictor(**MODEL_CONFIG).to(device)
model_state = torch.load(MODEL_PATH, map_location=device)
loaded_model.load_state_dict(model_state)

print("Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in loaded_model.parameters())}")

## 6. Interactive Visualization

This section creates an interactive visualization that allows you to:
- Navigate through different game frames
- Analyze defender influences on reception probabilities
- Visualize attention weights between players
- Compare different game states

In [None]:
# Launch interactive visualization
visualisation.create_simple_visualization(loaded_model, graphs, loaded_graphs, xT_grid)