# Settings script

In [58]:
# --- MODEL SELECTION ---
MODEL_SCRATCH = False     # No pretrained weights, trained from scratch
MODEL_BASELINE = False    # Pretrained EfficientNet, but NOT trained on Food-101
MODEL_FINETUNE = True     # Pretrained EfficientNet, THEN fine-tuned on Food-101

# --- OTHER SETTINGS ---
RUN_RAYTUNE = False
USE_FULL_DATASET = False    # If false, use small dataset instead (10K pictures, 100 per class)
NUM_EPOCHS = 10
MANUAL_DATA_AUGMENTATION = False
LOGGING_FILENAME = "efficientnet_b0_food101_run_9"

# --- Sanity check: EXACTLY one model active ---
assert sum([MODEL_BASELINE, MODEL_SCRATCH, MODEL_FINETUNE]) == 1, \
    "Exactly one model must be selected."

# Install dependencies

In [59]:
#pip install -r requirements.txt

# Imports

In [None]:
# Standard library
# Standard library
import os
import pathlib
import datetime

# ML
import numpy as np
import pathlib

# ML
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

# TorchVision
from torchvision import datasets, transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from torchvision.transforms import AutoAugment, AutoAugmentPolicy

# W&B
import wandb
import cv2
import matplotlib.pyplot as plt

# Ray Tune
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler


# Set GPU variable

In [61]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


# Data paths

In [62]:
ROOT = pathlib.Path().resolve()

if USE_FULL_DATASET:
    IMAGE_DIR = ROOT / "data" / "food-101-big"
else:
    IMAGE_DIR = ROOT / "data" / "food-101-small"

print("Using IMAGE_DIR =", IMAGE_DIR)

if not IMAGE_DIR.exists():
    raise FileNotFoundError(f"Missing dataset folder:\n{IMAGE_DIR}")

Using IMAGE_DIR = C:\Users\EG\OneDrive - ITU\Kandidat\3. semester\AML - Computer Vision\IMTL\data\food-101-small


# Data preprocessing

In [63]:
img_size = 224 # Needed width and height dimensions for EfficientNet

# Use AutoAugment for data augmentation
if MANUAL_DATA_AUGMENTATION:
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])    

else:
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
    ])

test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])

full_dataset = datasets.ImageFolder(IMAGE_DIR, transform=train_transform, allow_empty=True)
class_names = full_dataset.classes
num_classes = len(class_names)

print("Total images:", len(full_dataset))
print("Classes:", num_classes)

Total images: 10100
Classes: 101


# Split data into sets

In [64]:
train_size = int(0.70 * len(full_dataset))
val_size   = int(0.15 * len(full_dataset))
train_size = int(0.70 * len(full_dataset))
val_size   = int(0.15 * len(full_dataset))
test_size  = len(full_dataset) - train_size - val_size

train_ds, val_ds, test_ds = random_split(full_dataset, [train_size, val_size, test_size])

# Create DataLoaders

In [65]:
val_ds.dataset.transform = test_transform
test_ds.dataset.transform = test_transform

batch_size_default = 32

train_loader = DataLoader(train_ds, batch_size=batch_size_default, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size_default, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size_default, shuffle=False)

# Define Model

In [66]:
# Pretrained model but NOT trained (baseline)
def build_model_baseline(lr, num_classes):
    from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
    import torch.nn as nn

    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    # Freeze all parameters → no training
    for p in model.parameters():
        p.requires_grad = False

    optimizer = None 
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion


# Train from scratch
def build_model_scratch(lr, num_classes):
    from torchvision.models import efficientnet_b0
    import torch.nn as nn
    import torch

    model = efficientnet_b0(weights=None)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion


# Pretrained → fine-tuned
def build_model_finetune(lr, num_classes):
    from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
    import torch.nn as nn
    import torch

    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    return model, optimizer, criterion



# def build_model_1(lr, num_classes):
#     """EfficientNet-B0 model"""
#     weights = EfficientNet_B0_Weights.IMAGENET1K_V1
#     model = efficientnet_b0(weights=weights)

