# Advanced Training Pipeline

This notebook implements:
1. Stronger and more varied augmentation, including class-specific oversampling.
2. Model-level adjustments: gradual unfreezing, EfficientNet-B0/B3, label smoothing, focal loss/class-weighted loss.
3. Training strategies: early stopping, checkpoint ensembles, and k-fold cross-validation.


In [1]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
# Configuration
DATA_DIR = "data"  # Update this path
IMG_SIZE = 224
BATCH_SIZE = 8
NUM_CLASSES = 7
LR = 1e-4
NUM_EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
KFOLDS = 5

# Seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


In [3]:
class StoolDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_to_idx[class_name] = idx
                for fname in os.listdir(class_path):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(class_path, fname), idx))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [4]:
# Stronger and more varied augmentations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),  # random crop + resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
    transforms.RandomApply([transforms.Lambda(lambda img: img.filter(ImageFilter.FIND_EDGES))], p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
# Label Smoothing Loss (CrossEntropy with label_smoothing)
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [6]:
def create_model(backbone='efficientnet_b0', num_classes=NUM_CLASSES, freeze_until_layer=None):
    # Load pretrained EfficientNet
    if backbone == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=True)
    elif backbone == 'efficientnet_b3':
        model = models.efficientnet_b3(pretrained=True)
    else:
        raise ValueError('Invalid backbone')
    
    # Replace classifier head
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.Linear(512, num_classes)
    )
    
    # Freeze layers if specified
    if freeze_until_layer:
        for name, param in model.named_parameters():
            param.requires_grad = False
            if name.startswith(freeze_until_layer):
                break
        # Unfreeze subsequent layers
        unfreeze = False
        for name, param in model.named_parameters():
            if unfreeze:
                param.requires_grad = True
            if name.startswith(freeze_until_layer):
                unfreeze = True
    
    return model.to(DEVICE)


In [9]:
def evaluate_model(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return None, None, all_preds, all_labels


In [10]:
def train_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, fold_idx):
    best_acc = 0.0
    best_weights = None
    patience = 3
    counter = 0
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += (preds == labels).sum().item()
                val_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss_epoch = val_loss / val_total
        val_acc_epoch = val_corrects / val_total
        lr_scheduler.step(val_acc_epoch)

        print(f"Fold {fold_idx}, Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} - "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_acc_epoch:.4f}")

        # Early stopping
        if val_acc_epoch > best_acc:
            best_acc = val_acc_epoch
            best_weights = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best weights
    model.load_state_dict(best_weights)

    # Final validation metrics
    _, _, preds, labels = evaluate_model(model, val_loader)
    print("\nClassification Report for Fold {}:".format(fold_idx))
    print(classification_report(labels, preds, target_names=sorted(os.listdir(DATA_DIR))))

    return model, best_acc


In [11]:
# Prepare full dataset indices for k-fold
full_dataset = StoolDataset(DATA_DIR, transform=None)
indices = list(range(len(full_dataset)))

# Calculate class weights for full dataset
all_labels_full = [label for _, label in full_dataset]
class_counts = np.bincount(all_labels_full)
class_weights = 1.0 / class_counts
weights_full = [class_weights[label] for label in all_labels_full]

kf = KFold(n_splits=KFOLDS, shuffle=True, random_state=SEED)
fold_models = []
fold_accuracies = []

for fold_idx, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
    print(f"======= Fold {fold_idx} =======")
    # Subset transforms
    train_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=train_transforms), train_idx)
    val_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=val_transforms), val_idx)

    # Create weighted sampler for train_ds
    train_labels_fold = [train_ds.dataset.samples[i][1] for i in train_idx]
    class_sample_count_fold = np.array([train_labels_fold.count(i) for i in range(NUM_CLASSES)])
    class_weights_fold = 1.0 / class_sample_count_fold
    sample_weights_fold = np.array([class_weights_fold[label] for label in train_labels_fold])
    sample_weights_fold = torch.from_numpy(sample_weights_fold.astype(np.double))
    sampler_fold = WeightedRandomSampler(sample_weights_fold, num_samples=len(sample_weights_fold), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler_fold)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Create model and freeze initial layers
    model = create_model(backbone='efficientnet_b0')
    # Optionally freeze until a certain layer name, e.g., 'features.4'
    # model = create_model(backbone='efficientnet_b0', freeze_until_layer='features.4')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
    # Choose loss: label smoothing or focal
    # criterion = criterion_smooth
    criterion = FocalLoss(alpha=1, gamma=2)

    # Train and validate
    best_model, best_acc = train_validate(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, fold_idx)
    fold_models.append(best_model)
    fold_accuracies.append(best_acc)

