# Imports

In [2]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
from torchvision import transforms, models

from sklearn.model_selection import train_test_split

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report
from tabulate import tabulate

import copy
from torch.optim.lr_scheduler import OneCycleLR

In [3]:
SEED = 4
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

In [None]:
def format_confusion_matrix(cm):
    """Returns a string of a nicely formatted confusion matrix with indices and highlighted diagonal."""
    headers = [""] + [f"Pred {i}" for i in range(len(cm[0]))]
    table = []

    for i, row in enumerate(cm):
        formatted_row = []
        for j, val in enumerate(row):
            if i == j:
                formatted_row.append(f"*{val}*")  # Highlight diagonal
            else:
                formatted_row.append(str(val))
        table.append([f"True {i}"] + formatted_row)

    return tabulate(table, headers=headers, tablefmt="grid")

In [None]:
all_logs = []

def log_and_store(*msgs, table_format=False, is_confmat=False):
    """
    Logs plain messages or pretty-prints confusion matrices or tables.
    """
    if is_confmat and len(msgs) == 1 and isinstance(msgs[0], list):
        msg = format_confusion_matrix(msgs[0])
    elif table_format and len(msgs) == 1 and isinstance(msgs[0], list):
        msg = tabulate(msgs[0], tablefmt="grid")
    else:
        msg = " ".join(str(m) for m in msgs)

    print(msg)
    all_logs.append(msg)

def get_logs():
    """
    Returnerar en lista med alla loggade meddelanden.
    """
    return all_logs

def clear_logs():
    """
    Tömmer loggen.
    """
    all_logs.clear()

def save_logs_to_file(filename):
    """
    Sparar loggade meddelanden till en fil.
    """
    with open(filename, 'w') as f:
        for log in all_logs:
            f.write(log + '\n')

# Advanced Training Pipeline

This notebook implements:
1. Stronger and more varied augmentation, including class-specific oversampling.
2. Model-level adjustments: gradual unfreezing, EfficientNet-B0/B3, label smoothing, focal loss/class-weighted loss.
3. Training strategies: early stopping, checkpoint ensembles, and k-fold cross-validation.


In [None]:
# Configuration
DATA_DIR = "../data-pools/data-consistence"  # Update this path
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
NUM_CLASSES = 3


In [None]:
class StoolDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                self.class_to_idx[class_name] = idx
                for fname in os.listdir(class_path):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((os.path.join(class_path, fname), idx))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [None]:
# Stronger and more varied augmentations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),  # random crop + resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
    transforms.RandomApply([transforms.Lambda(lambda img: img.filter(ImageFilter.FIND_EDGES))], p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Label Smoothing Loss (CrossEntropy with label_smoothing)
criterion_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

# Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
    def __name__(self):
        return "FocalLoss"

In [None]:
def create_model(backbone, num_classes=NUM_CLASSES, freeze_until_layer=None):
    if backbone == 'mobilenet_v3_small':
        model = models.mobilenet_v3_small(pretrained=True)
        # freeze layers
        if freeze_until_layer:
            for name, param in model.features.named_parameters():
                param.requires_grad = False
                if freeze_until_layer in name:
                    break

        # Replace final classifier
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)

    elif backbone == 'mobilenet_v3_large':
        model = models.mobilenet_v3_large(pretrained=True)
        if freeze_until_layer:
            for name, param in model.features.named_parameters():
                param.requires_grad = False
                if freeze_until_layer in name:
                    break

        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)

    elif backbone == 'mobilenet_v2':
        model = models.mobilenet_v2(pretrained=True)
        if freeze_until_layer:
            for name, param in model.features.named_parameters():
                param.requires_grad = False
                if freeze_until_layer in name:
                    break

        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)

    elif backbone == 'efficientnet_b0':
        model = models.efficientnet_b0(pretrained=True)
        if freeze_until_layer:
            for name, param in model.features.named_parameters():
                param.requires_grad = False
                if freeze_until_layer in name:
                    break

        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)

    elif backbone == 'efficientnet_b3':
        model = models.efficientnet_b3(pretrained=True)
        if freeze_until_layer:
            for name, param in model.features.named_parameters():
                param.requires_grad = False
                if freeze_until_layer in name:
                    break

        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)

    else:
        raise ValueError('Invalid backbone')

    return model.to(DEVICE)

