# Image Classification

A comprehensive guide to image classification using PyTorch, covering:
- Types of classification (single-label, multi-label, fine-grained, zero-shot)
- Training loop mechanics
- Loss functions and optimizers
- Full inference pipeline
- Evaluation metrics

**Dataset:** Car Corner Classification (12 classes - vehicle viewing angles)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, datasets, models
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from collections import Counter

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: mps


## 2. Types of Image Classification

### Overview

| Type | Description | Output | Example |
|------|-------------|--------|---------|
| **Single-label** | One class per image | Softmax ‚Üí argmax | Cat vs Dog |
| **Multi-label** | Multiple classes per image | Sigmoid ‚Üí threshold | Tags: sunny, beach, people |
| **Fine-grained** | Distinguish subtle differences within a category | Specialized architectures | Bird species, car models |
| **Zero-shot** | Classify without training on specific classes | CLIP-style embeddings | Classify using text descriptions |

In [2]:
# Example: Single-label vs Multi-label output layers

class SingleLabelClassifier(nn.Module):
    """Single-label: One class per image (mutually exclusive)"""
    def __init__(self, num_classes=12):
        super().__init__()
        self.backbone = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
        self.backbone.classifier[-1] = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        return self.backbone(x)  # Raw logits ‚Üí use CrossEntropyLoss


class MultiLabelClassifier(nn.Module):
    """Multi-label: Multiple classes per image (not mutually exclusive)"""
    def __init__(self, num_classes=12):
        super().__init__()
        self.backbone = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
        self.backbone.classifier[-1] = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        logits = self.backbone(x)
        return torch.sigmoid(logits)  # Independent probabilities ‚Üí use BCELoss


# Demonstration
print("Single-label output (softmax):")
dummy_logits = torch.randn(1, 12)
print(f"  Logits shape: {dummy_logits.shape}")
print(f"  After softmax: {torch.softmax(dummy_logits, dim=1).shape} (sums to 1)")
print(f"  Prediction: class {torch.argmax(dummy_logits, dim=1).item()}")

print("\nMulti-label output (sigmoid):")
print(f"  After sigmoid: {torch.sigmoid(dummy_logits).shape} (each independent 0-1)")
print(f"  Predictions (threshold=0.5): {(torch.sigmoid(dummy_logits) > 0.5).squeeze().tolist()}")

Single-label output (softmax):
  Logits shape: torch.Size([1, 12])
  After softmax: torch.Size([1, 12]) (sums to 1)
  Prediction: class 6

Multi-label output (sigmoid):
  After sigmoid: torch.Size([1, 12]) (each independent 0-1)
  Predictions (threshold=0.5): [True, True, True, True, False, False, True, False, True, True, True, True]


## 3. Dataset & DataLoaders

Using the **Car Corner Classification** dataset with 12 viewing angle classes.

In [7]:
# Dataset paths
DATA_ROOT = Path("../data/car-corner")
TRAIN_DIR = DATA_ROOT / "train"
VAL_DIR = DATA_ROOT / "val"
TEST_DIR = DATA_ROOT / "test"

# Image transforms
IMG_SIZE = 384

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets using ImageFolder
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transforms)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_transforms)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=val_transforms)

# Class names and mapping
class_names = train_dataset.classes
num_classes = len(class_names)
print(f"Number of classes: {num_classes}")
print(f"Classes: {class_names}")

# Dataset sizes
print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

Number of classes: 12
Classes: ['45_phai_sau', '45_phai_truoc', '45_trai_sau', '45_trai_truoc', 'phai_sau_toan_canh', 'phai_toan_canh', 'phai_truoc_toan_canh', 'sau_toan_canh', 'trai_sau_toan_canh', 'trai_toan_canh', 'trai_truoc_toan_canh', 'truoc_toan_canh']

Dataset sizes:
  Train: 600
  Val: 600
  Test: 468


