In [None]:
# notebooks/02_data_preprocessing_and_dataset.ipynb
# This notebook will demonstrate the usage of your TrashDetectionDataset and DataLoader.

import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image

# Import configurations
from config.dataset_config import TRASH_ICRA19_CLASSES, TACO_ANNOTATION_PATH, TACO_IMAGES_DIR, \
                                   TRASH_ICRA19_IMAGES_DIR, TRASH_ICRA19_ANNOTATIONS_DIR, \
                                   TRASH_ICRA19_CLASS_NAMES_PATH, NUM_CLASSES_ICRA19

# Import custom dataset and collate_fn
from datasets.trash_detection_dataset import TrashDetectionDataset, collate_fn

# Import the DetectionTransforms from datasets/trash_detection_dataset.py for consistency
# For a real project, this would be moved to `utils/data_augmentations.py`
# but for now, we'll keep it near the Dataset class to be self-contained for testing.
from datasets.trash_detection_dataset import DetectionTransforms # Re-import or define here for clarity

print("--- 02_data_preprocessing_and_dataset.ipynb ---")

# --- 1. Define Transforms and Augmentations ---
# This is where you'll define the robust augmentation strategy.
# For now, we'll use a basic transform from `DetectionTransforms`.
# Later, you will enhance `utils/data_augmentations.py` with more advanced transforms.

# For training, you'd typically have more aggressive augmentations.
# For validation/test, usually only resizing and normalization.

train_transforms = DetectionTransforms(size=(800, 800)) # Larger size for training
val_test_transforms = DetectionTransforms(size=(800, 800)) # Consistent size

# --- 2. Instantiate Datasets ---

print("\n--- Instantiating Datasets ---")

# TACO Dataset (for primary training)
# Note: TACO does not have pre-defined train/val/test splits. You'd typically split it manually
# if you want a dedicated TACO validation, or just use it as a large training pool.
# For now, we'll load the full mapped dataset as a "train" set.
taco_train_dataset = TrashDetectionDataset(
    dataset_name='taco',
    transforms=train_transforms,
    taco_json_path=TACO_ANNOTATION_PATH,
    taco_images_dir=TACO_IMAGES_DIR
)
print(f"TACO training dataset size: {len(taco_train_dataset)} images.")

# Trash-ICRA19 Dataset (for fine-tuning and evaluation)
icra19_train_dataset = TrashDetectionDataset(
    dataset_name='trash_icra19',
    split='train',
    transforms=train_transforms,
    icra19_images_dir=TRASH_ICRA19_IMAGES_DIR,
    icra19_annotations_dir=TRASH_ICRA19_ANNOTATIONS_DIR,
    icra19_class_names_path=TRASH_ICRA19_CLASS_NAMES_PATH
)
print(f"Trash-ICRA19 training dataset size: {len(icra19_train_dataset)} images.")

icra19_val_dataset = TrashDetectionDataset(
    dataset_name='trash_icra19',
    split='val',
    transforms=val_test_transforms,
    icra19_images_dir=TRASH_ICRA19_IMAGES_DIR,
    icra19_annotations_dir=TRASH_ICRA19_ANNOTATIONS_DIR,
    icra19_class_names_path=TRASH_ICRA19_CLASS_NAMES_PATH
)
print(f"Trash-ICRA19 validation dataset size: {len(icra19_val_dataset)} images.")

icra19_test_dataset = TrashDetectionDataset(
    dataset_name='trash_icra19',
    split='test',
    transforms=val_test_transforms,
    icra19_images_dir=TRASH_ICRA19_IMAGES_DIR,
    icra19_annotations_dir=TRASH_ICRA19_ANNOTATIONS_DIR,
    icra19_class_names_path=TRASH_ICRA19_CLASS_NAMES_PATH
)
print(f"Trash-ICRA19 test dataset size: {len(icra19_test_dataset)} images.")


# --- 3. Create DataLoaders ---

print("\n--- Creating DataLoaders ---")

BATCH_SIZE = 2 # Small batch size for demonstration
NUM_WORKERS = 0 # Set to >0 for faster loading in production, but 0 for debugging in notebooks

taco_train_loader = DataLoader(taco_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                               num_workers=NUM_WORKERS, collate_fn=collate_fn)
icra19_train_loader = DataLoader(icra19_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                 num_workers=NUM_WORKERS, collate_fn=collate_fn)
icra19_val_loader = DataLoader(icra19_val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                               num_workers=NUM_WORKERS, collate_fn=collate_fn)
icra19_test_loader = DataLoader(icra19_test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                num_workers=NUM_WORKERS, collate_fn=collate_fn)

