In [None]:
!pip install torch opencv-python pandas
!pip install kaggle

In [None]:
!kaggle datasets download -d ayushspai/sportsmot
!unzip sportsmot.zip -d SportsMOT

In [None]:
import os
import cv2
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize

import matplotlib.pyplot as plt
import matplotlib.patches as patches



## Preprocessing

In [None]:
def xywh_to_x1y1x2y2(boxes):
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 0] + boxes[:, 2]
    y2 = boxes[:, 1] + boxes[:, 3]

    return torch.stack([x1, y1, x2, y2], dim=1)
def filter_boxes(boxes, img_width=1280, img_height=720):
    valid = (boxes[:, 0] >= 0) & \
            (boxes[:, 1] >= 0) & \
            (boxes[:, 0] + boxes[:, 2] <= img_width) & \
            (boxes[:, 1] + boxes[:, 3] <= img_height)
    return boxes[valid]

class BallEnhancer:
    def __call__(self, image, boxes, classes):
        ball_mask = (classes == 3)  # Assuming class 3 is ball
        if ball_mask.any():
            ball_img = self._enhance_ball(image, boxes[ball_mask])
            image = cv2.addWeighted(image, 0.7, ball_img, 0.3, 0)
        return image

## Dataset

In [None]:
class SportsMOTDataset(Dataset):
    def __init__(self, root_dir, sequence="train/seq1", transform=None):
        """
        Args:
            root_dir (str): Root directory of dataset
            sequence (str): Subdirectory path (e.g., "train/seq1")
            transform (callable, optional): Optional transforms
        """
        self.sequence_path = os.path.join(root_dir, sequence)
        self.img_dir = os.path.join(self.sequence_path, "img1")
        self.gt_file = os.path.join(self.sequence_path, "gt", "gt.txt")
        self.transform = transform
        
        # Load annotations
        self.annotations = self._load_annotations()
        self.frame_ids = sorted(list(self.annotations.keys()))
        
        # Default transforms
        self.default_transform = Compose([
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _load_annotations(self):
        """Load annotations into {frame_id: list(annotations)} format"""
        df = pd.read_csv(self.gt_file, header=None)
        df.columns = [
            "frame_id", "object_id", "x", "y", "w", "h",
            "conf", "class_id", "visibility"
        ]
        
        annotations = {}
        for frame_id, group in df.groupby("frame_id"):
            annotations[frame_id] = {
                "boxes": group[["x", "y", "w", "h"]].values.astype("float32"),
                "object_ids": group["object_id"].values.astype("int32"),
                "classes": group["class_id"].values.astype("int32")
            }
        return annotations

    def __len__(self):
        return len(self.frame_ids)

    def __getitem__(self, idx):
        frame_id = self.frame_ids[idx]
        img_path = os.path.join(self.img_dir, f"{frame_id:06d}.jpg")
        
        # Load image
        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        
        # Get annotations
        ann = self.annotations[frame_id]
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            image = self.default_transform(image)
            
        return image, {
            "boxes": torch.tensor(ann["boxes"]),
            "object_ids": torch.tensor(ann["object_ids"]),
            "classes": torch.tensor(ann["classes"]),
            "frame_id": frame_id
        }

## Data Loader

In [None]:
def create_dataloader(root_dir, sequence, batch_size=4, shuffle=False):
    dataset = SportsMOTDataset(
        root_dir=root_dir,
        sequence=sequence,
        transform=Compose([
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    )
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=custom_collate,
        num_workers=4
    )
    return loader

def custom_collate(batch):
    """Custom collate function to handle variable number of objects"""
    images = torch.stack([item[0] for item in batch])
    targets = [item[1] for item in batch]
    return images, targets

In [None]:
if __name__ == "__main__":
    # Initialize dataset
    dataset = SportsMOTDataset(root_dir="/path/to/SportsMOT", sequence="train/seq1")
    
    # Create dataloader
    dataloader = create_dataloader(
        root_dir="/path/to/SportsMOT",
        sequence="train/seq1",
        batch_size=4,
        shuffle=False
    )
    
    # Iterate through batches
    for batch_idx, (images, targets) in enumerate(dataloader):
        print(f"Batch {batch_idx}:")
        print(f"Images shape: {images.shape}")
        print(f"Number of targets: {len(targets)}")
        print("-" * 50)
        
        if batch_idx == 2:
            break

In [None]:
from torchvision.transforms import RandomHorizontalFlip, ColorJitter

train_transform = Compose([
    ToTensor(),
    RandomHorizontalFlip(p=0.5),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = SportsMOTDataset(
    root_dir="/path/to/SportsMOT",
    sequence="train/seq1",
    transform=train_transform
)

from torch.utils.data import random_split

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate)

## Visualization

In [None]:
def visualize_sample(image, boxes, object_ids):
    """Visualize a single frame with bounding boxes and object IDs"""
    fig, ax = plt.subplots(1)
    ax.imshow(image.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    
    for box, obj_id in zip(boxes, object_ids):
        x, y, w, h = box
        rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x, y, f"ID: {obj_id}", color='r', fontsize=8, backgroundcolor='white')
    
    plt.show()

# Load a sample
sample_image, sample_targets = dataset[0]
visualize_sample(sample_image, sample_targets["boxes"], sample_targets["object_ids"])