In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import json
import os
from PIL import Image
import numpy as np

In [None]:
!pip install SoccerNet

In [None]:
# Mount to google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from SoccerNet.Downloader import SoccerNetDownloader as SNdl
mySNdl = SNdl(LocalDirectory="path/to/SoccerNet")
mySNdl.downloadDataTask(task="jersey-2023", split=["train","test"])

In [None]:
# Extract the zip file to drive
!unzip /content/path/to/SoccerNet/jersey-2023/test.zip -d /content
!unzip /content/path/to/SoccerNet/jersey-2023/train.zip -d /content

In [None]:
## Go through json file and figure out the unique classes

classes = []
with open("/home/sahilc/Sports-Analysis/Jersey Number Recognition/train/train_gt.json") as f:
    data = json.load(f)
    for img_path, number in data.items():
        classes.append(number)
classes = list(set(classes))
num_classes = len(classes)
print(classes)
print(num_classes)

## Test it for test.json
with open("/home/sahilc/Sports-Analysis/Jersey Number Recognition/test/test_gt.json") as f:
    data = json.load(f)
    for img_path, number in data.items():
        classes.append(number)
classes = list(set(classes))
num_classes = len(classes)
print(classes)
print(num_classes)

In [None]:
class JerseyNumberDataset(Dataset):
    def __init__(self, root_dir, json_path, transform=None, is_train=True, number_mapping=None):
        """
        Args:
            root_dir (str): Directory containing the images folder
            json_path (str): Path to the ground truth JSON file
            transform: Optional transform to be applied on images
            is_train (bool): If True, performs data augmentation
            number_mapping (dict): Optional mapping from jersey numbers to class indices
        """
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train

        # Load annotations
        print(f"Loading annotations from {json_path}")
        with open(json_path, 'r') as f:
            self.annotations = json.load(f)

        # Create samples list
        self.samples = []
        images_dir = os.path.join(root_dir, "images")

        print(f"Looking for images in {images_dir}")
        if not os.path.exists(images_dir):
            raise ValueError(f"Images directory not found at {images_dir}")

        # Get all valid image paths and their labels
        all_numbers = set()

        # Handle nested directory structure
        for img_path, number in self.annotations.items():
            player_dir = os.path.join(images_dir, img_path)
            if os.path.isdir(player_dir):
                # If it's a directory, get all images inside it
                for img_file in os.listdir(player_dir):
                    if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        full_path = os.path.join(player_dir, img_file)
                        all_numbers.add(number)
                        self.samples.append({
                            'img_path': full_path,
                            'number': number
                        })
            elif os.path.isfile(player_dir) and player_dir.lower().endswith(('.jpg', '.jpeg', '.png')):
                # If it's directly an image file
                all_numbers.add(number)
                self.samples.append({
                    'img_path': player_dir,
                    'number': number
                })

        # Create number to index mapping if not provided
        if number_mapping is None:
            sorted_numbers = sorted(list(all_numbers))
            self.number_to_idx = {num: idx for idx, num in enumerate(sorted_numbers)}
        else:
            self.number_to_idx = number_mapping

        print(f"Found {len(self.samples)} valid images")
        if len(self.samples) == 0:
            raise ValueError(f"No valid samples found in {images_dir}")

        # Print class distribution
        numbers = [sample['number'] for sample in self.samples]
        class_dist = {}
        for num in numbers:
            class_dist[num] = class_dist.get(num, 0) + 1
        print("\nClass distribution:")
        for num, count in sorted(class_dist.items()):
            print(f"Number {num}: {count} samples ({count/len(numbers)*100:.2f}%)")

        print("\nClass mapping:")
        for num, idx in sorted(self.number_to_idx.items()):
            print(f"Jersey number {num} -> Class index {idx}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        try:
            image = Image.open(sample['img_path']).convert('RGB')
        except Exception as e:
            print(f"Error loading image {sample['img_path']}: {str(e)}")
            return self.__getitem__((idx + 1) % len(self))

        if self.transform:
            image = self.transform(image)

        target = self.number_to_idx[sample['number']]
        return image, target


In [None]:

def create_datasets(train_dir, test_dir):
    """Create train and validation datasets with consistent class mapping"""
    # First, gather all possible numbers from both train and test sets
    all_numbers = set()

    # From training set
    with open(os.path.join(train_dir, 'train_gt.json'), 'r') as f:
        train_annotations = json.load(f)
        all_numbers.update(train_annotations.values())

    # From test set
    with open(os.path.join(test_dir, 'test_gt.json'), 'r') as f:
        test_annotations = json.load(f)
        all_numbers.update(test_annotations.values())

    # Create mapping
    sorted_numbers = sorted(list(all_numbers))
    number_mapping = {num: idx for idx, num in enumerate(sorted_numbers)}

    print(f"\nTotal unique jersey numbers found: {len(number_mapping)}")

    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets with shared mapping
    train_dataset = JerseyNumberDataset(
        root_dir=train_dir,
        json_path=os.path.join(train_dir, 'train_gt.json'),
        transform=train_transform,
        number_mapping=number_mapping
    )

    val_dataset = JerseyNumberDataset(
        root_dir=test_dir,
        json_path=os.path.join(test_dir, 'test_gt.json'),
        transform=val_transform,
        number_mapping=number_mapping,
        is_train=False
    )

    return train_dataset, val_dataset, len(number_mapping)


In [None]:
class JerseyNumberNet(nn.Module):
    def __init__(self, num_classes=101):  # 100 numbers + 1 for no number
        super(JerseyNumberNet, self).__init__()

        # Use ResNet18 as backbone
        self.backbone = models.resnet18(pretrained=True)

        # Modify the first conv layer to handle potential different input size
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Replace the final FC layer
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

        # Add attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Extract features
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        features = self.backbone.layer4(x)

        # Apply attention
        att = self.attention(features)
        features = features * att

        # Global average pooling and classification
        x = F.adaptive_avg_pool2d(features, (1, 1))
        x = torch.flatten(x, 1)
        x = self.backbone.fc(x)

        return x


In [None]:

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=1, device='cuda'):
    """
    Training function with validation
    """
    model.to(device)
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            print('training right now')
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            print(f'Training Loss: {running_loss/len(train_loader):.4f}')
            print(f'Training Accuracy: {100. * correct / total:.2f}%')

        train_acc = 100. * correct / total

        # Validation phase
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100. * val_correct / val_total

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Training Loss: {running_loss/len(train_loader):.4f}')
        print(f'Training Accuracy: {train_acc:.2f}%')
        print(f'Validation Accuracy: {val_acc:.2f}%')

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_jersey_model.pth')


