# Quick subset + train pipeline (Capstone-Lazarus)

This notebook inspects `data/`, creates a small balanced subset, and trains a head-only transfer-learning model (timm EfficientNet-B0). Designed for fast experimentation on a laptop.

## Features:
- 🎯 Balanced stratified subset creation
- ⚡ Fast training with frozen backbone
- 🚀 AMP (Automatic Mixed Precision) support
- 💾 Automatic checkpointing
- 📊 Real-time metrics tracking

**Prerequisites:** Run this notebook from the repository root where `data/` directory exists.

In [None]:
# Install required libs (run once in notebook)
!pip install --upgrade pip
!pip install torch torchvision timm albumentations pillow scikit-learn tqdm

In [None]:
# Check CUDA availability and set device
import torch
import sys
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("CUDA not available, using CPU")
    
print(f"PyTorch version: {torch.__version__}")
print(f"Python version: {sys.version}")

In [None]:
# Inspect original data structure
from pathlib import Path

DATA_DIR = Path('data')
assert DATA_DIR.exists(), "Run this notebook from repo root where `data/` exists."

print("🔍 Inspecting original dataset structure:")
print("=" * 50)

class_counts = {}
for p in sorted([d for d in DATA_DIR.iterdir() if d.is_dir()]):
    image_count = len([f for f in p.iterdir() if f.suffix.lower() in ('.jpg','.jpeg','.png','.bmp','.tif','.tiff','.webp')])
    class_counts[p.name] = image_count
    print(f"{p.name:<40} {image_count:>6} images")

total_images = sum(class_counts.values())
num_classes = len(class_counts)
print("=" * 50)
print(f"Total classes: {num_classes}")
print(f"Total images: {total_images}")
print(f"Average per class: {total_images/num_classes:.1f}")

In [None]:
# Create balanced subset using our script
# Parameters: adjust samples_per_class small for quick runs
SAMPLES_PER_CLASS = 50   # try 30-100 for quick experiments
VAL_RATIO = 0.2

print(f"🎯 Creating balanced subset:")
print(f"   Samples per class: {SAMPLES_PER_CLASS}")
print(f"   Validation ratio: {VAL_RATIO}")
print(f"   Expected total: ~{SAMPLES_PER_CLASS * num_classes} images")
print("-" * 50)

!python scripts/create_subset.py --data-dir data --out-dir data_subset --samples-per-class {SAMPLES_PER_CLASS} --val-ratio {VAL_RATIO} --seed 42 --symlink true

In [None]:
# Create data loaders with Albumentations transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import torch

class ImageFolderAlb(Dataset):
    """Custom Dataset using Albumentations for transforms"""
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.samples = []
        exts = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff','.webp'}
        classes = sorted([p for p in self.root.iterdir() if p.is_dir()])
        self.class_to_idx = {d.name: i for i, d in enumerate(classes)}
        self.classes = [cls.name for cls in classes]
        
        for cls in classes:
            for img in cls.iterdir():
                if img.suffix.lower() in exts:
                    self.samples.append((img, cls.name))
        self.transform = transform
        print(f"Found {len(self.samples)} images in {len(classes)} classes")

    def __len__(self): 
        return len(self.samples)
    
    def __getitem__(self, idx):
        p, cls = self.samples[idx]
        img = np.array(Image.open(p).convert('RGB'))
        if self.transform:
            img = self.transform(image=img)['image']
        label = self.class_to_idx[cls]
        return img, label

def get_transforms(img_size=224, split='train'):
    """Get Albumentations transforms for train/val"""
    if split == 'train':
        return A.Compose([
            A.Resize(img_size, img_size),
            A.HorizontalFlip(p=0.5),
            A.RandomResizedCrop(img_size, img_size, scale=(0.7, 1.0), p=0.6),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:  # validation
        return A.Compose([
            A.Resize(img_size, img_size), 
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
            ToTensorV2()
        ])

# Configuration - adjust for your hardware
IMG_SIZE = 160      # Small size for laptop-friendly training
BATCH_SIZE = 16     # Adjust based on GPU memory
NUM_WORKERS = 4     # Adjust based on CPU cores

print(f"🔄 Setting up data loaders:")
print(f"   Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Num workers: {NUM_WORKERS}")

# Create datasets and loaders
train_ds = ImageFolderAlb('data_subset/train', transform=get_transforms(IMG_SIZE, 'train'))
val_ds = ImageFolderAlb('data_subset/val', transform=get_transforms(IMG_SIZE, 'val'))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                       num_workers=NUM_WORKERS, pin_memory=True)

print(f"\n✅ Data loaders ready:")
print(f"   Train: {len(train_ds)} images, {len(train_loader)} batches")
print(f"   Val: {len(val_ds)} images, {len(val_loader)} batches")
print(f"   Classes: {len(train_ds.classes)}")
print(f"   Class names: {train_ds.classes[:5]}..." if len(train_ds.classes) > 5 else f"   Class names: {train_ds.classes}")

In [None]:
# Initialize model with frozen backbone (transfer learning)
import timm
import torch.nn as nn

# Model configuration
MODEL_NAME = 'tf_efficientnet_b0'  # Efficient and fast
num_classes = len(train_ds.class_to_idx)

print(f"🏗️ Initializing model:")
print(f"   Architecture: {MODEL_NAME}")
print(f"   Number of classes: {num_classes}")
print(f"   Device: {device}")

# Create pre-trained model
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)

# Freeze backbone parameters (transfer learning)
print("\n🧊 Freezing backbone parameters...")
for param in model.parameters():
    param.requires_grad = False

# Reset classifier head and ensure it's trainable
print("🎯 Setting up classifier head...")
try:
    model.reset_classifier(num_classes)
except Exception:
    # Fallback for different model architectures
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, num_classes)
    elif hasattr(model, 'fc'):
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    else:
        raise RuntimeError("Could not find classifier layer")

# Ensure head parameters are trainable
head_params = 0
total_params = 0
for name, param in model.named_parameters():
    total_params += param.numel()
    if any(x in name.lower() for x in ['classifier', 'fc', 'head', 'ln']):
        param.requires_grad = True
        head_params += param.numel()

model = model.to(device)

print(f"✅ Model ready:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {head_params:,}")
print(f"   Frozen parameters: {total_params - head_params:,}")
print(f"   Training only: {(head_params/total_params)*100:.1f}% of parameters")

# Show model summary
print(f"\n📋 Model architecture:")
print(model)