In [None]:
# Analyze class distribution
def analyze_dataset_distribution(dataset, class_names, title="Dataset Distribution"):
    """Analyze and visualize class distribution in dataset."""
    labels = [label for _, label in dataset.samples]
    class_counts = Counter(labels)
    
    print(f"\n{title} - Total images: {len(dataset)}")
    print("-" * 50)
    
    for i, class_name in enumerate(class_names):
        count = class_counts.get(i, 0)
        percentage = (count / len(dataset)) * 100
        print(f"{class_name}: {count} images ({percentage:.1f}%)")
    
    # Visualize distribution
    counts = [class_counts.get(i, 0) for i in range(len(class_names))]
    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(range(len(class_names)), counts, color='skyblue')
    plt.xlabel('Classes')
    plt.ylabel('Number of Images')
    plt.title(f'{title} - Class Distribution')
    plt.xticks(range(len(class_names)), class_names, rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, count in zip(bars, counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_y() + count + 1, 
                str(count), ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    return class_counts

# Analyze distributions
print("CLASS DISTRIBUTION ANALYSIS")
print("=" * 60)

train_counts = analyze_dataset_distribution(train_dataset, class_names, "Training Set")
val_counts = analyze_dataset_distribution(val_dataset, class_names, "Validation Set")
test_counts = analyze_dataset_distribution(test_dataset, class_names, "Test Set")

# Summary statistics
total_train = len(train_dataset)
total_val = len(val_dataset)
total_test = len(test_dataset)

print("\nSUMMARY STATISTICS")
print("=" * 30)
print(f"Total Training Images:   {total_train}")
print(f"Total Validation Images: {total_val}")
print(f"Total Test Images:       {total_test}")
print(f"Grand Total:             {total_train + total_val + total_test}")
print(f"Train/Val/Test Split:    {total_train}:{total_val}:{total_test}")

# Check for class imbalance
max_count = max(train_counts.values())
min_count = min(train_counts.values())
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')

print("\nClass Imbalance Analysis:")
print(f"Max class count: {max_count}")
print(f"Min class count: {min_count}")
print(f"Imbalance ratio: {imbalance_ratio:.2f}")

if imbalance_ratio > 2.0:
    print("‚ö†Ô∏è  Dataset has significant class imbalance - consider using weighted loss or data augmentation")
else:
    print("‚úÖ Dataset is relatively balanced")

In [None]:
# DataLoaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Visualize sample batch
def show_batch(dataloader, class_names, n=8):
    images, labels = next(iter(dataloader))
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images_denorm = images[:n] * std + mean
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    for idx, ax in enumerate(axes.flat):
        if idx < n:
            img = images_denorm[idx].permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.set_title(class_names[labels[idx]], fontsize=10)
            ax.axis('off')
    plt.tight_layout()
    plt.show()

show_batch(train_loader, class_names)

## 4. Model Architecture

Using **MobileNetV3-Small** with transfer learning (pretrained on ImageNet).

In [None]:
def create_model(num_classes, pretrained=True, freeze_backbone=False):
    """
    Create MobileNetV3-Small model with custom classification head.
    
    Args:
        num_classes: Number of output classes
        pretrained: Use ImageNet pretrained weights
        freeze_backbone: Freeze feature extractor (for feature extraction mode)
    """
    if pretrained:
        model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
    else:
        model = models.mobilenet_v3_small(weights=None)
    
    # Freeze backbone if needed (feature extraction mode)
    if freeze_backbone:
        for param in model.features.parameters():
            param.requires_grad = False
    
    # Modify classification head
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)
    
    return model

# Create model
model = create_model(num_classes=num_classes, pretrained=True, freeze_backbone=False)
model = model.to(device)

# Model summary
print(f"Model: MobileNetV3-Small")
print(f"Input size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Output classes: {num_classes}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Loss Functions

### Key Loss Functions for Classification

| Loss Function | Use Case | Formula |
|--------------|----------|---------|
| **CrossEntropyLoss** | Single-label classification | $-\sum_i y_i \log(\hat{y}_i)$ |
| **BCELoss** | Multi-label (after sigmoid) | $-\frac{1}{N}\sum[y\log(\hat{y}) + (1-y)\log(1-\hat{y})]$ |
| **BCEWithLogitsLoss** | Multi-label (raw logits) | BCE + Sigmoid combined |
| **Focal Loss** | Imbalanced datasets | $-\alpha(1-\hat{y})^\gamma \log(\hat{y})$ |
| **Weighted CrossEntropy** | Class imbalance | CrossEntropy with class weights |

In [None]:
# =====================
# Loss Function Implementations
# =====================

# 1. Standard CrossEntropyLoss (Single-label)
criterion_ce = nn.CrossEntropyLoss()

# 2. BCEWithLogitsLoss (Multi-label)
criterion_bce = nn.BCEWithLogitsLoss()

# 3. Weighted CrossEntropyLoss (for imbalanced datasets)
def compute_class_weights(dataset):
    """Compute inverse frequency weights for each class."""
    labels = [label for _, label in dataset.samples]
    class_counts = Counter(labels)
    total = len(labels)
    weights = torch.tensor([total / class_counts[i] for i in range(len(class_counts))])
    return weights / weights.sum() * len(weights)  # Normalize

class_weights = compute_class_weights(train_dataset)
print(f"Class weights: {class_weights}")
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights.to(device))


# 4. Focal Loss (reduces impact of easy samples, good for imbalanced data)
class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance.
    
    FL(p_t) = -Œ±_t * (1 - p_t)^Œ≥ * log(p_t)
    
    Args:
        alpha: Weighting factor (default: 1)
        gamma: Focusing parameter (default: 2)
    """
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)  # probability of correct class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

criterion_focal = FocalLoss(alpha=1, gamma=2)

# Demonstration
print("\n--- Loss Function Demo ---")
dummy_logits = torch.randn(4, num_classes)
dummy_targets = torch.tensor([0, 1, 2, 3])

print(f"CrossEntropy Loss: {criterion_ce(dummy_logits, dummy_targets):.4f}")
print(f"Focal Loss (Œ≥=2): {criterion_focal(dummy_logits, dummy_targets):.4f}")

## 6. Optimizers

### Optimizer Comparison

| Optimizer | Characteristics | Best For |
|-----------|----------------|----------|
| **SGD** | Simple, strong generalization, needs LR scheduling | CNNs, final fine-tuning |
| **Adam** | Adaptive LR, fast convergence | Quick experiments, RNNs |
| **AdamW** | Adam + decoupled weight decay | Transformers, ViTs |

In [None]:
# =====================
# Optimizer Configurations
# =====================

# Learning rate
LR = 1e-3

# 1. SGD with momentum
optimizer_sgd = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)

# 2. Adam
optimizer_adam = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=1e-4)

# 3. AdamW (preferred for Transformers)
optimizer_adamw = optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01)

# Learning Rate Schedulers
# StepLR: Decay LR by gamma every step_size epochs
scheduler_step = optim.lr_scheduler.StepLR(optimizer_adam, step_size=10, gamma=0.1)

# CosineAnnealing: Smooth decay following cosine curve
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer_adam, T_max=30, eta_min=1e-6)

# ReduceLROnPlateau: Reduce LR when metric plateaus
scheduler_plateau = optim.lr_scheduler.ReduceLROnPlateau(optimizer_adam, mode='min', factor=0.5, patience=5)

# For this notebook, we'll use AdamW with CosineAnnealing
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)

print(f"Optimizer: AdamW (lr={LR}, weight_decay=0.01)")
print(f"Scheduler: CosineAnnealingLR (T_max=30)")

## 7. Training Loop

The core training loop consists of:
1. **Forward pass** - Input ‚Üí Model ‚Üí Predictions
2. **Compute loss** - Compare predictions to targets
3. **Backpropagation** - Compute gradients
4. **Update weights** - Optimizer step

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Train model for one epoch.
    
    Training Loop Steps:
    1. Forward pass: predictions = model(inputs)
    2. Compute loss: loss = criterion(predictions, targets)
    3. Backpropagation: loss.backward()
    4. Update weights: optimizer.step()
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, labels in pbar:
        # Move to device
        images, labels = images.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # 1. Forward pass
        outputs = model(images)
        
        # 2. Compute loss
        loss = criterion(outputs, labels)
        
        # 3. Backpropagation
        loss.backward()
        
        # 4. Update weights
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': loss.item(), 'acc': 100.*correct/total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


@torch.no_grad()
def validate(model, dataloader, criterion, device):
    """Validate model on validation set."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(dataloader, desc="Validating"):
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels)

