In [None]:
# Step 0: Create output folders for all stages, including interpretability results

import os

KAGGLE_WORKING = '/kaggle/working'
OUTPUT_ROOT = os.path.join(KAGGLE_WORKING, "outputs")
FEATURES_DIR = os.path.join(OUTPUT_ROOT, "features")           # for extracted regions/crops
SPLITS_DIR = os.path.join(OUTPUT_ROOT, "splits")               # for train/val/test splits
MODELS_DIR = os.path.join(OUTPUT_ROOT, "models")               # trained models
LOGS_DIR = os.path.join(OUTPUT_ROOT, "logs")                   # training logs, CSVs
METRICS_DIR = os.path.join(OUTPUT_ROOT, "metrics")             # evaluation metrics as JSON/CSV
VIS_DIR = os.path.join(OUTPUT_ROOT, "visualizations")          # plots, confusion matrix, etc.
BEST_MODEL_DIR = os.path.join(OUTPUT_ROOT, "best_model_results") # graphs and evaluations for best model
INTERP_DIR = os.path.join(OUTPUT_ROOT, "interpretability")     # interpretability visualizations
STATS_DIR = os.path.join(OUTPUT_ROOT, "stats")                 # statistical significance test results

for d in [
    OUTPUT_ROOT, FEATURES_DIR, SPLITS_DIR, MODELS_DIR, LOGS_DIR,
    METRICS_DIR, VIS_DIR, BEST_MODEL_DIR, INTERP_DIR, STATS_DIR
]:
    os.makedirs(d, exist_ok=True)

TRAIN_DIR = os.path.join(SPLITS_DIR, "train")
VAL_DIR = os.path.join(SPLITS_DIR, "val")
TEST_DIR = os.path.join(SPLITS_DIR, "test")

In [None]:
# Step 1: Set Global Seed for Reproducibility

import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# Step 2: Install essential libraries for training, evaluation, and model analysis (Kaggle: may skip if preinstalled)
!pip install torch torchvision
!pip install timm transformers roboflow pycocotools
!pip install grad-cam torchinfo

In [None]:
# Step 3: Download the COCO-format segmentation dataset from Roboflow

from roboflow import Roboflow
import os
import json

rf = Roboflow(api_key="Mvt9FCxE4mY6vBy5OG08")  # Replace with your key if needed
project = rf.workspace("urban-lake-wastef").project("another_approach_try")
version = project.version(4)
dataset = version.download("coco-segmentation")

DATA_ROOT = dataset.location
TRAIN_JSON = os.path.join(DATA_ROOT, 'train', '_annotations.coco.json')
IMG_DIR = os.path.join(DATA_ROOT, 'train')

In [None]:
# Step 4: Extract foreground objects from COCO masks and save cropped images in FEATURES_DIR

from PIL import Image, ImageDraw
import numpy as np

with open(TRAIN_JSON) as f:
    ann_data = json.load(f)

cat_map = {c['id']: c['name'] for c in ann_data['categories']}

for ann in ann_data['annotations']:
    img_info = next(img for img in ann_data['images'] if img['id'] == ann['image_id'])
    img_path = os.path.join(IMG_DIR, img_info['file_name'])
    img = Image.open(img_path).convert('RGB')

    seg = ann['segmentation']
    mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)

    for poly in seg:
        pts = np.array(poly).reshape(-1, 2)
        m = Image.new('L', (img_info['width'], img_info['height']), 0)
        ImageDraw.Draw(m).polygon([tuple(p) for p in pts], outline=1, fill=1)
        mask = np.maximum(mask, np.array(m))

    if mask.sum() < 100:  # Skip very small masks
        continue

    region = np.array(img) * mask[:, :, None]
    region_img = Image.fromarray(region)

    label = cat_map[ann['category_id']]
    out_dir = os.path.join(FEATURES_DIR, label)
    os.makedirs(out_dir, exist_ok=True)
    base = os.path.splitext(img_info['file_name'])[0]
    out_path = os.path.join(out_dir, f"{base}_{ann['id']}.png")
    region_img.save(out_path)

In [None]:
# Step 5: Visualize 10 original vs masked crops for journal-quality reporting and save in VIS_DIR

