In [21]:
"""feature_ablation_study.py

This script extends the original train.py to perform feature group ablations:
1. Train with only ball trajectory features (no static features)
2. Train with only static features (no ball trajectory)
"""

# Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import ast
import matplotlib.pyplot as plt
import os
import logging
import pickle
import time
from datetime import datetime
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define paths - modify these to match your environment
data_path = '/content/drive/MyDrive/smai_project/dataset/dataset10/'
output_path = '/content/drive/MyDrive/smai_project/output_ablation_study/'
model_path = '/content/drive/MyDrive/smai_project/model_ablation_study/'

# Create output and model directories if they don't exist
for path in [output_path, model_path]:
    if not os.path.exists(path):
        os.makedirs(path)

# Set up logging
log_file_path = os.path.join(output_path, 'ablation_study.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[logging.FileHandler(log_file_path), logging.StreamHandler()]
)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
# Feature Extraction Function (same as original)
def extract_sequence_features(df):
    """Extract features from ball and player sequences."""
    def get_shooter_position(ball_seq, players_seq, personId):
        shooter_x, shooter_y = None, None
        if players_seq and isinstance(players_seq, list) and len(players_seq) > 0:
            first_frame = players_seq[0] if isinstance(players_seq[0], list) else players_seq
            if first_frame and isinstance(first_frame, list):
                for player in first_frame:
                    if len(player) != 5:
                        continue
                    _, player_id, x, y, _ = player
                    if player_id == personId:
                        shooter_x, shooter_y = x, y
                        break

        if shooter_x is None or shooter_y is None:
            if not ball_seq or not isinstance(ball_seq, list) or len(ball_seq) == 0:
                return 0, 0
            first_pos = ball_seq[0]
            if not isinstance(first_pos, list) or len(first_pos) != 3:
                return 0, 0
            shooter_x, shooter_y, _ = first_pos
        return shooter_x, shooter_y

    df['shooter_x'], df['shooter_y'] = zip(*df.apply(
        lambda row: get_shooter_position(row['ball_seq'], row['players_seq'], row['personId']), axis=1
    ))

    df['initial_height'] = df['ball_seq'].apply(lambda seq: seq[0][2] if len(seq) > 0 and len(seq[0]) == 3 else 0)
    df['max_height'] = df['ball_seq'].apply(lambda seq: max([point[2] for point in seq]) if len(seq) > 0 and all(len(point) == 3 for point in seq) else 0)
    df['traj_length'] = df['ball_seq'].apply(len)
    df['shotDistance'] = df.apply(
        lambda row: np.sqrt((row['shooter_x'] - row['basket_x'])**2 + (row['shooter_y'] - row['basket_y'])**2), axis=1
    )
    df['shotDistance'] = df['shotDistance'].replace(0, 1)
    df['traj_curvature'] = df['max_height'] / df['shotDistance']

    df['release_angle'] = df.apply(
        lambda row: np.arctan2(row['shooter_y'] - row['basket_y'], row['shooter_x'] - row['basket_x']), axis=1
    )

    def calculate_defender_proximity(shooter_x, shooter_y, players_seq):
        min_distance = float('inf')
        shooter_team_id = None

        if not players_seq or not isinstance(players_seq, list) or len(players_seq) == 0:
            return 0

        first_frame = players_seq[0] if isinstance(players_seq[0], list) else players_seq
        if not first_frame or not isinstance(first_frame, list):
            return 0

        closest_player_distance = float('inf')
        for player in first_frame:
            if len(player) != 5:
                continue
            team_id, _, x, y, _ = player
            distance = np.sqrt((shooter_x - x)**2 + (shooter_y - y)**2)
            if distance < closest_player_distance:
                closest_player_distance = distance
                shooter_team_id = team_id

        if shooter_team_id is None:
            logging.warning(f"Could not determine shooter team for position ({shooter_x}, {shooter_y})")
            return 0

        for player in first_frame:
            if len(player) != 5:
                continue
            team_id, _, x, y, _ = player
            if team_id == shooter_team_id:
                continue
            distance = np.sqrt((shooter_x - x)**2 + (shooter_y - y)**2)
            min_distance = min(min_distance, distance)
        return min_distance if min_distance != float('inf') else 0

    df['defender_proximity'] = df.apply(
        lambda row: calculate_defender_proximity(row['shooter_x'], row['shooter_y'], row['players_seq']), axis=1
    )

    def extract_player_positions(players_seq, shooter_x, shooter_y):
        shooter_team_id = None
        teammates = []
        defenders = []

        if not players_seq or not isinstance(players_seq, list) or len(players_seq) == 0:
            return [0, 0, 0, 0]

        first_frame = players_seq[0] if isinstance(players_seq[0], list) else players_seq
        if not first_frame or not isinstance(first_frame, list):
            return [0, 0, 0, 0]

        closest_player_distance = float('inf')
        for player in first_frame:
            if len(player) != 5:
                continue
            team_id, _, x, y, _ = player
            distance = np.sqrt((shooter_x - x)**2 + (shooter_y - y)**2)
            if distance < closest_player_distance:
                closest_player_distance = distance
                shooter_team_id = team_id

        if shooter_team_id is None:
            logging.warning(f"Could not determine shooter team for position ({shooter_x}, {shooter_y})")
            return [0, 0, 0, 0]

        for player in first_frame:
            if len(player) != 5:
                continue
            team_id, _, x, y, _ = player
            if team_id == shooter_team_id:
                teammates.append([x, y])
            else:
                defenders.append([x, y])

        closest_teammate = [0, 0]
        if teammates:
            teammate_distances = [(np.sqrt((shooter_x - tx)**2 + (shooter_y - ty)**2), tx, ty) for tx, ty in teammates]
            teammate_distances.sort()
            closest_teammate = [teammate_distances[1][1], teammate_distances[1][2]] if len(teammate_distances) > 1 else teammate_distances[0][1:3]

        closest_defender = [0, 0]
        if defenders:
            defender_distances = [(np.sqrt((shooter_x - dx)**2 + (shooter_y - dy)**2), dx, dy) for dx, dy in defenders]
            defender_distances.sort()
            closest_defender = [defender_distances[0][1], defender_distances[0][2]]

        return closest_teammate + closest_defender

    df[['teammate_x', 'teammate_y', 'defender_x', 'defender_y']] = df.apply(
        lambda row: extract_player_positions(row['players_seq'], row['shooter_x'], row['shooter_y']), axis=1, result_type='expand'
    )

    def extract_ball_dynamics(ball_seq):
        if not ball_seq or len(ball_seq) < 3:
            return [0, 0, 0, 0, 0, 0]

        velocities = []
        for i in range(1, len(ball_seq)):
            if len(ball_seq[i]) != 3 or len(ball_seq[i-1]) != 3:
                continue
            dx = ball_seq[i][0] - ball_seq[i-1][0]
            dy = ball_seq[i][1] - ball_seq[i-1][1]
            dz = ball_seq[i][2] - ball_seq[i-1][2]
            velocities.append([dx, dy, dz])

        if not velocities:
            return [0, 0, 0, 0, 0, 0]

        accelerations = []
        for i in range(1, len(velocities)):
            ax = velocities[i][0] - velocities[i-1][0]
            ay = velocities[i][1] - velocities[i-1][1]
            az = velocities[i][2] - velocities[i-1][2]
            accelerations.append([ax, ay, az])

        if not accelerations:
            avg_vel_x = np.mean([v[0] for v in velocities])
            avg_vel_y = np.mean([v[1] for v in velocities])
            avg_vel_z = np.mean([v[2] for v in velocities])
            return [avg_vel_x, avg_vel_y, avg_vel_z, 0, 0, 0]

        avg_vel_x = np.mean([v[0] for v in velocities])
        avg_vel_y = np.mean([v[1] for v in velocities])
        avg_vel_z = np.mean([v[2] for v in velocities])
        avg_acc_x = np.mean([a[0] for a in accelerations])
        avg_acc_y = np.mean([a[1] for a in accelerations])
        avg_acc_z = np.mean([a[2] for a in accelerations])

        return [avg_vel_x, avg_vel_y, avg_vel_z, avg_acc_x, avg_acc_y, avg_acc_z]

    df[['avg_vel_x', 'avg_vel_y', 'avg_vel_z', 'avg_acc_x', 'avg_acc_y', 'avg_acc_z']] = df['ball_seq'].apply(
        lambda seq: extract_ball_dynamics(seq)
    ).apply(pd.Series)

    def extract_shot_arc(ball_seq, basket_x, basket_y):
        if not ball_seq or len(ball_seq) < 3:
            return [0, 0]

        try:
            release_x, release_y, release_z = ball_seq[0]
            peak_idx = max(range(len(ball_seq)), key=lambda i: ball_seq[i][2] if len(ball_seq[i]) == 3 else 0)
            peak_x, peak_y, peak_z = ball_seq[peak_idx]

            last_valid_idx = -1
            for i in range(len(ball_seq)-1, 0, -1):
                if len(ball_seq[i]) == 3:
                    last_valid_idx = i
                    break

            if last_valid_idx <= 0:
                return [0, 0]

            p2 = ball_seq[last_valid_idx]
            p1 = ball_seq[max(0, last_valid_idx-1)]

            if len(p1) != 3 or len(p2) != 3:
                return [0, 0]

            horizontal_dist = np.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)
            entry_angle = 90 if horizontal_dist == 0 else np.degrees(np.arctan2(p2[2] - p1[2], horizontal_dist))

            total_dist = np.sqrt((basket_x - release_x)**2 + (basket_y - release_y)**2)
            arc_height_ratio = 0 if total_dist == 0 else peak_z / total_dist

            return [entry_angle, arc_height_ratio]
        except Exception as e:
            logging.error(f"Error extracting shot arc: {e}")
            return [0, 0]

    df[['entry_angle', 'arc_height_ratio']] = df.apply(
        lambda row: extract_shot_arc(row['ball_seq'], row['basket_x'], row['basket_y']),
        axis=1, result_type='expand'
    )

    return df

