In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import cv2
import numpy as np
import os
import glob
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [28]:
# Simple configuration
video_dir = './SportsData/'
frame_size = (112, 112)  # Smaller for faster processing
sequence_length = 8      # Shorter sequences
frame_skip = 3           # Skip more frames
batch_size = 4           # Smaller batch
num_epochs = 10        # Fewer epochs
learning_rate = 0.001


In [29]:
class SimpleBasketballDataset(Dataset):
    def __init__(self, video_dir, frame_size, sequence_length, frame_skip):
        self.video_dir = video_dir
        self.frame_size = frame_size
        self.sequence_length = sequence_length
        self.frame_skip = frame_skip
        
        # Get video files
        self.video_files = []
        for ext in ['*.mp4', '*.avi', '*.mov']:
            self.video_files.extend(glob.glob(os.path.join(video_dir, ext)))
        
        # Get labels from filenames
        self.labels = []
        valid_videos = []
        
        for video_path in self.video_files:
            filename = os.path.basename(video_path).lower()
            if 'hit' in filename:
                self.labels.append(1)
                valid_videos.append(video_path)
            elif 'miss' in filename:
                self.labels.append(0)
                valid_videos.append(video_path)
        
        self.video_files = valid_videos
        print(f"Found {len(self.video_files)} videos")
        print(f"Hit: {sum(self.labels)}, Miss: {len(self.labels) - sum(self.labels)}")
    
    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        video_path = self.video_files[idx]
        label = self.labels[idx]
        frames = self.extract_frames(video_path)
        return frames, label
    
    def extract_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        frame_count = 0
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % self.frame_skip == 0:
                frame = cv2.resize(frame, self.frame_size)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            
            frame_count += 1
        
        cap.release()
        
        # Handle sequence length
        if len(frames) >= self.sequence_length:
            frames = frames[:self.sequence_length]
        else:
            while len(frames) < self.sequence_length:
                frames.append(frames[-1] if frames else np.zeros((*self.frame_size, 3)))
        
        # Convert to tensor
        frames = np.array(frames, dtype=np.float32) / 255.0
        frames = np.transpose(frames, (3, 0, 1, 2))  # (C, T, H, W)
        
        return torch.FloatTensor(frames)

In [30]:
# Create dataset
dataset = SimpleBasketballDataset(video_dir, frame_size, sequence_length, frame_skip)

# Simple train/test split (80/20)
total_size = len(dataset)
train_size = int(0.7 * total_size)
test_size = total_size - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Found 72 videos
Hit: 56, Miss: 16
Training samples: 50
Testing samples: 22


In [31]:
class Simple3DCNN(nn.Module):
    def __init__(self, input_shape=(3, 8, 112, 112)):
        super(Simple3DCNN, self).__init__()
        
        # 3D convolutions
        self.conv1 = nn.Conv3d(3, 32, kernel_size=(3, 3, 3), padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        
        self.conv2 = nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
        
        self.conv3 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=1)
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
        
        # Calculate the size after convolutions
        self.feature_size = self._get_conv_output_size(input_shape)
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.feature_size, 256)
        self.fc2 = nn.Linear(256, 2)  # 2 classes: Hit or Miss
        
    def _get_conv_output_size(self, input_shape):
        # Create a dummy input to calculate output size
        dummy_input = torch.zeros(1, *input_shape)
        with torch.no_grad():
            x = F.relu(self.conv1(dummy_input))
            x = self.pool1(x)
            x = F.relu(self.conv2(x))
            x = self.pool2(x)
            x = F.relu(self.conv3(x))
            x = self.pool3(x)
            return x.view(1, -1).size(1)
        
    def forward(self, x):
        # 3D convolutions
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        
        # Flatten and fully connected
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x


In [32]:
input_shape = (3, sequence_length, frame_size[0], frame_size[1])
model = Simple3DCNN(input_shape).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [33]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}')):
        data, targets = data.to(device), targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
    
    # Print epoch results
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    print(f'Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.2f}%')


Epoch 1: 100%|██████████| 13/13 [00:24<00:00,  1.86s/it]


Epoch 1: Loss = 0.7958, Accuracy = 68.00%


Epoch 2: 100%|██████████| 13/13 [00:23<00:00,  1.83s/it]


Epoch 2: Loss = 0.5467, Accuracy = 76.00%


Epoch 3: 100%|██████████| 13/13 [00:23<00:00,  1.82s/it]


Epoch 3: Loss = 0.5380, Accuracy = 76.00%


Epoch 4: 100%|██████████| 13/13 [00:23<00:00,  1.81s/it]


Epoch 4: Loss = 0.5390, Accuracy = 76.00%


Epoch 5: 100%|██████████| 13/13 [00:23<00:00,  1.81s/it]


Epoch 5: Loss = 0.4687, Accuracy = 78.00%


Epoch 6: 100%|██████████| 13/13 [00:23<00:00,  1.81s/it]


Epoch 6: Loss = 0.4339, Accuracy = 80.00%


Epoch 7: 100%|██████████| 13/13 [00:23<00:00,  1.81s/it]


Epoch 7: Loss = 0.4962, Accuracy = 82.00%


Epoch 8: 100%|██████████| 13/13 [00:23<00:00,  1.84s/it]


Epoch 8: Loss = 0.3327, Accuracy = 88.00%


Epoch 9: 100%|██████████| 13/13 [00:25<00:00,  1.96s/it]


Epoch 9: Loss = 0.3594, Accuracy = 84.00%


Epoch 10: 100%|██████████| 13/13 [00:25<00:00,  1.93s/it]

Epoch 10: Loss = 0.5718, Accuracy = 78.00%





In [37]:
model.eval()
test_correct = 0
test_total = 0
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, targets in tqdm(test_loader):
        data, targets = data.to(device), targets.to(device)
        
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        
        test_total += targets.size(0)
        test_correct += (predicted == targets).sum().item()
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

test_accuracy = 100. * test_correct / test_total
print(f'Test Accuracy: {test_accuracy:.2f}%')

100%|██████████| 6/6 [00:06<00:00,  1.04s/it]

Test Accuracy: 81.82%



