# Settings script

In [None]:
# Model selection
MODEL_SCRATCH          = False      # No pretrained weights, trained from scratch on Food-101
MODEL_BASELINE         = False      # Pretrained EfficientNet, but NOT trained on Food-101
MODEL_FINETUNE_PARTIAL = False      # Pretrained EfficientNet, THEN fine-tuned on Food-101
MODEL_FINETUNE         = True       # Pretrained EfficientNet, THEN fine-tuned on Food-101

# Augmentation settings
AUGMENTATION_MIN      = False
AUGMENTATION_MANUAL   = False
AUGMENTATION_AUTO     = False
AUGMENTATION_IMAGENET = True

# Training settings (Pick one)
REGULAR_TRAINING = True
RAYTUNE_TRAINING = False            
LOGGING_ACTIVE = REGULAR_TRAINING   # Only perform logging with regular training

# Other settings
USE_FULL_DATASET = False            # If false, use small dataset instead (10K pictures, 100 per class)
NUM_EPOCHS       = 15
BATCH_SIZE       = 32
LEARN_RATE       = 1e-3

# Sanity checks
assert sum([MODEL_BASELINE, MODEL_SCRATCH, MODEL_FINETUNE_PARTIAL ,MODEL_FINETUNE]) == 1, \
    "Exactly one model must be selected."

assert sum([REGULAR_TRAINING, RAYTUNE_TRAINING]) == 1, \
    "Exactly one training mode must be active."

assert not (RAYTUNE_TRAINING and LOGGING_ACTIVE), \
    "LOGGING_ACTIVE should be False during Ray Tune."

assert not (MODEL_BASELINE and (REGULAR_TRAINING or RAYTUNE_TRAINING)), \
    "MODEL_BASELINE can't be used with TRAINING."

# Install dependencies
Run this in the terminal.
Remember to check if the intended kernel is in use.

In [None]:
#pip install -r requirements.txt

# Setup

## Imports

In [None]:
# Standard library
import os
import pathlib
import time
from collections import Counter
import copy
import tempfile
from math import isnan

# ML
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# TorchVision
from torchvision import datasets, transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.transforms import AutoAugment, AutoAugmentPolicy

# W&B
import wandb
import cv2
import matplotlib.pyplot as plt

# Ray Tune
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

## Set GPU variable

In [None]:
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available.")

device = "cuda"
print("Device:", device)

# Data

## Data paths

In [None]:
ROOT = pathlib.Path().resolve()

if USE_FULL_DATASET:
    IMAGE_DIR = ROOT / "data" / "food-101-big" / "images"
else:
    IMAGE_DIR = ROOT / "data" / "food-101-small"

print("Using IMAGE_DIR =", IMAGE_DIR)

if not IMAGE_DIR.exists():
    raise FileNotFoundError(f"Missing dataset folder:\n{IMAGE_DIR}")

## Data preprocessing

In [None]:
img_size = 224 # Needed width and height dimensions for EfficientNet

imagenet_norm = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)

# Augment data
if AUGMENTATION_MIN:
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])
elif AUGMENTATION_MANUAL:
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])    
elif AUGMENTATION_AUTO:
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
    ])
elif AUGMENTATION_IMAGENET:
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        imagenet_norm,
    ])
else:
    raise ValueError("No augmentation mode selected.")

if AUGMENTATION_IMAGENET:
    validation_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        imagenet_norm,
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        imagenet_norm,
    ])

else:
    validation_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])

    test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])

full_dataset = datasets.ImageFolder(IMAGE_DIR, transform=None)
class_names = full_dataset.classes
num_classes = len(class_names)

print("Total images:", len(full_dataset))
print("Classes:", num_classes)

## Split data into sets

In [None]:
targets = np.array(full_dataset.targets)   # labels
indices = np.arange(len(targets))

# split: 70% train, 30% temp
train_idx, temp_idx = train_test_split(
    indices,
    test_size=0.30,
    stratify=targets,
    random_state=42,
)

# split temp into 15% val, 15% test
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.50,                       # 0.50 * 30% = 15%
    stratify=targets[temp_idx],
    random_state=42,
)

In [None]:
train_dataset = copy.deepcopy(full_dataset)
val_dataset   = copy.deepcopy(full_dataset)
test_dataset  = copy.deepcopy(full_dataset)

