# Inference & Visualization Notebook

This notebook loads the trained hybrid melanoma model, runs inference on the held-out test split, and generates several qualitative and quantitative visuals (ROC curve, probability histograms, regression scatter plots, and sample predictions with images). Update the configuration cell below to match the checkpoint and model variant you want to evaluate.


In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    roc_curve,
    confusion_matrix,
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)

# Ensure the project root is on the Python path
PROJECT_ROOT = Path.cwd().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.utils import load_emb_data, split_dataset, StructuredPreprocessor  # noqa: E402
from src.evaluate import EvalDataset  # noqa: E402
from src.models import create_dino_hybrid_model, create_hybrid_model  # noqa: E402



In [None]:
# ---- Configuration ---- #
metadata_path = PROJECT_ROOT / "data/merged_dataset.csv"
image_dir = PROJECT_ROOT / "data/images"
checkpoint_path = PROJECT_ROOT / "outputs/train2/checkpoints/best.pt"  # update as needed

model_type = "dino"  # "dino" or "resnet"
vit_arch = "vit_b_32"  # only used when model_type == "dino"
use_tokens = True
fusion_type = "cross_attention"
num_clin_tokens = 4
freeze_backbone_layers = 7

image_size = (384, 384)
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Task configuration
multitask = True  # set False for classification-only models
task = "classification"  # "classification" or "regression" (ignored when multitask=True)
loss_alpha = 0.5  # used for combined metrics in multitask mode

# Feature handling
remove_ajcc = True  # drop AJCC/stage columns to avoid leakage with thickness

# Visualization settings
num_image_samples = 6  # number of images to visualize with predictions
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)



In [None]:
# ---- Load & preprocess dataset ---- #
df = load_emb_data(metadata_path, image_dir)
splits = split_dataset(df, val_size=0.15, test_size=0.15, stratify=True)

target_col = "label" if (not multitask and task == "classification") else "target"
if target_col not in splits.train.columns:
    if multitask:
        target_col = "label"
    else:
        target_col = "thickness" if task == "regression" else "label"

for part in [splits.train, splits.val, splits.test]:
    part.loc[:, "target"] = part[target_col]
    if task == "classification" or multitask:
        part.loc[:, "target"] = part["target"].astype(float)

exclude = {"image_path", "image", "label", "target", "cathegory", "category"}
feature_cols = [c for c in splits.train.columns if c not in exclude]
feature_cols = [c for c in feature_cols if pd.api.types.is_numeric_dtype(splits.train[c])]

if remove_ajcc:
    ajcc_cols = [c for c in feature_cols if ("ajcc" in c.lower()) or c.lower().startswith("stage")]
    feature_cols = [c for c in feature_cols if c not in ajcc_cols]
    if ajcc_cols:
        print(f"Removed AJCC columns: {ajcc_cols}")

sp = StructuredPreprocessor(feature_names=feature_cols)
train_struct = sp.fit_transform(splits.train[feature_cols].to_numpy())
val_struct = sp.transform(splits.val[feature_cols].to_numpy())
test_struct = sp.transform(splits.test[feature_cols].to_numpy())

for i, col in enumerate(feature_cols):
    splits.train[col] = train_struct[:, i]
    splits.val[col] = val_struct[:, i]
    splits.test[col] = test_struct[:, i]

print(f"Structured feature count: {len(feature_cols)}")
test_df = splits.test.reset_index(drop=True)



In [None]:
# ---- Build dataloader ---- #
test_dataset = EvalDataset(test_df, feature_cols=feature_cols, image_size=image_size)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
)

# ---- Create & load model ---- #
if model_type == "dino":
    model = create_dino_hybrid_model(
        num_structured_features=len(feature_cols),
        task=task,
        multitask=multitask,
        arch=vit_arch,
        pretrained=True,
        dino_checkpoint=None,
        use_tokens=use_tokens,
        fusion_type=fusion_type,
        num_clin_tokens=num_clin_tokens,
        freeze_backbone_layers=freeze_backbone_layers,
    )