import matplotlib.pyplot as plt
from PIL import Image
import os
import random

def show_before_after_masked(original_dir, masked_dir, split_name, num_samples=10, vis_dir=None):
    print(f"\n🔍 Showing {num_samples} {split_name} images: original vs masked")
    classes = sorted(os.listdir(masked_dir))
    selected_images = []
    while len(selected_images) < num_samples:
        chosen_class = random.choice(classes)
        class_mask_dir = os.path.join(masked_dir, chosen_class)
        if not os.path.isdir(class_mask_dir):
            continue
        mask_files = os.listdir(class_mask_dir)
        if not mask_files:
            continue
        chosen_file = random.choice(mask_files)
        selected_images.append((chosen_class, chosen_file))
    for idx, (class_name, file_name) in enumerate(selected_images):
        masked_path = os.path.join(masked_dir, class_name, file_name)
        original_basename_base = "_".join(file_name.split("_")[:-1])
        possible_exts = [".jpg", ".png"]
        original_path = None
        for ext in possible_exts:
            candidate = os.path.join(original_dir, original_basename_base + ext)
            if os.path.exists(candidate):
                original_path = candidate
                break
        if not original_path:
            print(f"⚠️ Original not found for: {original_basename_base}")
            continue
        original_img = Image.open(original_path).convert("RGB")
        masked_img = Image.open(masked_path).convert("RGB")
        fig, axs = plt.subplots(1, 2, figsize=(8, 4))
        axs[0].imshow(original_img)
        axs[0].set_title("Original")
        axs[0].axis("off")
        axs[1].imshow(masked_img)
        axs[1].set_title("Masked (Extracted)")
        axs[1].axis("off")
        fig.suptitle(f"🟢 Class: {class_name} | 📄 File: {file_name}", fontsize=13)
        plt.tight_layout()
        if vis_dir:
            vis_path = os.path.join(vis_dir, f"show_before_after_{split_name}_{idx+1}.png")
            plt.savefig(vis_path)
        plt.show()

show_before_after_masked(
    original_dir=os.path.join(DATA_ROOT, "train"),
    masked_dir=FEATURES_DIR,
    split_name="Train Set",
    vis_dir=VIS_DIR
)

In [None]:
# Step 6: Split dataset into 60% train, 20% val, 20% test in SPLITS_DIR

import shutil
from sklearn.model_selection import train_test_split

for split in ['train', 'val', 'test']:
    os.makedirs(os.path.join(SPLITS_DIR, split), exist_ok=True)

for class_name in os.listdir(FEATURES_DIR):
    class_path = os.path.join(FEATURES_DIR, class_name)
    if not os.path.isdir(class_path):
        continue
    files = os.listdir(class_path)
    train_files, temp_files = train_test_split(files, test_size=0.4, random_state=42)
    val_files, test_files = train_test_split(temp_files, test_size=0.5, random_state=42)
    for split, split_files in zip(['train', 'val', 'test'], [train_files, val_files, test_files]):
        split_class_dir = os.path.join(SPLITS_DIR, split, class_name)
        os.makedirs(split_class_dir, exist_ok=True)
        for f in split_files:
            shutil.copy2(os.path.join(class_path, f), os.path.join(split_class_dir, f))

In [None]:
# Step 7: Define a PyTorch Dataset class with augmentation and weighted sampling

from torch.utils.data import Dataset, WeightedRandomSampler
from torchvision import transforms
from collections import defaultdict
from PIL import Image

class WasteRegionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.03),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.label2id = {}
        class_counts = defaultdict(int)
        classes = sorted(os.listdir(root_dir))
        self.label2id = {c: i for i, c in enumerate(classes)}
        for c in classes:
            class_dir = os.path.join(root_dir, c)
            for f in os.listdir(class_dir):
                self.samples.append((os.path.join(class_dir, f), self.label2id[c]))
                class_counts[self.label2id[c]] += 1
        self.sample_weights = [1.0 / class_counts[label] for _, label in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, label

In [None]:
# Step 8: Set device and initialize DataLoaders with journal-quality split paths

from timm import create_model
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_classes = len(os.listdir(TRAIN_DIR))

train_ds = WasteRegionDataset(TRAIN_DIR)
val_ds   = WasteRegionDataset(VAL_DIR, transform=train_ds.transform)
test_ds  = WasteRegionDataset(TEST_DIR, transform=val_ds.transform)

from torch.utils.data import DataLoader, WeightedRandomSampler

train_sampler = WeightedRandomSampler(train_ds.sample_weights, len(train_ds.sample_weights), replacement=True)
train_loader = DataLoader(train_ds, batch_size=16, sampler=train_sampler)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False)

