# üîç CropShield AI - Pre-Training Diagnostic Check

**Purpose**: Verify dataset loading, GPU availability, and data pipeline before CNN model training

This notebook checks:
1. ‚úÖ DataLoader functionality (FastImageFolder or WebDataset)
2. ‚úÖ Tensor shapes, dtypes, and label mappings
3. ‚úÖ Image visualization with class names
4. ‚úÖ Loading performance (throughput measurement)
5. ‚úÖ GPU accessibility and CUDA configuration

---

## üì¶ Import Libraries

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import time
from pathlib import Path

print("‚úÖ Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")

## üñ•Ô∏è Check GPU Availability

In [None]:
# Check CUDA availability
print("="*80)
print("GPU DIAGNOSTICS")
print("="*80)

cuda_available = torch.cuda.is_available()
print(f"\nüéÆ CUDA Available: {cuda_available}")

if cuda_available:
    print(f"   GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   Current Device: cuda:{torch.cuda.current_device()}")
    
    # Memory info
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    allocated = torch.cuda.memory_allocated(0) / (1024**3)
    cached = torch.cuda.memory_reserved(0) / (1024**3)
    
    print(f"\nüíæ GPU Memory:")
    print(f"   Total: {total_memory:.2f} GB")
    print(f"   Allocated: {allocated:.3f} GB")
    print(f"   Cached: {cached:.3f} GB")
    print(f"   Free: {total_memory - cached:.2f} GB")
    
    device = torch.device('cuda')
    print("\n‚úÖ GPU ready for training!")
else:
    device = torch.device('cpu')
    print("\n‚ö†Ô∏è  GPU not available, using CPU")
    print("   Training will be slower without GPU")

print(f"\nüéØ Using device: {device}")
print("="*80)

## üìÅ Choose DataLoader Type

Select which optimized loader to use:

In [None]:
# Configuration
USE_WEBDATASET = True  # Set to False to use FastImageFolder
BATCH_SIZE = 32
NUM_WORKERS = 0  # Use 0 for Windows, 12 for Linux/Mac

print("="*80)
print("DATALOADER CONFIGURATION")
print("="*80)
print(f"\nüìä Loader type: {'WebDataset' if USE_WEBDATASET else 'FastImageFolder'}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Num workers: {NUM_WORKERS}")
print("="*80)

## üîÑ Load DataLoader

In [None]:
print("\nüì¶ Creating DataLoader...")