else:
    model = create_hybrid_model(
        num_structured_features=len(feature_cols),
        task=task,
        multitask=multitask,
    )

state = torch.load(checkpoint_path, map_location=device)
if isinstance(state, dict) and "model_state" in state:
    state = state["model_state"]
model.load_state_dict(state, strict=False)
model = model.to(device)
model.eval()

print(f"Loaded checkpoint from {checkpoint_path}")



In [None]:
# ---- Run inference ---- #
cls_probs, cls_labels = [], []
reg_preds, reg_targets = [], []

with torch.no_grad():
    for batch in test_loader:
        imgs = batch["img"].to(device)
        tabs = batch["tab"].to(device)
        outputs = model(imgs, tabs)

        if isinstance(outputs, dict):
            if "cls" in outputs:
                prob = outputs["cls"].detach().cpu().numpy().reshape(-1)
                cls_probs.append(prob)
                cls_labels.append(batch.get("label", torch.zeros_like(outputs["cls"])).numpy().reshape(-1))
            if "reg" in outputs and "thickness" in batch:
                reg_preds.append(outputs["reg"].detach().cpu().numpy().reshape(-1))
                reg_targets.append(batch["thickness"].numpy().reshape(-1))
        else:
            prob = outputs.detach().cpu().numpy().reshape(-1)
            cls_probs.append(prob)
            cls_labels.append(batch.get("label", torch.zeros_like(outputs)).numpy().reshape(-1))

cls_probs = np.concatenate(cls_probs) if cls_probs else None
cls_labels = np.concatenate(cls_labels) if cls_labels else None
reg_preds = np.concatenate(reg_preds) if reg_preds else None
reg_targets = np.concatenate(reg_targets) if reg_targets else None

print(f"Collected predictions for {len(test_df)} samples")



In [None]:
# ---- Compute metrics ---- #
metrics = {}
if cls_probs is not None and cls_labels is not None:
    y_pred_binary = (cls_probs >= 0.5).astype(int)
    metrics["accuracy"] = accuracy_score(cls_labels, y_pred_binary)
    metrics["precision"] = precision_score(cls_labels, y_pred_binary, zero_division=0)
    metrics["recall"] = recall_score(cls_labels, y_pred_binary, zero_division=0)
    metrics["f1"] = f1_score(cls_labels, y_pred_binary, zero_division=0)
    try:
        metrics["auroc"] = roc_auc_score(cls_labels, cls_probs)
    except ValueError:
        metrics["auroc"] = np.nan

    print("Classification metrics:")
    for k, v in metrics.items():
        print(f"  {k:10s}: {v:.4f}")

if reg_preds is not None and reg_targets is not None:
    reg_metrics = {
        "mae": mean_absolute_error(reg_targets, reg_preds),
        "rmse": np.sqrt(mean_squared_error(reg_targets, reg_preds)),
        "r2": r2_score(reg_targets, reg_preds),
    }
    print("\nRegression metrics:")
    for k, v in reg_metrics.items():
        print(f"  {k:10s}: {v:.4f}")



In [None]:
# ---- Visuals: ROC curve & probability histogram ---- #
if cls_probs is not None and cls_labels is not None:
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # ROC curve
    try:
        fpr, tpr, _ = roc_curve(cls_labels, cls_probs)
        axes[0].plot(fpr, tpr, label=f"AUROC = {metrics.get('auroc', float('nan')):.3f}")
        axes[0].plot([0, 1], [0, 1], linestyle="--", color="gray")
        axes[0].set_xlabel("False Positive Rate")
        axes[0].set_ylabel("True Positive Rate")
        axes[0].set_title("ROC Curve")
        axes[0].legend(loc="lower right")
    except ValueError:
        axes[0].text(0.5, 0.5, "ROC undefined", ha="center")

    # Probability histogram
    axes[1].hist(cls_probs[cls_labels == 0], bins=20, alpha=0.6, label="Benign")
    axes[1].hist(cls_probs[cls_labels == 1], bins=20, alpha=0.6, label="Malignant")
    axes[1].axvline(0.5, color="red", linestyle="--", label="Threshold")
    axes[1].set_xlabel("Predicted probability")
    axes[1].set_ylabel("Count")
    axes[1].set_title("Prediction Distribution")
    axes[1].legend()

    plt.tight_layout()
    plt.show()



