# Basketball OOB Detection

In [2]:
"""
OOB Detection with LSTM + MobileNet
4-Fold Cross-Validation with Jersey Color Metadata
"""

import os
import torch
import torch.nn as nn
import cv2
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import random
from torchvision import transforms, models
import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 1
# Set all random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)

print("OOB Detection - 4-Fold Cross Validation")
print(f"Seed: {RANDOM_SEED}")

# Jersey color mapping for each clip
JERSEY_COLORS = {
    'clip 1': 'white',
    'clip 2': 'white',
    'clip 3': 'blue',
    'clip 4': 'blue',
    'clip 5': 'blue',
    'clip 6': 'blue',
    'clip 7': 'blue',
    'clip 8': 'blue',
    'clip 9': 'blue',
    'clip 10': 'black',
    'clip 11': 'black',
    'clip 12': 'black',
    'clip 13': 'black',
    'clip 14': 'blue',
    'clip 15': 'blue',
    'clip 16': 'blue',
    'clip 17': 'blue',
    'clip 18': 'white',
    'clip 19': 'black',
    'clip 20': 'black'
}

# Color to one-hot encoding
COLOR_ENCODING = {
    'white': [1, 0, 0],
    'blue': [0, 1, 0],
    'black': [0, 0, 1]
}

# Dataset with Jersey Color
class OOBDatasetWithColor(Dataset):
    """Dataset with jersey color metadata"""

    def __init__(self, video_paths, labels, jersey_colors, is_train=True, num_frames=8, augment=True):
        self.video_paths = video_paths
        self.labels = labels
        self.jersey_colors = jersey_colors
        self.num_frames = num_frames
        self.is_train = is_train
        self.augment = augment and is_train

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5 if self.augment else 0),
            transforms.ColorJitter(brightness=0.2, contrast=0.2) if self.augment else transforms.Lambda(lambda x: x),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.video_paths) * (3 if self.augment else 1)

    def __getitem__(self, idx):
        actual_idx = idx % len(self.video_paths)
        aug_version = idx // len(self.video_paths)

        video_path = self.video_paths[actual_idx]
        label = self.labels[actual_idx]
        jersey_color = self.jersey_colors[actual_idx]

        frames = self.extract_frames(video_path, aug_version, seed=idx)
        color_vec = torch.tensor(COLOR_ENCODING[jersey_color], dtype=torch.float32)

        return {
            'frames': frames,
            'label': torch.tensor(label, dtype=torch.long),
            'color': color_vec,
            'video_name': Path(video_path).name
        }

    def extract_frames(self, video_path, aug_version=0, seed=0):
        """Extract frames with temporal augmentation"""
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        if total_frames < self.num_frames:
            indices = np.array([i % total_frames for i in range(self.num_frames)])
        else:
            if self.augment and aug_version > 0:
                if aug_version == 1:
                    # Focus on first 2/3 of video
                    max_frame = int(total_frames * 0.66)
                    indices = np.linspace(0, max_frame, self.num_frames, dtype=int)
                else:  # aug_version == 2
                    # Focus on last 2/3 of video
                    min_frame = int(total_frames * 0.33)
                    indices = np.linspace(min_frame, total_frames-1, self.num_frames, dtype=int)
            else:
                # Standard: evenly spaced across entire video
                indices = np.linspace(0, total_frames-1, self.num_frames, dtype=int)

        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.transform(frame)
                frames.append(frame)

        cap.release()
        frames = torch.stack(frames) if frames else torch.zeros(self.num_frames, 3, 224, 224)
        return frames

