In [16]:
import kagglehub
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from torch import nn, optim
import torch.nn.functional as F
from collections import Counter

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class SVM(nn.Module):
    def __init__(self, input_dim):
        super(SVM, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)

class BasketballMomentumSVM:
    def __init__(self, num_previous_plays=5, momentum_threshold=6, learning_rate=0.001,
                 epochs=150, reg_param=0.01, batch_size=64, class_weight=10.0,
                 verbose=True):
        """
        Initialize the Basketball Momentum SVM model with PyTorch

        Parameters:
        -----------
        num_previous_plays : int
            Number of previous plays to consider for prediction
        momentum_threshold : int
            Number of consecutive points to define a momentum shift
        learning_rate : float
            Learning rate for optimizer
        epochs : int
            Number of training epochs
        reg_param : float
            Regularization parameter
        batch_size : int
            Batch size for training
        class_weight : float
            Weight for the positive class (momentum shift starts)
        verbose : bool
            Whether to print detailed debug information
        """
        self.num_previous_plays = num_previous_plays
        self.momentum_threshold = momentum_threshold
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.reg_param = reg_param
        self.batch_size = batch_size
        self.class_weight = class_weight
        self.scaler = StandardScaler()
        self.input_dim = None
        self.model = None
        self.verbose = verbose

    def _safe_numeric_convert(self, value, default=0.0):
        """Safely convert a value to float"""
        try:
            return float(value)
        except (ValueError, TypeError):
            return default

    def preprocess_data(self, data):
        """
        Preprocess the input data to ensure numeric types
        """
        # Columns that should be numeric
        numeric_columns = [
            'play_id', 'half', 'time_remaining_half', 'secs_remaining',
            'secs_remaining_absolute', 'home_score', 'away_score', 'score_diff',
            'play_length', 'scoring_play', 'foul', 'win_prob', 'naive_win_prob',
            'home_time_out_remaining', 'away_time_out_remaining', 'home_favored_by',
            'total_line', 'attendance', 'three_pt', 'free_throw'
        ]

        # Create a copy to avoid modifying the original
        processed_data = data.copy()

        # Convert each column to numeric
        for col in numeric_columns:
            if col in processed_data.columns:
                processed_data[col] = pd.to_numeric(processed_data[col], errors='coerce')
                processed_data[col] = processed_data[col].fillna(0)

        return processed_data

    def find_momentum_shifts(self, df):
        """
        Find momentum shifts using the user's exact algorithm

        Parameters:
        -----------
        df : DataFrame
            Full dataset with all plays

        Returns:
        --------
        df : DataFrame
            Updated dataframe with 'run_label' column
        runs_df : DataFrame
            DataFrame containing all identified runs
        """
        # Sort data by game and play ID
        df = df.sort_values(['game_id', 'play_id'], ascending=[True, True]).reset_index(drop=True)
        runs = []

        for game_id, game_df in df.groupby('game_id'):
            home = game_df.iloc[0]['home']
            away = game_df.iloc[0]['away']

            # Track cumulative points scored
            home_score = 0
            away_score = 0

            # Track potential runs
            run_points = {home: 0, away: 0}
            run_start_index = None
            run_start_time = None

            for idx, row in game_df.iterrows():
                new_home_score = self._safe_numeric_convert(row['home_score'])
                new_away_score = self._safe_numeric_convert(row['away_score'])

                # Determine who scored (if anyone)
                if new_home_score != home_score or new_away_score != away_score:
                    # Someone scored
                    scoring_team = home if new_home_score != home_score else away
                    points_scored = (new_home_score - home_score) if scoring_team == home else (new_away_score - away_score)

                    if run_points[scoring_team] == 0:
                        # First points in the run
                        run_start_index = idx
                        run_start_time = row['time_remaining_half']

                    run_points[scoring_team] += points_scored
                    # Reset the opponent's run points if they score
                    opponent = away if scoring_team == home else home
                    run_points[opponent] = 0

                    # Check if run is 6+ to 0
                    if run_points[scoring_team] >= self.momentum_threshold and run_points[opponent] == 0:
                        runs.append({
                            'game_id': game_id,
                            'team': scoring_team,
                            'start_play_id': game_df.loc[run_start_index]['play_id'],
                            'start_index': run_start_index,  # Store the DataFrame index too
                            'start_time': run_start_time,
                            'end_play_id': row['play_id'],
                            'end_index': idx,  # Store the DataFrame index too
                            'end_time': row['time_remaining_half'],
                            'points_scored': run_points[scoring_team]
                        })

                        # After finding a run, reset
                        run_points[scoring_team] = 0
                        run_start_index = None
                        run_start_time = None

                home_score = new_home_score
                away_score = new_away_score

        # Turn the runs list into a DataFrame
        runs_df = pd.DataFrame(runs)

        # Add the run_label column to the original DataFrame if it doesn't exist
        if 'run_label' not in df.columns:
            df['run_label'] = 0

        # Loop through each scoring run
        for _, run in runs_df.iterrows():
            game_id = run['game_id']
            start_play_id = run['start_play_id']
            end_play_id = run['end_play_id']

            # Apply mask: match game_id and play_id in range
            mask = (
                (df['game_id'] == game_id) &
                (df['play_id'] >= start_play_id) &
                (df['play_id'] <= end_play_id)
            )

            # Apply the label
            df.loc[mask, 'run_label'] = 1

        if self.verbose:
            print(f"Found {len(runs_df)} momentum shifts (runs of {self.momentum_threshold}+ points)")
            print(f"Total plays in runs: {df['run_label'].sum()}")
            print(f"Percentage of plays in runs: {df['run_label'].mean()*100:.2f}%")

        return df, runs_df

    def find_momentum_shift_starts(self, df, runs_df):
        """
        Identify the starting plays of momentum shifts based on runs_df

        Parameters:
        -----------
        df : DataFrame
            Full dataset with all plays
        runs_df : DataFrame
            DataFrame containing all identified runs

        Returns:
        --------
        df : DataFrame
            Updated dataframe with 'momentum_start' column
        """
        # Add the momentum_start column to the original DataFrame
        df['momentum_start'] = 0

        # Loop through each scoring run and mark only the starting play
        for _, run in runs_df.iterrows():
            game_id = run['game_id']
            start_play_id = run['start_play_id']

            # Apply mask: match game_id and exact start play_id
            mask = (
                (df['game_id'] == game_id) &
                (df['play_id'] == start_play_id)
            )

            # Apply the label to the starting play only
            df.loc[mask, 'momentum_start'] = 1

        if self.verbose:
            print(f"Identified {df['momentum_start'].sum()} momentum shift starting plays")
            print(f"Percentage of momentum shift starts: {df['momentum_start'].mean()*100:.2f}%")

        return df

    def prepare_data(self, game_data):
        """Prepare the data for training/prediction"""
        # First, preprocess the data
        data = self.preprocess_data(game_data)

        # Find momentum shifts
        data, runs_df = self.find_momentum_shifts(data)

        # Identify momentum shift starts
        data = self.find_momentum_shift_starts(data, runs_df)

        # Extract features from previous plays
        X = []
        y = []

        # Keep track of which plays we're using for training/prediction
        play_indices = []

        # Iterate through each game in the dataset
        for game_id in data['game_id'].unique():
            game_plays = data[data['game_id'] == game_id].sort_values(by=['play_id']).reset_index(drop=True)

            # Create sliding windows of plays
            for i in range(self.num_previous_plays, len(game_plays)):
                try:
                    # Get previous plays as features
                    prev_plays = game_plays.iloc[i-self.num_previous_plays:i]
                    current_play = game_plays.iloc[i]

                    # Extract relevant features from previous plays
                    features = self._extract_features(prev_plays, current_play)

                    # Check if this play starts a momentum shift
                    is_momentum_start = current_play['momentum_start']

                    X.append(features)
                    y.append(is_momentum_start)

                    # Store the original index of this play
                    play_indices.append(current_play.name)

                except Exception as e:
                    # Skip problematic plays but print error for debugging
                    if self.verbose:
                        print(f"Error processing play in game {game_id}, index {i}: {e}")
                    continue

        # Convert to numpy arrays
        X_np = np.array(X)
        y_np = np.array(y)

        if len(y_np) == 0:
            raise ValueError("No valid plays were processed. Check your data.")

        print(f"Total plays processed: {len(X_np)}")
        print(f"Momentum shift starts detected: {sum(y_np)} out of {len(y_np)} plays ({sum(y_np)/len(y_np)*100:.2f}%)")

        # Print class distribution
        print("Class distribution:", Counter(y_np))

        # Scale features
        X_scaled = self.scaler.fit_transform(X_np)

        # Set input dimension for the model
        self.input_dim = X_scaled.shape[1]

        # Convert to PyTorch tensors and move to GPU
        X_tensor = torch.FloatTensor(X_scaled).to(device)
        y_tensor = torch.FloatTensor(y_np).view(-1, 1).to(device)

        return X_tensor, y_tensor, play_indices

    def _extract_features(self, prev_plays, current_play):
        """Extract relevant features from a sequence of plays"""
        features = []

        # Home team and away team
        home_team = current_play['home']
        away_team = current_play['away']

        # Game Context Features
        # 1. Score differential and raw scores
        score_diff = self._safe_numeric_convert(current_play['score_diff'])
        home_score = self._safe_numeric_convert(current_play['home_score'])
        away_score = self._safe_numeric_convert(current_play['away_score'])
        features.append(score_diff)
        features.append(home_score)
        features.append(away_score)

        # 2. Score differential trend (change over last few plays)
        if len(prev_plays) >= 3:
            score_diff_3plays_ago = self._safe_numeric_convert(prev_plays.iloc[-3]['score_diff'])
            score_diff_trend = score_diff - score_diff_3plays_ago
            features.append(score_diff_trend)
        else:
            features.append(0)

        # 3. Time remaining (normalized by half)
        time_remaining = self._safe_numeric_convert(current_play['time_remaining_half'])
        features.append(time_remaining / 1200.0)  # 20 minutes per half

        # 4. Half
        features.append(self._safe_numeric_convert(current_play['half']))

        # 5. Point spread (home_favored_by)
        features.append(self._safe_numeric_convert(current_play['home_favored_by']))

        # 6. Win probability (if available)
        features.append(self._safe_numeric_convert(current_play['win_prob']))
        features.append(self._safe_numeric_convert(current_play['naive_win_prob']))

        # Momentum Features
        # 7. Recent scoring runs
        home_recent_points = 0
        away_recent_points = 0
        consecutive_home_points = 0
        consecutive_away_points = 0

        last_team = None

        for i, play in prev_plays.iterrows():
            if self._safe_numeric_convert(play['scoring_play']) == 1:
                points = 0
                if self._safe_numeric_convert(play['free_throw']) == 1:
                    points = 1
                elif self._safe_numeric_convert(play['three_pt']) == 1:
                    points = 3
                else:
                    points = 2

                if play['action_team'] == home_team:
                    home_recent_points += points
                    if last_team == home_team:
                        consecutive_home_points += points
                    else:
                        consecutive_home_points = points
                    consecutive_away_points = 0
                    last_team = home_team
                elif play['action_team'] == away_team:
                    away_recent_points += points
                    if last_team == away_team:
                        consecutive_away_points += points
                    else:
                        consecutive_away_points = points
                    consecutive_home_points = 0
                    last_team = away_team

        features.append(home_recent_points)
        features.append(away_recent_points)
        features.append(consecutive_home_points)
        features.append(consecutive_away_points)

        # 8. Which team has momentum currently
        # Calculate momentum based on recent scoring
        if consecutive_home_points > consecutive_away_points:
            momentum_team = 1  # Home team has momentum
        elif consecutive_away_points > consecutive_home_points:
            momentum_team = -1  # Away team has momentum
        else:
            momentum_team = 0  # Neither team has clear momentum
        features.append(momentum_team)

        # 9. Recent scoring frequency and efficiency
        home_scoring_plays = 0
        away_scoring_plays = 0
        home_shot_attempts = 0
        away_shot_attempts = 0

        for _, play in prev_plays.iterrows():
            # Count scoring plays
            if self._safe_numeric_convert(play['scoring_play']) == 1:
                if play['action_team'] == home_team:
                    home_scoring_plays += 1
                elif play['action_team'] == away_team:
                    away_scoring_plays += 1

            # Count shot attempts
            if play.get('shot_team') == home_team:
                home_shot_attempts += 1
            elif play.get('shot_team') == away_team:
                away_shot_attempts += 1

        features.append(home_scoring_plays / max(1, len(prev_plays)))
        features.append(away_scoring_plays / max(1, len(prev_plays)))
        features.append(home_scoring_plays / max(1, home_shot_attempts) if home_shot_attempts > 0 else 0)
        features.append(away_scoring_plays / max(1, away_shot_attempts) if away_shot_attempts > 0 else 0)

        # Shot Type Features
        # 10. Recent three-point percentage
        home_three_attempts = 0
        home_three_makes = 0
        away_three_attempts = 0
        away_three_makes = 0

        for _, play in prev_plays.iterrows():
            if play.get('shot_team') == home_team and self._safe_numeric_convert(play['three_pt']) == 1:
                home_three_attempts += 1
                if play.get('shot_outcome') == 'made':
                    home_three_makes += 1
            elif play.get('shot_team') == away_team and self._safe_numeric_convert(play['three_pt']) == 1:
                away_three_attempts += 1
                if play.get('shot_outcome') == 'made':
                    away_three_makes += 1

        features.append(home_three_makes / max(1, home_three_attempts))
        features.append(away_three_makes / max(1, away_three_attempts))

        # 11. Recent free throw percentage
        home_ft_attempts = 0
        home_ft_makes = 0
        away_ft_attempts = 0
        away_ft_makes = 0

        for _, play in prev_plays.iterrows():
            if play.get('shot_team') == home_team and self._safe_numeric_convert(play['free_throw']) == 1:
                home_ft_attempts += 1
                if play.get('shot_outcome') == 'made':
                    home_ft_makes += 1
            elif play.get('shot_team') == away_team and self._safe_numeric_convert(play['free_throw']) == 1:
                away_ft_attempts += 1
                if play.get('shot_outcome') == 'made':
                    away_ft_makes += 1

        features.append(home_ft_makes / max(1, home_ft_attempts))
        features.append(away_ft_makes / max(1, away_ft_attempts))

        # 12. Recent two-point percentage
        home_two_attempts = 0
        home_two_makes = 0
        away_two_attempts = 0
        away_two_makes = 0

        for _, play in prev_plays.iterrows():
            if (play.get('shot_team') == home_team and
                self._safe_numeric_convert(play['three_pt']) == 0 and
                self._safe_numeric_convert(play['free_throw']) == 0):
                home_two_attempts += 1
                if play.get('shot_outcome') == 'made':
                    home_two_makes += 1
            elif (play.get('shot_team') == away_team and
                  self._safe_numeric_convert(play['three_pt']) == 0 and
                  self._safe_numeric_convert(play['free_throw']) == 0):
                away_two_attempts += 1
                if play.get('shot_outcome') == 'made':
                    away_two_makes += 1

        features.append(home_two_makes / max(1, home_two_attempts))
        features.append(away_two_makes / max(1, away_two_attempts))

        # Foul Features
        # 13. Recent foul trouble
        home_fouls = sum(1 for _, play in prev_plays.iterrows()
                     if self._safe_numeric_convert(play['foul']) == 1
                     and play['action_team'] == home_team)
        away_fouls = sum(1 for _, play in prev_plays.iterrows()
                     if self._safe_numeric_convert(play['foul']) == 1
                     and play['action_team'] == away_team)

        features.append(home_fouls)
        features.append(away_fouls)

        # Timeout Features
        # 14. Timeout usage
        features.append(self._safe_numeric_convert(current_play['home_time_out_remaining']))
        features.append(self._safe_numeric_convert(current_play['away_time_out_remaining']))

        # Current play type
        # 15. Current play attributes
        features.append(1 if self._safe_numeric_convert(current_play['scoring_play']) == 1 else 0)
        features.append(1 if self._safe_numeric_convert(current_play['foul']) == 1 else 0)
        features.append(1 if self._safe_numeric_convert(current_play['three_pt']) == 1 else 0)
        features.append(1 if self._safe_numeric_convert(current_play['free_throw']) == 1 else 0)

        # 16. Who has possession
        features.append(1 if current_play.get('possession_before') == home_team else 0)

        return features

    def train(self, X, y):
        """Train the SVM model"""
        # Calculate class weights based on distribution
        if self.class_weight > 0:
            pos_weight = torch.tensor(self.class_weight).to(device)
        else:
            pos_weight = torch.tensor(1.0).to(device)

        # Initialize the model
        self.model = SVM(self.input_dim).to(device)

        # Define loss function with class weighting
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.reg_param)

        # Create data loader for batch training
        dataset = torch.utils.data.TensorDataset(X, y)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        # Training loop
        for epoch in range(self.epochs):
            running_loss = 0.0

            for batch_X, batch_y in data_loader:
                # Forward pass
                outputs = self.model(batch_X)
                loss = criterion(outputs, batch_y)

                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            # Print epoch statistics
            avg_loss = running_loss/len(data_loader)

            if (epoch+1) % 10 == 0 or epoch == 0:
                print(f'Epoch [{epoch+1}/{self.epochs}], Loss: {avg_loss:.4f}')

                # Calculate and print current training metrics
                if self.verbose and (epoch+1) % 30 == 0:
                    with torch.no_grad():
                        train_outputs = self.model(X)
                        train_predictions = (torch.sigmoid(train_outputs) >= 0.5).float()
                        train_accuracy = (train_predictions == y).float().mean()
                        print(f'Training Accuracy: {train_accuracy:.4f}')

                        # Count predictions by class
                        n_positive = train_predictions.sum().item()
                        n_total = train_predictions.size(0)
                        print(f'Predicting {n_positive} positive out of {n_total} ({n_positive/n_total*100:.2f}%)')

        print("Training complete!")

    def predict(self, X):
        """Make predictions"""
        # Ensure X is properly formatted
        if isinstance(X, np.ndarray):
            X = self.scaler.transform(X)
            X = torch.FloatTensor(X).to(device)

        # Set model to evaluation mode
        self.model.eval()

        # Get predictions
        with torch.no_grad():
            outputs = self.model(X)
            predicted = torch.sigmoid(outputs)
            predicted = (predicted >= 0.5).float()

        return predicted.cpu().numpy()

    def predict_proba(self, X):
        """Predict probability estimates"""
        # Ensure X is properly formatted
        if isinstance(X, np.ndarray):
            X = self.scaler.transform(X)
            X = torch.FloatTensor(X).to(device)

        # Set model to evaluation mode
        self.model.eval()

        # Get probabilities
        with torch.no_grad():
            outputs = self.model(X)
            probs = torch.sigmoid(outputs)

        return probs.cpu().numpy()

    def evaluate(self, X_test, y_test):
        """Evaluate the model"""
        # Ensure inputs are properly formatted
        if isinstance(X_test, np.ndarray):
            X_test = self.scaler.transform(X_test)
            X_test = torch.FloatTensor(X_test).to(device)

        if isinstance(y_test, np.ndarray):
            y_test = torch.FloatTensor(y_test).view(-1, 1).to(device)

        # Get predictions
        y_pred = self.predict(X_test)

        # Convert to numpy for sklearn metrics
        if isinstance(y_test, torch.Tensor):
            y_test_np = y_test.cpu().numpy().flatten()
        else:
            y_test_np = y_test.flatten()

        # Calculate metrics
        conf_matrix = confusion_matrix(y_test_np, y_pred.flatten())
        class_report = classification_report(y_test_np, y_pred.flatten(), output_dict=True)

        return {
            'confusion_matrix': conf_matrix,
            'classification_report': class_report
        }