if USE_WEBDATASET:
    # Use WebDataset
    if not Path('shards/').exists():
        print("‚ùå ERROR: shards/ directory not found!")
        print("   Run: python scripts/create_webdataset_shards.py")
        raise FileNotFoundError("shards/ directory not found")
    
    from webdataset_loader import make_webdataset_loaders
    
    train_loader, val_loader, test_loader, class_info = make_webdataset_loaders(
        shards_dir='shards/',
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    classes = class_info['classes']
    num_classes = class_info['num_classes']
    
else:
    # Use FastImageFolder
    if not Path('Database_resized/').exists():
        print("‚ùå ERROR: Database_resized/ directory not found!")
        print("   Run: python scripts/resize_images.py")
        raise FileNotFoundError("Database_resized/ directory not found")
    
    from fast_dataset import make_loaders
    
    train_loader, val_loader, test_loader, class_to_idx, classes = make_loaders(
        data_dir='Database_resized/',
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    num_classes = len(classes)

print(f"\n‚úÖ DataLoader created successfully!")
print(f"\nüìä Dataset Statistics:")
print(f"   Number of classes: {num_classes}")
print(f"   Classes (first 5): {classes[:5]}")
print(f"   Classes (last 5): {classes[-5:]}")

## üß™ Load One Batch and Analyze

In [None]:
print("="*80)
print("BATCH ANALYSIS")
print("="*80)

# Load one batch with timing
print("\n‚è±Ô∏è  Loading batch...")
start_time = time.time()

# Get first batch
images, labels = next(iter(train_loader))

load_time = time.time() - start_time

print(f"‚úÖ Batch loaded in {load_time:.3f} seconds")

# Analyze batch
print(f"\nüì¶ Tensor Information:")
print(f"   Images shape: {images.shape}")
print(f"   Images dtype: {images.dtype}")
print(f"   Labels shape: {labels.shape}")
print(f"   Labels dtype: {labels.dtype}")

print(f"\nüìä Value Ranges:")
print(f"   Image min: {images.min():.3f}")
print(f"   Image max: {images.max():.3f}")
print(f"   Image mean: {images.mean():.3f}")
print(f"   Image std: {images.std():.3f}")

print(f"\nüè∑Ô∏è  Label Information:")
print(f"   Label min: {labels.min()}")
print(f"   Label max: {labels.max()}")
print(f"   Unique labels in batch: {torch.unique(labels).tolist()}")

# Calculate throughput
batch_size_actual = images.size(0)
throughput = batch_size_actual / load_time

print(f"\nüöÄ Performance:")
print(f"   Batch size: {batch_size_actual}")
print(f"   Load time: {load_time:.3f}s")
print(f"   Throughput: {throughput:.1f} images/second")

print("="*80)

## üó∫Ô∏è Sample Label Mapping

Show class names for first 8 labels in the batch:

In [None]:
print("="*80)
print("LABEL MAPPING (First 8 samples)")
print("="*80)

for i in range(min(8, len(labels))):
    label_idx = labels[i].item()
    class_name = classes[label_idx]
    print(f"   Sample {i+1}: Label {label_idx:2d} ‚Üí {class_name}")

print("="*80)

## üñºÔ∏è Visualize Sample Images

Display a grid of 8 images with their class names:

In [None]:
def denormalize(tensor):
    """
    Denormalize image tensor from ImageNet normalization
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    # Denormalize
    img = tensor * std + mean
    
    # Clip to [0, 1]
    img = torch.clamp(img, 0, 1)
    
    return img


def show_batch(images, labels, classes, num_images=8):
    """
    Display a grid of images with class names
    """
    num_images = min(num_images, len(images))
    
    # Create figure
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    for i in range(num_images):
        # Get image and label
        img = images[i]
        label_idx = labels[i].item()
        class_name = classes[label_idx]
        
        # Denormalize image
        img = denormalize(img)
        
        # Convert to numpy and transpose to (H, W, C)
        img_np = img.cpu().numpy().transpose(1, 2, 0)
        
        # Display
        axes[i].imshow(img_np)
        axes[i].set_title(f"{class_name}\n(Label: {label_idx})", fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle('Sample Images from Training Batch', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Display sample images
print("\nüñºÔ∏è  Displaying sample images...\n")
show_batch(images, labels, classes, num_images=8)

## ‚ö° GPU Transfer Test

Test moving data to GPU (if available):

In [None]:
print("="*80)
print("GPU TRANSFER TEST")
print("="*80)

if cuda_available:
    print("\n‚è±Ô∏è  Testing GPU transfer speed...")
    
    # Test pinned memory transfer
    start_time = time.time()
    images_gpu = images.to(device, non_blocking=True)
    labels_gpu = labels.to(device, non_blocking=True)
    torch.cuda.synchronize()  # Wait for transfer to complete
    transfer_time = time.time() - start_time
    
    print(f"‚úÖ Transfer complete in {transfer_time*1000:.2f}ms")
    
    # Check memory usage
    allocated = torch.cuda.memory_allocated(0) / (1024**3)
    cached = torch.cuda.memory_reserved(0) / (1024**3)
    
    print(f"\nüíæ GPU Memory After Transfer:")
    print(f"   Allocated: {allocated:.3f} GB")
    print(f"   Cached: {cached:.3f} GB")
    
    # Verify data on GPU
    print(f"\n‚úÖ Data on GPU:")
    print(f"   Images device: {images_gpu.device}")
    print(f"   Labels device: {labels_gpu.device}")
    
    # Estimate batch transfer speed
    batch_size_mb = (images.numel() * images.element_size()) / (1024**2)
    transfer_speed = batch_size_mb / transfer_time
    
    print(f"\nüöÄ Transfer Performance:")
    print(f"   Batch size: {batch_size_mb:.2f} MB")
    print(f"   Transfer speed: {transfer_speed:.1f} MB/s")
    
    # Clean up GPU memory
    del images_gpu, labels_gpu
    torch.cuda.empty_cache()
    
else:
    print("\n‚ö†Ô∏è  No GPU available, skipping transfer test")

print("="*80)

## üìä Multi-Batch Loading Test

Test loading multiple batches to measure sustained throughput:

In [None]:
print("="*80)
print("MULTI-BATCH LOADING TEST")
print("="*80)

NUM_TEST_BATCHES = 10
print(f"\n‚è±Ô∏è  Loading {NUM_TEST_BATCHES} batches...")

start_time = time.time()
batch_times = []
total_images = 0

for i, (batch_images, batch_labels) in enumerate(train_loader):
    batch_start = time.time()
    
    # Simulate processing (move to GPU if available)
    if cuda_available:
        batch_images = batch_images.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)
    
    batch_time = time.time() - batch_start
    batch_times.append(batch_time)
    total_images += batch_images.size(0)
    
    if i + 1 >= NUM_TEST_BATCHES:
        break

total_time = time.time() - start_time
avg_batch_time = np.mean(batch_times)
std_batch_time = np.std(batch_times)
throughput = total_images / total_time

print(f"\n‚úÖ Loaded {NUM_TEST_BATCHES} batches ({total_images} images)")

print(f"\nüìä Loading Statistics:")
print(f"   Total time: {total_time:.3f}s")
print(f"   Avg batch time: {avg_batch_time*1000:.1f}ms ¬± {std_batch_time*1000:.1f}ms")
print(f"   Min batch time: {min(batch_times)*1000:.1f}ms")
print(f"   Max batch time: {max(batch_times)*1000:.1f}ms")

print(f"\nüöÄ Sustained Throughput:")
print(f"   {throughput:.1f} images/second")
print(f"   {throughput/BATCH_SIZE:.2f} batches/second")

# Estimate epoch time
if USE_WEBDATASET:
    train_samples = 17901  # From metadata
else:
    # Estimate from loader
    train_samples = 17909  # Approximate (80% of 22387)

epoch_time = train_samples / throughput
print(f"\n‚è±Ô∏è  Estimated Training Time:")
print(f"   Samples per epoch: {train_samples}")
print(f"   Time per epoch: {epoch_time:.1f}s ({epoch_time/60:.2f} minutes)")
print(f"   Time for 50 epochs: {epoch_time*50/60:.1f} minutes ({epoch_time*50/3600:.2f} hours)")

print("="*80)

## ‚úÖ Final Diagnostic Summary

In [None]:
print("\n" + "="*80)
print("DIAGNOSTIC SUMMARY")
print("="*80)

# Check all systems
checks = [
    ("‚úÖ", "PyTorch installed", f"v{torch.__version__}"),
    ("‚úÖ" if cuda_available else "‚ö†Ô∏è ", "GPU available", 
     f"{torch.cuda.get_device_name(0)}" if cuda_available else "CPU only"),
    ("‚úÖ", "DataLoader created", f"{'WebDataset' if USE_WEBDATASET else 'FastImageFolder'}"),
    ("‚úÖ", "Batch loading works", f"{BATCH_SIZE} images/batch"),
    ("‚úÖ", "Tensor shapes correct", f"{images.shape}"),
    ("‚úÖ", "Labels valid", f"Range: {labels.min()}-{labels.max()}"),
    ("‚úÖ", "Image normalization", f"Mean: {images.mean():.3f}, Std: {images.std():.3f}"),
    ("‚úÖ", "Throughput measured", f"{throughput:.1f} img/s"),
]

print("\nüìã System Checks:")
for status, check, detail in checks:
    print(f"   {status} {check:<30} {detail}")

print("\nüìä Dataset Configuration:")
print(f"   Number of classes: {num_classes}")
print(f"   Training samples: ~{train_samples:,}")
print(f"   Image size: 224√ó224√ó3")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Num workers: {NUM_WORKERS}")

print("\nüöÄ Performance:")
print(f"   Data loading: {throughput:.1f} img/s")
print(f"   Epoch time: ~{epoch_time/60:.2f} minutes")
print(f"   50 epochs: ~{epoch_time*50/60:.1f} minutes")

if cuda_available:
    print(f"\nüíæ GPU Memory:")
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"   Total: {total_memory:.2f} GB")
    print(f"   Available for training: ~{total_memory*0.9:.2f} GB")

print("\n" + "="*80)
print("üéâ ALL DIAGNOSTICS PASSED!")
print("="*80)

print("\n‚úÖ Ready for CNN Model Training!\n")
print("Next steps:")
print("   1. Build CNN model architecture (ResNet50, EfficientNet, or custom)")
print("   2. Create training script with this optimized DataLoader")
print("   3. Add mixed precision training (AMP) for faster GPU utilization")
print("   4. Implement learning rate scheduling and early stopping")
print("   5. Monitor training with TensorBoard")
print("\nüöÄ Let's build that model!")