# VLG-CBM CBL Analysis Notebook
Use this notebook to load a trained VLG-CBM run, visualize the learned Concept Bottleneck Layer (CBL), and reproduce the experiments from Sections 5.2â€“5.4 of the paper (top activations, decision interpretability, and NEC pruning). Update `RUN_DIR` below to point to your model artifacts.

## Overview
1. Load the saved artifacts (`concept_layer.pt`, `W_g.pt`, etc.)
2. Recompute concept activations on the validation split and cache image paths
3. Plot the top-5 activated images per concept as in Figure 4
4. Show the top contributions for a sample decision and compare predictions after pruning to NEC=5


In [None]:
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from PIL import Image

from dataset import CheXpertDataset, get_transforms
from vlg_cbm_lib.datasets import BackboneWithConcepts, ConceptLayer
from models import get_model, XRV_WEIGHTS

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_backbone_from_config(config, labels, device):
    use_xrv_backbone = config["backbone"] in XRV_WEIGHTS
    backbone_kwargs = {}
    if use_xrv_backbone:
        backbone_kwargs["target_labels"] = labels
    backbone_model = get_model(
        config["backbone"],
        num_classes=len(labels),
        pretrained=True,
        **backbone_kwargs,
    )
    if config.get("backbone_ckpt"):
        import inspect
        load_kwargs = {"map_location": device}
        if "weights_only" in inspect.signature(torch.load).parameters:
            load_kwargs["weights_only"] = False
        ckpt = torch.load(config["backbone_ckpt"], **load_kwargs)
        state = ckpt.get("model_state_dict", ckpt)
        backbone_model.load_state_dict(state, strict=False)
    if config["backbone"] == "densenet121":
        feature_dim = 1024
        backbone = backbone_model.backbone.features
    elif config["backbone"] == "resnet50":
        feature_dim = 2048
        backbone = torch.nn.Sequential(*list(backbone_model.backbone.children())[:-1])
    else:
        feature_dim = getattr(backbone_model, "feature_dim", 1024)
        class XRVDenseNetBackbone(torch.nn.Module):
            def __init__(self, wrapper):
                super().__init__()
                self.wrapper = wrapper

            def forward(self, x):
                return self.wrapper.get_features(x)

        backbone = XRVDenseNetBackbone(backbone_model)
    return backbone.to(device), feature_dim

class IndexedCheXpertDataset(CheXpertDataset):
    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        return image, label, torch.tensor(idx, dtype=torch.long)

def load_run_artifacts(run_dir):
    run_dir = Path(run_dir)
    config = json.loads(run_dir.joinpath("config.json").read_text())
    config["output"] = str(run_dir)
    concepts = [line.strip() for line in run_dir.joinpath("concepts.txt").read_text().splitlines() if line.strip()]
    labels = config["labels"]
    backbone, feature_dim = load_backbone_from_config(config, labels, DEVICE)
    concept_layer = ConceptLayer(feature_dim, len(concepts), num_hidden=config.get("cbl_hidden_layers", 1))
    ckpt_path = run_dir / "concept_layer_best.pt"
    if not ckpt_path.exists():
        ckpt_path = run_dir / "concept_layer.pt"
    concept_layer.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    model = BackboneWithConcepts(backbone, concept_layer).to(DEVICE)
    model.eval()
    W = torch.load(run_dir / "W_g.pt")
    b = torch.load(run_dir / "b_g.pt")
    mean = torch.load(run_dir / "concept_mean.pt")
    std = torch.load(run_dir / "concept_std.pt")
    return config, concepts, labels, model, W, b, mean, std

def make_dataset(config, split="valid", split_csv=None):
    data_dir = Path(config["data_dir"])
    if split_csv is not None:
        csv_path = Path(split_csv)
    else:
        csv_path = data_dir / f"{split}.csv"
    img_root = data_dir.parent
    dataset = IndexedCheXpertDataset(
        csv_path=str(csv_path),
        img_root=str(img_root),
        transform=get_transforms(224, is_training=False),
        labels=config["labels"],
        uncertain_strategy=config.get("uncertain_strategy", "ones"),
        frontal_only=config.get("frontal_only", True)
    )
    return dataset

def make_loader(dataset, batch_size=64):
    def collate(batch):
        images, labels, idxs = zip(*batch)
        return torch.stack(images), torch.stack(labels), torch.stack(idxs)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate)