In [None]:
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=10, save_path="best_model.pth"):
    """
    Full training loop with validation and checkpointing.
    """
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)
        
        # Training
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
        
        # Learning rate scheduler step
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, save_path)
            print(f"‚úì Saved best model (val_acc: {val_acc:.2f}%)")
    
    return history

# Use CrossEntropyLoss for single-label classification
criterion = nn.CrossEntropyLoss()

print("Training configuration:")
print(f"  Criterion: CrossEntropyLoss")
print(f"  Optimizer: AdamW")
print(f"  Scheduler: CosineAnnealingLR")

In [None]:
# Train the model (set num_epochs as needed)
NUM_EPOCHS = 5  # Increase for better results

history = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=NUM_EPOCHS,
    save_path="car_corner_best.pth"
)

In [None]:
# Plot training history
def plot_history(history):
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    # Learning Rate
    axes[2].plot(history['lr'])
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()

plot_history(history)

## 8. Full Inference Pipeline

Complete end-to-end pipeline:
1. **Input image** ‚Üí Load from file
2. **Preprocessing** ‚Üí Resize, normalize, convert to tensor
3. **Feature extraction** ‚Üí Forward through backbone
4. **Classification head** ‚Üí Final linear layer
5. **Prediction** ‚Üí Softmax ‚Üí argmax