In [None]:
# Step 9: Train multiple configurations of model and save each model uniquely in MODELS_DIR/LOGS_DIR
# This version supports resuming: if a checkpoint for a config exists, it skips to the next config!
# If interrupted, just rerun -- it will not repeat finished jobs.
# Added: Modular checkpoint/resume function, partial checkpoint detection/resume, and periodic partial saves.

import os
import csv
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from timm import create_model
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler
import pickle

from step0_setup_folders_and_seed import set_seed
from step6_dataset_and_loader import train_loader, val_loader, train_ds

MODELS_DIR = '/kaggle/working/outputs/models'
LOGS_DIR = '/kaggle/working/outputs/logs'
num_classes = len(train_ds.label2id)

def get_resume_checkpoint(run_id, model_dir):
    """
    Check for full or partial checkpoints for a run_id.
    Returns (resume_path, is_partial) or (None, False) if nothing found.
    """
    full_ckpt = os.path.join(model_dir, f"{run_id}_best.pth")
    part_ckpt = os.path.join(model_dir, f"{run_id}_partial.pth")
    if os.path.isfile(full_ckpt):
        print(f"✅ Skipping {run_id}: Full checkpoint exists.")
        return (full_ckpt, False)
    elif os.path.isfile(part_ckpt):
        print(f"⏸️ Resuming {run_id}: Partial checkpoint found.")
        return (part_ckpt, True)
    else:
        return (None, False)

def train_one_model(model_name, use_mixup=True, label_smooth=0.1, seed=42, patience=15, save_as=None, resume_ckpt=None):
    set_seed(seed)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = create_model(model_name, pretrained=True, num_classes=num_classes).to(device)

    # Optional: resume from partial checkpoint if provided
    if resume_ckpt is not None:
        print(f"Loading weights from {resume_ckpt}")
        model.load_state_dict(torch.load(resume_ckpt, map_location=device))

    labels = [label for _, label in train_ds.samples]
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smooth)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-3)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=3 * len(train_loader),
        num_training_steps=len(train_loader) * 100
    )
    scaler = GradScaler()

    def mixup_data(x, y, alpha=0.4):
        lam = np.random.beta(alpha, alpha)
        index = torch.randperm(x.size(0)).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'train_acc': []}
    best_val_loss = float('inf')
    no_improve_epochs = 0
    save_path = os.path.join(MODELS_DIR, f"{save_as}_best.pth" if save_as else f"{model_name}_best.pth")
    partial_path = os.path.join(MODELS_DIR, f"{save_as}_partial.pth")

    # If resuming from partial, try to restore optimizer/scheduler states (advanced -- not required for most use)
    # For simplicity, only model weights are restored here.

    # Optionally, you can load history up to now here for even better resume (not implemented for brevity)

    for epoch in range(5):  # Change to your preferred number of epochs
        model.train()
        total_loss = 0
        train_correct = train_total = 0

        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()

            if use_mixup:
                imgs, y_a, y_b, lam = mixup_data(imgs, labels)
            else:
                y_a, y_b, lam = labels, labels, 1.0

            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam) if use_mixup else criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()

            # Training accuracy (only valid if mixup disabled)
            if not use_mixup:
                preds = outputs.argmax(dim=1)
                train_correct += (preds == labels).sum().item()
                train_total += labels.size(0)

        scheduler.step()
        avg_train_loss = total_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        train_acc = train_correct / train_total if (not use_mixup and train_total > 0) else None
        history['train_acc'].append(train_acc)

        model.eval()
        val_loss = 0
        correct = total = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                val_loss += criterion(outputs, labels).item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct / total
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)

        print(f"[{model_name}] Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | "
              f"Val Acc: {val_acc:.4f}" +
              (f" | Train Acc: {train_acc:.4f}" if train_acc is not None else ""))

        # Save partial checkpoint after every epoch (for resume if interrupted)
        torch.save(model.state_dict(), partial_path)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve_epochs = 0
            torch.save(model.state_dict(), save_path)
            print("🔸 New best model saved.")
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print("🛑 Early stopping triggered.")
                break

    # Save training history
    pd.DataFrame(history).to_csv(os.path.join(LOGS_DIR, f"{save_as}_history.csv"), index=False)

    # If finished successfully, remove partial checkpoint file
    try:
        if os.path.exists(partial_path):
            os.remove(partial_path)
    except Exception as e:
        print(f"Warning: could not remove partial checkpoint: {e}")

    return model, history