train_dataset.transform = train_transform
val_dataset.transform   = validation_transform
test_dataset.transform  = test_transform

# Wrap in Subsets
train_ds = torch.utils.data.Subset(train_dataset, train_idx)
val_ds   = torch.utils.data.Subset(val_dataset, val_idx)
test_ds  = torch.utils.data.Subset(test_dataset, test_idx)

In [None]:
def check_split_balance(subset, name):
    labels = [full_dataset.targets[i] for i in subset.indices]
    counts = Counter(labels)

    unique_counts = set(counts.values())
    if len(unique_counts) == 1:
        # perfectly balanced
        value = unique_counts.pop()
        print(f"{name} distribution: {value} per class")
    else:
        # imbalance detected
        print(f"{name} distribution is NOT balanced:")
        print(counts)

check_split_balance(train_ds, "Train")
check_split_balance(val_ds, "Val")
check_split_balance(test_ds, "Test")

## Create DataLoaders

In [None]:
if REGULAR_TRAINING:
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    print("Train samples:", len(train_ds))
    print("Val samples:  ", len(val_ds))
    print("Test samples: ", len(test_ds))

# Model

## Define Model

In [None]:
# Pretrained model but NOT trained (baseline)
def build_model_baseline(lr, num_classes):
    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    # Freeze all parameters → no training
    for p in model.parameters():
        p.requires_grad = False

    optimizer = None 
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion


# Train from scratch
def build_model_scratch(lr, num_classes):
    model = efficientnet_b0(weights=None)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion

def build_model_finetune_partial(lr, num_classes):
    """
    Pretrained EfficientNet-B0 with partial fine-tuning:
    - Freeze early feature extractor blocks
    - Train only later blocks + classifier
    """

    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

    # Replace classifier head
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    # --- Freeze early layers (blocks 0–4) ---
    for idx, block in enumerate(model.features):
        if idx <= 4:
            for p in block.parameters():
                p.requires_grad = False
        else:
            # idx >= 5 → unfreeze
            for p in block.parameters():
                p.requires_grad = True

    # classifier stays trainable
    for p in model.classifier.parameters():
        p.requires_grad = True

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=lr
    )
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion


# Pretrained → fine-tuned
def build_model_finetune(lr, num_classes):
    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion



# def build_model_1(lr, num_classes):
#     """EfficientNet-B0 model"""
#     weights = EfficientNet_B0_Weights.IMAGENET1K_V1
#     model = efficientnet_b0(weights=weights)

#     model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

#     for p in model.features.parameters():
#         p.requires_grad = False

#     model = model.to(device)
#     optimizer = optim.Adam(model.parameters(), lr=lr)
#     criterion = nn.CrossEntropyLoss()

#     return model, optimizer, criterion


# def build_model_2(lr, num_classes):
#     """Placeholder model — intentionally blank"""
#     raise NotImplementedError("MODEL_2 not implemented yet.")

## Initialize model

In [None]:
if MODEL_BASELINE:
    build_model = build_model_baseline
elif MODEL_SCRATCH:
    build_model = build_model_scratch
elif MODEL_FINETUNE_PARTIAL:
    build_model = build_model_finetune_partial
elif MODEL_FINETUNE:
    build_model = build_model_finetune

if REGULAR_TRAINING:
    model, optimizer, criterion = build_model(LEARN_RATE, num_classes)

    model = model.to(device)
    criterion = criterion.to(device)

# Setup Logging

