# Advanced Training Pipeline

This notebook implements:
1. Stronger and more varied augmentation, including class-specific oversampling.
2. Model-level adjustments: gradual unfreezing, EfficientNet-B0/B3, label smoothing, focal loss/class-weighted loss.
3. Training strategies: early stopping, checkpoint ensembles, and k-fold cross-validation.


In [25]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report

In [26]:
# Configuration
DATA_DIR = "../data"  # Update this path
IMG_SIZE = 224
BATCH_SIZE = 16 # Adjust based on your GPU memory
NUM_CLASSES = 7
LR = 1e-4
NUM_EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
KFOLDS = 5

# Freeze strategy: 'features.4', 'features.5', 'features.6', etc.
FREEZE_UNTIL_BLOCK = 'features.3'

# Loss function: 'focal' or 'smooth'
LOSS_TYPE = 'focal'

# Seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


In [27]:
class StoolDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_to_idx[class_name] = idx
                for fname in os.listdir(class_path):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(class_path, fname), idx))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [28]:
# Stronger and more varied augmentations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),  # random crop + resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
    transforms.RandomApply([transforms.Lambda(lambda img: img.filter(ImageFilter.FIND_EDGES))], p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [29]:
# Label Smoothing Loss (CrossEntropy with label_smoothing)
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [30]:
def create_model(backbone='efficientnet_b0', num_classes=NUM_CLASSES, freeze_until_layer=None):
    # Load pretrained EfficientNet
    if backbone == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=True)
    elif backbone == 'efficientnet_b3':
        model = models.efficientnet_b3(pretrained=True)
    else:
        raise ValueError('Invalid backbone')
    
    # Replace classifier head
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.Linear(512, num_classes)
    )
    
    # Freeze layers if specified
    if freeze_until_layer:
        for name, param in model.named_parameters():
            param.requires_grad = False
            if name.startswith(freeze_until_layer):
                break
        # Unfreeze subsequent layers
        unfreeze = False
        for name, param in model.named_parameters():
            if unfreeze:
                param.requires_grad = True
            if name.startswith(freeze_until_layer):
                unfreeze = True
    
    return model.to(DEVICE)


In [31]:
def evaluate_model(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return None, None, all_preds, all_labels


In [32]:
def train_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, fold_idx):
    best_acc = 0.0
    best_weights = None
    patience = 3
    counter = 0
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += (preds == labels).sum().item()
                val_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss_epoch = val_loss / val_total
        val_acc_epoch = val_corrects / val_total
        lr_scheduler.step(val_acc_epoch)

        print(f"Fold {fold_idx}, Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} - "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_acc_epoch:.4f}")

        # Early stopping
        if val_acc_epoch > best_acc:
            best_acc = val_acc_epoch
            best_weights = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best weights
    model.load_state_dict(best_weights)

    # Final validation metrics
    _, _, preds, labels = evaluate_model(model, val_loader)
    print("\nClassification Report for Fold {}:".format(fold_idx))
    print(classification_report(labels, preds, target_names=sorted(os.listdir(DATA_DIR))))

    return model, best_acc


In [36]:
!find ../data -name ".DS_Store" -type f -delete

In [37]:
# Prepare full dataset indices for k-fold
full_dataset = StoolDataset(DATA_DIR, transform=None)
indices = list(range(len(full_dataset)))

print(f"Total samples in full dataset: {len(full_dataset)}")
print(f"They are these classes: {sorted(os.listdir(DATA_DIR))}")

# Calculate class weights for full dataset
all_labels_full = [label for _, label in full_dataset]
class_counts = np.bincount(all_labels_full)
class_weights = 1.0 / class_counts


print('class_counts:', class_counts)
print('class_weights:', class_weights)


Total samples in full dataset: 1006
They are these classes: ['type-1', 'type-2', 'type-3', 'type-4', 'type-5', 'type-6', 'type-7']
class_counts: [122 126 326 304  28  25  75]
class_weights: [0.00819672 0.00793651 0.00306748 0.00328947 0.03571429 0.04
 0.01333333]


In [None]:
# Prepare full dataset indices for k-fold
full_dataset = StoolDataset(DATA_DIR, transform=None)
indices = list(range(len(full_dataset)))

# Calculate class weights for full dataset
all_labels_full = [label for _, label in full_dataset]
class_counts = np.bincount(all_labels_full)
class_weights = 1.0 / class_counts
weights_full = [class_weights[label] for label in all_labels_full]

kf = KFold(n_splits=KFOLDS, shuffle=True, random_state=SEED)
fold_models = []
fold_accuracies = []