# --- ablation runs with modular checkpoint skip/resume and partial detection ---
model_list = ['efficientnet_b0']  # You can add more models
seeds = [42, 123, 777]
mixup_options = [True, False]
label_smoothings = [0.1, 0.0]

all_histories = {}
results_csv_path = os.path.join(LOGS_DIR, "ablation_results.csv")
all_histories_path = os.path.join(LOGS_DIR, "all_histories.pkl")
resume = os.path.exists(results_csv_path)

# If resuming, load all_histories so far (if exists)
if os.path.exists(all_histories_path):
    with open(all_histories_path, "rb") as f:
        all_histories = pickle.load(f)

# CSV exists? Open in append mode. Else, write header.
csv_mode = "a" if resume else "w"
with open(results_csv_path, csv_mode, newline='') as f:
    writer = csv.writer(f)
    if not resume:
        writer.writerow(["Model", "Seed", "Mixup", "Label Smoothing", "Val Accuracy", "Checkpoint"])
    for model_name in model_list:
        for mixup in mixup_options:
            for smooth in label_smoothings:
                for seed in seeds:
                    run_id = f"{model_name}_mixup{mixup}_smooth{smooth}_seed{seed}"
                    checkpoint_path = os.path.join(MODELS_DIR, f"{run_id}_best.pth")
                    history_path = os.path.join(LOGS_DIR, f"{run_id}_history.csv")

                    # Use modular function for checkpoint skip/resume logic!
                    resume_ckpt, is_partial = get_resume_checkpoint(run_id, MODELS_DIR)
                    if resume_ckpt and not is_partial:
                        # Full checkpoint exists, skip this run
                        if os.path.isfile(history_path):
                            all_histories[run_id] = pd.read_csv(history_path).to_dict('list')
                        continue

                    print(f"🔁 Training: {run_id}")
                    model, history = train_one_model(
                        model_name=model_name,
                        use_mixup=mixup,
                        label_smooth=smooth,
                        seed=seed,
                        patience=15,
                        save_as=run_id,
                        resume_ckpt=resume_ckpt if is_partial else None
                    )
                    final_val_acc = history['val_acc'][-1]
                    writer.writerow([model_name, seed, mixup, smooth, final_val_acc, checkpoint_path])
                    all_histories[run_id] = history
                    # Immediately save histories for resume safety
                    with open(all_histories_path, "wb") as hf:
                        pickle.dump(all_histories, hf)

In [None]:
# Step 10: Summarize ablation results and print/save summary

import pandas as pd

df = pd.read_csv(results_csv_path)
print("\n📝 All ablation results:\n")
print(df)

summary = df.groupby(["Model", "Mixup", "Label Smoothing"])["Val Accuracy"].agg(['mean', 'std']).reset_index()
print("\n📊 Ablation Results (mean ± std across seeds):\n")
print(summary)

best_row = df.loc[df["Val Accuracy"].idxmax()]
print("\n\n🏆 Best configuration:\n\n", best_row)

# Save summary to metrics
summary_path = os.path.join(METRICS_DIR, "ablation_summary.csv")
summary.to_csv(summary_path, index=False)

In [None]:
"""
  #Step 11: Statistical Significance Testing
- Prints a clean table with short config labels for easy reading.
- Includes judgment about whether each difference is statistically significant (p < 0.05).
- Includes a summary at the end to help beginners interpret the results.
"""

import os
import pandas as pd
import numpy as np
import itertools
from scipy.stats import ttest_rel, wilcoxon

