# Quick subset + train pipeline (Capstone-Lazarus)

This notebook inspects `data/`, creates a small balanced subset, and trains a head-only transfer-learning model (timm EfficientNet-B0). Designed for fast experimentation on a laptop.

## Features:
- 🎯 Balanced stratified subset creation
- ⚡ Fast training with frozen backbone
- 🚀 AMP (Automatic Mixed Precision) support
- 💾 Automatic checkpointing
- 📊 Real-time metrics tracking

**Prerequisites:** Run this notebook from the repository root where `data/` directory exists.

In [None]:
# Install required libs (run once in notebook)
!pip install --upgrade pip
!pip install torch torchvision timm albumentations pillow scikit-learn tqdm

In [None]:
# Check CUDA availability and set device
import torch
import sys
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("CUDA not available, using CPU")
    
print(f"PyTorch version: {torch.__version__}")
print(f"Python version: {sys.version}")

In [None]:
# Inspect original data structure
from pathlib import Path

DATA_DIR = Path('data')
assert DATA_DIR.exists(), "Run this notebook from repo root where `data/` exists."

print("🔍 Inspecting original dataset structure:")
print("=" * 50)

class_counts = {}
for p in sorted([d for d in DATA_DIR.iterdir() if d.is_dir()]):
    image_count = len([f for f in p.iterdir() if f.suffix.lower() in ('.jpg','.jpeg','.png','.bmp','.tif','.tiff','.webp')])
    class_counts[p.name] = image_count
    print(f"{p.name:<40} {image_count:>6} images")

total_images = sum(class_counts.values())
num_classes = len(class_counts)
print("=" * 50)
print(f"Total classes: {num_classes}")
print(f"Total images: {total_images}")
print(f"Average per class: {total_images/num_classes:.1f}")

In [None]:
# Create balanced subset using our script
# Parameters: adjust samples_per_class small for quick runs
SAMPLES_PER_CLASS = 50   # try 30-100 for quick experiments
VAL_RATIO = 0.2

print(f"🎯 Creating balanced subset:")
print(f"   Samples per class: {SAMPLES_PER_CLASS}")
print(f"   Validation ratio: {VAL_RATIO}")
print(f"   Expected total: ~{SAMPLES_PER_CLASS * num_classes} images")
print("-" * 50)

!python scripts/create_subset.py --data-dir data --out-dir data_subset --samples-per-class {SAMPLES_PER_CLASS} --val-ratio {VAL_RATIO} --seed 42 --symlink true

In [None]:
# Create data loaders with Albumentations transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import torch

class ImageFolderAlb(Dataset):
    """Custom Dataset using Albumentations for transforms"""
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.samples = []
        exts = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff','.webp'}
        classes = sorted([p for p in self.root.iterdir() if p.is_dir()])
        self.class_to_idx = {d.name: i for i, d in enumerate(classes)}
        self.classes = [cls.name for cls in classes]
        
        for cls in classes:
            for img in cls.iterdir():
                if img.suffix.lower() in exts:
                    self.samples.append((img, cls.name))
        self.transform = transform
        print(f"Found {len(self.samples)} images in {len(classes)} classes")

    def __len__(self): 
        return len(self.samples)
    
    def __getitem__(self, idx):
        p, cls = self.samples[idx]
        img = np.array(Image.open(p).convert('RGB'))
        if self.transform:
            img = self.transform(image=img)['image']
        label = self.class_to_idx[cls]
        return img, label