In [None]:

def main():
    try:
        # Create datasets with consistent class mapping
        train_dataset, val_dataset, num_classes = create_datasets(
            train_dir='/content/train',
            test_dir='/content/test'
        )

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

        # Initialize model with correct number of classes
        print(f"\nInitializing model with {num_classes} classes")
        model = JerseyNumberNet(num_classes=num_classes)

        # Calculate class weights for weighted loss
        class_counts = torch.zeros(num_classes)
        for sample in train_dataset.samples:
            class_idx = train_dataset.number_to_idx[sample['number']]
            class_counts[class_idx] += 1

        class_weights = 1.0 / (class_counts + 1e-6)
        class_weights = class_weights / class_weights.sum()
        criterion = nn.CrossEntropyLoss(weight=class_weights.cuda())

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # Train the model
        train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

    except Exception as e:
        print(f"Error during initialization: {str(e)}")
        import traceback
        traceback.print_exc()

In [None]:

if __name__ == "__main__":
    main()

In [None]:

## Test the model on any random image

# Load the model

model = JerseyNumberNet(num_classes=55)
model.load_state_dict(torch.load('/home/sahilc/Sports-Analysis/Jersey Number Recognition/best_jersey_model (1).pth'))
model.eval()

image_path = '/home/sahilc/Sports-Analysis/Jersey Number Recognition/train/images/1/1_15.jpg'
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(0)

with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output.data, 1)

print(f"Predicted jersey number: {predicted.item()}")




In [None]:
## view the image to confirm

from PIL import Image
import matplotlib.pyplot as plt

# Open the image
image = Image.open(image_path)

# Display the image
plt.imshow(image)
plt.axis('off')  # Turn off axis labels
plt.show()