OUTPUT_ROOT = '/kaggle/working/outputs'
LOGS_DIR = os.path.join(OUTPUT_ROOT, "logs")
STATS_DIR = os.path.join(OUTPUT_ROOT, "stats")
results_csv_path = os.path.join(LOGS_DIR, "ablation_results.csv")
stats_csv_path = os.path.join(STATS_DIR, "statistical_tests.csv")
stats_md_path = os.path.join(STATS_DIR, "statistical_tests.md")

def interpret_p(p):
    """Return interpretation string for a p-value."""
    if p < 0.01:
        return "Highly significant (p < 0.01)"
    elif p < 0.05:
        return "Significant (p < 0.05)"
    elif p < 0.10:
        return "Suggestive (p < 0.10)"
    else:
        return "Not significant (p ≥ 0.10)"

if os.path.exists(results_csv_path):
    df = pd.read_csv(results_csv_path)
    group_cols = ["Model", "Mixup", "Label Smoothing"]
    records = []

    # Get all unique ablation config pairs (ignoring seed)
    configs = df[group_cols].drop_duplicates().to_dict('records')
    pairs = list(itertools.combinations(configs, 2))

    def config_short(cfg):
        return f"{cfg['Model']}, Mixup={cfg['Mixup']}, LS={cfg['Label Smoothing']}"

    for cfg_a, cfg_b in pairs:
        mask_a = (df[group_cols] == pd.Series(cfg_a)).all(axis=1)
        mask_b = (df[group_cols] == pd.Series(cfg_b)).all(axis=1)
        vals_a = df.loc[mask_a, "Val Accuracy"].sort_index().values
        vals_b = df.loc[mask_b, "Val Accuracy"].sort_index().values

        # Only if both configs have same number of seeds/runs
        if len(vals_a) == len(vals_b) and len(vals_a) > 1:
            t_stat, t_p = ttest_rel(vals_a, vals_b)
            try:
                w_stat, w_p = wilcoxon(vals_a, vals_b)
            except Exception:
                w_stat, w_p = None, None

            # Add judgment columns
            t_judgement = interpret_p(t_p)
            w_judgement = interpret_p(w_p if w_p is not None else 1.0)

            records.append({
                "ConfigA": config_short(cfg_a),
                "ConfigB": config_short(cfg_b),
                "meanA": np.mean(vals_a),
                "meanB": np.mean(vals_b),
                "t_stat": t_stat,
                "t_p": t_p,
                "t_judgement": t_judgement,
                "wilcoxon_stat": w_stat,
                "wilcoxon_p": w_p,
                "wilcoxon_judgement": w_judgement,
            })

    # Save results as CSV
    pd.DataFrame(records).to_csv(stats_csv_path, index=False)

    # Save results as Markdown
    with open(stats_md_path, "w") as f:
        f.write("| ConfigA | ConfigB | meanA | meanB | t_stat | t_p | t_judgement | wilcoxon_stat | wilcoxon_p | wilcoxon_judgement |\n")
        f.write("|---------|---------|-------|-------|--------|-----|-------------|---------------|------------|--------------------|\n")
        for r in records:
            f.write(f"| {r['ConfigA']} | {r['ConfigB']} | {r['meanA']:.4f} | {r['meanB']:.4f} | {r['t_stat']:.3f} | {r['t_p']:.4f} | {r['t_judgement']} | {r['wilcoxon_stat']} | {r['wilcoxon_p']:.4f} | {r['wilcoxon_judgement']} |\n")

    print(f"🧪 Statistical significance test results saved to:\n  {stats_csv_path}\n  {stats_md_path}")

    # Print nicely formatted table with judgment
    print("\nStatistical significance test results:\n")
    print("| ConfigA | ConfigB | meanA | meanB | t_stat | t_p | t_judgement | wilcoxon_stat | wilcoxon_p | wilcoxon_judgement |")
    print("|---------|---------|-------|-------|--------|-------|-------------|---------------|------------|--------------------|")
    for r in records:
        print(f"| {r['ConfigA']} | {r['ConfigB']} | {r['meanA']:.4f} | {r['meanB']:.4f} | {r['t_stat']:.3f} | {r['t_p']:.4f} | {r['t_judgement']} | {r['wilcoxon_stat']} | {r['wilcoxon_p']:.4f} | {r['wilcoxon_judgement']} |")

    # Beginner-friendly summary
    print("\nHow to interpret these results:")
    print("- 't_p' and 'wilcoxon_p' are p-values. If they are less than 0.05, the difference between the two configurations is considered statistically significant (unlikely due to chance).")
    print("- The 't_judgement' and 'wilcoxon_judgement' columns say if the difference is significant.")
    print("- 'meanA' and 'meanB' show the average validation accuracy for each configuration. Higher is better.")
    print("- If all p-values are 'Not significant', then the changes you made in configuration likely did not have a meaningful effect on performance.")
    print("- If you see 'Significant' or 'Highly significant', the difference is likely real and not random. Prefer the config with higher mean accuracy.")