# Model with Jersey Color Input
class OOBModelWithColor(nn.Module):
    """Model that uses both video frames and jersey color metadata"""

    def __init__(self, num_classes=2):
        super().__init__()

        mobilenet = models.mobilenet_v2(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(mobilenet.children())[:-1])

        for param in list(self.feature_extractor.parameters())[:-10]:
            param.requires_grad = False

        self.lstm = nn.LSTM(
            input_size=1280,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        self.color_embedding = nn.Sequential(
            nn.Linear(3, 16),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(256 + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, frames, color):
        batch_size, time_steps = frames.size(0), frames.size(1)

        x = frames.view(-1, frames.size(2), frames.size(3), frames.size(4))
        features = self.feature_extractor(x)
        features = features.mean(dim=[2, 3])
        features = features.view(batch_size, time_steps, -1)

        lstm_out, _ = self.lstm(features)
        final_features = lstm_out[:, -1, :]

        color_features = self.color_embedding(color)

        combined = torch.cat([final_features, color_features], dim=1)
        output = self.classifier(combined)
        return output

# Load and Prepare Data
def prepare_data():
    """Load all videos and prepare for cross-validation"""
    video_dir = Path('/content/')

    video_paths = []
    labels = []
    jersey_colors = []

    for video_file in sorted(video_dir.glob("clip*.mp4")):
        filename = video_file.stem
        clip_num = filename.split(' - ')[0].lower()

        if "away" in filename.lower():
            label = 0
        elif "home" in filename.lower():
            label = 1
        else:
            continue

        jersey_color = JERSEY_COLORS.get(clip_num, 'blue')

        video_paths.append(str(video_file))
        labels.append(label)
        jersey_colors.append(jersey_color)

    print(f"Found {len(video_paths)} videos")
    print(f"  Home (1): {sum(labels)}, Away (0): {len(labels) - sum(labels)}")
    print(f"  White jerseys: {jersey_colors.count('white')}")
    print(f"  Blue jerseys: {jersey_colors.count('blue')}")
    print(f"  Black jerseys: {jersey_colors.count('black')}")

    return video_paths, labels, jersey_colors

# Training Function
def train_fold(train_loader, val_loader, fold_num, epochs=5):
    """Train one fold"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = OOBModelWithColor().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0
    best_predictions = []

    print(f"\n{'-'*60}")
    print(f"Training Fold {fold_num}")
    print(f"{'-'*60}")

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            frames = batch['frames'].to(device)
            labels = batch['label'].to(device)
            colors = batch['color'].to(device)

            optimizer.zero_grad()
            outputs = model(frames, colors)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_predictions = []

        with torch.no_grad():
            for batch in val_loader:
                frames = batch['frames'].to(device)
                labels = batch['label'].to(device)
                colors = batch['color'].to(device)

                outputs = model(frames, colors)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

                for i in range(len(predicted)):
                    val_predictions.append({
                        'video': batch['video_name'][i],
                        'pred': predicted[i].item(),
                        'true': labels[i].item(),
                        'correct': predicted[i].item() == labels[i].item()
                    })

        train_acc = 100. * train_correct / train_total
        val_acc = 100. * val_correct / val_total
        avg_loss = train_loss / len(train_loader)

        print(f"Epoch {epoch+1}/{epochs}: Train Loss={avg_loss:.3f}, Train Acc={train_acc:.1f}%, Val Acc={val_acc:.1f}%")

        scheduler.step(avg_loss)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_predictions = val_predictions.copy()

    print(f"\nFold {fold_num} Best Predictions (Best Val Acc: {best_val_acc:.1f}%):")
    for pred in best_predictions:
        status = "Correct" if pred['correct'] else "False"
        pred_label = "Home" if pred['pred'] == 1 else "Away"
        true_label = "Home" if pred['true'] == 1 else "Away"
        print(f"  {status} {pred['video'][:30]:30} | Pred: {pred_label:4} | True: {true_label:4}")

    return best_val_acc, best_predictions

# 4-Fold Cross Validation
def cross_validate():
    """Perform 4-fold cross-validation"""

    video_paths, labels, jersey_colors = prepare_data()

    if len(video_paths) != 20:
        print(f"Expected 20 videos, found {len(video_paths)}")

    n_folds = 4
    fold_size = len(video_paths) // n_folds

    indices = list(range(len(video_paths)))
    random.seed(RANDOM_SEED)
    random.shuffle(indices)

    fold_results = []
    all_predictions = []

    print(f"\nCreating {n_folds} folds with {fold_size} videos each")

    for fold in range(n_folds):
        set_seed(RANDOM_SEED + fold)

        val_start = fold * fold_size
        val_end = val_start + fold_size
        val_indices = indices[val_start:val_end]
        train_indices = [i for i in indices if i not in val_indices]

        train_paths = [video_paths[i] for i in train_indices]
        train_labels = [labels[i] for i in train_indices]
        train_colors = [jersey_colors[i] for i in train_indices]

        val_paths = [video_paths[i] for i in val_indices]
        val_labels = [labels[i] for i in val_indices]
        val_colors = [jersey_colors[i] for i in val_indices]

        print(f"\nFold {fold+1}: Train on {len(train_paths)} videos, Validate on {len(val_paths)} videos")
        print(f"Val videos: {[Path(p).name[:15] for p in val_paths]}")

        train_dataset = OOBDatasetWithColor(train_paths, train_labels, train_colors,
                                           is_train=True, augment=True)
        val_dataset = OOBDatasetWithColor(val_paths, val_labels, val_colors,
                                         is_train=False, augment=False)

        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

        val_acc, predictions = train_fold(train_loader, val_loader, fold+1, epochs=5)

        fold_results.append(val_acc)
        all_predictions.extend(predictions)

    print(f"\n{'-'*60}")
    print("CROSS-VALIDATION RESULTS")
    print(f"{'-'*60}")

    for i, acc in enumerate(fold_results):
        print(f"Fold {i+1}: {acc:.1f}%")

    mean_acc = np.mean(fold_results)
    std_acc = np.std(fold_results)

    overall_correct = sum(p['correct'] for p in all_predictions)
    overall_acc = 100 * overall_correct / len(all_predictions)
    print(f"Overall Accuracy: {overall_acc:.1f}% ({overall_correct}/{len(all_predictions)})")

    home_preds = [p for p in all_predictions if p['true'] == 1]
    away_preds = [p for p in all_predictions if p['true'] == 0]

    home_acc = 100 * sum(p['correct'] for p in home_preds) / len(home_preds) if home_preds else 0
    away_acc = 100 * sum(p['correct'] for p in away_preds) / len(away_preds) if away_preds else 0

    print(f"\nPerformance by Label:")
    print(f"  Home team ball: {home_acc:.1f}% ({sum(p['correct'] for p in home_preds)}/{len(home_preds)})")
    print(f"  Away team ball: {away_acc:.1f}% ({sum(p['correct'] for p in away_preds)}/{len(away_preds)})")

    return fold_results, all_predictions

# Run Cross-Validation
if __name__ == "__main__":
    if len(list(Path('/content/').glob('clip*.mp4'))) < 20:
        print("Please upload all 20 video clips first!")
        print("Use the Files sidebar to upload your clips to /content/")
    else:
        fold_results, predictions = cross_validate()


OOB Detection - 4-Fold Cross Validation
Seed: 1
Found 20 videos
  Home (1): 7, Away (0): 13
  White jerseys: 3
  Blue jerseys: 11
  Black jerseys: 6

Creating 4 folds with 5 videos each

Fold 1: Train on 15 videos, Validate on 5 videos
Val videos: ['clip 2 - Away t', 'clip 14 - Away ', 'clip 7 - Home t', 'clip 9 - Home t', 'clip 18 - Home ']
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


100%|██████████| 13.6M/13.6M [00:00<00:00, 66.9MB/s]



------------------------------------------------------------
Training Fold 1
------------------------------------------------------------
Epoch 1/5: Train Loss=0.589, Train Acc=66.7%, Val Acc=40.0%
Epoch 2/5: Train Loss=0.525, Train Acc=73.3%, Val Acc=40.0%
Epoch 3/5: Train Loss=0.292, Train Acc=86.7%, Val Acc=40.0%
Epoch 4/5: Train Loss=0.305, Train Acc=93.3%, Val Acc=80.0%
Epoch 5/5: Train Loss=0.442, Train Acc=82.2%, Val Acc=60.0%

Fold 1 Best Predictions (Best Val Acc: 80.0%):
  Correct clip 2 - Away team ball.mp4    | Pred: Away | True: Away
  False clip 14 - Away team ball.mp4   | Pred: Home | True: Away
  Correct clip 7 - Home team ball.mp4    | Pred: Home | True: Home
  Correct clip 9 - Home team ball.mp4    | Pred: Home | True: Home
  Correct clip 18 - Home team ball.mp4   | Pred: Home | True: Home

Fold 2: Train on 15 videos, Validate on 5 videos
Val videos: ['clip 1 - Away t', 'clip 6 - Away t', 'clip 10 - Away ', 'clip 5 - Away t', 'clip 15 - Away ']

---------------------