def get_transforms(img_size=224, split='train'):
    """Get Albumentations transforms for train/val"""
    if split == 'train':
        return A.Compose([
            A.Resize(img_size, img_size),
            A.HorizontalFlip(p=0.5),
            A.RandomResizedCrop(img_size, img_size, scale=(0.7, 1.0), p=0.6),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:  # validation
        return A.Compose([
            A.Resize(img_size, img_size), 
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
            ToTensorV2()
        ])

# Configuration - adjust for your hardware
IMG_SIZE = 160      # Small size for laptop-friendly training
BATCH_SIZE = 16     # Adjust based on GPU memory
NUM_WORKERS = 4     # Adjust based on CPU cores

print(f"🔄 Setting up data loaders:")
print(f"   Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Num workers: {NUM_WORKERS}")

# Create datasets and loaders
train_ds = ImageFolderAlb('data_subset/train', transform=get_transforms(IMG_SIZE, 'train'))
val_ds = ImageFolderAlb('data_subset/val', transform=get_transforms(IMG_SIZE, 'val'))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                       num_workers=NUM_WORKERS, pin_memory=True)

print(f"\n✅ Data loaders ready:")
print(f"   Train: {len(train_ds)} images, {len(train_loader)} batches")
print(f"   Val: {len(val_ds)} images, {len(val_loader)} batches")
print(f"   Classes: {len(train_ds.classes)}")
print(f"   Class names: {train_ds.classes[:5]}..." if len(train_ds.classes) > 5 else f"   Class names: {train_ds.classes}")

In [None]:
# Initialize model with frozen backbone (transfer learning)
import timm
import torch.nn as nn

# Model configuration
MODEL_NAME = 'tf_efficientnet_b0'  # Efficient and fast
num_classes = len(train_ds.class_to_idx)

print(f"🏗️ Initializing model:")
print(f"   Architecture: {MODEL_NAME}")
print(f"   Number of classes: {num_classes}")
print(f"   Device: {device}")

# Create pre-trained model
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)

# Freeze backbone parameters (transfer learning)
print("\n🧊 Freezing backbone parameters...")
for param in model.parameters():
    param.requires_grad = False

# Reset classifier head and ensure it's trainable
print("🎯 Setting up classifier head...")
try:
    model.reset_classifier(num_classes)
except Exception:
    # Fallback for different model architectures
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, num_classes)
    elif hasattr(model, 'fc'):
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    else:
        raise RuntimeError("Could not find classifier layer")

# Ensure head parameters are trainable
head_params = 0
total_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if any(x in name.lower() for x in ['classifier', 'fc', 'head', 'ln']):
        param.requires_grad = True
        head_params += param.numel()

model = model.to(device)

print(f"✅ Model ready:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {head_params:,}")
print(f"   Frozen parameters: {total_params - head_params:,}")
print(f"   Training only: {(head_params/total_params)*100:.1f}% of parameters")

# Show model summary
print(f"\n📋 Model architecture:")
print(model)

In [None]:
# Training loop with AMP and checkpointing
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
import time

# Training configuration
EPOCHS = 4          # Small number for quick experiments
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4

# Setup loss, optimizer, and AMP scaler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)
scaler = GradScaler() if torch.cuda.is_available() else None

print(f"🚀 Training setup:")
print(f"   Epochs: {EPOCHS}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Weight decay: {WEIGHT_DECAY}")
print(f"   AMP enabled: {scaler is not None}")
print(f"   Optimizer: AdamW")

def train_one_epoch(model, loader, epoch):
    """Train for one epoch with progress tracking"""
    model.train()
    losses, preds, targets = [], [], []
    
    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch_idx, (imgs, labels) in enumerate(loop):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # Forward pass with AMP
        with autocast(enabled=(scaler is not None)):
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        
        # Backward pass
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        # Track metrics
        losses.append(loss.item())
        preds.extend(outputs.argmax(1).cpu().numpy())
        targets.extend(labels.cpu().numpy())
        
        # Update progress bar
        current_loss = np.mean(losses)
        current_acc = accuracy_score(targets, preds)
        loop.set_postfix({
            'loss': f'{current_loss:.4f}',
            'acc': f'{current_acc:.4f}'
        })
    
    return np.mean(losses), accuracy_score(targets, preds)