else:
    print(f"ERROR: ablation_results.csv not found at {results_csv_path}")

In [None]:
# Step 12: Evaluate Best-Performing Model on Test Set and save metrics/plots in BEST_MODEL_DIR
# Also plot and save the loss and accuracy curves for the best model.

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, roc_auc_score
)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import json
import os

# Load best model
best_checkpoint_path = best_row["Checkpoint"]
best_model_name = best_row["Model"]

model = create_model(best_model_name, pretrained=False, num_classes=num_classes).to(device)
model.load_state_dict(torch.load(best_checkpoint_path))

def evaluate_model_detailed(model, dataloader, device, label2id, show_roc_auc=True, save_metrics_dir=None, save_vis_dir=None):
    model.eval()
    all_preds, all_probs, all_labels = [], [], []
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            preds = outputs.argmax(dim=1)
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    id2label = {v: k for k, v in label2id.items()}
    target_names = [id2label[i] for i in sorted(id2label)]
    print("\n🔍 Classification Report (on Test Set):\n")
    report_str = classification_report(all_labels, all_preds, target_names=target_names, digits=4)
    print(report_str)
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='weighted')
    rec = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    print(f"Test Accuracy:  {acc:.4f}")
    print(f"Test Precision: {prec:.4f}")
    print(f"Test Recall:    {rec:.4f}")
    print(f"Test F1-Score:  {f1:.4f}")

    # Save metrics to BEST_MODEL_DIR as JSON
    if save_metrics_dir:
        metrics = {
            "accuracy": acc,
            "precision": prec,
            "recall": rec,
            "f1": f1,
            "classification_report": classification_report(all_labels, all_preds, target_names=target_names, digits=4, output_dict=True)
        }
        with open(os.path.join(save_metrics_dir, "test_metrics.json"), "w") as f:
            json.dump(metrics, f, indent=2)

    # Confusion Matrix – Raw
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names)
    plt.title("Confusion Matrix – Raw (Test Set)")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    if save_vis_dir:
        plt.savefig(os.path.join(save_vis_dir, "confusion_matrix_raw.png"))
    plt.show()

    # Confusion Matrix – Normalized
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Oranges',
                xticklabels=target_names, yticklabels=target_names)
    plt.title("Confusion Matrix – Normalized (Test Set)")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    if save_vis_dir:
        plt.savefig(os.path.join(save_vis_dir, "confusion_matrix_normalized.png"))
    plt.show()

    # ROC AUC Score
    if show_roc_auc:
        try:
            from sklearn.preprocessing import label_binarize
            y_true_bin = label_binarize(all_labels, classes=list(range(len(label2id))))
            auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr')
            print(f"Test ROC AUC (Multiclass OVR): {auc:.4f}")
            if save_metrics_dir:
                with open(os.path.join(save_metrics_dir, "test_roc_auc.txt"), "w") as f:
                    f.write(f"{auc:.6f}\n")
        except Exception as e:
            print(f"⚠️ ROC-AUC computation failed: {e}")

# --- Evaluate on test set and save metrics/plots ---
evaluate_model_detailed(
    model, test_loader, device, train_ds.label2id,
    save_metrics_dir=BEST_MODEL_DIR,
    save_vis_dir=BEST_MODEL_DIR
)

# --- Plot and save the loss and accuracy curves of the best model ---

# Try to get the best run id. This should match the key in all_histories for the best model.
if "run_id" in best_row:
    best_run_id = best_row["run_id"]