In [None]:
if LOGGING_ACTIVE:
    wandb.login()

    batch_size = BATCH_SIZE
    learning_rate = LEARN_RATE

    if MODEL_SCRATCH:
        model_name = "scratch"
        PRETRAINED = False
        FINETUNED = False
    elif MODEL_BASELINE:
        model_name = "baseline"
        PRETRAINED = True
        FINETUNED = False
    elif MODEL_FINETUNE_PARTIAL:
        model_name = "finetune_partial"
        PRETRAINED = True
        FINETUNED = True
    elif MODEL_FINETUNE:
        model_name = "finetune"
        PRETRAINED = True
        FINETUNED = True
    else:
        model_name = ""

    if AUGMENTATION_MIN:
        augmentation_string = "noaugment_"
    elif AUGMENTATION_MANUAL:
        augmentation_string = "manualaugment_"
    elif AUGMENTATION_AUTO:
        augmentation_string = "autoaugment_"
    else:
        augmentation_string = ""

    def log_metric(data: dict, step=None):
        wandb.log(data, step=step)

    def log_batch_metric(loss_value, optimizer, epoch):
        wandb.log(
            {
                "batch/train_loss": loss_value,
                "batch/lr": optimizer.param_groups[0]["lr"],
            },
            step=epoch
        )

    def log_accuracy_vs_epoch(split: str, acc_value: float, epoch: int):
        wandb.log({f"accuracy_vs_epoch/{split}": acc_value}, step=epoch)


    run_name = (
        f"model{model_name}_"
        f"lr{learning_rate}_bs{batch_size}_"
        f"{'finetuned_' if FINETUNED else ''}"
        f"{'pretrained_' if PRETRAINED else ''}"
        f"{augmentation_string}"
    )

    wandb_config = {
        "model_name": "efficientnet_b0",
        "img_size": img_size,
        "batch_size": batch_size,
        "optimizer": "Adam",
        "learning_rate": learning_rate,
        "train_size": len(train_ds),
        "val_size": len(val_ds),
        "test_size": len(test_ds),
        "num_classes": num_classes,
        "transfer_learning": True,
        "feature_extractor_frozen": True,
        "augmentation": "AutoAugment_IMAGENET",
    }

    run = wandb.init(
        project="IMTL",
        config=wandb_config,
        name=run_name
    )

    wandb.watch(model, log="all", log_freq=100)

# Train the Model

## Define Regular training and evaluation functions

In [None]:
if REGULAR_TRAINING:
    
    def train_one_epoch(model, loader, optimizer, criterion, device, epoch, log_interval=100):
        model.train()
        total_loss, correct, total = 0.0, 0, 0

        pbar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)

        for batch_idx, (imgs, labels) in enumerate(pbar):
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

            avg_loss_so_far = total_loss / total
            avg_acc_so_far = correct / total

            pbar.set_postfix({
                "loss": f"{avg_loss_so_far:.4f}",
                "acc": f"{100 * avg_acc_so_far:.2f}%"
            })

            if (batch_idx + 1) % log_interval == 0:
                log_batch_metric(loss.item(), optimizer, epoch)

        avg_loss = total_loss / total
        avg_acc = correct / total

        log_metric({"loss/train": avg_loss, "accuracy/train": avg_acc}, step=epoch)
        log_accuracy_vs_epoch("train", avg_acc, epoch)

        return avg_loss, avg_acc


    def evaluate(model, loader, criterion, device, epoch=None, split="val"):
        model.eval()
        total_loss, correct, total = 0.0, 0, 0

        desc = f"Evaluating ({split})"
        if epoch is not None:
            desc += f" Epoch {epoch}"

        with torch.no_grad():
            pbar = tqdm(loader, desc=desc, leave=False)

            for imgs, labels in pbar:
                imgs, labels = imgs.to(device), labels.to(device)

                outputs = model(imgs)
                loss = criterion(outputs, labels)

                total_loss += loss.item() * imgs.size(0)
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

                pbar.set_postfix({
                    "loss": f"{total_loss / total:.4f}",
                    "acc": f"{100 * (correct / total):.2f}%"
                })

        avg_loss = total_loss / total
        avg_acc = correct / total

        log_metric(
            {
                f"{split}/loss": avg_loss,
                f"{split}/accuracy": avg_acc,
                f"{split}/error_rate": 1 - avg_acc,
                **({"epoch": epoch} if epoch is not None else {})
            },
            step=epoch
        )

        if split == "val" and epoch is not None:
            log_accuracy_vs_epoch("val", avg_acc, epoch)

        return avg_loss, avg_acc

## Define RayTune training and evaluation functions

