# Setting up the dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset('detection-datasets/fashionpedia')
dataset

In [None]:
import torch
from torchvision import transforms

def collate_fn(batch):
    images = [item["pixel_values"] for item in batch]
    targets = [item["objects"] for item in batch]

    # Stack images into a single 4D tensor [B, C, H, W]
    # Note: This only works if all images are the same size!
    images = torch.stack(images)
    return images, targets

def transform_fn(examples):
    # Basic transforms: Convert PIL to Tensor
    # You might want to add Resize() or Normalize() here
    t = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])

    examples["pixel_values"] = [t(img.convert("RGB")) for img in examples["image"]]
    # Keep the objects as they are for the collate_fn to handle
    return examples

# Apply the transformation to the dataset
transformed_dataset = dataset["train"].with_transform(transform_fn)

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    transformed_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

# Test a single batch
batch = next(iter(train_dataloader))
images, targets = batch
print(f"Batch images shape: {images.shape}")
print(f"Number of target dicts: {len(targets)}")