else:
    best_run_id = os.path.basename(best_checkpoint_path).replace("_best.pth", "")

if best_run_id in all_histories:
    history = all_histories[best_run_id]
    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(16, 5))
    # Loss Curve
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, history['val_loss'], label='Val Loss', marker='x')
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    # Accuracy Curve
    plt.subplot(1, 2, 2)
    if 'train_acc' in history and any(x is not None for x in history['train_acc']):
        train_acc_plot = [x if x is not None else np.nan for x in history['train_acc']]
        plt.plot(epochs, train_acc_plot, label='Train Accuracy', marker='d', color='blue')
    plt.plot(epochs, history['val_acc'], label='Val Accuracy', marker='s', color='green')
    plt.title("Accuracy Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.suptitle(f"Best Model Training Progress – {best_run_id}", fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    # Save in BEST_MODEL_DIR
    plt.savefig(os.path.join(BEST_MODEL_DIR, f"curve_{best_run_id}.png"))
    plt.show()
else:
    print(f"Best run's training history not found for plotting (run_id: {best_run_id})")

In [None]:
# Step 13: Visualize side-by-side: original, Grad-CAM++, Score-CAM for best model predictions

print(f"Loaded model for interpretability: {best_checkpoint_path}")
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
from pytorch_grad_cam import GradCAMPlusPlus, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# --- Configuration ---
N_SAMPLES = 10  # Number of samples to visualize
interp_save_dir = INTERP_DIR
os.makedirs(interp_save_dir, exist_ok=True)

# --- Load best model ---
# (Assume model, best_checkpoint_path, test_loader, and train_ds.label2id are already loaded)
model.eval()

# Helper: get sample images and predictions from test set
def get_sample_images_and_preds(model, dataloader, device, n=10):
    samples = []
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            for i in range(imgs.size(0)):
                samples.append((imgs[i].cpu(), labels[i].cpu().item(), preds[i].cpu().item()))
                if len(samples) >= n:
                    return samples
    return samples

samples = get_sample_images_and_preds(model, test_loader, device, n=N_SAMPLES)

# Prepare CAM methods
# Choose the final conv layer as target layer for CAMs (adjust as needed)
try:
    target_layer = model.conv_head
except AttributeError:
    # Fallback: get the last Conv2d layer in the model
    import torch.nn as nn
    target_layer = None
    for m in reversed(list(model.modules())):
        if isinstance(m, nn.Conv2d):
            target_layer = m
            break
    if target_layer is None:
        raise RuntimeError("Could not find a Conv2d layer for CAM visualization.")

gradcampp = GradCAMPlusPlus(model=model, target_layers=[target_layer])
scorecam = ScoreCAM(model=model, target_layers=[target_layer])

# Visualization loop
for idx, (img_tensor, label, pred) in enumerate(samples):
    # Prepare input and image
    input_tensor = img_tensor.unsqueeze(0)
    # Undo normalization for display
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = img_tensor.permute(1,2,0).numpy() * std + mean
    img_np = np.clip(img_np, 0, 1)
    
    # Grad-CAM++
    grayscale_cam_pp = gradcampp(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    campp_img = show_cam_on_image(img_np, grayscale_cam_pp, use_rgb=True)
    
    # Score-CAM
    grayscale_cam_score = scorecam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    camscore_img = show_cam_on_image(img_np, grayscale_cam_score, use_rgb=True)
    
    # Plot and save
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(img_np)
    axs[0].set_title("Original")
    axs[0].axis("off")
    axs[1].imshow(campp_img)
    axs[1].set_title("Grad-CAM++")
    axs[1].axis("off")
    axs[2].imshow(camscore_img)
    axs[2].set_title("Score-CAM")
    axs[2].axis("off")
    gt_class = [k for k, v in train_ds.label2id.items() if v == label][0]
    pred_class = [k for k, v in train_ds.label2id.items() if v == pred][0]
    fig.suptitle(f"GT: {gt_class} | Pred: {pred_class}", fontsize=14)
    plt.tight_layout()
    save_path = os.path.join(interp_save_dir, f"interp_{idx+1}_gt_{gt_class}_pred_{pred_class}.png")
    plt.savefig(save_path)
    plt.show()