In [2]:
import os
import random
from PIL import Image, ImageOps
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import datasets

  from .autonotebook import tqdm as notebook_tqdm


# 1. Define Main Clothing Item

In [None]:
main_item_names = [
    'shirt, blouse',
    'top, t-shirt, sweatshirt',
    'sweater',
    'cardigan',
    'jacket',
    'vest',
    'pants',
    'shorts',
    'skirt',
    'coat',
    'dress',
    'jumpsuit',
    'cape',
    'glasses',
    'hat',
    'headband, head covering, hair accessory',
    'tie',
    'glove',
    'watch',
    'belt',
    'leg warmer',
    'tights, stockings',
    'sock',
    'shoe',
    'bag, wallet',
    'scarf',
    'umbrella'
]

# 2. Load and Split the Dataset

In [None]:
dataset = datasets.load_dataset("detection-datasets/fashionpedia", split="train")

# Split into 90% training and 10% validation.
num_total = len(dataset)
train_dataset = dataset.select(range(int(0.9 * num_total)))
val_dataset = dataset.select(range(int(0.9 * num_total), num_total))

# Get the official label list from the dataset (for the "objects.category" field).
label_list = dataset.features["objects"].feature["category"].names
# Map the main item names to their numeric IDs.
main_item_ids = [label_list.index(name) for name in main_item_names if name in label_list]
print("Main item IDs:", main_item_ids)

# 3. Filter Annotations for Main Items

In [None]:
def has_main_item(example):
    """Keep samples with at least one bounding box having a main item category."""
    categories = example["objects"]["category"]
    return any(cat in main_item_ids for cat in categories)

def filter_main_items(example):
    """Remove bounding boxes and related annotations that are not main items."""
    old_cats = example["objects"]["category"]
    old_bboxes = example["objects"]["bbox"]
    old_area = example["objects"]["area"]
    new_cats, new_bboxes, new_area = [], [], []
    for cat, bb, area in zip(old_cats, old_bboxes, old_area):
        if cat in main_item_ids:
            new_cats.append(cat)
            new_bboxes.append(bb)
            new_area.append(area)
    example["objects"]["category"] = new_cats
    example["objects"]["bbox"] = new_bboxes
    example["objects"]["area"] = new_area
    return example

In [None]:
# Filter to only images with at least one main item.
train_dataset = train_dataset.filter(has_main_item)
val_dataset = val_dataset.filter(has_main_item)

# Map each sample to keep only the main item annotations.
train_dataset = train_dataset.map(filter_main_items)
val_dataset = val_dataset.map(filter_main_items)

# 4. Data Augmentation & Transformation

In [None]:
def augment_sample(example):
    """
    For training: apply random horizontal flip (with bbox adjustment),
    color jitter, resize to a fixed size, and convert to tensor.
    """
    # Ensure image is a PIL image.
    image = example["image"]
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    orig_width, orig_height = image.size

    # Random horizontal flip with probability 0.5.
    if random.random() < 0.5:
        image = ImageOps.mirror(image)
        new_bboxes = []
        for bb in example["objects"]["bbox"]:
            x1, y1, x2, y2 = bb
            # Flip bbox horizontally.
            new_bb = [orig_width - x2, y1, orig_width - x1, y2]
            new_bboxes.append(new_bb)
        example["objects"]["bbox"] = new_bboxes

    # Define augmentation transforms: color jitter, resize and conversion to tensor.
    aug_transforms = transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.Resize((224, 224)),  # Resize images to 224x224
        transforms.ToTensor()
    ])
    image_tensor = aug_transforms(image)
    example["image"] = image_tensor

    # Update bounding boxes to the new image size.
    new_width, new_height = 224, 224
    scale_x = new_width / orig_width
    scale_y = new_height / orig_height
    new_bboxes = []
    for bb in example["objects"]["bbox"]:
        x1, y1, x2, y2 = bb
        new_bb = [x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y]
        new_bboxes.append(new_bb)
    example["objects"]["bbox"] = new_bboxes

    return example

def transform_sample(example):
    """
    For validation: apply only resizing and conversion to tensor.
    """
    image = example["image"]
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    orig_width, orig_height = image.size
    trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    image_tensor = trans(image)
    example["image"] = image_tensor

    new_width, new_height = 224, 224
    scale_x = new_width / orig_width
    scale_y = new_height / orig_height
    new_bboxes = []
    for bb in example["objects"]["bbox"]:
        x1, y1, x2, y2 = bb
        new_bb = [x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y]
        new_bboxes.append(new_bb)
    example["objects"]["bbox"] = new_bboxes

    return example

In [None]:
# Apply augmentations to training set and a simpler transform to validation set.
train_dataset = train_dataset.map(augment_sample)
val_dataset = val_dataset.map(transform_sample)

# 5. Create a Custom Collate Function for PyTorch

In [None]:
def collate_fn(batch):
    """
    Custom collate_fn to stack images (which have uniform shape after resizing)
    and leave bounding boxes and labels as lists (since they vary in number per image).
    """
    images = torch.stack([sample["image"] for sample in batch])
    # Each sample's bounding boxes and category labels remain as lists.
    bboxes = [torch.tensor(sample["objects"]["bbox"], dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample["objects"]["category"], dtype=torch.int64) for sample in batch]
    
    return {
        "images": images,  # Tensor of shape (batch_size, C, H, W)
        "bboxes": bboxes,  # List of tensors (variable length per image)
        "labels": labels   # List of tensors (variable length per image)
    }

# 6. Create PyTorch DataLoaders

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [None]:
for batch in train_loader:
    print("Train Batch Images Shape:", batch["images"].shape)  # Expected: (4, C, 224, 224)
    print("Number of Bounding Box Sets in Batch:", len(batch["bboxes"]))
    # Each entry in batch["bboxes"] is a tensor of shape (num_objects, 4)
    break

for batch in val_loader:
    print("Validation Batch Images Shape:", batch["images"].shape)
    print("Number of Bounding Box Sets in Batch:", len(batch["bboxes"]))
    break