def main():
    """Main function to demonstrate the basketball momentum SVM model"""
    # Load your CSV data
    path = kagglehub.dataset_download("robbypeery/college-basketball-pbp-23-24")
    # csv = '/ALLTourneyPBP2324.csv'
    csv = '/Colorado_pbp.csv'
    print("Loading data...")
    data = pd.read_csv(path + csv)

    print("Initializing model...")
    model = BasketballMomentumSVM(
        num_previous_plays=5,
        momentum_threshold=6,  # Using exactly 6 points as in your code
        learning_rate=0.001,
        epochs=150,
        batch_size=64,
        class_weight=10.0,
        verbose=True
    )

    print("Preparing data with momentum shift labels...")
    X, y, play_indices = model.prepare_data(data)

    print(f"Data shape: X = {X.shape}, y = {y.shape}")

    # Split into train and test sets
    X_np = X.cpu().numpy()
    y_np = y.cpu().numpy().flatten()
    X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(X_np, y_np, test_size=0.2, random_state=42, stratify=y_np)

    # Convert back to tensors
    X_train = torch.FloatTensor(X_train_np).to(device)
    y_train = torch.FloatTensor(y_train_np).view(-1, 1).to(device)
    X_test = torch.FloatTensor(X_test_np).to(device)
    y_test = torch.FloatTensor(y_test_np).view(-1, 1).to(device)

    print("Training model to predict momentum shift starts...")
    model.train(X_train, y_train)

    print("Evaluating model...")
    metrics = model.evaluate(X_test, y_test)

    print("\nConfusion Matrix:")
    print(metrics['confusion_matrix'])

    print("\nClassification Report:")
    for label, values in metrics['classification_report'].items():
        if not isinstance(values, dict):
            continue
        print(f"Class {label}:")
        print(f"  Precision: {values['precision']:.3f}")
        print(f"  Recall: {values['recall']:.3f}")
        print(f"  F1-Score: {values['f1-score']:.3f}")

    print("\nModel training complete!")
    print("This model now predicts whether a play will START a momentum shift!")