#     model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

#     for p in model.features.parameters():
#         p.requires_grad = False

#     model = model.to(device)
#     optimizer = optim.Adam(model.parameters(), lr=lr)
#     criterion = nn.CrossEntropyLoss()

#     return model, optimizer, criterion


# def build_model_2(lr, num_classes):
#     """Placeholder model — intentionally blank"""
#     raise NotImplementedError("MODEL_2 not implemented yet.")

In [67]:
if MODEL_BASELINE:
    build_model = build_model_baseline
elif MODEL_SCRATCH:
    build_model = build_model_scratch
elif MODEL_FINETUNE:
    build_model = build_model_finetune


# Define training and evaluation functions

## Base training

In [142]:
def train_one_epoch(model, loader, optimizer, criterion, device, epoch, log_interval=100):
def train_one_epoch(model, loader, optimizer, criterion, device, epoch, log_interval=100):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    total_loss, correct, total = 0.0, 0, 0

    pbar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)

    for batch_idx, (imgs, labels) in enumerate(pbar):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

        avg_loss_so_far = total_loss / total
        avg_acc_so_far = correct / total
        avg_loss_so_far = total_loss / total
        avg_acc_so_far = correct / total

        pbar.set_postfix({
            "loss": f"{avg_loss_so_far:.4f}",
            "acc": f"{100 * avg_acc_so_far:.2f}%"
        })

        # Optional: batch-level logging to W&B
        if (batch_idx + 1) % log_interval == 0:
            wandb.log(
                {
                    "batch/train_loss": loss.item(),
                    "batch/lr": optimizer.param_groups[0]["lr"],
                },
                step=epoch
            )

    avg_loss = total_loss / total
    avg_acc = correct / total

    # Epoch-level logging to W&B
    wandb.log({
        "train/loss": avg_loss,
        "train/accuracy": avg_acc,
        "train/error_rate": 1 - avg_acc,
    }, step=epoch)

    return avg_loss, avg_acc