In [None]:
class ImageClassificationPipeline:
    """
    Complete inference pipeline for image classification.
    
    Pipeline stages:
    1. Input image (file path or PIL Image)
    2. Preprocessing (resize, normalize)
    3. Feature extraction (backbone forward pass)
    4. Classification head (final layer)
    5. Prediction (softmax + argmax)
    """
    
    def __init__(self, model, class_names, device, img_size=384):
        self.model = model
        self.class_names = class_names
        self.device = device
        self.img_size = img_size
        
        # Preprocessing transform
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.model.eval()
    
    def preprocess(self, image):
        """Stage 2: Preprocess image."""
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        return self.transform(image).unsqueeze(0)  # Add batch dimension
    
    @torch.no_grad()
    def predict(self, image, top_k=5):
        """
        Run full inference pipeline.
        
        Returns:
            dict with predicted class, confidence, and top-k predictions
        """
        # 1. Input image
        if isinstance(image, str):
            original_image = Image.open(image).convert('RGB')
        else:
            original_image = image
        
        # 2. Preprocessing
        input_tensor = self.preprocess(original_image).to(self.device)
        
        # 3 & 4. Feature extraction + Classification head (forward pass)
        logits = self.model(input_tensor)
        
        # 5. Prediction (softmax + top-k)
        probabilities = torch.softmax(logits, dim=1)
        top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
        
        # Get results
        predicted_class = self.class_names[top_indices[0, 0].item()]
        confidence = top_probs[0, 0].item()
        
        top_k_results = [
            (self.class_names[idx.item()], prob.item())
            for idx, prob in zip(top_indices[0], top_probs[0])
        ]
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'top_k': top_k_results,
            'original_image': original_image
        }
    
    def visualize_prediction(self, result):
        """Visualize prediction with top-k bar chart."""
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Show image
        axes[0].imshow(result['original_image'])
        axes[0].set_title(f"Predicted: {result['predicted_class']}\nConfidence: {result['confidence']:.2%}")
        axes[0].axis('off')
        
        # Top-k bar chart
        classes = [x[0] for x in result['top_k']]
        probs = [x[1] for x in result['top_k']]
        colors = ['green' if i == 0 else 'steelblue' for i in range(len(classes))]
        
        axes[1].barh(classes[::-1], probs[::-1], color=colors[::-1])
        axes[1].set_xlabel('Probability')
        axes[1].set_title('Top-5 Predictions')
        axes[1].set_xlim(0, 1)
        
        for i, (cls, prob) in enumerate(zip(classes[::-1], probs[::-1])):
            axes[1].text(prob + 0.02, i, f'{prob:.2%}', va='center')
        
        plt.tight_layout()
        plt.show()


# Load best model and create pipeline
checkpoint = torch.load("car_corner_best.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model (val_acc: {checkpoint['val_acc']:.2f}%)")

pipeline = ImageClassificationPipeline(model, class_names, device)

In [None]:
# Test pipeline on a sample image
sample_images = list(TEST_DIR.glob("*/*.jpg"))[:3]

for img_path in sample_images:
    result = pipeline.predict(str(img_path))
    print(f"\nImage: {img_path.name}")
    print(f"Predicted: {result['predicted_class']} ({result['confidence']:.2%})")
    pipeline.visualize_prediction(result)

## 9. Evaluation Metrics

### Classification Metrics Summary

| Metric | Description | Formula |
|--------|-------------|---------|
| **Accuracy** | Overall correctness | $\frac{TP + TN}{Total}$ |
| **Precision** | Of predicted positives, how many are correct | $\frac{TP}{TP + FP}$ |
| **Recall** | Of actual positives, how many were found | $\frac{TP}{TP + FN}$ |
| **F1 Score** | Harmonic mean of precision and recall | $2 \cdot \frac{P \cdot R}{P + R}$ |

### Multi-class Averaging

| Type | Description |
|------|-------------|
| **Macro** | Average metrics across all classes (equal weight) |
| **Weighted** | Average weighted by class support |
| **Top-5 Accuracy** | Correct if true label in top 5 predictions |

