# Baseball Pitch Detection - Assignment 3

1. **Custom data loader** for baseball videos
2. **PyTorch neural network** (Faster R-CNN) training
3. **Model save/load** functionality for evaluation


In [None]:
import os
import cv2
import torch
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
from collections import defaultdict


## Custom Data Loader


In [None]:
def video_to_tensor(video_path, resize=(640, 480), frame_skip=1):
    """Convert video file to PyTorch tensor."""
    print(f"Loading: {os.path.basename(video_path)}")
    cap = cv2.VideoCapture(video_path)
    frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if resize:
            frame = cv2.resize(frame, resize)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)

    cap.release()

    if not frames:
        raise ValueError(f"No frames in {video_path}")

    frames = np.stack(frames)[::frame_skip]
    frames = torch.from_numpy(frames).float().permute(0, 3, 1, 2) / 255.0
    print(f"  -> {frames.shape[0]} frames loaded\\n")
    return frames


def parse_cvat_xml(xml_path, frame_skip=1):
    """Parse CVAT XML annotation file."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    annotations = defaultdict(list)

    for track in root.findall("track"):
        label = track.attrib["label"]
        for box in track.findall("box"):
            frame = int(box.attrib["frame"])
            outside = int(box.attrib["outside"])

            if outside != 0:
                continue

            xtl = float(box.attrib["xtl"])
            ytl = float(box.attrib["ytl"])
            xbr = float(box.attrib["xbr"])
            ybr = float(box.attrib["ybr"])

            moving_attr = box.find("attribute[@name='moving']")
            if moving_attr is None:
                continue

            moving_flag = 1 if moving_attr.text.lower() == "true" else 0

            if frame % frame_skip == 0:
                adjusted_frame = frame // frame_skip
                annotations[adjusted_frame].append({
                    "label": label,
                    "bbox": [xtl, ytl, xbr, ybr],
                    "moving": moving_flag
                })

    return annotations


class BaseballVideoDataset(Dataset):
    """Custom Dataset for baseball pitch videos with annotations."""

    def __init__(self, video_dir, xml_dir, resize=(640, 480), frame_skip=1):
        self.video_dir = video_dir
        self.xml_dir = xml_dir
        self.resize = resize
        self.frame_skip = frame_skip
        self.video_tensors = {}
        self.index_map = []

        video_files = [f for f in os.listdir(video_dir)
                      if f.lower().endswith(('.mp4', '.mov', '.avi'))]

        print(f"Found {len(video_files)} video files\\n")

        for video_file in video_files:
            stem = os.path.splitext(video_file)[0]
            video_path = os.path.join(video_dir, video_file)
            xml_path = os.path.join(xml_dir, f"{stem}.xml")

            if not os.path.exists(xml_path):
                print(f"Skipping {video_file}: no XML\\n")
                continue

            try:
                video_tensor = video_to_tensor(video_path, resize=resize, frame_skip=frame_skip)
                annotations = parse_cvat_xml(xml_path, frame_skip=frame_skip)
                self.video_tensors[video_path] = video_tensor

                for frame_idx, ann_list in annotations.items():
                    if len(ann_list) == 0 or frame_idx >= len(video_tensor):
                        continue

                    boxes = torch.tensor([a["bbox"] for a in ann_list], dtype=torch.float32)
                    moving = torch.tensor([a["moving"] for a in ann_list], dtype=torch.int64)

                    self.index_map.append((video_path, frame_idx, {
                        "boxes": boxes,
                        "moving": moving
                    }))

                print(f"Indexed {len(annotations)} frames from {video_file}\\n")

            except Exception as e:
                print(f"Error: {e}\\n")
                continue

        print(f"✓ Dataset: {len(self.index_map)} frames from {len(self.video_tensors)} videos\\n")

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_path, frame_idx, target = self.index_map[idx]
        frame = self.video_tensors[video_path][frame_idx]
        labels = target["moving"].clone() + 1

        return frame, {
            "boxes": target["boxes"],
            "labels": labels
        }


def collate_fn(batch):
    """Custom collate function."""
    frames = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return frames, targets


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


##  Neural Network Model & Training


In [None]:
def get_model(num_classes, pretrained=True):
    """Create Faster R-CNN model."""
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrained)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
    return model


def train_one_epoch(model, dataloader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0

    for batch_idx, (images, targets) in enumerate(dataloader):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_loss += losses.item()

        if batch_idx % 10 == 0:
            print(f"  Batch {batch_idx}/{len(dataloader)} | Loss: {losses.item():.4f}")

    return total_loss / len(dataloader)


@torch.no_grad()
def validate(model, dataloader, device):
    """Validate model."""
    model.train()
    total_loss = 0.0

    for images, targets in dataloader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()

    return total_loss / len(dataloader)


def train_model(train_dataset, val_dataset, num_classes=3, epochs=3, lr=5e-5, batch_size=2):
    """Train the model."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}\\n")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    model = get_model(num_classes, pretrained=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        print("-" * 40)

        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_loss = validate(model, val_loader, device)

        print(f"Train: {train_loss:.4f} | Val: {val_loss:.4f}\\n")

    print("✓ Training complete!\\n")
    return model


## Model Save/Load Functions


In [None]:
def save_model(model, save_path="baseball_detector.pth"):
    """Save trained model weights."""
    torch.save(model.state_dict(), save_path)
    print(f"✓ Model saved to '{save_path}'")


def load_model(weights_path, num_classes=3, device=None):
    """
    IMPORT SCRIPT: Load trained model without retraining.
    Use this function to evaluate the model.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Loading model on: {device}")

    # Recreate architecture
    model = get_model(num_classes, pretrained=False)

    # Load saved weights
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)

    # Set to evaluation mode
    model.to(device)
    model.eval()

    print(f"✓ Model loaded from '{weights_path}'\\n")
    return model


@torch.no_grad()
def predict(model, video_path, resize=(640, 480), confidence_threshold=0.5):
    """Run inference on a video."""
    device = next(model.parameters()).device
    model.eval()

    video_tensor = video_to_tensor(video_path, resize=resize)
    predictions = []

    for frame in video_tensor:
        frame = frame.to(device)
        pred = model([frame])[0]

        mask = pred['scores'] > confidence_threshold
        predictions.append({
            'boxes': pred['boxes'][mask].cpu(),
            'labels': pred['labels'][mask].cpu(),
            'scores': pred['scores'][mask].cpu()
        })

    return predictions


## 4. Run Training


In [23]:
# Set paths to your data
VIDEO_DIR = "/content/sample_data/OneDrive_2_11-12-2025"
XML_DIR = "/content/sample_data/OneDrive_1_11-12-2025/"

# Load dataset
print("=" * 50)
print("LOADING DATASET")
print("=" * 50 + "\\n")

dataset = BaseballVideoDataset(
    video_dir=VIDEO_DIR,
    xml_dir=XML_DIR,
    resize=(640, 480),
    frame_skip=1
)

# Split train/validation (80/20)
n = len(dataset)
train_size = int(0.8 * n)
val_size = n - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

print(f"Train: {train_size} samples")
print(f"Val: {val_size} samples\\n")

# Train model
print("=" * 50)
print("TRAINING")
print("=" * 50)

model = train_model(
    train_dataset,
    val_dataset,
    num_classes=3,  # background, not-moving, moving
    epochs=3,
    lr=5e-5,
    batch_size=2
)

# Save model
print("=" * 50)
print("SAVING")
print("=" * 50 + "\\n")

save_model(model, "baseball_detector.pth")

print("\\n✓ DONE!")


LOADING DATASET
Found 12 video files\n
Loading: IMG_7942_dusty.mov
  -> 57 frames loaded\n
Indexed 57 frames from IMG_7942_dusty.mov\n
Loading: IMG_7997_khem.mov
  -> 37 frames loaded\n
Indexed 37 frames from IMG_7997_khem.mov\n
Loading: IMG_7919_dusty.mov
  -> 50 frames loaded\n
Indexed 50 frames from IMG_7919_dusty.mov\n
Loading: IMG_7998_khem.mov
  -> 55 frames loaded\n
Indexed 55 frames from IMG_7998_khem.mov\n
Loading: IMG_9435_hugo.mov
  -> 66 frames loaded\n
Indexed 66 frames from IMG_9435_hugo.mov\n
Loading: IMG_7943_khem.mov
  -> 51 frames loaded\n
Indexed 54 frames from IMG_7943_khem.mov\n
Loading: IMG_7917_dusty.mov
  -> 60 frames loaded\n
Indexed 60 frames from IMG_7917_dusty.mov\n
Loading: IMG_7918_dusty.mov
  -> 76 frames loaded\n
Indexed 76 frames from IMG_7918_dusty.mov\n
Loading: dusty_1.mov
  -> 76 frames loaded\n
Indexed 76 frames from dusty_1.mov\n
Loading: IMG_9197_hugo.mov
  -> 70 frames loaded\n
Indexed 70 frames from IMG_9197_hugo.mov\n
Loading: IMG_9199_hugo.mo



Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


100%|██████████| 160M/160M [00:00<00:00, 177MB/s]


Epoch 1/3
----------------------------------------
  Batch 0/312 | Loss: 115.9895
  Batch 10/312 | Loss: 60.8186
  Batch 20/312 | Loss: 84.5531
  Batch 30/312 | Loss: 54.7157
  Batch 40/312 | Loss: 30.7799
  Batch 50/312 | Loss: 58.1665
  Batch 60/312 | Loss: 57.9091
  Batch 70/312 | Loss: 24.9072
  Batch 80/312 | Loss: 29.5750
  Batch 90/312 | Loss: 80.0623
  Batch 100/312 | Loss: 17.1348
  Batch 110/312 | Loss: 21.9460
  Batch 120/312 | Loss: 29.8280
  Batch 130/312 | Loss: 29.7437
  Batch 140/312 | Loss: 26.9574
  Batch 150/312 | Loss: 40.5491
  Batch 160/312 | Loss: 26.0527
  Batch 170/312 | Loss: 26.7784
  Batch 180/312 | Loss: 27.2215
  Batch 190/312 | Loss: 24.4175
  Batch 200/312 | Loss: 22.9246
  Batch 210/312 | Loss: 59.6585
  Batch 220/312 | Loss: 28.3562
  Batch 230/312 | Loss: 24.9417
  Batch 240/312 | Loss: 38.6133
  Batch 250/312 | Loss: 22.1172
  Batch 260/312 | Loss: 21.5881
  Batch 270/312 | Loss: 74.0798
  Batch 280/312 | Loss: 39.3794
  Batch 290/312 | Loss: 42.4798

## Evalution



In [26]:

# Load the trained model
loaded_model = load_model("/content/baseball_detector.pth", num_classes=3)

# Test on a video
test_video = "/content/sample_data/OneDrive_2_11-12-2025/dusty_1.mov"
predictions = predict(loaded_model, test_video, confidence_threshold=0.5)

print(f"✓ Predicted {len(predictions)} frames")
for i, frame_preds in enumerate(predictions):
    print(f"Frame {i} predictions: {len(frame_preds['boxes'])} detections")

Loading model on: cuda
✓ Model loaded from '/content/baseball_detector.pth'\n
Loading: dusty_1.mov
  -> 76 frames loaded\n
✓ Predicted 76 frames
Frame 0 predictions: 0 detections
Frame 1 predictions: 1 detections
Frame 2 predictions: 0 detections
Frame 3 predictions: 0 detections
Frame 4 predictions: 0 detections
Frame 5 predictions: 0 detections
Frame 6 predictions: 1 detections
Frame 7 predictions: 0 detections
Frame 8 predictions: 0 detections
Frame 9 predictions: 1 detections
Frame 10 predictions: 1 detections
Frame 11 predictions: 0 detections
Frame 12 predictions: 1 detections
Frame 13 predictions: 1 detections
Frame 14 predictions: 2 detections
Frame 15 predictions: 2 detections
Frame 16 predictions: 0 detections
Frame 17 predictions: 0 detections
Frame 18 predictions: 3 detections
Frame 19 predictions: 2 detections
Frame 20 predictions: 1 detections
Frame 21 predictions: 2 detections
Frame 22 predictions: 0 detections
Frame 23 predictions: 1 detections
Frame 24 predictions: 0 d