In [None]:
if RAYTUNE_TRAINING:
    
    def train_one_epoch_raytune(model, loader, optimizer, criterion, device):
        model.train()
        total_loss, correct, total = 0, 0, 0

        for imgs, labels in tqdm(loader, desc="Training", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        return total_loss / total, correct / total

    def evaluate_raytune(model, loader, criterion, device):
        model.eval()
        total_loss, correct, total = 0.0, 0, 0

        with torch.no_grad():
            pbar = tqdm(loader, desc="Evaluating (raytune)", leave=False)

            for imgs, labels in pbar:
                imgs, labels = imgs.to(device), labels.to(device)

                outputs = model(imgs)
                loss = criterion(outputs, labels)

                total_loss += loss.item() * imgs.size(0)
                _, preds = outputs.max(1)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

                pbar.set_postfix({
                    "loss": f"{total_loss / total:.4f}",
                    "acc": f"{100 * (correct / total):.2f}%"
                })

        avg_loss = total_loss / total
        avg_acc = correct / total

        return avg_loss, avg_acc

    def tune_train(config):
        batch_size = config["batch_size"]
        lr = config["lr"]

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)

        model, optimizer, criterion = build_model(lr, num_classes)

        # Move to GPU
        model = model.to(device)
        criterion = criterion.to(device)

        for epoch in range(config["epochs"]):
            # Run epoch and time it
            start_time = time.time()
            train_one_epoch_raytune(model, train_loader, optimizer, criterion, device)
            epoch_time = time.time() - start_time

            # Evaluation
            val_loss, val_acc = evaluate_raytune(model, val_loader, criterion, device)

            # GPU memory logging
            gpu_mem = 0.0
            if torch.cuda.is_available():
                gpu_mem = torch.cuda.max_memory_allocated() / 1e9  # GB
                torch.cuda.reset_peak_memory_stats()

            # Report metrics
            tune.report({
                "loss": float(val_loss),
                "accuracy": float(val_acc),
                "epoch_time": float(epoch_time),
                "gpu_mem": float(gpu_mem),
                "lr": float(optimizer.param_groups[0]["lr"]),
            })

## Regular training (If toggled on)

In [None]:
if REGULAR_TRAINING:
    if optimizer is None:
        print("=== Baseline model selected — skipping training ===")

    else:
        # Save accuracies for each run
        train_acc_history = []
        val_acc_history   = []

        for epoch in range(NUM_EPOCHS):
            print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)
            val_loss, val_acc     = evaluate(model, val_loader, criterion, device)

            train_acc_history.append(train_acc)
            val_acc_history.append(val_acc)
            
            print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
            print(f"Val   Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")


## Raytune training (If toggled on)

In [None]:
if RAYTUNE_TRAINING:
    os.environ["RAY_DISABLE_METRICS_EXPORT"] = "1"

    search_space = {
        #"lr": tune.loguniform(1e-5, 1e-2),
        "lr": tune.loguniform(3e-5, 3e-3),
        "batch_size": tune.choice([16, 32, 64]),
        "epochs":  tune.randint(3, 15),
    }

    scheduler = ASHAScheduler(
        metric="accuracy",
        mode="max",
        max_t=15,
        grace_period=2,
        reduction_factor=3,
    )

    reporter = CLIReporter(
        metric_columns=["loss", "accuracy", "epoch_time", "gpu_mem"]
    )

    tuner = tune.Tuner(
        tune.with_resources(
            tune_train,
            resources={"cpu": 4, "gpu": 1 if torch.cuda.is_available() else 0},
        ),
        param_space=search_space,
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            num_samples=6,
        ),
        run_config=tune.RunConfig(progress_reporter=reporter),
    )

    results = tuner.fit()
    best = results.get_best_result(metric="accuracy", mode="max")

    print("Best config:", best.config)
    best_config = best.config


### Write RayTune result to file