In [None]:
def compute_top_k_accuracy(outputs, targets, k=5):
    """
    Compute top-k accuracy.
    
    Correct if the true label is among the top k predictions.
    """
    _, top_k_preds = outputs.topk(k, dim=1)
    correct = top_k_preds.eq(targets.view(-1, 1).expand_as(top_k_preds))
    return correct.any(dim=1).float().mean().item() * 100


@torch.no_grad()
def evaluate_model(model, dataloader, class_names, device):
    """
    Comprehensive model evaluation with all metrics.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_outputs = []
    
    for images, labels in tqdm(dataloader, desc="Evaluating"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        
        _, predicted = outputs.max(1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_outputs.append(outputs.cpu())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_outputs = torch.cat(all_outputs, dim=0)
    
    # Basic metrics
    accuracy = accuracy_score(all_labels, all_preds) * 100
    
    # Per-class and averaged metrics
    precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0) * 100
    precision_weighted = precision_score(all_labels, all_preds, average='weighted', zero_division=0) * 100
    
    recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0) * 100
    recall_weighted = recall_score(all_labels, all_preds, average='weighted', zero_division=0) * 100
    
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) * 100
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) * 100
    
    # Top-k accuracy
    top5_acc = compute_top_k_accuracy(all_outputs, torch.tensor(all_labels), k=5)
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Classification report
    report = classification_report(all_labels, all_preds, target_names=class_names, zero_division=0)
    
    results = {
        'accuracy': accuracy,
        'top5_accuracy': top5_acc,
        'precision_macro': precision_macro,
        'precision_weighted': precision_weighted,
        'recall_macro': recall_macro,
        'recall_weighted': recall_weighted,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'confusion_matrix': cm,
        'classification_report': report,
        'predictions': all_preds,
        'labels': all_labels
    }
    
    return results


def print_metrics(results):
    """Print evaluation metrics in a formatted way."""
    print("=" * 50)
    print("EVALUATION METRICS")
    print("=" * 50)
    
    print(f"\nüìä Overall Metrics:")
    print(f"   Accuracy:      {results['accuracy']:.2f}%")
    print(f"   Top-5 Accuracy: {results['top5_accuracy']:.2f}%")
    
    print(f"\nüìà Precision:")
    print(f"   Macro:    {results['precision_macro']:.2f}%")
    print(f"   Weighted: {results['precision_weighted']:.2f}%")
    
    print(f"\nüìâ Recall:")
    print(f"   Macro:    {results['recall_macro']:.2f}%")
    print(f"   Weighted: {results['recall_weighted']:.2f}%")
    
    print(f"\nüéØ F1 Score:")
    print(f"   Macro:    {results['f1_macro']:.2f}%")
    print(f"   Weighted: {results['f1_weighted']:.2f}%")
    
    print(f"\nüìã Per-Class Report:")
    print(results['classification_report'])

In [None]:
# Evaluate on test set
results = evaluate_model(model, test_loader, class_names, device)
print_metrics(results)

In [None]:
def plot_confusion_matrix(cm, class_names, figsize=(12, 10)):
    """Plot confusion matrix as heatmap."""
    plt.figure(figsize=figsize)
    
    # Normalize confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(
        cm_normalized, 
        annot=True, 
        fmt='.2f', 
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        square=True
    )
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Normalized Confusion Matrix')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

plot_confusion_matrix(results['confusion_matrix'], class_names)

## Summary

### Key Takeaways

| Topic | Key Points |
|-------|------------|
| **Classification Types** | Single-label (softmax), Multi-label (sigmoid), Fine-grained, Zero-shot |
| **Training Loop** | Forward ‚Üí Loss ‚Üí Backward ‚Üí Update |
| **Loss Functions** | CrossEntropy (single), BCE (multi), Focal (imbalanced) |
| **Optimizers** | SGD (generalization), Adam (fast), AdamW (Transformers) |
| **Pipeline** | Input ‚Üí Preprocess ‚Üí Feature Extract ‚Üí Classify ‚Üí Predict |
| **Metrics** | Accuracy, Precision, Recall, F1, Confusion Matrix, Top-k |

### Next Steps
- Try different backbones (ResNet, EfficientNet, ViT)
- Experiment with data augmentation (Albumentations)
- Implement learning rate finder
- Add mixed precision training (torch.cuda.amp)
- Export model for deployment (ONNX, TorchScript)