In [None]:
import pathlib
from PIL import Image
import torch
from torchvision.transforms import v2

# Paths
yolo_base = pathlib.Path("data_yolo")
image_files = sorted((yolo_base / "images").glob("*/*.jpg"))
label_files = sorted((yolo_base / "labels").glob("*/*.txt"))

# Transformations
transform = v2.Compose([
    v2.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=90)
])

# Store results if needed
all_transformed_images = []
all_transformed_bboxes = []
all_classes = []
all_batch_idx = []

# Process all images
for i, (img_path, label_path) in enumerate(zip(image_files, label_files)):
    # Load image
    img = Image.open(img_path).convert("RGB")
    img_tensor = v2.ToImage()(img)

    # Load labels
    with open(label_path) as f:
        label_data = [line.strip().split() for line in f.readlines()]

    if not label_data:
        continue  # skip empty labels

    classes = torch.tensor([int(row[0]) for row in label_data])
    bboxes = torch.tensor([[float(x) for x in row[1:]] for row in label_data])

    # Apply transform (single image version)
    transformed_img, transformed_bboxes = transform(img_tensor.unsqueeze(0), bboxes)
    transformed_img = transformed_img.squeeze(0)

    # Store (optional)
    all_transformed_images.append(transformed_img)
    all_transformed_bboxes.append(transformed_bboxes)
    all_classes.append(classes)
    all_batch_idx.extend([i] * len(bboxes))

# If needed: stack all into batch tensors
all_images_tensor = torch.stack(all_transformed_images)
all_bboxes_tensor = torch.cat(all_transformed_bboxes)
all_classes_tensor = torch.cat(all_classes)
all_batch_idx_tensor = torch.tensor(all_batch_idx)

# Example: Visualize first transformed image
plot_with_bboxes(all_images_tensor, all_bboxes_tensor, all_classes_tensor, batch_idx=all_batch_idx_tensor, index=0)
