# Data Module

> Dataset utilities for object detection

In [None]:
#| default_exp data

In [None]:
#| export
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import random
import albumentations as A
from typing import Dict, List, Tuple, Union, Optional, Callable
from PIL import Image

from objdetect.core import box_xyxy_to_cxcywh, box_cxcywh_to_xyxy, plot_boxes

In [None]:
#| hide
from nbdev.showdoc import *

## Data Transforms

In [None]:
#| export
class Compose:
    """Composes transforms for object detection.
    
    Applies transforms to both images and targets (bounding boxes, labels).
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

In [None]:
#| export
class Normalize:
    """Normalize an image with mean and standard deviation."""
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

In [None]:
#| export
class ToTensor:
    """Convert image and target to tensors."""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

In [None]:
#| export
class RandomHorizontalFlip:
    """Randomly horizontally flips the image with a given probability.
    The targets are also flipped accordingly.
    """
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:] if torch.is_tensor(image) else (image.height, image.width)
            image = F.hflip(image)
            
            if "boxes" in target:
                boxes = target["boxes"]
                # flip the x coordinates
                boxes = boxes[:, [0, 1, 2, 3]] * torch.tensor([1, 1, 1, 1]) + torch.tensor([0, 0, 0, 0])
                boxes[:, [0, 2]] = 1 - boxes[:, [2, 0]]  # horizontal flip
                target["boxes"] = boxes
                
        return image, target

In [None]:
#| export
class Resize:
    """Resize image to a specific size and adjust targets accordingly."""
    def __init__(self, size):
        self.size = size if isinstance(size, (list, tuple)) else (size, size)

    def __call__(self, image, target):
        orig_height, orig_width = image.shape[-2:] if torch.is_tensor(image) else (image.height, image.width)
        image = F.resize(image, self.size)
        
        if "boxes" in target and len(target["boxes"]):
            scale_x = self.size[1] / orig_width
            scale_y = self.size[0] / orig_height
            
            boxes = target["boxes"]
            boxes[:, [0, 2]] *= scale_x
            boxes[:, [1, 3]] *= scale_y
            target["boxes"] = boxes
            
        return image, target

## Base Dataset

In [None]:
#| export
class ObjectDetectionDataset(Dataset):
    """Base dataset for object detection."""
    
    def __init__(self, 
                 img_files: List[str], 
                 annotations: List[Dict], 
                 transforms: Optional[Callable] = None,
                 class_names: Optional[List[str]] = None):
        """
        Args:
            img_files: List of image file paths
            annotations: List of annotation dictionaries with keys 'boxes' and 'labels'
            transforms: Optional transforms to apply to images and targets
            class_names: List of class names
        """
        self.img_files = img_files
        self.annotations = annotations
        self.transforms = transforms
        self.class_names = class_names or []
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        img = Image.open(img_path).convert("RGB")
        
        annotation = self.annotations[idx]
        target = {}
        
        # Convert annotations to tensors
        if "boxes" in annotation and len(annotation["boxes"]):
            boxes = torch.tensor(annotation["boxes"], dtype=torch.float32)
            target["boxes"] = boxes
        else:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            
        if "labels" in annotation and len(annotation["labels"]):
            labels = torch.tensor(annotation["labels"], dtype=torch.int64)
            target["labels"] = labels
        else:
            target["labels"] = torch.zeros((0,), dtype=torch.int64)
        
        # Apply transforms
        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target
    
    @property
    def num_classes(self):
        """Return the number of classes in the dataset."""
        if self.class_names:
            return len(self.class_names)
        else:
            # Infer from annotations
            all_labels = []
            for anno in self.annotations:
                if "labels" in anno and len(anno["labels"]) > 0:
                    all_labels.extend(anno["labels"])
            return max(all_labels) + 1 if all_labels else 0
    
    def collate_fn(self, batch):
        """Custom collate function for data loader."""
        images, targets = list(zip(*batch))
        return images, targets
    
    @classmethod
    def from_coco(cls, coco_path, splits=['train'], transforms=None):
        """Create dataset from COCO format annotations.
        
        Args:
            coco_path: Path to COCO dataset directory
            splits: List of splits to include ('train', 'val', etc.)
            transforms: Optional transforms
            
        Returns:
            ObjectDetectionDataset
        """
        coco_path = Path(coco_path)
        img_files = []
        annotations = []
        class_names = []
        
        for split in splits:
            # For now we'll create a stub implementation
            # In a real implementation, we would parse the COCO JSON files
            pass
        
        return cls(img_files, annotations, transforms, class_names)
    
    def show_sample(self, idx, figsize=(10, 10)):
        """Show a sample from the dataset with annotations."""
        img, target = self[idx]
        
        # If image is a tensor, convert to PIL
        if isinstance(img, torch.Tensor):
            img = torchvision.transforms.ToPILImage()(img)
            
        boxes = target["boxes"]
        labels = target["labels"]
        class_names = self.class_names if self.class_names else None
        
        # Plot boxes
        return plot_boxes(img, boxes, labels, class_names=class_names, figsize=figsize)

## Data Utilities

In [None]:
#| export
def get_detection_transforms(train=True, size=640, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    """Get transforms for object detection datasets.
    
    Args:
        train: Whether to include data augmentation transforms for training
        size: Image size for resizing
        mean: Normalization mean
        std: Normalization standard deviation
        
    Returns:
        Compose object with transforms
    """
    transforms = []
    
    # Add training augmentations
    if train:
        transforms.append(RandomHorizontalFlip())
    
    # Add common transforms
    transforms.extend([
        Resize(size),
        ToTensor(),
        Normalize(mean=mean, std=std),
    ])
    
    return Compose(transforms)