# üñºÔ∏è DataLoader Batch Visualization

**Purpose**: Visually confirm training augmentations look realistic

**What to check**:
- ‚úÖ Images are properly denormalized (visible colors)
- ‚úÖ Augmentations appear natural (rotations, brightness, flips)
- ‚úÖ Class labels are correct
- ‚úÖ No artifacts or corruption
- ‚úÖ Training vs validation differences visible

In [None]:
# Import libraries
import torch
import matplotlib.pyplot as plt
from visualize_batch import (
    visualize_batch, 
    check_augmentation_statistics,
    visualize_augmentation_comparison,
    denormalize
)
from fast_dataset import make_loaders

# Configure matplotlib for better display
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 150
plt.rcParams['font.size'] = 10

print("‚úÖ Libraries loaded!")

## 1Ô∏è‚É£ Load DataLoaders

In [None]:
# Create DataLoaders with MODERATE augmentation
train_loader, val_loader, test_loader, class_names, info = make_loaders(
    data_dir='Database_resized/',
    batch_size=32,
    augmentation_mode='moderate'  # Try: conservative, moderate, aggressive
)

print(f"\n‚úÖ DataLoaders created!")
print(f"   Classes: {info['num_classes']}")
print(f"   Training: {info['train_size']} images ({len(train_loader)} batches)")
print(f"   Validation: {info['val_size']} images ({len(val_loader)} batches)")
print(f"   Test: {info['test_size']} images ({len(test_loader)} batches)")

## 2Ô∏è‚É£ Visualize Training Batch (WITH Augmentations)

This shows your training data **with** augmentations applied.

**What to look for**:
- Natural-looking rotations (no extreme distortions)
- Realistic color variations (brightness, contrast)
- Proper flips (horizontal/vertical)
- Disease features still visible

In [None]:
# Visualize training batch
fig, axes = visualize_batch(
    train_loader, 
    class_names, 
    num_images=10,
    figsize=(20, 10),
    title="Training Batch with MODERATE Augmentations"
)
plt.show()

print("\nüëÜ Inspect the images above:")
print("   ‚Ä¢ Do rotations look natural?")
print("   ‚Ä¢ Are colors realistic (not oversaturated)?")
print("   ‚Ä¢ Can you still identify disease symptoms?")
print("   ‚Ä¢ Are class labels correct?")

## 3Ô∏è‚É£ Visualize Validation Batch (NO Augmentations)

This shows validation data **without** augmentations.

**Expected**: Clean, centered crops with no randomness.

In [None]:
# Visualize validation batch
fig, axes = visualize_batch(
    val_loader, 
    class_names, 
    num_images=10,
    figsize=(20, 10),
    title="Validation Batch (No Augmentations)"
)
plt.show()

print("\nüëÜ Compare with training batch:")
print("   ‚Ä¢ Validation images should look more consistent")
print("   ‚Ä¢ No random rotations or flips")
print("   ‚Ä¢ Centered crops")
print("   ‚Ä¢ Standard brightness/contrast")

## 4Ô∏è‚É£ Check Augmentation Variety

Fetches the **same images multiple times** to show different augmentations.

**Expected**: Each row shows the same class but with different transforms applied.

In [None]:
# Compare augmentation variety
fig, axes = visualize_augmentation_comparison(
    train_loader,
    class_names,
    num_samples=3,
    figsize=(20, 12)
)
plt.show()

print("\nüëÜ Each row should show variety:")
print("   ‚Ä¢ Different rotations")
print("   ‚Ä¢ Different brightness levels")
print("   ‚Ä¢ Different crops")
print("   ‚Ä¢ But SAME underlying class")

## 5Ô∏è‚É£ Statistical Validation

Verify augmented images have reasonable statistics.

In [None]:
# Check statistics
check_augmentation_statistics(train_loader, num_batches=5)

print("\n‚úÖ If all checks passed, your augmentations are working correctly!")

## 6Ô∏è‚É£ Interactive Exploration (Optional)

Manually inspect individual images and their transformations.

In [None]:
# Get a single batch
images, labels = next(iter(train_loader))

# Pick an image to inspect
idx = 0  # Change this to inspect different images (0-31)

# Denormalize and display
img_denorm = denormalize(images[idx])
img_np = img_denorm.permute(1, 2, 0).cpu().numpy()

# Get label
label_idx = labels[idx].item()
class_name = class_names[label_idx]

# Display
plt.figure(figsize=(8, 8))
plt.imshow(img_np)
plt.title(f"Image {idx}: {class_name} (Class {label_idx})", fontsize=14, fontweight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()

# Print statistics
print(f"\nüìä Image Statistics:")
print(f"   Shape: {img_np.shape}")
print(f"   Min: {img_np.min():.4f}")
print(f"   Max: {img_np.max():.4f}")
print(f"   Mean: {img_np.mean():.4f}")
print(f"   Std: {img_np.std():.4f}")
print(f"\n   Per-channel mean: R={img_np[:,:,0].mean():.4f}, G={img_np[:,:,1].mean():.4f}, B={img_np[:,:,2].mean():.4f}")

## 7Ô∏è‚É£ Compare Multiple Augmentation Modes (Optional)

Compare conservative vs moderate vs aggressive augmentation strengths.

In [None]:
# Load different augmentation modes
modes = ['conservative', 'moderate', 'aggressive']
loaders = {}

for mode in modes:
    print(f"\nLoading {mode.upper()} mode...")
    train_l, _, _, _, _ = make_loaders(
        data_dir='Database_resized/',
        batch_size=32,
        augmentation_mode=mode
    )
    loaders[mode] = train_l

# Visualize each mode
for mode, loader in loaders.items():
    print(f"\n{'='*60}")
    print(f"{mode.upper()} MODE")
    print('='*60)
    
    fig, axes = visualize_batch(
        loader, 
        class_names, 
        num_images=8,
        figsize=(20, 8),
        title=f"{mode.capitalize()} Augmentation Mode"
    )
    plt.show()

print("\n‚úÖ Comparison complete!")
print("\nüí° Which mode looks best?")
print("   ‚Ä¢ CONSERVATIVE: Minimal changes, safe choice")
print("   ‚Ä¢ MODERATE: Balanced, recommended for most cases")
print("   ‚Ä¢ AGGRESSIVE: Strong transforms, use if overfitting")

## ‚úÖ Summary

### What You Should See:

**Training Batch**:
- ‚úÖ Natural-looking rotations (¬±15¬∞)
- ‚úÖ Realistic brightness/contrast variations
- ‚úÖ Random flips (horizontal/vertical)
- ‚úÖ Disease features still recognizable
- ‚úÖ Colors in valid range [0, 1]

**Validation Batch**:
- ‚úÖ Consistent, centered crops
- ‚úÖ No random transformations
- ‚úÖ Standard appearance

**Statistics**:
- ‚úÖ Mean ‚âà 0.3-0.5 (natural images)
- ‚úÖ Std ‚âà 0.15-0.25 (good variety)
- ‚úÖ Values in [0, 1] after denormalization
- ‚úÖ Multiple classes present

### Next Steps:

1. **If augmentations look good** ‚Üí Proceed to model training
2. **If too strong** ‚Üí Switch to `conservative` mode
3. **If too weak** ‚Üí Switch to `aggressive` mode
4. **If artifacts present** ‚Üí Check transform pipeline in `transforms.py`