# Advanced Training Pipeline

This notebook implements:
1. Stronger and more varied augmentation, including class-specific oversampling.
2. Model-level adjustments: gradual unfreezing, EfficientNet-B0/B3, label smoothing, focal loss/class-weighted loss.
3. Training strategies: early stopping, checkpoint ensembles, and k-fold cross-validation.


In [13]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report

In [14]:
# Configuration
DATA_DIR = "data"  # Update this path
IMG_SIZE = 224
BATCH_SIZE = 8
NUM_CLASSES = 7
LR = 1e-4
NUM_EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
KFOLDS = 5

# Seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


In [15]:
class StoolDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_to_idx[class_name] = idx
                for fname in os.listdir(class_path):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(class_path, fname), idx))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [16]:
# Stronger and more varied augmentations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),  # random crop + resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
    transforms.RandomApply([transforms.Lambda(lambda img: img.filter(ImageFilter.FIND_EDGES))], p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [17]:
# Label Smoothing Loss (CrossEntropy with label_smoothing)
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [18]:
def create_model(backbone='efficientnet_b0', num_classes=NUM_CLASSES, freeze_until_layer=None):
    # Load pretrained EfficientNet
    if backbone == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=True)
    elif backbone == 'efficientnet_b3':
        model = models.efficientnet_b3(pretrained=True)
    else:
        raise ValueError('Invalid backbone')
    
    # Replace classifier head
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.Linear(512, num_classes)
    )
    
    # Freeze layers if specified
    if freeze_until_layer:
        for name, param in model.named_parameters():
            param.requires_grad = False
            if name.startswith(freeze_until_layer):
                break
        # Unfreeze subsequent layers
        unfreeze = False
        for name, param in model.named_parameters():
            if unfreeze:
                param.requires_grad = True
            if name.startswith(freeze_until_layer):
                unfreeze = True
    
    return model.to(DEVICE)


In [19]:
def evaluate_model(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return None, None, all_preds, all_labels


In [20]:
def train_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, fold_idx):
    best_acc = 0.0
    best_weights = None
    patience = 3
    counter = 0
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += (preds == labels).sum().item()
                val_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss_epoch = val_loss / val_total
        val_acc_epoch = val_corrects / val_total
        lr_scheduler.step(val_acc_epoch)

        print(f"Fold {fold_idx}, Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} - "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_acc_epoch:.4f}")

        # Early stopping
        if val_acc_epoch > best_acc:
            best_acc = val_acc_epoch
            best_weights = model.state_dict().copy()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best weights
    model.load_state_dict(best_weights)

    # Final validation metrics
    _, _, preds, labels = evaluate_model(model, val_loader)
    print("\nClassification Report for Fold {}:".format(fold_idx))
    print(classification_report(labels, preds, target_names=sorted(os.listdir(DATA_DIR))))

    return model, best_acc


In [21]:
# Prepare full dataset indices for k-fold
full_dataset = StoolDataset(DATA_DIR, transform=None)
indices = list(range(len(full_dataset)))

# Calculate class weights for full dataset
all_labels_full = [label for _, label in full_dataset]
class_counts = np.bincount(all_labels_full)
class_weights = 1.0 / class_counts
weights_full = [class_weights[label] for label in all_labels_full]

kf = KFold(n_splits=KFOLDS, shuffle=True, random_state=SEED)
fold_models = []
fold_accuracies = []

for fold_idx, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
    print(f"======= Fold {fold_idx} =======")
    # Subset transforms
    train_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=train_transforms), train_idx)
    val_ds = torch.utils.data.Subset(StoolDataset(DATA_DIR, transform=val_transforms), val_idx)

    # Create weighted sampler for train_ds
    train_labels_fold = [train_ds.dataset.samples[i][1] for i in train_idx]
    class_sample_count_fold = np.array([train_labels_fold.count(i) for i in range(NUM_CLASSES)])
    class_weights_fold = 1.0 / class_sample_count_fold
    sample_weights_fold = np.array([class_weights_fold[label] for label in train_labels_fold])
    sample_weights_fold = torch.from_numpy(sample_weights_fold.astype(np.double))
    sampler_fold = WeightedRandomSampler(sample_weights_fold, num_samples=len(sample_weights_fold), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler_fold)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Create model and freeze initial layers
    model = create_model(backbone='efficientnet_b0')
    # Optionally freeze until a certain layer name, e.g., 'features.4'
    # model = create_model(backbone='efficientnet_b0', freeze_until_layer='features.4')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
    # Choose loss: label smoothing or focal
    # criterion = criterion_smooth
    criterion = FocalLoss(alpha=1, gamma=2)

    # Train and validate
    best_model, best_acc = train_validate(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, fold_idx)
    fold_models.append(best_model)
    fold_accuracies.append(best_acc)

