# Data Pipeline Verification

This notebook verifies that the custom dataset and data loaders work correctly.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from src.data import create_dataloaders, get_class_names

# Set style
plt.rcParams['figure.figsize'] = (15, 10)

## 1. Create Data Loaders

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = create_dataloaders(
    config_path='../configs/config.yaml',
    batch_size=4,
    num_workers=0  # Use 0 for debugging, increase for training
)

# Get class names
class_names = get_class_names('../configs/config.yaml')
print(f"\nClass names: {class_names}")

## 2. Visualize Sample Batch

In [None]:
def denormalize(tensor):
    """Denormalize image tensor for visualization"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

def plot_batch(images, targets, class_names, title="Batch Visualization"):
    """Plot a batch of images with bounding boxes"""
    batch_size = images.shape[0]
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    axes = axes.flatten()
    
    # Define colors for each class
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow']
    
    for i in range(min(batch_size, 4)):
        ax = axes[i]
        
        # Denormalize and convert to numpy
        img = denormalize(images[i]).permute(1, 2, 0).cpu().numpy()
        img = np.clip(img, 0, 1)
        
        ax.imshow(img)
        
        # Get targets for this image
        img_targets = targets[targets[:, 0] == i]
        
        # Draw bounding boxes
        h, w = img.shape[:2]
        for target in img_targets:
            class_id = int(target[1].item())
            x_center, y_center, box_w, box_h = target[2:].cpu().numpy()
            
            # Convert from YOLO format to pixel coordinates
            x_center *= w
            y_center *= h
            box_w *= w
            box_h *= h
            
            x1 = x_center - box_w / 2
            y1 = y_center - box_h / 2
            
            # Create rectangle
            rect = patches.Rectangle(
                (x1, y1), box_w, box_h,
                linewidth=2,
                edgecolor=colors[class_id % len(colors)],
                facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add label
            ax.text(
                x1, y1 - 5,
                class_names[class_id],
                color='white',
                fontsize=8,
                bbox=dict(facecolor=colors[class_id % len(colors)], alpha=0.7)
            )
        
        ax.set_title(f"Image {i+1} - {len(img_targets)} objects")
        ax.axis('off')
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Get a batch from train loader
images, targets, metadata = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Targets shape: {targets.shape}")
print(f"Number of objects in batch: {len(targets)}")

# Visualize
plot_batch(images, targets, class_names, "Training Batch with Augmentations")

## 3. Check Augmentations

In [None]:
# Load multiple batches to see augmentation variety
print("Loading 3 batches to show augmentation variety...\n")

for batch_idx in range(3):
    images, targets, metadata = next(iter(train_loader))
    plot_batch(images, targets, class_names, f"Augmented Batch {batch_idx + 1}")

## 4. Verify Validation Loader (No Augmentation)

In [None]:
# Get a batch from validation loader
val_images, val_targets, val_metadata = next(iter(val_loader))
plot_batch(val_images, val_targets, class_names, "Validation Batch (No Augmentation)")

## 5. Check Class Distribution in Batches

In [None]:
from collections import Counter

# Count classes in first 10 batches
class_counter = Counter()
total_objects = 0

for batch_idx, (images, targets, metadata) in enumerate(train_loader):
    if batch_idx >= 10:
        break
    
    for target in targets:
        class_id = int(target[1].item())
        class_counter[class_id] += 1
        total_objects += 1

print(f"Class distribution in first 10 batches ({total_objects} objects):\n")
for class_id in sorted(class_counter.keys()):
    count = class_counter[class_id]
    percentage = (count / total_objects) * 100
    print(f"{class_names[class_id]:20s}: {count:4d} ({percentage:5.2f}%)")

# Visualize
plt.figure(figsize=(10, 6))
classes = [class_names[i] for i in sorted(class_counter.keys())]
counts = [class_counter[i] for i in sorted(class_counter.keys())]
plt.bar(classes, counts, color='steelblue', edgecolor='black')
plt.xlabel('Algae Type', fontweight='bold')
plt.ylabel('Count', fontweight='bold')
plt.title('Class Distribution in Sample Batches', fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 6. Verify Tensor Shapes and Data Types

In [None]:
images, targets, metadata = next(iter(train_loader))

print("=" * 60)
print("DATA PIPELINE VERIFICATION SUMMARY")
print("=" * 60)
print(f"\nImages:")
print(f"  Shape: {images.shape}")
print(f"  Dtype: {images.dtype}")
print(f"  Device: {images.device}")
print(f"  Min value: {images.min():.3f}")
print(f"  Max value: {images.max():.3f}")

print(f"\nTargets:")
print(f"  Shape: {targets.shape}")
print(f"  Dtype: {targets.dtype}")
print(f"  Format: [batch_idx, class_id, x_center, y_center, width, height]")
print(f"  Sample target: {targets[0]}")

print(f"\nMetadata:")
print(f"  Batch size: {len(metadata)}")
print(f"  Sample metadata: {metadata[0]}")

print("\n" + "=" * 60)
print("âœ… DATA PIPELINE VERIFICATION COMPLETE!")
print("=" * 60)