@torch.no_grad()
def validate(model, loader):
    """Validate model performance"""
    model.eval()
    losses, preds, targets = [], [], []
    
    for imgs, labels in tqdm(loader, desc="Validating", leave=False):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        
        losses.append(loss.item())
        preds.extend(outputs.argmax(1).cpu().numpy())
        targets.extend(labels.cpu().numpy())
    
    return np.mean(losses), accuracy_score(targets, preds)

# Training loop
print("\\n" + "="*60)
print("🎯 STARTING TRAINING")
print("="*60)

best_acc = 0.0
training_history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
start_time = time.time()

for epoch in range(EPOCHS):
    print(f"\\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 30)
    
    # Training
    train_loss, train_acc = train_one_epoch(model, train_loader, epoch)
    
    # Validation
    val_loss, val_acc = validate(model, val_loader)
    
    # Save metrics
    training_history['train_loss'].append(train_loss)
    training_history['train_acc'].append(train_acc)
    training_history['val_loss'].append(val_loss)
    training_history['val_acc'].append(val_acc)
    
    # Print results
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
            'training_history': training_history
        }, 'best_head_only.pth')
        print(f"💾 Saved best model! (Val Acc: {best_acc:.4f})")
    
    print()

elapsed_time = time.time() - start_time
print("="*60)
print("🏁 TRAINING COMPLETED")
print(f"⏱️ Total time: {elapsed_time:.1f}s ({elapsed_time/60:.1f}m)")
print(f"🎯 Best validation accuracy: {best_acc:.4f}")
print("="*60)

In [None]:
# Results visualization and final evaluation
import matplotlib.pyplot as plt

def plot_training_history(history):
    """Plot training and validation curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    ax1.plot(history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    ax1.plot(history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title('Model Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracy
    ax2.plot(history['train_acc'], 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title('Model Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot training curves
print("📊 TRAINING RESULTS VISUALIZATION")
print("="*60)
plot_training_history(training_history)

# Load best model for final evaluation
print("\\n🔍 FINAL MODEL EVALUATION")
print("="*60)
checkpoint = torch.load('best_head_only.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Loaded best model from epoch {checkpoint['epoch']+1}")

# Final validation
final_val_loss, final_val_acc = validate(model, val_loader)
print(f"🎯 Final Validation Accuracy: {final_val_acc:.4f}")
print(f"📉 Final Validation Loss: {final_val_loss:.4f}")

# Class-wise performance analysis
@torch.no_grad()
def detailed_evaluation(model, loader, class_names):
    """Detailed per-class evaluation"""
    model.eval()
    all_preds, all_targets = [], []
    
    for imgs, labels in tqdm(loader, desc="Detailed evaluation"):
        imgs = imgs.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(1)
        
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(labels.numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    # Class-wise accuracy
    print("\\n📋 CLASS-WISE PERFORMANCE:")
    print("-" * 50)
    for i, class_name in enumerate(class_names):
        class_mask = (all_targets == i)
        if class_mask.sum() > 0:
            class_acc = (all_preds[class_mask] == all_targets[class_mask]).mean()
            print(f"{class_name:<30} | Acc: {class_acc:.4f} | Samples: {class_mask.sum()}")
    
    overall_acc = (all_preds == all_targets).mean()
    print("-" * 50)
    print(f"{'OVERALL':<30} | Acc: {overall_acc:.4f} | Samples: {len(all_targets)}")
    
    return all_preds, all_targets

# Perform detailed evaluation
pred_labels, true_labels = detailed_evaluation(model, val_loader, class_names)

print("\\n" + "="*60)
print("🎉 SUBSET TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"📁 Subset size: {samples_per_class} samples per class")
print(f"🔥 Best accuracy: {best_acc:.4f}")
print(f"⏱️ Total training time: {elapsed_time:.1f}s")
print(f"💾 Best model saved as: best_head_only.pth")
print("\\n✨ Ready for full dataset training or further experiments!")