if __name__ == "__main__":
    main()

Using device: cuda
Loading data...
Initializing model...
Preparing data with momentum shift labels...
Found 201 momentum shifts (runs of 6+ points)
Total plays in runs: 2499
Percentage of plays in runs: 23.02%
Identified 201 momentum shift starting plays
Percentage of momentum shift starts: 1.85%
Total plays processed: 10688
Momentum shift starts detected: 194 out of 10688 plays (1.82%)
Class distribution: Counter({np.int64(0): 10494, np.int64(1): 194})
Data shape: X = torch.Size([10688, 33]), y = torch.Size([10688, 1])
Training model to predict momentum shift starts...
Epoch [1/150], Loss: 0.5202
Epoch [10/150], Loss: 0.2542
Epoch [20/150], Loss: 0.2455
Epoch [30/150], Loss: 0.2445
Training Accuracy: 0.9497
Predicting 469.0 positive out of 8550 (5.49%)
Epoch [40/150], Loss: 0.2387
Epoch [50/150], Loss: 0.2387
Epoch [60/150], Loss: 0.2366
Training Accuracy: 0.9309
Predicting 662.0 positive out of 8550 (7.74%)
Epoch [70/150], Loss: 0.2357
Epoch [80/150], Loss: 0.2389
Epoch [90/150], Los