print("DataLoaders created. Ready to iterate through batches.")

# --- 4. Iterate and Verify a Batch ---

print("\n--- Verifying Batches ---")

# Verify TACO Batch
print("\n--- Sample Batch from TACO Training Loader ---")
for batch_idx, (images, targets) in enumerate(taco_train_loader):
    print(f"Batch {batch_idx+1}:")
    print(f"  Images batch shape: {images.shape}") # Should be [BATCH_SIZE, 3, H, W]
    print(f"  Number of targets in batch: {len(targets)}") # Should be BATCH_SIZE
    for i, target in enumerate(targets):
        print(f"    Target {i+1}:")
        print(f"      Boxes (normalized cxcywh): {target['boxes'].shape} {target['boxes'].dtype}")
        print(f"      Labels: {target['labels'].shape} {target['labels'].dtype} -> {target['labels'].tolist()}")
        print(f"      Image ID: {target['image_id'].item()}")
        print(f"      Area: {target['area'].shape}")
        print(f"      iscrowd: {target['iscrowd'].shape}")
        print(f"      Original Size: {target['orig_size'].tolist()}") # Original H, W
        print(f"      Current Size: {target['size'].tolist()}") # Resized H, W

    # You can visualize a batch if needed (denormalize images, convert boxes back)
    if batch_idx == 0:
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

        def show_image_with_boxes(img_tensor, boxes_cxcywh_norm, labels, title=""):
            img = img_tensor * std + mean # Denormalize
            img = T.ToPILImage()(img)

            fig, ax = plt.subplots(1, figsize=(10, 10))
            ax.imshow(img)
            ax.set_title(title)
            ax.axis('off')

            img_width, img_height = img.size
            for bbox_norm, label_id in zip(boxes_cxcywh_norm, labels):
                # Convert normalized cxcywh to absolute xmin, ymin, xmax, ymax
                cx, cy, w, h = bbox_norm.tolist()
                xmin = (cx - w / 2) * img_width
                ymin = (cy - h / 2) * img_height
                xmax = (cx + w / 2) * img_width
                ymax = (cy + h / 2) * img_height

                rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                         linewidth=2, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                class_name = list(TRASH_ICRA19_CLASSES.keys())[label_id]
                plt.text(xmin, ymin - 5, class_name, color='red', fontsize=12,
                         bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0.5))
            plt.show()

        # Show the first image from the batch
        if len(images) > 0 and len(targets[0]['boxes']) > 0:
            show_image_with_boxes(images[0], targets[0]['boxes'], targets[0]['labels'],
                                  title="TACO Sample with Transformed Annotations")
        else:
            print("No bounding boxes found in first TACO batch for visualization.")
    break # Only show first batch for TACO

# Verify Trash-ICRA19 Batch
print("\n--- Sample Batch from Trash-ICRA19 Validation Loader ---")
for batch_idx, (images, targets) in enumerate(icra19_val_loader):
    print(f"Batch {batch_idx+1}:")
    print(f"  Images batch shape: {images.shape}")
    print(f"  Number of targets in batch: {len(targets)}")
    for i, target in enumerate(targets):
        print(f"    Target {i+1}:")
        print(f"      Boxes (normalized cxcywh): {target['boxes'].shape} {target['boxes'].dtype}")
        print(f"      Labels: {target['labels'].shape} {target['labels'].dtype} -> {target['labels'].tolist()}")
        print(f"      Image ID: {target['image_id'].item()}")
        print(f"      Area: {target['area'].shape}")
        print(f"      iscrowd: {target['iscrowd'].shape}")
        print(f"      Original Size: {target['orig_size'].tolist()}")
        print(f"      Current Size: {target['size'].tolist()}")

    if batch_idx == 0:
        if len(images) > 0 and len(targets[0]['boxes']) > 0:
            show_image_with_boxes(images[0], targets[0]['boxes'], targets[0]['labels'],
                                  title="Trash-ICRA19 Val Sample with Transformed Annotations")
        else:
            print("No bounding boxes found in first Trash-ICRA19 batch for visualization.")
    break # Only show first batch for ICRA19

print("\n--- Preprocessing and Dataset Setup Complete for Phase 1 ---")
print("You are now ready to proceed to Phase 2: ViT Object Detection Training & Evaluation.")