In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import (
    convnext_base, ConvNeXt_Base_Weights,
    densenet161, DenseNet161_Weights,
    efficientnet_b3, EfficientNet_B3_Weights,
    mobilenet_v3_large, MobileNet_V3_Large_Weights,
    vit_b_16, ViT_B_16_Weights
)
from torch.cuda.amp import autocast, GradScaler
from transformers import get_cosine_schedule_with_warmup
from PIL import Image
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)
from contextlib import nullcontext

In [None]:
#configuration
base_data_dir = "splits_images/"
splits = ["broken", "capsules", "daylewis", "double", "minor_major", "oval_round_oblong", "tablets"]
model_names = ["convnext_base", "densenet161", "efficientnet_b3", "mobilenet_v3_large", "vit_b_16"]
batch_size = 32
num_epochs = 15
lr = 0.0001
warmup_steps = 250
use_augmentation = True

#device and flags
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_compile = True
use_channels_last = True
use_amp = True

os.makedirs("saved_models", exist_ok=True)

#transforms
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

plain = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

#dataset
class CustomImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.image_paths = sorted([
            os.path.join(folder_path, f)
            for f in os.listdir(folder_path)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = 0 if 'proper' in os.path.basename(path).lower() else 1
        return image, label

#model loader
model_loader = {
    "convnext_base": lambda: convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT),
    "densenet161": lambda: densenet161(weights=DenseNet161_Weights.DEFAULT),
    "efficientnet_b3": lambda: efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT),
    "mobilenet_v3_large": lambda: mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT),
    "vit_b_16": lambda: vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
}

def replace_head(model, name):
    if name.startswith("convnext"):
        model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 2)
    elif name.startswith("densenet"):
        model.classifier = nn.Linear(model.classifier.in_features, 2)
    elif name.startswith("efficientnet") or name.startswith("mobilenet"):
        model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 2)
    elif name.startswith("vit"):
        model.heads.head = nn.Linear(model.heads.head.in_features, 2)
    return model

results = []

for model_key in model_names:
    for split in tqdm(splits, desc=f"{model_key} - Splits"):
        prefix = "aug_" if use_augmentation else "base_"
        model_name = f"{prefix}{model_key}_{split}"
        data_dir = os.path.join(base_data_dir, split)

        train_dataset = CustomImageDataset(os.path.join(data_dir, "train"), transform=train_aug if use_augmentation else plain)
        val_dataset = CustomImageDataset(os.path.join(data_dir, "val"), transform=plain)
        test_dataset = CustomImageDataset(os.path.join(data_dir, "test"), transform=plain)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)

        labels = [label for _, label in train_dataset]
        class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

        model = model_loader[model_key]()
        model = replace_head(model, model_key)
        if use_compile:
            model = torch.compile(model)
        model = model.to(device)
        if use_channels_last:
            model = model.to(memory_format=torch.channels_last)

        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        total_steps = len(train_loader) * num_epochs
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
        scaler = GradScaler() if use_amp else None

        best_f1 = 0.0
        best_threshold = 0.5
        best_model_state = None

        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                if use_channels_last:
                    images = images.to(memory_format=torch.channels_last)

                optimizer.zero_grad()
                with autocast() if use_amp else nullcontext():
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                if use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()

                scheduler.step()
                total_loss += loss.item()

        #validation threshold tuning
        model.eval()
        y_true, y_probs = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                if use_channels_last:
                    images = images.to(memory_format=torch.channels_last)
                with autocast() if use_amp else nullcontext():
                    outputs = model(images)
                    probs = torch.softmax(outputs, dim=1)[:, 1]
                y_true.extend(labels.cpu().numpy())
                y_probs.extend(probs.cpu().numpy())

        thresholds = np.arange(0.05, 0.95, 0.05)
        for t in thresholds:
            preds = (np.array(y_probs) >= t).astype(int)
            f1 = f1_score(y_true, preds)
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = t
                best_model_state = model.state_dict()

        #save model and test
        torch.save(best_model_state, f"saved_models/{model_name}.pt")
        model.load_state_dict(best_model_state)
        model.eval()

        y_true, y_probs = [], []
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                if use_channels_last:
                    images = images.to(memory_format=torch.channels_last)
                with autocast() if use_amp else nullcontext():
                    outputs = model(images)
                    probs = torch.softmax(outputs, dim=1)[:, 1]
                y_true.extend(labels.cpu().numpy())
                y_probs.extend(probs.cpu().numpy())

        y_pred = (np.array(y_probs) >= best_threshold).astype(int)
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()

        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        tnr = tn / (tn + fp)
        fpr = fp / (fp + tn)
        fnr = fn / (fn + tp)

        #collect result
        results.append({
            "model": model_name,
            "cnn_model": model_key,
            "split": split,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "tnr": tnr,
            "fpr": fpr,
            "fnr": fnr,
            "confusion_matrix": cm.tolist(),
            "threshold": best_threshold,
            "augmented": use_augmentation,
            "num_samples": len(test_dataset)
        })

#save all results to CSV
output_csv = "metrics_fine_tuning.csv" if use_augmentation else "metrics_fine_tuning_no_aug.csv"
df = pd.DataFrame(results)
df.to_csv(output_csv, index=False)