In [None]:
def evaluate_model(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return None, None, all_preds, all_labels


In [None]:
def train_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, fold_idx):

    #patience = 3
    #counter = 0
    #lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3) # for accuracy
    #lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3) # for loss

    best_acc     = 0.0
    best_loss    = float('inf')
    best_weights = copy.deepcopy(model.state_dict())

    # ── One‐Cycle LR schedule ──────────────────────────────────────────────────
    scheduler = OneCycleLR(
        optimizer,
        max_lr=optimizer.param_groups[0]['lr'] * 10,  # e.g. 10× your base LR
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
    )

    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0) 
            _, preds = torch.max(outputs, 1)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += (preds == labels).sum().item()
                val_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss_epoch = val_loss / val_total
        val_acc_epoch = val_corrects / val_total
        # step the one‐cycle scheduler each batch‐cycle (done at epoch‐end here)
        scheduler.step()

        # keep snapshot of best‐ever validation loss (for final restore)
        if val_loss_epoch < best_loss:
            best_loss    = val_loss_epoch
            best_acc     = val_acc_epoch
            best_weights = copy.deepcopy(model.state_dict())

        print(f"Fold {fold_idx}, Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} - "
              f"Val Loss: {val_loss_epoch:.4f}, Val Acc: {val_acc_epoch:.4f}")

        # Early stopping
        #if val_acc_epoch > best_acc:
        #    best_acc = val_acc_epoch
        #   best_weights = model.state_dict().copy()
        #   counter = 0
        #else:
        #    counter += 1
        #    if counter >= patience:
        #        print(f"Early stopping at epoch {epoch+1}")
        #       break

    # Load best weights
    model.load_state_dict(best_weights)

    # Final validation metrics
    _, _, preds, labels = evaluate_model(model, val_loader)
    log_and_store("\nClassification Report for Fold {}:".format(fold_idx))
    log_and_store(classification_report(labels, preds, target_names=sorted(os.listdir(DATA_DIR))))

    return model, best_acc


In [None]:
# Choose device
device = torch.device(
    "cuda" if torch.cuda.is_available() 
    else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() 
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


# Training functions

In [None]:
def run_kfold_training(
    data_dir,
    backbone='mobilenet_v2',
    freeze_until_layer=None,
    criterion_fn=None,
    num_epochs=10,
    lr=1e-4,
    k_folds=3,
    batch_size=32,
    seed=42,
    num_classes=NUM_CLASSES,
):
    full_dataset = StoolDataset(data_dir, transform=None)
    indices = list(range(len(full_dataset)))

    # Class weights for full dataset (optional, for balance insights)
    all_labels_full = [label for _, label in full_dataset]
    class_counts = np.bincount(all_labels_full)
    class_weights = 1.0 / class_counts
    weights_full = [class_weights[label] for label in all_labels_full]

    kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=seed)
    fold_models = []
    fold_accuracies = []

    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(indices, all_labels_full), 1):
        print(f"\n======= Fold {fold_idx} =======")

        # Subset + transforms
        train_ds = torch.utils.data.Subset(StoolDataset(data_dir, transform=train_transforms), train_idx)
        val_ds   = torch.utils.data.Subset(StoolDataset(data_dir, transform=val_transforms), val_idx)

        # Weighted sampler
        train_labels_fold = [train_ds.dataset.samples[i][1] for i in train_idx]
        class_sample_count_fold = np.array([train_labels_fold.count(i) for i in range(num_classes)])

        print(f"Class sample counts: {class_sample_count_fold}")
        
        class_weights_fold = 1.0 / class_sample_count_fold
        sample_weights_fold = np.array([class_weights_fold[label] for label in train_labels_fold])
        sample_weights_fold = torch.from_numpy(sample_weights_fold.astype(np.double))
        sampler_fold = WeightedRandomSampler(sample_weights_fold, num_samples=len(sample_weights_fold), replacement=True)

        # DataLoaders
        train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler_fold)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

        # Create model
        model = create_model(backbone=backbone, freeze_until_layer=freeze_until_layer)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

        # Loss
        criterion = criterion_fn if criterion_fn is not None else nn.CrossEntropyLoss()

        # Train
        best_model, best_acc = train_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, fold_idx)
        fold_models.append(best_model)
        fold_accuracies.append(best_acc)

        # Evaluate and log
        best_model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                logits = best_model(xb)
                preds = logits.argmax(dim=1).cpu().numpy()
                all_preds.extend(preds)
                all_labels.extend(yb.numpy())

        cm = confusion_matrix(all_labels, all_preds)
        crpt = classification_report(all_labels, all_preds, digits=4)

        log_and_store(f"\n--- Fold {fold_idx} Confusion Matrix ---")
        log_and_store(cm, is_confmat=True)

        log_and_store(f"\n--- Fold {fold_idx} Classification Report ---")
        log_and_store(crpt)

    log_and_store(["\nFold Models:", [f"Fold {i+1}" for i in range(len(fold_models))]])
    log_and_store(["Fold Accuracies:", fold_accuracies])
    log_and_store(["Mean Accuracy:", np.mean(fold_accuracies)])

    return fold_models, fold_accuracies