# Sequence Processing Function (same as original)
def process_ball_sequences(df):
    """Process ball sequences for LSTM input."""
    sequences = []
    labels = []

    for _, row in df.iterrows():
        ball_seq = row['ball_seq']
        result = row['shotResult']

        if len(ball_seq) > 37:
            ball_seq = ball_seq[:37]
        elif len(ball_seq) < 37:
            last_pos = ball_seq[-1] if ball_seq else [0, 0, 0]
            while len(ball_seq) < 37:
                ball_seq.append(last_pos)

        sequence = []
        for moment in ball_seq:
            if isinstance(moment, list) and len(moment) == 3:
                sequence.append(moment)
            else:
                sequence.append([0, 0, 0])

        sequences.append(sequence)
        labels.append(result)

    return np.array(sequences), np.array(labels)

# PyTorch Dataset modified to handle ablation studies
class ShotDataset(Dataset):
    """Custom Dataset for basketball shot data with options for feature ablation."""
    def __init__(self, ball_sequences=None, static_features=None, labels=None):
        # For only-ball scenario, static_features might be None
        # For only-static scenario, ball_sequences might be None
        self.ball_sequences = None if ball_sequences is None else torch.tensor(ball_sequences, dtype=torch.float32)
        self.static_features = None if static_features is None else torch.tensor(static_features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

        # Placeholder for only-static scenario
        if ball_sequences is None:
            self.use_ball = False
        else:
            self.use_ball = True

        # Placeholder for only-ball scenario
        if static_features is None:
            self.use_static = False
        else:
            self.use_static = True

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.use_ball and self.use_static:
            return self.ball_sequences[idx], self.static_features[idx], self.labels[idx]
        elif self.use_ball:
            return self.ball_sequences[idx], self.labels[idx]
        else:  # use_static only
            return self.static_features[idx], self.labels[idx]

In [23]:
# Define PyTorch Models for different ablation scenarios
class BallOnlyPredictor(nn.Module):
    """Model using only ball trajectory features."""
    def __init__(self, ball_input_size=3, hidden_size=64):
        super(BallOnlyPredictor, self).__init__()
        self.lstm1 = nn.LSTM(ball_input_size, hidden_size, batch_first=True)
        self.dropout1 = nn.Dropout(0.2)
        self.lstm2 = nn.LSTM(hidden_size, hidden_size // 2, batch_first=True)
        self.dropout2 = nn.Dropout(0.2)

        self.fc = nn.Linear(hidden_size // 2, 32)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc_out = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, ball_seq):
        # Ball sequence processing
        lstm_out, _ = self.lstm1(ball_seq)
        lstm_out = self.dropout1(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        lstm_out = self.dropout2(lstm_out)
        lstm_out = lstm_out[:, -1, :]  # Take the last time step

        # Final layers
        out = self.fc(lstm_out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc_out(out)
        out = self.sigmoid(out)
        return out

class StaticOnlyPredictor(nn.Module):
    """Model using only static features."""
    def __init__(self, static_input_size=20):
        super(StaticOnlyPredictor, self).__init__()
        self.fc1 = nn.Linear(static_input_size, 64)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.2)

        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.2)

        self.fc_out = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, static_features):
        # Static features processing
        out = self.fc1(static_features)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.fc2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        out = self.fc_out(out)
        out = self.sigmoid(out)
        return out

class CombinedPredictor(nn.Module):
    """Original combined model as baseline."""
    def __init__(self, ball_input_size=3, static_input_size=20, hidden_size=64):
        super(CombinedPredictor, self).__init__()
        self.lstm1 = nn.LSTM(ball_input_size, hidden_size, batch_first=True)
        self.dropout1 = nn.Dropout(0.2)
        self.lstm2 = nn.LSTM(hidden_size, hidden_size // 2, batch_first=True)
        self.dropout2 = nn.Dropout(0.2)

        self.static_fc1 = nn.Linear(static_input_size, 32)
        self.static_relu = nn.ReLU()
        self.static_dropout = nn.Dropout(0.2)

        self.fc_combined = nn.Linear(hidden_size // 2 + 32, 32)
        self.relu_combined = nn.ReLU()
        self.dropout_combined = nn.Dropout(0.2)
        self.fc_out = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, ball_seq, static_features):
        # Ball sequence processing
        lstm_out, _ = self.lstm1(ball_seq)
        lstm_out = self.dropout1(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        lstm_out = self.dropout2(lstm_out)
        lstm_out = lstm_out[:, -1, :]  # Take the last time step

        # Static features processing
        static_out = self.static_fc1(static_features)
        static_out = self.static_relu(static_out)
        static_out = self.static_dropout(static_out)

        # Combine
        combined = torch.cat((lstm_out, static_out), dim=1)
        out = self.fc_combined(combined)
        out = self.relu_combined(out)
        out = self.dropout_combined(out)
        out = self.fc_out(out)
        out = self.sigmoid(out)
        return out

In [24]:
# Training function
def train_model(model_type, train_loader, val_loader, device, static_features, num_epochs=20):
    """Train and evaluate model based on specified type."""

    if model_type == "ball_only":
        model = BallOnlyPredictor().to(device)
    elif model_type == "static_only":
        model = StaticOnlyPredictor(static_input_size=len(static_features)).to(device)
    else:  # combined
        model = CombinedPredictor(ball_input_size=3, static_input_size=len(static_features)).to(device)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch in train_loader:
            if model_type == "combined":
                ball_seq, static_features, labels = batch
                ball_seq, static_features, labels = ball_seq.to(device), static_features.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(ball_seq, static_features).squeeze()

            elif model_type == "ball_only":
                ball_seq, labels = batch
                ball_seq, labels = ball_seq.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(ball_seq).squeeze()

            else:  # static_only
                static_features, labels = batch
                static_features, labels = static_features.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(static_features).squeeze()

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in val_loader:
                if model_type == "combined":
                    ball_seq, static_features, labels = batch
                    ball_seq, static_features, labels = ball_seq.to(device), static_features.to(device), labels.to(device)
                    outputs = model(ball_seq, static_features).squeeze()

                elif model_type == "ball_only":
                    ball_seq, labels = batch
                    ball_seq, labels = ball_seq.to(device), labels.to(device)
                    outputs = model(ball_seq).squeeze()

                else:  # static_only
                    static_features, labels = batch
                    static_features, labels = static_features.to(device), labels.to(device)
                    outputs = model(static_features).squeeze()

                loss = criterion(outputs, labels)
                val_loss += loss.item() * labels.size(0)
                predicted = (outputs > 0.5).float()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss = val_loss / total
        val_acc = correct / total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        logging.info(f'[{model_type}] Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Save model and metrics
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_filename = f"{model_type}_model_{timestamp}.pth"
    torch.save(model.state_dict(), os.path.join(model_path, model_filename))

    # Plot and save training metrics
    plt.figure(figsize=(10, 5))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title(f'{model_type.replace("_", " ").title()} Model: Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_path, f'{model_type}_accuracy_over_epochs.png'))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title(f'{model_type.replace("_", " ").title()} Model: Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_path, f'{model_type}_loss_over_epochs.png'))
    plt.close()

    return model, train_accuracies[-1], val_accuracies[-1]

In [25]:
from sklearn.metrics import precision_score, recall_score, f1_score
import logging
import torch

def evaluate_model(model, model_type, test_loader, device, static_features):
    """Evaluate model on test set with accuracy, precision, recall, and F1 score."""
    model.eval()
    correct = 0
    total = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in test_loader:
            if model_type == "combined":
                ball_seq, static_features, labels = batch
                ball_seq, static_features, labels = ball_seq.to(device), static_features.to(device), labels.to(device)
                outputs = model(ball_seq, static_features).squeeze()

            elif model_type == "ball_only":
                ball_seq, labels = batch
                ball_seq, labels = ball_seq.to(device), labels.to(device)
                outputs = model(ball_seq).squeeze()

            else:  # static_only
                static_features, labels = batch
                static_features, labels = static_features.to(device), labels.to(device)
                outputs = model(static_features).squeeze()

            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    test_acc = correct / total
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)

    logging.info(f'[{model_type}] Test Accuracy: {test_acc:.4f}')
    logging.info(f'[{model_type}] Precision: {precision:.4f}')
    logging.info(f'[{model_type}] Recall: {recall:.4f}')
    logging.info(f'[{model_type}] F1 Score: {f1:.4f}')

    return test_acc, precision, recall, f1


In [26]:
# # Main function to run the ablation study
# def run_ablation_study():
#     """Run the complete ablation study."""
#     # Load and preprocess data
#     print("Loading datasets...")
#     train_df = pd.read_csv(data_path + 'train.csv')
#     val_df = pd.read_csv(data_path + 'val.csv')
#     test_df = pd.read_csv(data_path + 'test.csv')

#     # Parse sequences
#     for df in [train_df, val_df, test_df]:
#         df['players_seq'] = df['players_seq'].apply(ast.literal_eval)
#         df['ball_seq'] = df['ball_seq'].apply(ast.literal_eval)

#     # Filter shots and encode labels
#     for df in [train_df, val_df, test_df]:
#         df.drop(df[~df['shotResult'].isin(['Made Shot', 'Missed Shot'])].index, inplace=True)
#         df['shotResult'] = df['shotResult'].map({'Made Shot': 1, 'Missed Shot': 0})

#     # Extract features
#     print("Extracting features...")
#     train_df = extract_sequence_features(train_df)
#     val_df = extract_sequence_features(val_df)
#     test_df = extract_sequence_features(test_df)

#     static_features = [
#         'shooter_x', 'shooter_y', 'release_angle', 'initial_height', 'max_height',
#         'traj_length', 'traj_curvature', 'defender_proximity',
#         'teammate_x', 'teammate_y', 'defender_x', 'defender_y',
#         'avg_vel_x', 'avg_vel_y', 'avg_vel_z', 'avg_acc_x', 'avg_acc_y', 'avg_acc_z',
#         'entry_angle', 'arc_height_ratio'
#     ]

#     # Clean up NaN or inf values
#     for df in [train_df, val_df, test_df]:
#         df[static_features] = df[static_features].replace([np.inf, -np.inf], 0).fillna(0)

#     # Process ball sequences
#     print("Processing ball sequences...")
#     X_train_ball, y_train = process_ball_sequences(train_df)
#     X_val_ball, y_val = process_ball_sequences(val_df)
#     X_test_ball, y_test = process_ball_sequences(test_df)

#     # Reshape for LSTM
#     X_train_ball = X_train_ball.reshape(X_train_ball.shape[0], 37, 3)
#     X_val_ball = X_val_ball.reshape(X_val_ball.shape[0], 37, 3)
#     X_test_ball = X_test_ball.reshape(X_test_ball.shape[0], 37, 3)

#     # Normalize ball sequences
#     scaler_ball = MinMaxScaler()
#     X_train_ball_2d = X_train_ball.reshape(-1, 3)
#     X_train_ball_2d = scaler_ball.fit_transform(X_train_ball_2d)
#     X_train_ball = X_train_ball_2d.reshape(X_train_ball.shape)

#     X_val_ball_2d = X_val_ball.reshape(-1, 3)
#     X_val_ball_2d = scaler_ball.transform(X_val_ball_2d)
#     X_val_ball = X_val_ball_2d.reshape(X_val_ball.shape)

#     X_test_ball_2d = X_test_ball.reshape(-1, 3)
#     X_test_ball_2d = scaler_ball.transform(X_test_ball_2d)
#     X_test_ball = X_test_ball_2d.reshape(X_test_ball.shape)

#     # Process static features
#     X_train_static = train_df[static_features].values
#     X_val_static = val_df[static_features].values
#     X_test_static = test_df[static_features].values

#     # Normalize static features
#     scaler_static = MinMaxScaler()
#     X_train_static = scaler_static.fit_transform(X_train_static)
#     X_val_static = scaler_static.transform(X_val_static)
#     X_test_static = scaler_static.transform(X_test_static)

#     # Save scalers for future use
#     with open(os.path.join(model_path, 'scaler_ball.pkl'), 'wb') as f:
#         pickle.dump(scaler_ball, f)
#     with open(os.path.join(model_path, 'scaler_static.pkl'), 'wb') as f:
#         pickle.dump(scaler_static, f)

#     # Create datasets for each ablation scenario
#     batch_size = 32
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#     # 1. Combined model (baseline)
#     print("\n=== Training Combined Model (Baseline) ===")
#     combined_train_dataset = ShotDataset(X_train_ball, X_train_static, y_train)
#     combined_val_dataset = ShotDataset(X_val_ball, X_val_static, y_val)
#     combined_test_dataset = ShotDataset(X_test_ball, X_test_static, y_test)

#     combined_train_loader = DataLoader(combined_train_dataset, batch_size=batch_size, shuffle=True)
#     combined_val_loader = DataLoader(combined_val_dataset, batch_size=batch_size, shuffle=False)
#     combined_test_loader = DataLoader(combined_test_dataset, batch_size=batch_size, shuffle=False)

#     combined_model, combined_train_acc, combined_val_acc = train_model(
#         "combined", combined_train_loader, combined_val_loader, device, static_features
#     )
#     combined_test_acc = evaluate_model(combined_model, "combined", combined_test_loader, device, static_features)

#     # 2. Ball-only model
#     print("\n=== Training Ball-Only Model ===")
#     ball_train_dataset = ShotDataset(ball_sequences=X_train_ball, static_features=None, labels=y_train)
#     ball_val_dataset = ShotDataset(ball_sequences=X_val_ball, static_features=None, labels=y_val)
#     ball_test_dataset = ShotDataset(ball_sequences=X_test_ball, static_features=None, labels=y_test)

#     ball_train_loader = DataLoader(ball_train_dataset, batch_size=batch_size, shuffle=True)
#     ball_val_loader = DataLoader(ball_val_dataset, batch_size=batch_size, shuffle=False)
#     ball_test_loader = DataLoader(ball_test_dataset, batch_size=batch_size, shuffle=False)

#     ball_model, ball_train_acc, ball_val_acc = train_model(
#         "ball_only", ball_train_loader, ball_val_loader, device, static_features
#     )
#     ball_test_acc = evaluate_model(ball_model, "ball_only", ball_test_loader, device, static_features)

#     # 3. Static-only model
#     print("\n=== Training Static-Only Model ===")
#     static_train_dataset = ShotDataset(ball_sequences=None, static_features=X_train_static, labels=y_train)
#     static_val_dataset = ShotDataset(ball_sequences=None, static_features=X_val_static, labels=y_val)
#     static_test_dataset = ShotDataset(ball_sequences=None, static_features=X_test_static, labels=y_test)

#     static_train_loader = DataLoader(static_train_dataset, batch_size=batch_size, shuffle=True)
#     static_val_loader = DataLoader(static_val_dataset, batch_size=batch_size, shuffle=False)
#     static_test_loader = DataLoader(static_test_dataset, batch_size=batch_size, shuffle=False)

#     static_model, static_train_acc, static_val_acc = train_model(
#         "static_only", static_train_loader, static_val_loader, device, static_features
#     )
#     static_test_acc = evaluate_model(static_model, "static_only", static_test_loader, device, static_features)

#     # Summary of results
#     results = {
#         "Combined Model": {
#             "Train Accuracy": combined_train_acc,
#             "Validation Accuracy": combined_val_acc,
#             "Test Accuracy": combined_test_acc
#         },
#         "Ball-Only Model": {
#             "Train Accuracy": ball_train_acc,
#             "Validation Accuracy": ball_val_acc,
#             "Test Accuracy": ball_test_acc
#         },
#         "Static-Only Model": {
#             "Train Accuracy": static_train_acc,
#             "Validation Accuracy": static_val_acc,
#             "Test Accuracy": static_test_acc
#         }
#     }

#     # Print summary table
#     print("\n===== ABLATION STUDY RESULTS =====")
#     print(f"{'Model':<20} {'Train Acc':<15} {'Val Acc':<15} {'Test Acc':<15}")
#     print("-" * 65)
#     for model_name, metrics in results.items():
#         print(f"{model_name:<20} {metrics['Train Accuracy']:<15.4f} {metrics['Validation Accuracy']:<15.4f} {metrics['Test Accuracy']:<15.4f}")

#     # Save results to file
#     with open(os.path.join(output_path, 'ablation_results.txt'), 'w') as f:
#         f.write("===== ABLATION STUDY RESULTS =====\n")
#         f.write(f"{'Model':<20} {'Train Acc':<15} {'Val Acc':<15} {'Test Acc':<15}\n")
#         f.write("-" * 65 + "\n")
#         for model_name, metrics in results.items():
#             f.write(f"{model_name:<20} {metrics['Train Accuracy']:<15.4f} {metrics['Validation Accuracy']:<15.4f} {metrics['Test Accuracy']:<15.4f}\n")

#     # Create comparison bar plot
#     plt.figure(figsize=(12, 8))
#     models = list(results.keys())
#     train_accs = [results[m]["Train Accuracy"] for m in models]
#     val_accs = [results[m]["Validation Accuracy"] for m in models]
#     test_accs = [results[m]["Test Accuracy"] for m in models]

#     x = np.arange(len(models))
#     width = 0.25

#     plt.bar(x - width, train_accs, width, label='Train Accuracy')
#     plt.bar(x, val_accs, width, label='Validation Accuracy')
#     plt.bar(x + width, test_accs, width, label='Test Accuracy')

#     plt.xlabel('Models')
#     plt.ylabel('Accuracy')
#     plt.title('Comparison of Model Performance Across Ablation Studies')
#     plt.xticks(x, models)
#     plt.legend()
#     plt.grid(axis='y', alpha=0.3)

#     plt.tight_layout()
#     plt.savefig(os.path.join(output_path, 'ablation_comparison.png'))
#     plt.close()

#     print(f"Results saved to {output_path}")
#     return results

In [27]:
# Main function to run the ablation study
def run_ablation_study():
    """Run the complete ablation study."""
    # Load and preprocess data
    print("Loading datasets...")
    train_df = pd.read_csv(data_path + 'train.csv')
    val_df = pd.read_csv(data_path + 'val.csv')
    test_df = pd.read_csv(data_path + 'test.csv')

    # Parse sequences
    for df in [train_df, val_df, test_df]:
        df['players_seq'] = df['players_seq'].apply(ast.literal_eval)
        df['ball_seq'] = df['ball_seq'].apply(ast.literal_eval)

    # Filter shots and encode labels
    for df in [train_df, val_df, test_df]:
        df.drop(df[~df['shotResult'].isin(['Made Shot', 'Missed Shot'])].index, inplace=True)
        df['shotResult'] = df['shotResult'].map({'Made Shot': 1, 'Missed Shot': 0})

    # Extract features
    print("Extracting features...")
    train_df = extract_sequence_features(train_df)
    val_df = extract_sequence_features(val_df)
    test_df = extract_sequence_features(test_df)

    # Define static features
    static_features = [
        'shooter_x', 'shooter_y', 'release_angle', 'initial_height', 'max_height',
        'traj_length', 'traj_curvature', 'defender_proximity',
        'teammate_x', 'teammate_y', 'defender_x', 'defender_y',
        'avg_vel_x', 'avg_vel_y', 'avg_vel_z', 'avg_acc_x', 'avg_acc_y', 'avg_acc_z',
        'entry_angle', 'arc_height_ratio'
    ]

    # Clean up NaN or inf values
    for df in [train_df, val_df, test_df]:
        df[static_features] = df[static_features].replace([np.inf, -np.inf], 0).fillna(0)

    # Process ball sequences
    print("Processing ball sequences...")
    X_train_ball, y_train = process_ball_sequences(train_df)
    X_val_ball, y_val = process_ball_sequences(val_df)
    X_test_ball, y_test = process_ball_sequences(test_df)

    # Reshape for LSTM
    X_train_ball = X_train_ball.reshape(X_train_ball.shape[0], 37, 3)
    X_val_ball = X_val_ball.reshape(X_val_ball.shape[0], 37, 3)
    X_test_ball = X_test_ball.reshape(X_test_ball.shape[0], 37, 3)

    # Normalize ball sequences
    scaler_ball = MinMaxScaler()
    X_train_ball_2d = X_train_ball.reshape(-1, 3)
    X_train_ball_2d = scaler_ball.fit_transform(X_train_ball_2d)
    X_train_ball = X_train_ball_2d.reshape(X_train_ball.shape)

    X_val_ball_2d = X_val_ball.reshape(-1, 3)
    X_val_ball_2d = scaler_ball.transform(X_val_ball_2d)
    X_val_ball = X_val_ball_2d.reshape(X_val_ball.shape)

    X_test_ball_2d = X_test_ball.reshape(-1, 3)
    X_test_ball_2d = scaler_ball.transform(X_test_ball_2d)
    X_test_ball = X_test_ball_2d.reshape(X_test_ball.shape)

    # Process static features
    X_train_static = train_df[static_features].values
    X_val_static = val_df[static_features].values
    X_test_static = test_df[static_features].values

    # Normalize static features
    scaler_static = MinMaxScaler()
    X_train_static = scaler_static.fit_transform(X_train_static)
    X_val_static = scaler_static.transform(X_val_static)
    X_test_static = scaler_static.transform(X_test_static)

    # Save scalers for future use
    with open(os.path.join(model_path, 'scaler_ball.pkl'), 'wb') as f:
        pickle.dump(scaler_ball, f)
    with open(os.path.join(model_path, 'scaler_static.pkl'), 'wb') as f:
        pickle.dump(scaler_static, f)

    # Create datasets for each ablation scenario
    batch_size = 32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1. Combined model (baseline)
    print("\n=== Training Combined Model (Baseline) ===")
    combined_train_dataset = ShotDataset(X_train_ball, X_train_static, y_train)
    combined_val_dataset = ShotDataset(X_val_ball, X_val_static, y_val)
    combined_test_dataset = ShotDataset(X_test_ball, X_test_static, y_test)

    combined_train_loader = DataLoader(combined_train_dataset, batch_size=batch_size, shuffle=True)
    combined_val_loader = DataLoader(combined_val_dataset, batch_size=batch_size, shuffle=False)
    combined_test_loader = DataLoader(combined_test_dataset, batch_size=batch_size, shuffle=False)

    combined_model, combined_train_acc, combined_val_acc = train_model(
        "combined", combined_train_loader, combined_val_loader, device, static_features
    )
    combined_test_metrics = evaluate_model(combined_model, "combined", combined_test_loader, device, static_features)
    combined_test_acc = combined_test_metrics[0]  # Extract test accuracy

    # 2. Ball-only model
    print("\n=== Training Ball-Only Model ===")
    ball_train_dataset = ShotDataset(ball_sequences=X_train_ball, static_features=None, labels=y_train)
    ball_val_dataset = ShotDataset(ball_sequences=X_val_ball, static_features=None, labels=y_val)
    ball_test_dataset = ShotDataset(ball_sequences=X_test_ball, static_features=None, labels=y_test)

    ball_train_loader = DataLoader(ball_train_dataset, batch_size=batch_size, shuffle=True)
    ball_val_loader = DataLoader(ball_val_dataset, batch_size=batch_size, shuffle=False)
    ball_test_loader = DataLoader(ball_test_dataset, batch_size=batch_size, shuffle=False)

    ball_model, ball_train_acc, ball_val_acc = train_model(
        "ball_only", ball_train_loader, ball_val_loader, device, static_features
    )
    ball_test_metrics = evaluate_model(ball_model, "ball_only", ball_test_loader, device, static_features)
    ball_test_acc = ball_test_metrics[0]  # Extract test accuracy

    # 3. Static-only model
    print("\n=== Training Static-Only Model ===")
    static_train_dataset = ShotDataset(ball_sequences=None, static_features=X_train_static, labels=y_train)
    static_val_dataset = ShotDataset(ball_sequences=None, static_features=X_val_static, labels=y_val)
    static_test_dataset = ShotDataset(ball_sequences=None, static_features=X_test_static, labels=y_test)

    static_train_loader = DataLoader(static_train_dataset, batch_size=batch_size, shuffle=True)
    static_val_loader = DataLoader(static_val_dataset, batch_size=batch_size, shuffle=False)
    static_test_loader = DataLoader(static_test_dataset, batch_size=batch_size, shuffle=False)

    static_model, static_train_acc, static_val_acc = train_model(
        "static_only", static_train_loader, static_val_loader, device, static_features
    )
    static_test_metrics = evaluate_model(static_model, "static_only", static_test_loader, device, static_features)
    static_test_acc = static_test_metrics[0]  # Extract test accuracy

    # Summary of results
    results = {
        "Combined Model": {
            "Train Accuracy": combined_train_acc,
            "Validation Accuracy": combined_val_acc,
            "Test Accuracy": combined_test_acc
        },
        "Ball-Only Model": {
            "Train Accuracy": ball_train_acc,
            "Validation Accuracy": ball_val_acc,
            "Test Accuracy": ball_test_acc
        },
        "Static-Only Model": {
            "Train Accuracy": static_train_acc,
            "Validation Accuracy": static_val_acc,
            "Test Accuracy": static_test_acc
        }
    }

    # Print summary table
    print("\n===== ABLATION STUDY RESULTS =====")
    print(f"{'Model':<20} {'Train Acc':<15} {'Val Acc':<15} {'Test Acc':<15}")
    print("-" * 65)
    for model_name, metrics in results.items():
        print(f"{model_name:<20} {metrics['Train Accuracy']:<15.4f} {metrics['Validation Accuracy']:<15.4f} {metrics['Test Accuracy']:<15.4f}")

    # Save results to file
    with open(os.path.join(output_path, 'ablation_results.txt'), 'w') as f:
        f.write("===== ABLATION STUDY RESULTS =====\n")
        f.write(f"{'Model':<20} {'Train Acc':<15} {'Val Acc':<15} {'Test Acc':<15}\n")
        f.write("-" * 65 + "\n")
        for model_name, metrics in results.items():
            f.write(f"{model_name:<20} {metrics['Train Accuracy']:<15.4f} {metrics['Validation Accuracy']:<15.4f} {metrics['Test Accuracy']:<15.4f}\n")

    # Create comparison bar plot
    plt.figure(figsize=(12, 8))
    models = list(results.keys())
    train_accs = [results[m]["Train Accuracy"] for m in models]
    val_accs = [results[m]["Validation Accuracy"] for m in models]
    test_accs = [results[m]["Test Accuracy"] for m in models]

    x = np.arange(len(models))
    width = 0.25

    plt.bar(x - width, train_accs, width, label='Train Accuracy')
    plt.bar(x, val_accs, width, label='Validation Accuracy')
    plt.bar(x + width, test_accs, width, label='Test Accuracy')

    plt.xlabel('Models')
    plt.ylabel('Accuracy')
    plt.title('Comparison of Model Performance Across Ablation Studies')
    plt.xticks(x, models)
    plt.legend()
    plt.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(output_path, 'ablation_comparison.png'))
    plt.close()

    print(f"Results saved to {output_path}")
    return results

In [28]:
if __name__ == "__main__":
  run_ablation_study()

Loading datasets...
Extracting features...
Processing ball sequences...

=== Training Combined Model (Baseline) ===

=== Training Ball-Only Model ===

=== Training Static-Only Model ===

===== ABLATION STUDY RESULTS =====
Model                Train Acc       Val Acc         Test Acc       
-----------------------------------------------------------------
Combined Model       0.8132          0.8272          0.8217         
Ball-Only Model      0.7914          0.8281          0.7840         
Static-Only Model    0.7728          0.7658          0.7623         
Results saved to /content/drive/MyDrive/smai_project/output_ablation_study/