In [None]:
if RAYTUNE_TRAINING:

    summary_path = ROOT / "raytune" / "raytune_summary.txt"

    best_result = results.get_best_result(metric="accuracy", mode="max")
    analysis = results._experiment_analysis
    trials = analysis.trials  # raw Trial objects

    # Helper to format floats like Ray does
    def fmt(x, width):
        if isinstance(x, float):
            if isnan(x):
                return "nan".ljust(width)
            # scientific notation for small lr values, otherwise 4 decimals
            if abs(x) < 1e-3:
                return f"{x:.2e}".ljust(width)
            return f"{x:.4f}".ljust(width)
        return str(x).ljust(width)


    # Column widths
    W_NAME = 24
    W_STATUS = 12
    W_BATCH = 12
    W_LR = 12
    W_LOSS = 10
    W_ACC = 12
    W_TIME = 12
    W_MEM = 10

    # Prepare table header
    header = (
        f"| {'Trial name'.ljust(W_NAME)}"
        f"| {'status'.ljust(W_STATUS)}"
        f"| {'batch_size'.ljust(W_BATCH)}"
        f"| {'lr'.ljust(W_LR)}"
        f"| {'loss'.ljust(W_LOSS)}"
        f"| {'accuracy'.ljust(W_ACC)}"
        f"| {'epoch_time'.ljust(W_TIME)}"
        f"| {'gpu_mem'.ljust(W_MEM)} |\n"
    )

    separator = "+" + "-"*(len(header)-3) + "+\n"


    # Build the ASCII table
    table = separator
    table += header
    table += separator

    for trial in trials:
        r = trial.last_result

        row = (
            f"| {trial.trial_id.ljust(W_NAME)}"
            f"| {trial.status.ljust(W_STATUS)}"
            f"| {str(trial.config.get('batch_size', '-')).ljust(W_BATCH)}"
            f"| {fmt(trial.config.get('lr', '-'), W_LR)}"
            f"| {fmt(r.get('loss', '-'), W_LOSS)}"
            f"| {fmt(r.get('accuracy', '-'), W_ACC)}"
            f"| {fmt(r.get('epoch_time', '-'), W_TIME)}"
            f"| {fmt(r.get('gpu_mem', '-'), W_MEM)} |\n"
        )
        table += row

    table += separator

    # Save everything to the file
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write("=== Ray Tune Summary (ASCII Table) ===\n\n")
        f.write(f"Result logdir: {best_result.path}\n")
        f.write(f"Number of trials: {len(trials)}\n\n")
        f.write(table)
        f.write("\nBest config:\n")
        f.write(str(best_result.config))

    print(f"\nSaved pretty ASCII summary to: {summary_path}")
    print(table)

# Test the Model

In [None]:
if REGULAR_TRAINING:
    test_loss, test_acc = evaluate(model, test_loader, criterion, device, split="test")

    print("\n=== Test Results ===")
    print(f"Test Loss:     {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")

# Evaluation results

## Epoch plot: Train / Validation accuracy

In [None]:
if LOGGING_ACTIVE:
    if optimizer is not None:
        epochs_range = range(1, NUM_EPOCHS + 1)

        plt.figure(figsize=(16, 10))
        plt.plot(epochs_range, train_acc_history, marker="o", label="Train accuracy")
        plt.plot(epochs_range, val_acc_history, marker="o", label="Validation accuracy")

        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Accuracy over epochs")
        plt.grid(True)
        plt.legend()

        # Create folder for plots if it doesn't exist
        plots_dir = ROOT / "plots"
        plots_dir.mkdir(exist_ok=True)

        # Use your logging filename for uniqueness
        plot_path = plots_dir / f"{run_name}_accuracy.png"
        plt.savefig(plot_path, dpi=300, bbox_inches="tight")
        plt.close()

        print(f"Saved accuracy plot to: {plot_path}")

        # (Optional) also log to W&B
        wandb.log({"plots/accuracy_curve": wandb.Image(str(plot_path))})