# Summary of fold accuracies
print("\nFold Accuracies:", fold_accuracies)
print("Mean Accuracy:", np.mean(fold_accuracies))






Fold 1, Epoch 1/20 - Train Loss: 1.1125, Train Acc: 0.3797 - Val Loss: 0.8411, Val Acc: 0.4624
Fold 1, Epoch 2/20 - Train Loss: 0.6697, Train Acc: 0.5781 - Val Loss: 0.5630, Val Acc: 0.5663
Fold 1, Epoch 3/20 - Train Loss: 0.5102, Train Acc: 0.6652 - Val Loss: 0.4892, Val Acc: 0.6022
Fold 1, Epoch 4/20 - Train Loss: 0.4670, Train Acc: 0.6598 - Val Loss: 0.5216, Val Acc: 0.6057
Fold 1, Epoch 5/20 - Train Loss: 0.4371, Train Acc: 0.6903 - Val Loss: 0.4070, Val Acc: 0.6416
Fold 1, Epoch 6/20 - Train Loss: 0.3503, Train Acc: 0.7540 - Val Loss: 0.4111, Val Acc: 0.6738
Fold 1, Epoch 7/20 - Train Loss: 0.3465, Train Acc: 0.7639 - Val Loss: 0.3932, Val Acc: 0.6631
Fold 1, Epoch 8/20 - Train Loss: 0.2964, Train Acc: 0.7926 - Val Loss: 0.3575, Val Acc: 0.7061
Fold 1, Epoch 9/20 - Train Loss: 0.2945, Train Acc: 0.7792 - Val Loss: 0.3599, Val Acc: 0.6989
Fold 1, Epoch 10/20 - Train Loss: 0.2711, Train Acc: 0.8043 - Val Loss: 0.3719, Val Acc: 0.6846
Fold 1, Epoch 11/20 - Train Loss: 0.2365, Train A



Fold 2, Epoch 1/20 - Train Loss: 1.1755, Train Acc: 0.3519 - Val Loss: 0.9603, Val Acc: 0.4265
Fold 2, Epoch 2/20 - Train Loss: 0.7392, Train Acc: 0.5296 - Val Loss: 0.6347, Val Acc: 0.5627
Fold 2, Epoch 3/20 - Train Loss: 0.5227, Train Acc: 0.6490 - Val Loss: 0.5715, Val Acc: 0.5950
Fold 2, Epoch 4/20 - Train Loss: 0.4916, Train Acc: 0.6553 - Val Loss: 0.4679, Val Acc: 0.6380
Fold 2, Epoch 5/20 - Train Loss: 0.4373, Train Acc: 0.6840 - Val Loss: 0.4478, Val Acc: 0.6774
Fold 2, Epoch 6/20 - Train Loss: 0.3794, Train Acc: 0.7217 - Val Loss: 0.4941, Val Acc: 0.6452
Fold 2, Epoch 7/20 - Train Loss: 0.3360, Train Acc: 0.7379 - Val Loss: 0.4000, Val Acc: 0.6774
Fold 2, Epoch 8/20 - Train Loss: 0.3140, Train Acc: 0.7720 - Val Loss: 0.3532, Val Acc: 0.6953
Fold 2, Epoch 9/20 - Train Loss: 0.2784, Train Acc: 0.7971 - Val Loss: 0.4191, Val Acc: 0.7061
Fold 2, Epoch 10/20 - Train Loss: 0.2689, Train Acc: 0.7998 - Val Loss: 0.3672, Val Acc: 0.6882
Fold 2, Epoch 11/20 - Train Loss: 0.2602, Train A



Fold 3, Epoch 1/20 - Train Loss: 1.1385, Train Acc: 0.3959 - Val Loss: 0.7297, Val Acc: 0.5591
Fold 3, Epoch 2/20 - Train Loss: 0.6558, Train Acc: 0.5871 - Val Loss: 0.5332, Val Acc: 0.6129
Fold 3, Epoch 3/20 - Train Loss: 0.5223, Train Acc: 0.6688 - Val Loss: 0.4361, Val Acc: 0.6523
Fold 3, Epoch 4/20 - Train Loss: 0.4822, Train Acc: 0.6885 - Val Loss: 0.4385, Val Acc: 0.6344
Fold 3, Epoch 5/20 - Train Loss: 0.4364, Train Acc: 0.6903 - Val Loss: 0.3596, Val Acc: 0.7133
Fold 3, Epoch 6/20 - Train Loss: 0.3715, Train Acc: 0.7433 - Val Loss: 0.3408, Val Acc: 0.6918
Fold 3, Epoch 7/20 - Train Loss: 0.3347, Train Acc: 0.7504 - Val Loss: 0.2935, Val Acc: 0.7133
Fold 3, Epoch 8/20 - Train Loss: 0.3512, Train Acc: 0.7496 - Val Loss: 0.3425, Val Acc: 0.6953
Early stopping at epoch 8