def compute_activations(model, loader):
    model.eval()
    all_concepts = []
    all_labels = []
    all_idxs = []
    with torch.no_grad():
        for images, labels_batch, idxs in loader:
            images = images.to(DEVICE)
            logits = model(images).cpu()
            all_concepts.append(logits)
            all_labels.append(labels_batch)
            all_idxs.append(idxs)
    return torch.cat(all_concepts), torch.cat(all_labels), torch.cat(all_idxs)

def normalize_concepts(concepts, mean, std):
    return (concepts - mean) / torch.clamp(std, min=1e-6)

def final_logits(concepts, mean, std, W, b):
    c_norm = normalize_concepts(concepts, mean, std)
    return c_norm @ W.t() + b

def compute_contributions(concepts, W):
    probs = torch.sigmoid(concepts)
    return probs.unsqueeze(1) * W.unsqueeze(0)

def prune_weights(W, topk=5):
    pruned = torch.zeros_like(W)
    abs_W = W.abs()
    for class_idx in range(W.size(0)):
        topk_idxs = torch.topk(abs_W[class_idx], topk).indices
        pruned[class_idx, topk_idxs] = W[class_idx, topk_idxs]
    return pruned

def display_images(paths, titles=None, figsize=(14, 3)):
    n = len(paths)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axes = [axes]
    for ax, path, title in zip(axes, paths, titles or [None] * n):
        ax.imshow(Image.open(path).convert("RGB"))
        ax.axis("off")
        if title:
            ax.set_title(title)
    plt.tight_layout()

def plot_top_images_for_concept(concept_idx, activations, idxs, dataset, concept_names, topk=5):
    scores = activations[:, concept_idx]
    top_indices = torch.argsort(scores, descending=True)[:topk]
    paths = [dataset.get_image_path(int(idxs[i])) for i in top_indices]
    titles = [f"{concept_names[concept_idx]} ({scores[i].item():.3f})" for i in top_indices]
    display_images(paths, titles)


In [None]:
RUN_DIR = Path("saved_models/vlg_cbm_exp5")  # update to your checkpoint directory
config, concepts, labels, model, W, b, mean, std = load_run_artifacts(RUN_DIR)
dataset = make_dataset(config, split="valid")
loader = make_loader(dataset, batch_size=64)
activations, pathology_labels, idxs = compute_activations(model, loader)
logits = final_logits(activations, mean, std, W, b)
probs = torch.sigmoid(logits)
concept_predictions = torch.sigmoid(activations)
print("Loaded run:", RUN_DIR)
print("Activation matrix shape:", activations.shape)
print("Dataset size:", len(dataset))


## Top-5 Activated Images per Concept
Pick a concept of interest (e.g., one that matches VINDr tags) and inspect the five images with the highest CBL activation.

In [None]:
concept_idx = 0  # update with the concept you want to inspect
plot_top_images_for_concept(concept_idx, activations, idxs, dataset, concepts, topk=5)
print("Concept name:", concepts[concept_idx])


## Decision Interpretability Case Study
Examine the top-5 concept contributions for a validation image by multiplying the concept prediction scores by `W_g`.


In [None]:
def top_contributions(sample_idx, topk=5):
    contributions = compute_contributions(activations, W)
    sample_probs = probs[sample_idx]
    predicted_class = torch.argmax(sample_probs).item()
    class_contribs = contributions[sample_idx, predicted_class]
    topk_idxs = torch.argsort(class_contribs.abs(), descending=True)[:topk]
    return predicted_class, class_contribs[topk_idxs], topk_idxs

sample_idx = 0  # change to another sample to inspect
pred_class, contribs, contrib_idxs = top_contributions(sample_idx)
print("Sample image:", dataset.get_image_path(int(idxs[sample_idx])))
display_images([dataset.get_image_path(int(idxs[sample_idx]))], [f"Predicted class: {labels[pred_class]}"], figsize=(5, 5))
for rank, (idx, score) in enumerate(zip(contrib_idxs, contribs), 1):
    print(f"{rank}. {concepts[idx]} -> contribution {score.item():.4f}")


## Top-5 Pruning Experiment
Prune the final weight matrix `W_g` to keep only the magnitude-top-5 concepts per class, then re-evaluate how often the binarized predictions change (Section 5.4).

In [None]:
pruned_W = prune_weights(W, topk=5)
pruned_logits = final_logits(activations, mean, std, pruned_W, b)
pruned_preds = (torch.sigmoid(pruned_logits) >= 0.5)
orig_preds = (probs >= 0.5)
changed = (orig_preds != pruned_preds).any(dim=1).float().mean().item() * 100
print(f"% of samples with changed decisions after pruning to top-5 concepts: {changed:.2f}%")