In [None]:
# --- Evaluation + Top-N Confusion Logging ---
if LOGGING_ACTIVE:
    model.eval()   # IMPORTANT: inference mode

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            _, preds = outputs.max(1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

## Confusion matrix

In [None]:
if LOGGING_ACTIVE:
    cm = confusion_matrix(all_labels, all_preds)    # Confusion matrix

    # Top-N confused class pairs
    N = 20

    off = cm.copy()
    np.fill_diagonal(off, 0)

    indices = np.dstack(
        np.unravel_index(np.argsort(off.ravel())[::-1], off.shape)
    )[0][:N]

    rows = []
    for (t, p) in indices:
        rows.append([class_names[t], class_names[p], int(cm[t, p])])

    table = wandb.Table(
        columns=["true_class", "predicted_class", "count"],
        data=rows
    )

    wandb.log({"top_confusions": table})

    # Log Confusion Matrix Figure in W&B
    fig = plt.figure(figsize=(30, 30), dpi=300)
    plt.imshow(cm, interpolation="nearest", cmap="viridis")
    plt.colorbar()

    plt.title("Confusion Matrix (101×101)", fontsize=28)
    plt.xlabel("Predicted", fontsize=24)
    plt.ylabel("True", fontsize=24)
    plt.xticks([])
    plt.yticks([])

    plt.tight_layout(pad=0.3)

    # Create temporary file
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
        tmp_path = tmp.name
        plt.savefig(tmp_path, dpi=300, bbox_inches="tight")

    plt.close(fig)

    # Log to W&B
    wandb.log({"confusion_matrix_image": wandb.Image(tmp_path)})

    # Delete temporary file
    os.remove(tmp_path)

## Per-class accuracy (sorted)

In [None]:
# --- Per-class accuracy sorted (best -> worst) ---
if LOGGING_ACTIVE:
    num_classes = len(class_names)
    class_correct = np.zeros(num_classes, dtype=int)
    class_total = np.zeros(num_classes, dtype=int)

    for t, p in zip(all_labels, all_preds):
        class_total[t] += 1
        if t == p:
            class_correct[t] += 1

    class_accuracy = (class_correct / class_total)

    # build sorted list
    sorted_idx = np.argsort(class_accuracy)[::-1]  # best first

    rows = []
    for idx in sorted_idx:
        rows.append([
            class_names[idx],
            float(class_accuracy[idx]),
            int(class_correct[idx]),
            int(class_total[idx])
        ])

    acc_table = wandb.Table(
        columns=["class", "accuracy", "correct", "total"],
        data=rows
    )

    wandb.log({"per_class_accuracy_sorted": acc_table})


## Top 10 most misclassified classes

In [None]:
if LOGGING_ACTIVE:
    # Find top misclassified classes
    num_classes   = len(class_names)
    class_correct = np.zeros(num_classes, dtype=int)
    class_total   = np.zeros(num_classes, dtype=int)

    for t, p in zip(all_labels, all_preds):
        class_total[t] += 1
        if t == p:
            class_correct[t] += 1

    class_errors = class_total - class_correct
    sorted_err_idx = np.argsort(class_errors)[::-1] # Sort classes by number of misclassifications

    top_k = 10
    top_mis_classes = sorted_err_idx[:top_k]

    # Collect misclassified images for these classes
    images_by_class = {cls: [] for cls in top_mis_classes}

    model.eval()
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            _, preds = outputs.max(1)

            mismatch = preds != labels

            for img, pred, true in zip(imgs[mismatch], preds[mismatch], labels[mismatch]):
                t = true.item()
                if t in images_by_class and len(images_by_class[t]) < 10:
                    images_by_class[t].append(
                        wandb.Image(
                            img.cpu(),
                            caption=f"true={class_names[t]}, pred={class_names[pred.item()]}"
                        )
                    )

            # stop early if all 10 classes reached 10 images
            if all(len(v) >= 10 for v in images_by_class.values()):
                break

    # Log grouped images to W&B
    log_dict = {}
    for cls in top_mis_classes:
        name = class_names[cls]
        log_dict[f"misclassified/{name}"] = images_by_class[cls]

    wandb.log(log_dict)

## GradCam for wrong predictions

In [None]:
# The code in this cell was generated by Chat-GPT

if LOGGING_ACTIVE:
    
    # 1. Grad-CAM class
    class GradCAM:
        def __init__(self, model, target_layer):
            self.model = model
            self.target_layer = target_layer

            self.gradients = None
            self.activations = None

            target_layer.register_forward_hook(self._save_activation)
            target_layer.register_full_backward_hook(self._save_gradient)

        def _save_activation(self, module, inp, out):
            self.activations = out

        def _save_gradient(self, module, grad_in, grad_out):
            self.gradients = grad_out[0]

        def __call__(self, x, class_idx):
            self.model.zero_grad()
            out = self.model(x)
            loss = out[:, class_idx]
            loss.backward()

            weights = self.gradients.mean(dim=(2, 3), keepdim=True)
            cam = (weights * self.activations).sum(dim=1)

            # Normalization with zero-protection
            cam = F.relu(cam)
            cam = cam - cam.min()

            max_val = cam.max()
            if max_val > 0:
                cam = cam / max_val
            else:
                cam = torch.zeros_like(cam)

            return cam

    # 2. Correct EfficientNet-B0 target layer
    target_layer = model.features[7][0].block[1][0]

    # enable gradients only on this layer
    for p in target_layer.parameters():
        p.requires_grad = True

    gradcam = GradCAM(model, target_layer)

    # 3. Overlay helper
    def overlay_cam(img_tensor, cam):
        img = img_tensor.numpy().transpose(1, 2, 0)
        img = (img - img.min()) / (img.max() - img.min())

        cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
        heatmap = cv2.applyColorMap(np.uint8(cam * 255), cv2.COLORMAP_JET)
        heatmap = heatmap.astype(np.float32) / 255.0

        overlay = img * 0.5 + heatmap * 0.5
        overlay = np.clip(overlay, 0, 1)
        return overlay

    # Heat Legend (colorbar)
    def generate_colorbar_with_text(height=300, width=60):
        """
        Returns a (H,W,3) numpy image containing:
        - JET colormap
        - Labels 'High', 'Mid', 'Low'
        - Numeric ticks 1.0, 0.5, 0.0
        """
        # Create vertical gradient
        gradient = np.linspace(1, 0, height).reshape(height, 1)
        gradient = np.repeat(gradient, width, axis=1)
        gradient_uint8 = np.uint8(gradient * 255)
        jet = cv2.applyColorMap(gradient_uint8, cv2.COLORMAP_JET)
        jet = cv2.cvtColor(jet, cv2.COLOR_BGR2RGB)

        # Create a canvas for text (slightly wider)
        canvas = np.ones((height, width + 120, 3), dtype=np.float32)
        canvas[:, :width] = jet / 255.0

        # Add text using cv2
        font = cv2.FONT_HERSHEY_SIMPLEX

        cv2.putText(canvas, "High (1.0)", (width + 10, 20), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)           # High (red)
        cv2.putText(canvas, "Mid (0.5)", (width + 10, height // 2), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)   # Middle
        cv2.putText(canvas, "Low (0.0)", (width + 10, height - 10), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)   # Low (blue)

        return canvas

    # Generate + log legend
    colorbar_img = generate_colorbar_with_text()
    wandb.log({"gradcam/color_scale": wandb.Image(colorbar_img)})

    # 4. Grad-CAM on misclassified images
    gradcam_results = []
    model.eval()

    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        _, preds = outputs.max(1)

        mismatch = preds != labels
        mismatch_idx = torch.where(mismatch)[0]

        for idx in mismatch_idx:
            img = imgs[idx].unsqueeze(0)
            class_idx = preds[idx].item()

            cam = gradcam(img, class_idx=class_idx)[0].detach().cpu().numpy()

            overlay = overlay_cam(imgs[idx].cpu(), cam)

            gradcam_results.append(
                wandb.Image(
                    overlay,
                    caption=f"true={class_names[labels[idx].item()]}, pred={class_names[preds[idx].item()]}"
                )
            )

            if len(gradcam_results) >= 20:
                break

        if len(gradcam_results) >= 20:
            break

    # 5. Log all Grad-CAM images
    wandb.log({"gradcam/misclassified": gradcam_results})

## Misclassified examples

In [None]:
if LOGGING_ACTIVE:
    misclassified_images = []
    misclassified_preds = []
    misclassified_labels = []

    model.eval()
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            _, preds = outputs.max(1)

            mismatch = preds != labels
            if mismatch.any():
                for img, pred, true in zip(imgs[mismatch], preds[mismatch], labels[mismatch]):
                    misclassified_images.append(img.cpu())
                    misclassified_preds.append(pred.cpu().item())
                    misclassified_labels.append(true.cpu().item())

                    # keep it small
                    if len(misclassified_images) >= 32:
                        break
            if len(misclassified_images) >= 32:
                break

    # Log to W&B as images
    wandb.log({
        "test/misclassified_examples": [
            wandb.Image(
                img,
                caption=f"pred: {class_names[p]}, true: {class_names[t]}"
            )
            for img, p, t in zip(misclassified_images, misclassified_preds, misclassified_labels)
        ]
    })

In [None]:
if LOGGING_ACTIVE:
    wandb.finish()