# ðŸ§¬ PyTorch Skin Condition Classification (ISIC 2019) - Code Explanation Notebook

This notebook provides a **detailed explanation** of the PyTorch-based skin condition classification pipeline using the **ISIC 2019 dataset**. Instead of training a model, we'll focus on understanding *what each component does*, *why architectural and optimization choices were made*, and *how the code implements best practices* for medical image classification.

## ðŸ“‹ Table of Contents

1. [Architecture Overview](#1-architecture-overview)
2. [Dataset & Data Loading](#2-dataset--data-loading)
3. [Model Architecture](#3-model-architecture)
4. [Training Strategy](#4-training-strategy)
5. [Optimization Techniques](#5-optimization-techniques)
6. [Evaluation & Visualization](#6-evaluation--visualization)


---

## 1. Architecture Overview

Let's first examine the overall structure of the PyTorch implementation:

```python
# Key Components:
# 1. SkinConditionDataset - Custom PyTorch Dataset for medical images
# 2. EfficientNetB0PyTorch - Modified EfficientNet-B0 with custom head
# 3. SkinConditionTrainer - Main training orchestration class
# 4. Specialized components: FocalLoss, AddGaussianNoise, Gradual Unfreezing
```

### Why PyTorch for Medical Imaging?

- **Dynamic Computation Graphs**: Flexible for experimentation
- **Fine-grained Control**: Better for custom architectures and training loops
- **Research-Friendly**: Widely used in academic medical imaging research
- **GPU Optimization**: Excellent CUDA integration for large datasets


---

## 2. Dataset & Data Loading

### 2.1 Custom Dataset Class
```python
from torch.utils.data import Dataset
import os
from tqdm import tqdm
class SkinConditionDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        
        # Pre-filter valid files during initialization
        self.valid_indices = []
        
        print("Verifying image files...")
        for idx in tqdm(range(len(self.df)), desc="Checking files"):
            image_path = self.df.iloc[idx]['filepath']
            if os.path.exists(image_path):
                self.valid_indices.append(idx)
            else:
                print(f"Warning: File not found - {image_path}")
```

**Key Design Choices:**

1. **Pre-filtering**: Validates all files during initialization to avoid runtime errors
2. **Memory Efficiency**: Stores only valid indices, not entire dataset in memory
3. **Error Handling**: Gracefully handles missing files and returns dummy tensors
4. **Pre-resizing**: Images resized to 224Ã—224 during loading to reduce transform overhead


### 2.2 Data Transforms

```python
import torch
class AddGaussianNoise:
    def __init__(self, mean=0., std=0.02):
        self.mean = mean
        self.std = std
    
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std
```
**Augmentation Strategy:**
```python
from torchvision import transforms
# Training transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomRotation(degrees=3),          # Mild rotation (Â±3Â°)
    transforms.RandomResizedCrop(224, scale=(0.92, 1.0)),  # Zoom variation
    transforms.ColorJitter(brightness=0.05, contrast=0.05, 
                         saturation=0.05, hue=0.05),  # Color variations
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet statistics
                       std=[0.229, 0.224, 0.225]),
    AddGaussianNoise(std=0.02)                     # Sensor noise simulation
])
```
**Why These Specific Augmentations?**

- **Mild Augmentations**: Medical images require preservation of diagnostic features
- **ImageNet Normalization**: Pretrained weights expect ImageNet statistics
- **Gaussian Noise**: Models real-world sensor noise and compression artifacts
- **Color Jitter**: Accounts for varying lighting conditions and skin tones



### 2.3 DataLoader Configuration
```python
num_workers = min(6, os.cpu_count())
persistent_workers = True
prefetch_factor = 3

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
    prefetch_factor=prefetch_factor,
    drop_last=True,
    pin_memory_device=str(device) if device.type == 'cuda' else ''
)
```
**Optimization Decisions:**

1. **`persistent_workers=True`**: Keeps worker processes alive between epochs
2. **`prefetch_factor=3`**: Preloads batches to reduce GPU idle time
3. **`pin_memory=True`**: Enables faster GPU data transfer
4. **`drop_last=True`**: Ensures consistent batch sizes for batch normalization


--- 
## 3. Model Architecture

### 3.1 Modified EfficientNet-B0

```python
class EfficientNetB0PyTorch(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        # Load pretrained EfficientNet-B0
        self.base_model = models.efficientnet_b0(pretrained=pretrained)
        
        # Initially freeze ALL layers
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Modify classifier head
        in_features = self.base_model.classifier[1].in_features
        self.base_model.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(256, num_classes)
        )
        
        # Always keep classifier trainable
        for param in self.base_model.classifier.parameters():
            param.requires_grad = True
```

**Architectural Choices:**

1. **Transfer Learning**: Uses pretrained EfficientNet-B0 (ImageNet weights)
2. **Initial Freezing**: All backbone layers frozen at start (prevents catastrophic forgetting)
3. **Enhanced Classifier Head**:
   - Two fully-connected layers (256 â†’ num_classes)
   - Batch normalization for stability
   - High dropout (0.5) for regularization
   - ReLU activation for non-linearity

### 3.2 Focal Loss for Class Imbalance

```python
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        ce = nn.functional.cross_entropy(
            logits, targets, weight=self.alpha, reduction='none'
        )
        pt = torch.exp(-ce)
        return ((1 - pt) ** self.gamma * ce).mean()
```

**Why Focal Loss?**

- **Class Imbalance**: ISIC dataset has highly imbalanced classes (NV â‰« DF, VASC)
- **Hard Example Focus**: Down-weights easy examples, focuses on hard misclassifications
- **Adaptive Weighting**: Î³=2.0 provides moderate focusing (common in medical tasks)


---
## 4. Training Strategy

### 4.1 Gradual Unfreezing Schedule

```python
class SkinConditionTrainer:
    def __init__(self, config):
        # Gradual unfreezing configuration
        self.unfreezing_schedule = [
            (0, 7, 1.0),    # Start: freeze most blocks (7/8)
            (20, 5, 0.5),   # Epoch 20: unfreeze some layers
            (40, 3, 0.2),   # Epoch 40: unfreeze more
            (60, 1, 0.1),   # Epoch 60: almost all unfrozen
            (80, 0, 0.05),  # Epoch 80: all layers trainable
        ]
```

**Gradual Unfreezing Logic:**

1. **Stage 1 (Epochs 0-19)**: Only classifier trainable
2. **Stage 2 (Epochs 20-39)**: Unfreeze last 3 blocks, lower learning rate
3. **Stage 3 (Epochs 40-59)**: Unfreeze more layers, further reduced LR
4. **Stage 4 (Epochs 60-79)**: Most layers unfrozen
5. **Stage 5 (Epochs 80+)**: All layers trainable at very low LR

**Why Gradual Unfreezing?**

- **Stable Training**: Prevents disruption of pretrained features
- **Progressive Specialization**: Allows network to adapt slowly to medical domain
- **Learning Rate Annealing**: Lower LRs for deeper layers (empirical best practice)

### 4.2 Optimizer Configuration

```python
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['lr'],           # 5e-5
    weight_decay=config['weight_decay']  # 1e-5
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',                # Monitor validation accuracy
    factor=0.5,                # Halve LR when plateau
    patience=5,                # Wait 5 epochs
    min_lr=1e-6,               # Minimum learning rate
    verbose=True
)
```

**Optimizer Selection Rationale:**

- **AdamW**: Better generalization than Adam (decoupled weight decay)
- **Learning Rate**: Small initial LR (5e-5) for fine-tuning
- **Weight Decay**: Moderate regularization (1e-5)
- **ReduceLROnPlateau**: Adaptive learning rate based on validation performance


---
## 5. Optimization Techniques

### 5.1 Mixed Precision Training

```python
self.scaler = GradScaler()  # For mixed precision

# In training loop:
with autocast():  # Mixed precision context
    outputs = model(inputs)
    loss = criterion(outputs, targets)

self.scaler.scale(loss).backward()
self.scaler.step(optimizer)
self.scaler.update()
```

**Benefits of Mixed Precision:**

- **Memory Efficiency**: FP16 uses half the memory of FP32
- **Speed**: Faster computation on modern GPUs (Tensor Cores)
- **Same Accuracy**: Gradient scaling maintains training stability

### 5.2 GPU Optimization Settings

```python
# Enable PyTorch optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    torch.set_float32_matmul_precision('high')  # For Ampere+ GPUs

# DataLoader optimizations
pin_memory=True
persistent_workers=True
prefetch_factor=3
```

**Performance Optimizations:**

1. **cuDNN Benchmarking**: Auto-tunes for optimal convolution algorithms
2. **Tensor Core Utilization**: High precision setting for Ampere GPUs
3. **Data Loading Pipeline**: Overlapping data loading with computation


---
## 6. Evaluation & Visualization

### 6.1 Comprehensive Metrics Tracking

```python
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'frozen_layers': [], 'lr_multiplier': []  # Tracks training dynamics
}
```

**Monitoring Strategy:**

- **Training Metrics**: Loss and accuracy per epoch
- **Validation Metrics**: Early stopping based on validation accuracy
- **Training Dynamics**: Frozen layers count and LR multiplier over time
- **GPU Utilization**: Memory usage monitoring

### 6.2 Visualization Functions

```python
def plot_results(self, history, y_true, y_pred, class_names):
    # 1. Training curves (accuracy & loss)
    # 2. Unfreezing progression
    # 3. Learning rate schedule
    # 4. Confusion matrix
    # 5. Classification report
```

**Visualization Components:**

1. **Training Curves**: Monitor overfitting and convergence
2. **Unfreezing Timeline**: Visualize gradual unfreezing strategy
3. **Confusion Matrix**: Per-class performance analysis
4. **Classification Report**: Precision, recall, F1-scores

### 6.3 Model Saving Strategy

```python
# Save checkpoints
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_acc': val_acc,
    'frozen_layers': current_freeze,
    'config': config
}, 'best_model_gradual.pth')

# Save training history
joblib.dump(history, 'pytorch_history_gradual.joblib')
```

**Checkpoint Contents:**

1. **Model Weights**: For inference or resuming training
2. **Optimizer State**: Maintains momentum and adaptive learning rates
3. **Training Metadata**: Epoch, validation accuracy, frozen layer count
4. **Configuration**: Reproducibility of training parameters


---
## ðŸŽ¯ Key Takeaways

### Architectural Decisions:
- **EfficientNet-B0**: Optimal accuracy/efficiency trade-off
- **Custom Head**: Enhanced capacity for medical domain adaptation
- **Gradual Unfreezing**: Stable transfer learning strategy

### Optimization Choices:
- **Mixed Precision**: Memory and speed benefits
- **Focal Loss**: Handles class imbalance effectively
- **AdamW Optimizer**: Better generalization than standard Adam

### Medical Imaging Considerations:
- **Mild Augmentations**: Preserve diagnostic features
- **Class Weights**: Address dataset imbalance
- **Careful Normalization**: Use ImageNet statistics for pretrained models

### PyTorch-Specific Best Practices:
- **Custom Dataset Class**: Efficient data loading with pre-filtering
- **DataLoader Optimizations**: `pin_memory`, `persistent_workers`, `prefetch_factor`
- **Modular Design**: Separated trainer, model, dataset for maintainability

This PyTorch implementation demonstrates a **production-ready, research-informed approach** to medical image classification, balancing performance, accuracy, and practical considerations for real-world deployment.