# Summary of fold accuracies
print("\nFold Accuracies:", fold_accuracies)
print("Mean Accuracy:", np.mean(fold_accuracies))


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /Users/sebastianapelgren/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 32.9MB/s]


Fold 1, Epoch 1/20 - Train Loss: 1.4219, Train Acc: 0.2153 - Val Loss: 1.3573, Val Acc: 0.2703
Fold 1, Epoch 2/20 - Train Loss: 1.3040, Train Acc: 0.4236 - Val Loss: 1.2128, Val Acc: 0.5405
Fold 1, Epoch 3/20 - Train Loss: 1.2324, Train Acc: 0.3819 - Val Loss: 0.9646, Val Acc: 0.6216
Fold 1, Epoch 4/20 - Train Loss: 1.0800, Train Acc: 0.5069 - Val Loss: 0.7871, Val Acc: 0.7568
Fold 1, Epoch 5/20 - Train Loss: 0.9351, Train Acc: 0.5903 - Val Loss: 0.6440, Val Acc: 0.7027
Fold 1, Epoch 6/20 - Train Loss: 0.7867, Train Acc: 0.6042 - Val Loss: 0.4696, Val Acc: 0.7297
Fold 1, Epoch 7/20 - Train Loss: 0.6370, Train Acc: 0.6528 - Val Loss: 0.3705, Val Acc: 0.7568
Early stopping at epoch 7

Classification Report for Fold 1:
              precision    recall  f1-score   support

      type-1       0.33      1.00      0.50         1
      type-2       1.00      0.40      0.57         5
      type-3       0.00      0.00      0.00         2
      type-4       0.80      1.00      0.89         8
   



Fold 2, Epoch 1/20 - Train Loss: 1.3986, Train Acc: 0.2276 - Val Loss: 1.3529, Val Acc: 0.2500
Fold 2, Epoch 2/20 - Train Loss: 1.3113, Train Acc: 0.3172 - Val Loss: 1.2609, Val Acc: 0.3611
Fold 2, Epoch 3/20 - Train Loss: 1.1733, Train Acc: 0.4897 - Val Loss: 1.1020, Val Acc: 0.4444
Fold 2, Epoch 4/20 - Train Loss: 1.0684, Train Acc: 0.5103 - Val Loss: 0.8978, Val Acc: 0.4444
Fold 2, Epoch 5/20 - Train Loss: 0.8869, Train Acc: 0.6000 - Val Loss: 0.7227, Val Acc: 0.5556
Fold 2, Epoch 6/20 - Train Loss: 0.7931, Train Acc: 0.5379 - Val Loss: 0.5694, Val Acc: 0.6389
Fold 2, Epoch 7/20 - Train Loss: 0.6368, Train Acc: 0.6759 - Val Loss: 0.5024, Val Acc: 0.6944
Fold 2, Epoch 8/20 - Train Loss: 0.5630, Train Acc: 0.6759 - Val Loss: 0.4310, Val Acc: 0.6667
Fold 2, Epoch 9/20 - Train Loss: 0.5117, Train Acc: 0.6759 - Val Loss: 0.4832, Val Acc: 0.6389
Fold 2, Epoch 10/20 - Train Loss: 0.4473, Train Acc: 0.7586 - Val Loss: 0.4938, Val Acc: 0.6111
Early stopping at epoch 10

Classification Report



Fold 3, Epoch 1/20 - Train Loss: 1.4228, Train Acc: 0.1586 - Val Loss: 1.3651, Val Acc: 0.3889
Fold 3, Epoch 2/20 - Train Loss: 1.3085, Train Acc: 0.3379 - Val Loss: 1.2560, Val Acc: 0.3889
Fold 3, Epoch 3/20 - Train Loss: 1.2477, Train Acc: 0.4138 - Val Loss: 1.1360, Val Acc: 0.6389
Fold 3, Epoch 4/20 - Train Loss: 1.1176, Train Acc: 0.4966 - Val Loss: 0.9524, Val Acc: 0.6389
Fold 3, Epoch 5/20 - Train Loss: 0.9071, Train Acc: 0.5862 - Val Loss: 0.7884, Val Acc: 0.5833
Fold 3, Epoch 6/20 - Train Loss: 0.7782, Train Acc: 0.5655 - Val Loss: 0.6192, Val Acc: 0.7778
Fold 3, Epoch 7/20 - Train Loss: 0.6410, Train Acc: 0.6828 - Val Loss: 0.6300, Val Acc: 0.6389
Fold 3, Epoch 8/20 - Train Loss: 0.4789, Train Acc: 0.7310 - Val Loss: 0.5351, Val Acc: 0.6389
Fold 3, Epoch 9/20 - Train Loss: 0.4803, Train Acc: 0.7172 - Val Loss: 0.5335, Val Acc: 0.7500
Early stopping at epoch 9

