In [None]:
# Cell 1 – Imports & Configuration
import os
import sys
from pathlib import Path

import torch
import clip
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# Optional: seaborn style for nicer plots
import seaborn as sns
sns.set(style="whitegrid")

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Paths (adapt to your project structure)
DATA_ROOT   = Path("/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed")
EXPERIMENT  = "dataset_9f30917e/experiments/20250712091227"
PATCH_DIR   = DATA_ROOT / EXPERIMENT / "simclr/fold0/training/patches"  # or wherever your .jpg patches live
OUTPUT_DIR  = Path("explainability_outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

print(f"Using device:    {DEVICE}")
print(f"Patches folder:  {PATCH_DIR}")
print(f"Results to:      {OUTPUT_DIR}")


In [None]:
# Cell 2 – Load CLIP & (Optional) Classification Model
# CLIP for concept matching
clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)
clip_model.eval()

# If you also want to overlay GradCAM on your SSL classifier:
# from torchvision import models
# clf = torch.load(Path(EXPERIMENT)/"simclr/fold0/training/simclr_bestepoch002_fold0.pt")["model"]
# clf = clf.to(DEVICE).eval()


In [None]:
# Cell 3 – Define & Encode Pathology Concepts
CONCEPTS = [
    "necrosis",
    "clear cytoplasm",
    "mitotic figures",
    "nuclear atypia",
    "fibrovascular stroma",
    "inflammatory infiltrate",
    "blood vessels",
    "cellular pleomorphism",
    "stromal cells",
    "tumor cell clusters",
    "normal kidney tissue",
    "glomerulus",
    "tubules",
    "fibrosis",
    "calcification",
    "hemorrhage",
]

with torch.no_grad():
    tokens = clip.tokenize(CONCEPTS).to(DEVICE)
    concept_embeddings = clip_model.encode_text(tokens)
    concept_embeddings /= concept_embeddings.norm(dim=-1, keepdim=True)

print(f"→ Encoded {len(CONCEPTS)} concepts into CLIP embeddings")


In [None]:
# Cell 4 – Patch‐Level Concept Explanation
def explain_patch_concepts(
    patch: Image.Image,
    clip_model,
    clip_preprocess,
    concept_embeddings: torch.Tensor,
    concepts: list[str],
    top_k: int = 3,
) -> list[tuple[str, float]]:
    """
    Return top_k (concept, cosine_score) for a given image patch.
    """
    # preprocess & encode
    x = clip_preprocess(patch).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        img_emb = clip_model.encode_image(x)
        img_emb /= img_emb.norm(dim=-1, keepdim=True)
    # similarity scores
    sims = (img_emb @ concept_embeddings.T).squeeze(0).cpu().numpy()
    idxs = np.argsort(sims)[::-1][:top_k]
    return [(concepts[i], float(sims[i])) for i in idxs]


In [None]:
# Cell 5 – Install & Import GradCAM++
!pip install grad-cam --quiet

from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image


In [None]:
# Cell 6 – GradCAM++ Explanation Function

def explain_patch_gradcam(
    model: torch.nn.Module,
    target_layer: torch.nn.Module,
    patch: Image.Image,
    preprocess_fn,
    target_class: int | None = None,
) -> np.ndarray:
    """
    Compute GradCAM++ heatmap for one patch.
    Returns a H×W float32 array in [0,1].
    """
    # preprocess
    img_tensor = preprocess_fn(patch).unsqueeze(0).to(DEVICE)
    rgb = np.array(patch.resize((img_tensor.shape[-1], img_tensor.shape[-2]))) / 255.0

    cam = GradCAMPlusPlus(
        model=model,
        target_layers=[target_layer],
        use_cuda=(DEVICE == "cuda"),
    )
    targets = [ClassifierOutputTarget(target_class)] if target_class is not None else None
    grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0]

    # overlay on RGB (for plotting)
    overlay = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)
    return grayscale_cam, overlay


In [None]:
# Cell 7 – Batch‐Process All Patches & Save CSV + Heatmaps

def batch_explain_patches(
    patch_dir: Path,
    clip_model, clip_preprocess,
    concept_embeddings, concepts: list[str],
    classifier: torch.nn.Module | None,
    cam_layer: torch.nn.Module | None,
    top_k: int = 3,
):
    records = []
    for img_path in patch_dir.rglob("*.jpg"):
        pid = img_path.stem
        try:
            img = Image.open(img_path).convert("RGB")
            # concept match
            top_concepts = explain_patch_concepts(
                img, clip_model, clip_preprocess, concept_embeddings, concepts, top_k
            )
            rec = {"patch_id": pid}
            for i, (c, s) in enumerate(top_concepts, 1):
                rec[f"concept_{i}"] = c
                rec[f"score_{i}"]   = s

            # GradCAM++ (optional)
            if classifier is not None and cam_layer is not None:
                # assume patch-level classification prediction
                classifier.eval()
                with torch.no_grad():
                    # forward to get predicted class
                    inp = clip_preprocess(img).unsqueeze(0).to(DEVICE)
                    out = classifier(inp)
                    cls = int(out.argmax(dim=-1).item())
                cam_map, _ = explain_patch_gradcam(
                    classifier, cam_layer, img, clip_preprocess, cls
                )
                # save heatmap as .npy
                np.save(OUTPUT_DIR / f"{pid}_gradcam.npy", cam_map)
                rec["predicted_class"] = cls
            records.append(rec)

        except Exception as e:
            print(f"⚠️ Skipping {pid}: {e}", file=sys.stderr)

    df = pd.DataFrame(records)
    csv_path = OUTPUT_DIR / "patch_explanations.csv"
    df.to_csv(csv_path, index=False)
    print(f"✔️ Saved CSV to {csv_path}")
    return df

# Example usage (if you have a classifier and know its target conv layer):
# df = batch_explain_patches(PATCH_DIR, clip_model, clip_preprocess,
#                            concept_embeddings, CONCEPTS,
#                            classifier=clf, cam_layer=clf.layer4[-1], top_k=3)

# If you only want concepts:
df = batch_explain_patches(PATCH_DIR, clip_model, clip_preprocess,
                           concept_embeddings, CONCEPTS,
                           classifier=None, cam_layer=None, top_k=3)