In [None]:
# ---- Visuals: Regression scatter ---- #
if reg_preds is not None and reg_targets is not None:
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(reg_targets, reg_preds, alpha=0.6)
    min_val = min(reg_targets.min(), reg_preds.min())
    max_val = max(reg_targets.max(), reg_preds.max())
    ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="red")
    ax.set_xlabel("True thickness (mm)")
    ax.set_ylabel("Predicted thickness (mm)")
    ax.set_title(f"Regression Scatter (RÂ²={r2_score(reg_targets, reg_preds):.3f})")
    plt.show()



In [None]:
# ---- Assemble results dataframe ---- #
results = test_df.copy()
if cls_probs is not None:
    results["pred_prob"] = cls_probs
    results["pred_label"] = (cls_probs >= 0.5).astype(int)
if reg_preds is not None:
    results["pred_thickness"] = reg_preds

results.head()



In [None]:
# ---- Sample predictions with visuals ---- #
def show_samples(df, title, selector):
    subset = selector(df)
    if subset.empty:
        print(f"No samples available for {title}")
        return
    samples = subset.head(num_image_samples)

    n_cols = min(3, num_image_samples)
    n_rows = int(np.ceil(len(samples) / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    axes = np.array(axes).reshape(-1)

    for ax, (_, row) in zip(axes, samples.iterrows()):
        try:
            img = Image.open(row["image_path"]).convert("RGB")
            ax.imshow(img)
            title_lines = [Path(row["image_path"]).name]
            if "label" in row:
                title_lines.append(f"GT label: {int(row['label'])}")
            if "pred_prob" in row:
                title_lines.append(f"Pred prob: {row['pred_prob']:.2f}")
            if "thickness" in row:
                title_lines.append(f"GT thick: {row['thickness']:.2f}mm")
            if "pred_thickness" in row:
                title_lines.append(f"Pred thick: {row['pred_thickness']:.2f}mm")
            ax.set_title("\n".join(title_lines), fontsize=9)
            ax.axis("off")
        except Exception as exc:
            ax.text(0.5, 0.5, str(exc), ha="center")
            ax.axis("off")

    for ax in axes[len(samples):]:
        ax.axis("off")

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

if cls_probs is not None:
    show_samples(results.sort_values("pred_prob", ascending=False), "Top malignant predictions", lambda df: df.sort_values("pred_prob", ascending=False))
    show_samples(results.sort_values("pred_prob", ascending=True), "Most confidently benign", lambda df: df.sort_values("pred_prob", ascending=True))

if reg_preds is not None:
    show_samples(results.reindex((reg_preds - reg_targets).argsort()[::-1]), "Largest thickness overestimates", lambda df: df.assign(err=df["pred_thickness"] - df["thickness"]).sort_values("err", ascending=False))



In [None]:
# ---- Confusion matrix ---- #
if cls_probs is not None and cls_labels is not None:
    cm = confusion_matrix(cls_labels, (cls_probs >= 0.5).astype(int))
    fig, ax = plt.subplots(figsize=(4, 4))
    im = ax.imshow(cm, cmap="Blues")
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, cm[i, j], ha="center", va="center", color="black")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title("Confusion Matrix")
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    plt.colorbar(im, ax=ax)
    plt.show()



In [None]:
# ---- Optional: Save results to disk ---- #
results_path = PROJECT_ROOT / "outputs" / "inference_results.csv"
results.to_csv(results_path, index=False)
print(f"Saved detailed predictions to {results_path}")