Classification Report for Fold 3:
              precision    recall  f1-score   support

      type-1       0.33      



Fold 4, Epoch 1/20 - Train Loss: 1.4235, Train Acc: 0.1793 - Val Loss: 1.3611, Val Acc: 0.3056
Fold 4, Epoch 2/20 - Train Loss: 1.3388, Train Acc: 0.3931 - Val Loss: 1.2873, Val Acc: 0.5556
Fold 4, Epoch 3/20 - Train Loss: 1.2244, Train Acc: 0.5310 - Val Loss: 1.1810, Val Acc: 0.6667
Fold 4, Epoch 4/20 - Train Loss: 1.0826, Train Acc: 0.5586 - Val Loss: 1.0372, Val Acc: 0.6667
Fold 4, Epoch 5/20 - Train Loss: 0.9187, Train Acc: 0.5517 - Val Loss: 0.9071, Val Acc: 0.5278
Fold 4, Epoch 6/20 - Train Loss: 0.7750, Train Acc: 0.6138 - Val Loss: 0.7379, Val Acc: 0.6389
Early stopping at epoch 6

Classification Report for Fold 4:
              precision    recall  f1-score   support

      type-1       0.40      0.50      0.44         4
      type-2       0.20      0.50      0.29         2
      type-3       0.33      0.67      0.44         3
      type-4       1.00      0.40      0.57         5
      type-5       0.80      0.57      0.67         7
      type-6       0.75      1.00      0.86 



Fold 5, Epoch 1/20 - Train Loss: 1.3691, Train Acc: 0.2414 - Val Loss: 1.2152, Val Acc: 0.5556
Fold 5, Epoch 2/20 - Train Loss: 1.2807, Train Acc: 0.3655 - Val Loss: 1.0594, Val Acc: 0.6389
Fold 5, Epoch 3/20 - Train Loss: 1.1597, Train Acc: 0.4207 - Val Loss: 0.8640, Val Acc: 0.7222
Fold 5, Epoch 4/20 - Train Loss: 0.9830, Train Acc: 0.5793 - Val Loss: 0.7263, Val Acc: 0.7500
Fold 5, Epoch 5/20 - Train Loss: 0.8691, Train Acc: 0.5862 - Val Loss: 0.6467, Val Acc: 0.6944
Fold 5, Epoch 6/20 - Train Loss: 0.7272, Train Acc: 0.6138 - Val Loss: 0.5472, Val Acc: 0.7500
Fold 5, Epoch 7/20 - Train Loss: 0.5871, Train Acc: 0.7310 - Val Loss: 0.5147, Val Acc: 0.6667
Early stopping at epoch 7

Classification Report for Fold 5:
              precision    recall  f1-score   support

      type-1       0.25      1.00      0.40         1
      type-2       0.50      0.33      0.40         3
      type-3       0.20      1.00      0.33         1
      type-4       1.00      0.43      0.60         7
   

In [12]:
# Example: Ensemble inference on a test image
def ensemble_predict(models, image_path, transform):
    image = Image.open(image_path).convert('RGB')
    img_t = transform(image).unsqueeze(0).to(DEVICE)
    probs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            outputs = model(img_t)
            probs.append(F.softmax(outputs, dim=1).cpu().numpy())
    avg_probs = np.mean(np.vstack(probs), axis=0)
    pred_class = np.argmax(avg_probs)
    return sorted(os.listdir(DATA_DIR))[pred_class], np.max(avg_probs)

# Usage example (replace 'some_image.jpg')
# label, confidence = ensemble_predict(fold_models, 'some_image.jpg', val_transforms)
# print(f"Ensembled Prediction: {label}, Confidence: {confidence:.2f}")