for fold_idx, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
    print(f"======= Fold {fold_idx} =======")
    # Subset transforms
    train_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=train_transforms), train_idx)
    val_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=val_transforms), val_idx)

    # Create weighted sampler for train_ds
    train_labels_fold = [train_ds.dataset.samples[i][1] for i in train_idx]
    class_sample_count_fold = np.array([train_labels_fold.count(i) for i in range(NUM_CLASSES)])
    class_weights_fold = 1.0 / class_sample_count_fold
    sample_weights_fold = np.array([class_weights_fold[label] for label in train_labels_fold])
    sample_weights_fold = torch.from_numpy(sample_weights_fold.astype(np.double))
    sampler_fold = WeightedRandomSampler(sample_weights_fold, num_samples=len(sample_weights_fold), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler_fold)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Create model and freeze initial layers
    # model = create_model(backbone='efficientnet_b0')
    model = create_model(backbone='efficientnet_b3', freeze_until_layer=FREEZE_UNTIL_BLOCK)
    # Optionally freeze until a certain layer name, e.g., 'features.4'
    # model = create_model(backbone='efficientnet_b0', freeze_until_layer='features.4')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
    # Choose loss: label smoothing or focal
    # criterion = criterion_smooth
    criterion = FocalLoss(alpha=1, gamma=2)

    # Train and validate
    best_model, best_acc = train_validate(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, fold_idx)
    fold_models.append(best_model)
    fold_accuracies.append(best_acc)

# Summary of fold accuracies
print("\nFold Accuracies:", fold_accuracies)
print("Mean Accuracy:", np.mean(fold_accuracies))




Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /Users/sebastianapelgren/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth


100%|██████████| 47.2M/47.2M [00:00<00:00, 70.4MB/s]


Fold 1, Epoch 1/20 - Train Loss: 1.3135, Train Acc: 0.3085 - Val Loss: 1.1509, Val Acc: 0.3960
Fold 1, Epoch 2/20 - Train Loss: 0.9320, Train Acc: 0.5025 - Val Loss: 0.7318, Val Acc: 0.5099
Fold 1, Epoch 3/20 - Train Loss: 0.5344, Train Acc: 0.6604 - Val Loss: 0.5696, Val Acc: 0.5248
Fold 1, Epoch 4/20 - Train Loss: 0.4653, Train Acc: 0.6704 - Val Loss: 0.4294, Val Acc: 0.6139
Fold 1, Epoch 5/20 - Train Loss: 0.4191, Train Acc: 0.7139 - Val Loss: 0.3900, Val Acc: 0.6139
Fold 1, Epoch 6/20 - Train Loss: 0.3656, Train Acc: 0.7289 - Val Loss: 0.3904, Val Acc: 0.6535
Fold 1, Epoch 7/20 - Train Loss: 0.3487, Train Acc: 0.7326 - Val Loss: 0.4454, Val Acc: 0.6040
Fold 1, Epoch 8/20 - Train Loss: 0.3155, Train Acc: 0.7550 - Val Loss: 0.3866, Val Acc: 0.6535
Fold 1, Epoch 9/20 - Train Loss: 0.2821, Train Acc: 0.7960 - Val Loss: 0.3385, Val Acc: 0.6733
Fold 1, Epoch 10/20 - Train Loss: 0.2689, Train Acc: 0.7898 - Val Loss: 0.3435, Val Acc: 0.6238
Fold 1, Epoch 11/20 - Train Loss: 0.2377, Train A



Fold 2, Epoch 1/20 - Train Loss: 1.3220, Train Acc: 0.3006 - Val Loss: 1.1902, Val Acc: 0.3433
Fold 2, Epoch 2/20 - Train Loss: 0.9311, Train Acc: 0.5280 - Val Loss: 0.7945, Val Acc: 0.4975
Fold 2, Epoch 3/20 - Train Loss: 0.6303, Train Acc: 0.6124 - Val Loss: 0.6596, Val Acc: 0.4428


In [None]:
# Summary of fold accuracies
print("\nFold Accuracies:", fold_accuracies)
print("Mean Accuracy:", np.mean(fold_accuracies))

In [None]:
# Example: Ensemble inference on a test image
def ensemble_predict(models, image_path, transform):
    image = Image.open(image_path).convert('RGB')
    img_t = transform(image).unsqueeze(0).to(DEVICE)
    probs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            outputs = model(img_t)
            probs.append(F.softmax(outputs, dim=1).cpu().numpy())
    avg_probs = np.mean(np.vstack(probs), axis=0)
    pred_class = np.argmax(avg_probs)
    return sorted(os.listdir(DATA_DIR))[pred_class], np.max(avg_probs)

# Usage example (replace 'some_image.jpg')
# label, confidence = ensemble_predict(fold_models, 'some_image.jpg', val_transforms)
# print(f"Ensembled Prediction: {label}, Confidence: {confidence:.2f}")


In [None]:
# 1. Define where to save
SAVE_PATH = "../api/stool_model.pth"

# 2. Save the state_dict
#torch.save(model.state_dict(), SAVE_PATH)
print(f"Model weights saved to {SAVE_PATH}")

Model weights saved to stool_model.pth
