<a href="https://colab.research.google.com/github/Chaudhari-Amar/econ8310-assignment-baseball-amar/blob/main/dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
DATA_PATH = "/content/drive/MyDrive/Baseball Detection video"

In [None]:
import os
import glob
import math
import xml.etree.ElementTree as ET
from typing import Dict, List, Tuple, Optional

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.transforms import functional as TF

# ---- Configuration (override from train/inference scripts if needed) ----
DATA_PATH = os.environ.get("DATA_PATH", "/content/drive/MyDrive/Baseball Detection video")
VIDEOS_DIR = os.path.join(DATA_PATH, "Raw Videos")
ANN_DIR = os.path.join(DATA_PATH, "Annotations")


class BaseballCvatFrameDataset(Dataset):
    """
    Frame-level dataset built from CVAT XML annotations.
    - Each item is a single video frame that has at least one bounding box for label "baseball".
    - Target is a single bounding box per frame (we pick the first visible box in that frame).
    - Box format returned is normalized (x_c, y_c, w, h) in [0,1].
    """

    def __init__(
        self,
        videos_dir: str = VIDEOS_DIR,
        ann_dir: str = ANN_DIR,
        transform=None,
        max_frames_per_video: Optional[int] = None,
        frame_subsample: int = 1,
    ):
        super().__init__()
        self.videos_dir = videos_dir
        self.ann_dir = ann_dir
        self.transform = transform
        self.max_frames_per_video = max_frames_per_video
        self.frame_subsample = max(1, int(frame_subsample))

        # Map base stem (without extension) -> (video_path, xml_path)
        self.pairs: List[Tuple[str, str]] = []
        vid_exts = (".mov", ".mp4", ".MP4", ".MOV")
        for vp in glob.glob(os.path.join(videos_dir, "*")):
            if os.path.splitext(vp)[1] in vid_exts:
                stem = os.path.basename(os.path.splitext(vp)[0])
                # try exact name, or with suffixes like _dusty
                xml_candidates = glob.glob(os.path.join(ann_dir, f"{stem}*.xml"))
                if not xml_candidates:
                    # also try without suffix if video has suffix
                    xml_candidates = glob.glob(os.path.join(ann_dir, f"{stem.split('_')[0]}*.xml"))
                if xml_candidates:
                    self.pairs.append((vp, xml_candidates[0]))

        if not self.pairs:
            raise FileNotFoundError(
                f"No video/annotation pairs found. Looked in {videos_dir} and {ann_dir}."
            )

        # Build index: list of (video_path, frame_idx, bbox_norm, (H, W))
        self.index: List[Tuple[str, int, Tuple[float, float, float, float], Tuple[int, int]]] = []
        for vp, xp in self.pairs:
            # Read video metadata only to know frame count later when parsing xml
            video, _, info = torchvision.io.read_video(vp, pts_unit="sec")
            # video: (T, H, W, C) uint8
            T, H, W, _ = video.shape

            frame_to_box = self._parse_cvat_xml_boxes(xp)
            # Optionally limit frames and subsample
            frame_ids = sorted([f for f in frame_to_box.keys() if f < T])[:: self.frame_subsample]
            if self.max_frames_per_video is not None:
                frame_ids = frame_ids[: self.max_frames_per_video]
            for f in frame_ids:
                x1, y1, x2, y2 = frame_to_box[f]
                # Normalize to (cx, cy, w, h) in [0,1]
                cx = ((x1 + x2) / 2.0) / W
                cy = ((y1 + y2) / 2.0) / H
                bw = (x2 - x1) / W
                bh = (y2 - y1) / H
                self.index.append((vp, f, (cx, cy, bw, bh), (H, W)))

        # Cache videos to avoid re-decoding every __getitem__ call if memory allows
        self._video_cache: Dict[str, torch.Tensor] = {}

    def _parse_cvat_xml_boxes(self, xml_path: str) -> Dict[int, Tuple[float, float, float, float]]:
        """Return a mapping frame_idx -> (x1, y1, x2, y2) in pixel coords.
        If multiple boxes exist for the same frame, choose the first visible (outside="0")."""
        tree = ET.parse(xml_path)
        root = tree.getroot()
        frames: Dict[int, Tuple[float, float, float, float]] = {}
        for track in root.findall("track"):
            if track.get("label") != "baseball":
                continue
            for box in track.findall("box"):
                if box.get("outside") == "1":
                    continue  # not visible in this frame
                f = int(box.get("frame"))
                x1 = float(box.get("xtl"))
                y1 = float(box.get("ytl"))
                x2 = float(box.get("xbr"))
                y2 = float(box.get("ybr"))
                if f not in frames:
                    frames[f] = (x1, y1, x2, y2)
        return frames

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        vp, frame_idx, bbox_norm, (H, W) = self.index[idx]
        # Load video (from cache if possible)
        if vp not in self._video_cache:
            vid, _, _ = torchvision.io.read_video(vp, pts_unit="sec")
            self._video_cache[vp] = vid  # (T, H, W, C)
        video = self._video_cache[vp]
        frame = video[frame_idx]  # (H, W, C)
        frame = frame.permute(2, 0, 1).float() / 255.0  # -> (C, H, W) in [0,1]

        if self.transform is not None:
            frame = self.transform(frame)

        target = torch.tensor(bbox_norm, dtype=torch.float32)
        sample = {
            "image": frame,
            "target": target,  # (cx, cy, w, h) normalized
            "meta": {
                "video_path": vp,
                "frame_idx": frame_idx,
                "size": (H, W),
            },
        }
        return sample


def create_dataloader(
    batch_size: int = 8,
    shuffle: bool = True,
    num_workers: int = 2,
    **dataset_kwargs,
) -> DataLoader:
    ds = BaseballCvatFrameDataset(**dataset_kwargs)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


if __name__ == "__main__":
    # Quick smoke test
    dl = create_dataloader(batch_size=2, max_frames_per_video=16, frame_subsample=2)
    batch = next(iter(dl))
    print(batch["image"].shape, batch["target"].shape)



RuntimeError: DataLoader worker (pid(s) 35840) exited unexpectedly