Classification Report for Fold 3:
              precision    recall  f1-score   support

      type-1       0.84      0.88      0.86        56
      type-2       0.67      0.53      0.59        45
      type-3    



Fold 4, Epoch 1/20 - Train Loss: 1.1668, Train Acc: 0.3345 - Val Loss: 0.8813, Val Acc: 0.4640
Fold 4, Epoch 2/20 - Train Loss: 0.7155, Train Acc: 0.5372 - Val Loss: 0.5356, Val Acc: 0.5935
Fold 4, Epoch 3/20 - Train Loss: 0.5578, Train Acc: 0.6377 - Val Loss: 0.4128, Val Acc: 0.6727
Fold 4, Epoch 4/20 - Train Loss: 0.4775, Train Acc: 0.6735 - Val Loss: 0.3815, Val Acc: 0.6978
Fold 4, Epoch 5/20 - Train Loss: 0.4275, Train Acc: 0.6978 - Val Loss: 0.3626, Val Acc: 0.7086
Fold 4, Epoch 6/20 - Train Loss: 0.3780, Train Acc: 0.7238 - Val Loss: 0.3412, Val Acc: 0.7302
Fold 4, Epoch 7/20 - Train Loss: 0.3421, Train Acc: 0.7489 - Val Loss: 0.3174, Val Acc: 0.7518
Fold 4, Epoch 8/20 - Train Loss: 0.3112, Train Acc: 0.7821 - Val Loss: 0.3420, Val Acc: 0.7230
Fold 4, Epoch 9/20 - Train Loss: 0.3249, Train Acc: 0.7596 - Val Loss: 0.3092, Val Acc: 0.7482
Fold 4, Epoch 10/20 - Train Loss: 0.2704, Train Acc: 0.7848 - Val Loss: 0.3069, Val Acc: 0.7698
Fold 4, Epoch 11/20 - Train Loss: 0.2768, Train A



Fold 5, Epoch 1/20 - Train Loss: 1.1859, Train Acc: 0.3480 - Val Loss: 0.8639, Val Acc: 0.4892
Fold 5, Epoch 2/20 - Train Loss: 0.6907, Train Acc: 0.5731 - Val Loss: 0.5321, Val Acc: 0.6007
Fold 5, Epoch 3/20 - Train Loss: 0.5604, Train Acc: 0.6260 - Val Loss: 0.4777, Val Acc: 0.6511
Fold 5, Epoch 4/20 - Train Loss: 0.4978, Train Acc: 0.6655 - Val Loss: 0.3676, Val Acc: 0.7122
Fold 5, Epoch 5/20 - Train Loss: 0.3973, Train Acc: 0.7103 - Val Loss: 0.3663, Val Acc: 0.6906
Fold 5, Epoch 6/20 - Train Loss: 0.3983, Train Acc: 0.7085 - Val Loss: 0.3319, Val Acc: 0.7302
Fold 5, Epoch 7/20 - Train Loss: 0.3447, Train Acc: 0.7354 - Val Loss: 0.3253, Val Acc: 0.7410
Fold 5, Epoch 8/20 - Train Loss: 0.2760, Train Acc: 0.7785 - Val Loss: 0.3416, Val Acc: 0.7302
Fold 5, Epoch 9/20 - Train Loss: 0.2650, Train Acc: 0.8000 - Val Loss: 0.3328, Val Acc: 0.7554
Fold 5, Epoch 10/20 - Train Loss: 0.2530, Train Acc: 0.7928 - Val Loss: 0.3040, Val Acc: 0.7734
Fold 5, Epoch 11/20 - Train Loss: 0.2650, Train A

In [22]:
# Example: Ensemble inference on a test image
def ensemble_predict(models, image_path, transform):
    image = Image.open(image_path).convert('RGB')
    img_t = transform(image).unsqueeze(0).to(DEVICE)
    probs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            outputs = model(img_t)
            probs.append(F.softmax(outputs, dim=1).cpu().numpy())
    avg_probs = np.mean(np.vstack(probs), axis=0)
    pred_class = np.argmax(avg_probs)
    return sorted(os.listdir(DATA_DIR))[pred_class], np.max(avg_probs)

# Usage example (replace 'some_image.jpg')
# label, confidence = ensemble_predict(fold_models, 'some_image.jpg', val_transforms)
# print(f"Ensembled Prediction: {label}, Confidence: {confidence:.2f}")


In [None]:

# 1. Define where to save
SAVE_PATH = "stool_model.pth"

# 2. Save the state_dict
torch.save(model.state_dict(), SAVE_PATH)
print(f"Model weights saved to {SAVE_PATH}")

Model weights saved to stool_model.pth