In [None]:
def train_single_split(
    data_dir,
    model_name='efficientnet_b3',
    freeze_until=None,
    criterion='focal',  # or 'smooth'
    batch_size=32,
    lr=1e-4,
    num_epochs=10,
    val_split=0.2,
    seed=42,
):
    print("======= Single Split Training =======")

    # Full dataset
    full_dataset = StoolDataset(data_dir, transform=None)
    indices = list(range(len(full_dataset)))
    all_labels = [label for _, label in full_dataset]

    # Train/val split
    train_idx, val_idx = train_test_split(
        indices, test_size=val_split, random_state=seed, stratify=all_labels
    )

    # Create datasets with transforms
    train_ds = Subset(StoolDataset(data_dir, transform=train_transforms), train_idx)
    val_ds = Subset(StoolDataset(data_dir, transform=val_transforms), val_idx)

    # Weighted sampler for class imbalance
    train_labels = [train_ds.dataset.samples[i][1] for i in train_idx]
    class_sample_counts = np.array([train_labels.count(i) for i in range(NUM_CLASSES)])

    print(f"Class sample counts: {class_sample_counts}")
    
    class_weights = 1.0 / class_sample_counts
    sample_weights = np.array([class_weights[label] for label in train_labels])
    sample_weights = torch.from_numpy(sample_weights.astype(np.double))
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

    # Data loaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    # Model
    model = create_model(backbone=model_name, freeze_until_layer=freeze_until)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    if criterion == 'focal':
        loss_fn = FocalLoss(alpha=1, gamma=2)
    elif criterion == 'smooth':
        loss_fn = criterion_smooth
    else:
        raise ValueError("Unsupported loss function")

    # Train
    best_model, best_acc = train_validate(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, fold_idx=None)

    # Evaluation
    best_model.eval()
    all_preds, all_labels_eval = [], []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            preds = best_model(xb).argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels_eval.extend(yb.numpy())

    # Confusion matrix and classification report
    cm = confusion_matrix(all_labels_eval, all_preds)
    cr = classification_report(all_labels_eval, all_preds, digits=4)

    log_and_store("--- Confusion Matrix ---")
    log_and_store(cm, is_confmat=True)
    log_and_store("--- Classification Report ---")
    log_and_store(cr)

    return best_model, best_acc

# Category Model

In [None]:
kf_models, accs = run_kfold_training(
    data_dir=DATA_DIR,
    backbone='efficientnet_b3',  # or 'mobilenet_v3_small', 'efficientnet_b0', 'efficientnet_b3'
    freeze_until_layer=None,  # or 'features.4' for EfficientNet
    criterion_fn=FocalLoss(alpha=1, gamma=2),
    num_epochs=20,
    lr=1e-4,
    k_folds=5,
    batch_size=16,
    seed=44,
    num_classes=3,
)

# Constipated Model

# Normal Model

# Loose Model

# Accuracy Computing

# Save Models