def evaluate(model, loader, criterion, device, epoch=None, split="val"):
    """
    Generic evaluation for val/test.
    split: "val" or "test" (used in W&B metric names).
    """
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    desc = f"Evaluating ({split})"
    if epoch is not None:
        desc += f" Epoch {epoch}"
    total_loss, correct, total = 0.0, 0, 0

    desc = f"Evaluating ({split})"
    if epoch is not None:
        desc += f" Epoch {epoch}"

    with torch.no_grad():
        pbar = tqdm(loader, desc=desc, leave=False)

        for imgs, labels in pbar:
        pbar = tqdm(loader, desc=desc, leave=False)

        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)


            outputs = model(imgs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

            avg_loss_so_far = total_loss / total
            avg_acc_so_far = correct / total

            pbar.set_postfix({
                "loss": f"{avg_loss_so_far:.4f}",
                "acc": f"{100 * avg_acc_so_far:.2f}%"
            })

    avg_loss = total_loss / total
    avg_acc = correct / total

    # Log to W&B if epoch is known (for val) or just once (for test)
    metric_prefix = f"{split}"
    log_data = {
        f"{metric_prefix}/loss": avg_loss,
        f"{metric_prefix}/accuracy": avg_acc,
        f"{metric_prefix}/error_rate": 1 - avg_acc,
    }
    if epoch is not None:
        log_data["epoch"] = epoch

    wandb.log(log_data, step=epoch if epoch is not None else wandb.run.step)

    return avg_loss, avg_acc

## RayTune training (Hyperparameter Optimization)

In [69]:
def train_one_epoch_raytune(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for imgs, labels in tqdm(loader, desc="Training", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total

def train_one_epoch_raytune(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for imgs, labels in tqdm(loader, desc="Training", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total

def tune_train(config):
    batch_size = config["batch_size"]
    lr = config["lr"]

    # local transforms
    train_ds.dataset.transform = train_transform
    val_ds.dataset.transform   = test_transform

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)

    model, optimizer, criterion = build_model(lr, num_classes)

    for epoch in range(config["epochs"]):
        train_one_epoch_raytune(model, train_loader, optimizer, criterion, device)
        train_one_epoch_raytune(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        tune.report({"loss": float(val_loss), "accuracy": float(val_acc)})

# Run RayTune (If toggled on)

In [70]:
if RUN_RAYTUNE:
    os.environ["RAY_DISABLE_METRICS_EXPORT"] = "1"

    search_space = {
        "lr": tune.loguniform(1e-5, 1e-2),
        "batch_size": tune.choice([16, 32, 64]),
        "epochs": 3
    }

    scheduler = ASHAScheduler(metric="accuracy", mode="max")
    reporter  = CLIReporter(metric_columns=["loss", "accuracy"])

    tuner = tune.Tuner(
        tune.with_resources(
            tune_train,
            resources={"cpu": 4, "gpu": 1 if torch.cuda.is_available() else 0},
        ),
        param_space=search_space,
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            num_samples=6,
        ),
        run_config=tune.RunConfig(progress_reporter=reporter),
    )

    results = tuner.fit()
    best = results.get_best_result(metric="accuracy", mode="max")

    print("Best config:", best.config)

    best_config = best.config

# Initialize model

In [145]:
if not RUN_RAYTUNE:
    model, optimizer, criterion = build_model(1e-3, num_classes)

# Logging

In [None]:
if not RUN_RAYTUNE:
    # W&B init
    wandb.login()  # will prompt you the first time in this environment

    batch_size = 32
    learning_rate = 1e-3

    def get_run_name(lr, batch_size):
        lr_string = str(lr)
        batch_size_string = str(batch_size)
        date_string = str(datetime.datetime.now())
        return f"lr{lr_string}_bs{batch_size_string}_{date_string}"

    wandb_config = {
        "model_name": "efficientnet_b0",
        "img_size": img_size,
        "batch_size": batch_size,
        "optimizer": "Adam",
        "learning_rate": learning_rate,
        "train_size": len(train_ds),
        "val_size": len(val_ds),
        "test_size": len(test_ds),
        "num_classes": num_classes,
        "transfer_learning": True,      
        "feature_extractor_frozen": True,
        "augmentation": "AutoAugment_IMAGENET"
    }

    run = wandb.init(
        project="IMTL",
        config=wandb_config,
        name=get_run_name(learning_rate, batch_size),
    )

    wandb.watch(model, log="all", log_freq=100)

# Initialize model

In [71]:
if not RUN_RAYTUNE:
    model, optimizer, criterion = build_model(1e-3, num_classes)

    model = model.to(device)
    criterion = criterion.to(device)


--- Epoch 1/5 ---


                                                                                            

Train Loss: 3.8417, Acc: 0.2143
Val   Loss: 3.1640, Acc: 0.3426

--- Epoch 2/5 ---


                                                                                            

Train Loss: 2.6542, Acc: 0.4614
Val   Loss: 2.6948, Acc: 0.4026

--- Epoch 3/5 ---


                                                                                            

Train Loss: 2.1690, Acc: 0.5366
Val   Loss: 2.4803, Acc: 0.4271

--- Epoch 4/5 ---


                                                                                            

Train Loss: 1.9051, Acc: 0.5777
Val   Loss: 2.3912, Acc: 0.4422

--- Epoch 5/5 ---


                                                                                            

Train Loss: 1.6965, Acc: 0.6232
Val   Loss: 2.3381, Acc: 0.4462




# Logging

In [None]:
if not RUN_RAYTUNE:
    # W&B init
    wandb.login()  # will prompt you the first time in this environment

    batch_size = 32
    learning_rate = 1e-3

    def get_run_name(lr, batch_size):
        lr_string = str(lr)
        batch_size_string = str(batch_size)
        date_string = str(datetime.datetime.now())
        return f"lr{lr_string}_bs{batch_size_string}_{date_string}"

    wandb_config = {
        "model_name": "efficientnet_b0",
        "img_size": img_size,
        "batch_size": batch_size,
        "optimizer": "Adam",
        "learning_rate": learning_rate,
        "train_size": len(train_ds),
        "val_size": len(val_ds),
        "test_size": len(test_ds),
        "num_classes": num_classes,
        "transfer_learning": True,      
        "feature_extractor_frozen": True,
        "augmentation": "AutoAugment_IMAGENET"
    }

    run = wandb.init(
        project="IMTL",
        config=wandb_config,
        name=get_run_name(learning_rate, batch_size),
    )

    wandb.watch(model, log="all", log_freq=100)

# Train the Model

In [73]:
if optimizer is None:
    print("\n=== Baseline model selected — skipping training ===")
    # Run validation ONCE so you still see how bad the logits are
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

else:
    epochs = NUM_EPOCHS
    for epoch in range(epochs):
        print(f"\n--- Epoch {epoch+1}/{epochs} ---")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)
        val_loss, val_acc     = evaluate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")



--- Epoch 1/10 ---


                                                                                            

Train Loss: 3.0849, Acc: 0.2813
Val   Loss: 2.3874, Acc: 0.3974

--- Epoch 2/10 ---


                                                                                            

Train Loss: 1.7448, Acc: 0.5436
Val   Loss: 1.9461, Acc: 0.5162

--- Epoch 3/10 ---


                                                                                            

Train Loss: 1.0933, Acc: 0.6999
Val   Loss: 1.9143, Acc: 0.5366

--- Epoch 4/10 ---


                                                                                            

Train Loss: 0.6933, Acc: 0.8013
Val   Loss: 1.9575, Acc: 0.5498

--- Epoch 5/10 ---


                                                                                            

Train Loss: 0.5028, Acc: 0.8491
Val   Loss: 2.0760, Acc: 0.5518

--- Epoch 6/10 ---


                                                                                            

Train Loss: 0.3579, Acc: 0.8965
Val   Loss: 2.1434, Acc: 0.5228

--- Epoch 7/10 ---


                                                                                             

Train Loss: 0.3018, Acc: 0.9074
Val   Loss: 2.2364, Acc: 0.5452

--- Epoch 8/10 ---


                                                                                            

Train Loss: 0.2560, Acc: 0.9231
Val   Loss: 2.3704, Acc: 0.5380

--- Epoch 9/10 ---


                                                                                            

Train Loss: 0.2504, Acc: 0.9290
Val   Loss: 2.5851, Acc: 0.5380

--- Epoch 10/10 ---


                                                                                            

Train Loss: 0.2454, Acc: 0.9245
Val   Loss: 2.5062, Acc: 0.5281




# Test the Model

In [None]:
if not RUN_RAYTUNE:
    test_loss, test_acc = evaluate(model, test_loader, criterion, device, split="test")

    print("\n=== Test Results ===")
    print(f"Test Loss:     {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")

    

                                                                                          


=== Test Results ===
Test Loss:     2.4314
Test Accuracy: 0.5347




# Evaluation results

In [75]:
# --- Evaluation + Top-N Confusion Logging ---

model.eval()   # IMPORTANT: inference mode

all_preds = []
all_labels = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        _, preds = outputs.max(1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

## Confusion matrix

In [76]:
# --- Confusion matrix ---
cm = confusion_matrix(all_labels, all_preds)

# --- Top-N confused class pairs (better than 101×101 heatmap) ---
N = 20  # adjust as needed

off = cm.copy()
np.fill_diagonal(off, 0)

indices = np.dstack(np.unravel_index(
    np.argsort(off.ravel())[::-1], off.shape
))[0][:N]

rows = []
for (t, p) in indices:
    rows.append([
        class_names[t],
        class_names[p],
        int(cm[t, p])
    ])

table = wandb.Table(
    columns=["true_class", "predicted_class", "count"],
    data=rows
)

wandb.log({"top_confusions": table})

## Per-class accuracy (sorted)

In [77]:
# --- Per-class accuracy sorted (best -> worst) ---

num_classes = len(class_names)
class_correct = np.zeros(num_classes, dtype=int)
class_total = np.zeros(num_classes, dtype=int)

for t, p in zip(all_labels, all_preds):
    class_total[t] += 1
    if t == p:
        class_correct[t] += 1

class_accuracy = (class_correct / class_total)

# build sorted list
sorted_idx = np.argsort(class_accuracy)[::-1]  # best first

rows = []
for idx in sorted_idx:
    rows.append([
        class_names[idx],
        float(class_accuracy[idx]),
        int(class_correct[idx]),
        int(class_total[idx])
    ])

acc_table = wandb.Table(
    columns=["class", "accuracy", "correct", "total"],
    data=rows
)

wandb.log({"per_class_accuracy_sorted": acc_table})


## Top 10 most misclassified classes

In [78]:
# --- Find top misclassified classes ---

num_classes = len(class_names)
class_correct = np.zeros(num_classes, dtype=int)
class_total   = np.zeros(num_classes, dtype=int)

for t, p in zip(all_labels, all_preds):
    class_total[t] += 1
    if t == p:
        class_correct[t] += 1

class_errors = class_total - class_correct

# sort classes by number of misclassifications
sorted_err_idx = np.argsort(class_errors)[::-1]

top_k = 10
top_mis_classes = sorted_err_idx[:top_k]

# --- Collect misclassified images for these classes ---

images_by_class = {cls: [] for cls in top_mis_classes}

model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        _, preds = outputs.max(1)

        mismatch = preds != labels

        for img, pred, true in zip(imgs[mismatch], preds[mismatch], labels[mismatch]):
            t = true.item()
            if t in images_by_class and len(images_by_class[t]) < 10:
                images_by_class[t].append(
                    wandb.Image(
                        img.cpu(),
                        caption=f"true={class_names[t]}, pred={class_names[pred.item()]}"
                    )
                )

        # stop early if all 10 classes reached 10 images
        if all(len(v) >= 10 for v in images_by_class.values()):
            break

# --- Log grouped images to W&B ---

log_dict = {}
for cls in top_mis_classes:
    name = class_names[cls]
    log_dict[f"misclassified/{name}"] = images_by_class[cls]

wandb.log(log_dict)


## GradCam for wrong predictions

In [79]:
# ----------------------------
# 1. Grad-CAM class
# ----------------------------
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer

        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_backward_hook(self._save_gradient)

    def _save_activation(self, module, inp, out):
        self.activations = out

    def _save_gradient(self, module, grad_in, grad_out):
        self.gradients = grad_out[0]

    def __call__(self, x, class_idx):
        self.model.zero_grad()

        out = self.model(x)
        loss = out[:, class_idx]
        loss.backward()

        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1)

        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / cam.max()

        return cam


# ----------------------------
# 2. Correct EfficientNet-B0 target layer
# ----------------------------
target_layer = model.features[7][0].block[1][0]

# enable gradients only on this layer
for p in target_layer.parameters():
    p.requires_grad = True

gradcam = GradCAM(model, target_layer)


# ----------------------------
# 3. Overlay helper
# ----------------------------
def overlay_cam(img_tensor, cam):
    img = img_tensor.numpy().transpose(1, 2, 0)
    img = (img - img.min()) / (img.max() - img.min())

    cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(cam * 255), cv2.COLORMAP_JET)
    heatmap = heatmap.astype(np.float32) / 255.0

    overlay = img * 0.5 + heatmap * 0.5
    overlay = np.clip(overlay, 0, 1)
    return overlay


# ----------------------------
# 3.5  SINGLE HEAT LEGEND (colorbar)
# ----------------------------
# ----------------------------
# 3.5  SINGLE HEAT LEGEND (colorbar with text)
# ----------------------------
def generate_colorbar_with_text(height=300, width=60):
    """
    Returns a (H,W,3) numpy image containing:
    - JET colormap
    - Labels 'High', 'Mid', 'Low'
    - Numeric ticks 1.0, 0.5, 0.0
    """
    # Create vertical gradient
    gradient = np.linspace(1, 0, height).reshape(height, 1)
    gradient = np.repeat(gradient, width, axis=1)
    gradient_uint8 = np.uint8(gradient * 255)
    jet = cv2.applyColorMap(gradient_uint8, cv2.COLORMAP_JET)
    jet = cv2.cvtColor(jet, cv2.COLOR_BGR2RGB)

    # Create a canvas for text (slightly wider)
    canvas = np.ones((height, width + 120, 3), dtype=np.float32)
    canvas[:, :width] = jet / 255.0

    # Add text using cv2
    font = cv2.FONT_HERSHEY_SIMPLEX

    cv2.putText(canvas, "High (1.0)", (width + 10, 20), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)           # High (red)
    cv2.putText(canvas, "Mid (0.5)", (width + 10, height // 2), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)   # Middle
    cv2.putText(canvas, "Low (0.0)", (width + 10, height - 10), font, 0.5, (1, 1, 1), 1, cv2.LINE_AA)   # Low (blue)

    return canvas

# Generate + log legend
colorbar_img = generate_colorbar_with_text()
wandb.log({"gradcam/color_scale": wandb.Image(colorbar_img)})


# ----------------------------
# 4. Grad-CAM on misclassified images
# ----------------------------
gradcam_results = []
model.eval()

for imgs, labels in test_loader:
    imgs = imgs.to(device)
    labels = labels.to(device)

    outputs = model(imgs)
    _, preds = outputs.max(1)

    mismatch = preds != labels
    mismatch_idx = torch.where(mismatch)[0]

    for idx in mismatch_idx:
        img = imgs[idx].unsqueeze(0)
        class_idx = preds[idx].item()

        cam = gradcam(img, class_idx=class_idx)[0].detach().cpu().numpy()

        overlay = overlay_cam(imgs[idx].cpu(), cam)

        gradcam_results.append(
            wandb.Image(
                overlay,
                caption=f"true={class_names[labels[idx].item()]}, pred={class_names[preds[idx].item()]}"
            )
        )

        if len(gradcam_results) >= 20:
            break

    if len(gradcam_results) >= 20:
        break

# ----------------------------
# 5. Log all Grad-CAM images
# ----------------------------
wandb.log({"gradcam/misclassified": gradcam_results})

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  heatmap = cv2.applyColorMap(np.uint8(cam * 255), cv2.COLORMAP_JET)


## Misclassified examples

In [80]:
misclassified_images = []
misclassified_preds = []
misclassified_labels = []

model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        _, preds = outputs.max(1)

        mismatch = preds != labels
        if mismatch.any():
            for img, pred, true in zip(imgs[mismatch], preds[mismatch], labels[mismatch]):
                misclassified_images.append(img.cpu())
                misclassified_preds.append(pred.cpu().item())
                misclassified_labels.append(true.cpu().item())

                # keep it small
                if len(misclassified_images) >= 32:
                    break
        if len(misclassified_images) >= 32:
            break

# Log to W&B as images
wandb.log({
    "test/misclassified_examples": [
        wandb.Image(
            img,
            caption=f"pred: {class_names[p]}, true: {class_names[t]}"
        )
        for img, p, t in zip(misclassified_images, misclassified_preds, misclassified_labels)
    ]
})


In [81]:
wandb.finish()

0,1
batch/lr,▁▁▁▁▁▁▁▁▁▁
batch/train_loss,█▆▅▄▃▂▁▁▂▁
train/accuracy,▁▄▆▇▇█████
train/error_rate,█▅▃▂▂▁▁▁▁▁
train/loss,█▅▃▂▂▁▁▁▁▁
val/accuracy,▁▆▇██▇█▇▇▇
val/error_rate,█▃▂▁▁▂▁▂▂▂
val/loss,▆▁▁▁▃▃▄▆█▆

0,1
batch/lr,0.001
batch/train_loss,0.27734
train/accuracy,0.92447
train/error_rate,0.07553
train/loss,0.2454
val/accuracy,0.53465
val/error_rate,0.46535
val/loss,2.4314
