In [2]:
%load_ext autoreload
%autoreload 2
from shared_modules.data_module import DataModule
from shared_modules.utils import load_config
from shared_modules.plotting import slice_comparison_multi
from shared_modules.xai import normalize_and_clamp, compute_ablation_cam_3d, compute_ablation_cam_3d_direct

from trainer import LitModel
import torch
from tqdm import tqdm
from monai.transforms import ScaleIntensityRangePercentiles
from captum.attr import Occlusion
from captum.attr import Saliency
from pathlib import Path

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_eval



Please cite the following paper when using Report Guided Annotations:

Bosma, J.S., et al. "Semi-supervised learning with report-guided lesion annotation for deep learning-based prostate cancer detection in bpMRI" to be submitted


If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/Report-Guided-Annotation



In [3]:
# Output folder for saving images
OUTPUT_DIR = Path("xai_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)


# Settings:
SAVE_PREDS=False
SAVE_PROB_MAPS=False
dataset="picai" 
label_key = "pca"
config = load_config("config.yaml")
gpu = 0
config.gpus = [gpu]
config.cache_rate = 1.0
config.transforms.label_keys = ["pca", "prostate_pred", "zones"]
checkpoint_path = "/cluster/home/bragehk/U-MambaMTL-XAI/gc_algorithms/base_container/models/umamba_mtl/weights/f0.ckpt"
model = LitModel.load_from_checkpoint(checkpoint_path, config=config)

model = model.eval()
model.to(gpu)

LitModel(
  (model): UMambaBotMTL(
    (encoder): UNetResEncoder(
      (stem): Sequential(
        (0): BasicResBlock(
          (conv1): Conv3d(3, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
          (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
          (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv3d(3, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            

In [4]:
case_id = 13
dm = DataModule(
    config=config,
    debug_index=case_id
)
dm.setup("debug")
dl = dm.debug_dataloader()

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


[{'image': ['/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/workdir/nnUNet_raw_data/Task2203_picai_baseline/imagesTr/10040_1000040_0000.nii.gz', '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/workdir/nnUNet_raw_data/Task2203_picai_baseline/imagesTr/10040_1000040_0001.nii.gz', '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/workdir/nnUNet_raw_data/Task2203_picai_baseline/imagesTr/10040_1000040_0002.nii.gz'], 'prostate': '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/picai_labels/anatomical_delineations/whole_gland/AI/Bosma22b/10040_1000040.nii.gz', 'zones': '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/picai_labels/anatomical_delineations/zonal_pz_tz/AI/Yuan23/10040_1000040.nii.gz', 'pca': '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/workdir/nnUNet_raw_data/Task2203_picai_baseline/labelsTr/10040_1000040.nii.gz', 'case_pca': 1, 'prostate_pred': '/cluster/projects/vc/data/mic/open/Prostate/PI-CAI-V2.0/UmambaBot_prostate/preds/10

Loading dataset: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]


In [None]:
def agg_segmentation_wrapper(inp):
    model_out = model(inp)  
    out_max = model_out.argmax(dim=1, keepdim=True)
    selected_inds = torch.zeros_like(model_out).scatter_(1, out_max, 1)
    aggregated_logits = (model_out * selected_inds).sum(dim=(2, 3, 4))  
    return aggregated_logits


occlusion = Occlusion(agg_segmentation_wrapper)
attribute_fn = Saliency(agg_segmentation_wrapper)
size = 22

sliding_window_shapes = (1, size, size, 2) 
strides = (1, size, size, 1)               
baselines = 0                          
perturbations_per_eval = 1

In [7]:
import torch

x = torch.tensor([2,3,4])
x = (x.repeat(3,3,3))
print(torch.sum(x))

tensor(243)


In [31]:
tz = 0
pz = 0

logits = None

for batch in tqdm(dl):
    if batch["pca"].max() == 0:
        print("Not any PCa here!")
        import sys
        sys.exit(0)
        continue
    
    print("To gpu")
    x = batch["image"].to(gpu)
    print("Model inference")
    logits = model(x)
    print("Check if False negative")
    if (torch.sigmoid(logits[:, 1]) > 0.5).any().item():
        print("True positive")
        import time
        t0 = time.time()
        attention_map = attribute_fn.attribute(x, target=1, abs=True)
        print(f"Saliency took {time.time() - t0:.2f}s")
        t0 = time.time()
        occlusion_map = occlusion.attribute(
            x,
            sliding_window_shapes=sliding_window_shapes,
            strides=strides,
            baselines=baselines,
            target=1,
            perturbations_per_eval=perturbations_per_eval,
            show_progress=True
        )
        print(f"Occlusion took {time.time() - t0:.2f}s")
    else:
        print("False negative..")
        print((torch.sigmoid(logits) > 0.5).int()[0][1][None].to("cpu")[0,1].max().item())

        
    pca_in_pz = (batch["pca"] * batch["zones"] == 1).sum()
    pca_in_tz = (batch["pca"] * batch["zones"] == 2).sum()
    
    break

  0%|          | 0/1 [00:00<?, ?it/s]

To gpu
Model inference
Check if False negative
True positive
Saliency took 0.04s


Occlusion attribution:   0%|          | 0/3932161 [00:00<?, ?it/s]

  0%|          | 0/1 [06:00<?, ?it/s]

Occlusion took 360.29s





In [32]:
# Save
torch.save(occlusion_map, OUTPUT_DIR / f"{case_id}_occlusion_map.pt")
torch.save(attention_map, OUTPUT_DIR / f"{case_id}_attention_map.pt")

In [27]:
# Load
occlusion_map = torch.load(OUTPUT_DIR / f"{case_id}_occlusion_map.pt", weights_only=False)
attention_map = torch.load(OUTPUT_DIR / f"{case_id}_attention_map.pt", weights_only=False)

In [17]:
for batch in tqdm(dl):
    print("To gpu")
    x = batch["image"].to(gpu)
    print("Model inference")
    logits = model(x)
    break

  0%|          | 0/1 [00:00<?, ?it/s]

To gpu
Model inference





In [34]:
img = batch["image"][0]
gt = batch["pca"][0]
pred = (torch.sigmoid(logits) > 0.5).int()[0][1][None].to("cpu").float()
logit = logits[0][1][None].to("cpu")

print(f"gt shape: {gt.shape}, max: {gt.max()}, min: {gt.min()}")
print(f"pred shape: {pred.shape}, max: {pred.max()}, min: {pred.min()}")

activation = attention_map[0].to("cpu")

occlusion = occlusion_map[0].to("cpu")

occ = ScaleIntensityRangePercentiles(lower=.1, upper=99.9, b_min=-1, b_max=1, clip=True)(occlusion)
acc = ScaleIntensityRangePercentiles(lower=.1, upper=99.9, b_min=-1, b_max=1, clip=True)(activation)

#print("Activation map") #print("max", acc.max()) #print("min", acc.min())
acc = normalize_and_clamp(acc)
occ = normalize_and_clamp(occ)
#print("Acc After normalization") #print("max", acc.max()) #print("min", acc.min()) #print("Occlusion map") #print("max", occ.max()) #print("min", occ.min()) occ = normalize_and_clamp(occ) #print("occ After normalization") #print("max", occ.max()) #print("min", occ.min()) # Find slice with most label pixels (minimum 10 pixels)
min_pixels = 10
slice_pixel_counts = gt[0].sum(dim=(0, 1))  # Sum over H, W for each slice
best_slice_idx = slice_pixel_counts.argmax().item()
if slice_pixel_counts[best_slice_idx] < min_pixels:
    print(f"Warning: Best slice only has {slice_pixel_counts[best_slice_idx]} label pixels")
print(f"Using slice {best_slice_idx} with {slice_pixel_counts[best_slice_idx].item()} label pixels")

print(f"More pca in pz: {pca_in_pz > pca_in_tz}") 
print(f"pca in tz: {pca_in_tz}") 
print(f"pca in pz: {pca_in_pz}")
case_id = batch["image"].meta["filename_or_obj"][0].split("/")[-1]
print(f"Case ID: {case_id}")
confidence = round(torch.sigmoid(logits)[0,1].max().item() * 100, 2)
print(f"PCa confidence: {confidence}%")

slice_comparison_multi(image=ScaleIntensityRangePercentiles(lower=0, upper=100, b_min=0, b_max=1)(img), labels=[gt, pred, acc, occ], titles=["original", "Ground Truth","Prediction", "Saliency", "Occlusion"])

gt shape: torch.Size([1, 256, 256, 20]), max: 1.0, min: 0.0
pred shape: torch.Size([1, 256, 256, 20]), max: 1.0, min: 0.0
Using slice 7 with 840.0 label pixels
More pca in pz: True
pca in tz: 0
pca in pz: 2451
Case ID: 10040_1000040_0000.nii.gz
PCa confidence: 89.19%


interactive(children=(IntSlider(value=0, description='Slice Index:', max=19), Output()), _dom_classes=('widget…

In [35]:
from shared_modules.plotting import slice_comparison_multi_gif

slice_comparison_multi_gif(image=ScaleIntensityRangePercentiles(lower=0, upper=100, b_min=0, b_max=1)(img), labels=[gt, pred, acc, occ], titles=["original", "Ground Truth","Prediction", "Saliency", "Occlusion"])

Saved pendulum GIF (38 frames) to slice_comparison.gif


'slice_comparison.gif'

In [None]:
# Then your function call
cam_volume = compute_ablation_cam_3d_direct(
    model=model,
    input_tensor=x,
    target_category=1,
    target_mask=pred[0].cpu().numpy(),
    device=gpu,
    ablation_size=(16, 16, 5),
    stride=(16, 16, 1),
)


AblationCAM 3D: 100%|██████████| 16/16 [05:18<00:00, 19.91s/it]


In [29]:
cam_volume.shape

(256, 256, 20)

In [40]:
img.shape

torch.Size([3, 256, 256, 20])

In [None]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt

def view_volume(volume, title="Volume", cmap="viridis", vmin=None, vmax=None):
    """Interactive viewer for 3D volume with slider."""
    if vmin is None:
        vmin = volume.min()
    if vmax is None:
        vmax = volume.max()
    
    @interact(slice_idx=IntSlider(min=0, max=volume.shape[-1]-1, value=volume.shape[-1]//2, description='Slice:'))
    def compare_slices(slice_idx):
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
    
        # Original image (e.g., T2W channel)
        axes[0].imshow(img[0, :, :, slice_idx].cpu(), cmap='gray')
        axes[0].set_title('T2W')
    
        # Ground truth
        axes[1].imshow(gt[0, :, :, slice_idx].cpu(), cmap='Reds', vmin=0, vmax=1)
        axes[1].set_title('Ground Truth')
        
        # Prediction
        axes[2].imshow(pred[0, :, :, slice_idx].cpu(), cmap='Reds', vmin=0, vmax=1)
        axes[2].set_title('Prediction')
    
        # AblationCAM
        im = axes[3].imshow(cam_volume[:, :, slice_idx], cmap='hot', vmin=vmin, vmax=vmax)
        axes[3].set_title('AblationCAM')
        plt.colorbar(im, ax=axes[3])
    
        for ax in axes:
            ax.axis('off')
        plt.tight_layout()
        plt.show()


view_volume(cam_volume, title="AblationCAM", cmap="hot")


interactive(children=(IntSlider(value=10, description='Slice:', max=19), Output()), _dom_classes=('widget-inte…

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io

def create_volume_gif(img, gt, pred, cam_volume, output_path="volume.gif", fps=5, dpi=100, ping_pong=False):
    """
    Create a GIF animation of the volume slices.
    
    Parameters:
    -----------
    img : tensor
        Input image tensor (channels, H, W, D)
    gt : tensor
        Ground truth tensor (channels, H, W, D)
    pred : tensor
        Prediction tensor (channels, H, W, D), values 0-1
    cam_volume : array
        CAM volume array (H, W, D)
    output_path : str
        Output path for the GIF file
    fps : int
        Frames per second for the GIF
    dpi : int
        Resolution of each frame
    ping_pong : bool
        If True, play forward then backward for smooth looping
    """
    # Get global min/max for consistent color scaling
    vmin = cam_volume.min()
    vmax = cam_volume.max()
    
    num_slices = cam_volume.shape[-1]
    frames = []
    
    for slice_idx in range(num_slices):
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        # Original image (e.g., T2W channel)
        axes[0].imshow(img[0, :, :, slice_idx].cpu(), cmap='gray')
        axes[0].set_title('T2W')
        
        # Ground truth
        axes[1].imshow(gt[0, :, :, slice_idx].cpu(), cmap='Reds', vmin=0, vmax=1)
        axes[1].set_title('Ground Truth')
        
        # Prediction
        axes[2].imshow(pred[0, :, :, slice_idx].cpu(), cmap='Reds', vmin=0, vmax=1)
        axes[2].set_title('Prediction')
        
        # AblationCAM
        im = axes[3].imshow(cam_volume[:, :, slice_idx], cmap='hot', vmin=vmin, vmax=vmax)
        axes[3].set_title('AblationCAM')
        plt.colorbar(im, ax=axes[3])
        
        for ax in axes:
            ax.axis('off')
        
        # Add slice indicator
        fig.suptitle(f'Slice {slice_idx + 1}/{num_slices}', fontsize=12)
        plt.tight_layout()
        
        # Convert figure to PIL Image
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
        buf.seek(0)
        frame = Image.open(buf).convert('RGB')
        frames.append(frame.copy())
        buf.close()
        plt.close(fig)
    
    # Add reverse frames for ping-pong effect
    if ping_pong:
        frames = frames + frames[-2:0:-1]
    
    # Save as GIF
    duration = int(1000 / fps)  # Duration per frame in milliseconds
    frames[0].save(
        output_path,
        save_all=True,
        append_images=frames[1:],
        duration=duration,
        loop=0  # 0 means infinite loop
    )
    
    print(f"GIF saved to {output_path} ({len(frames)} frames)")
    return output_path

In [46]:
create_volume_gif(img, gt, pred, cam_volume, output_path="ablation_cam.gif", fps=5, ping_pong=True)

GIF saved to ablation_cam.gif (38 frames)


'ablation_cam.gif'

In [41]:
cam_volume_2 = compute_ablation_cam_3d(
    model=model,
    input_tensor=x,
    target_category=1,  # PCa channel
    target_mask=pred[0].cpu().numpy(),  # Focus on predicted region
    device=gpu
)

Computing AblationCAM per slice:  20%|██        | 4/20 [00:01<00:05,  2.78it/s]



Computing AblationCAM per slice:  30%|███       | 6/20 [00:01<00:03,  3.62it/s]



Computing AblationCAM per slice: 100%|██████████| 20/20 [00:02<00:00,  9.41it/s]

made  20  slices





In [None]:
view_volume(cam_volume_2)

interactive(children=(IntSlider(value=10, description='Slice:', max=19), Output()), _dom_classes=('widget-inte…

In [67]:
from shared_modules.xai import compute_xai_metrics_segmentation

metrics = compute_xai_metrics_segmentation(
    model=model,
    inputs=x,
    attributions=attention_map,
    target=1,  # PCa class
    aggregation="predicted",  # or "masked" with mask=your_mask
    n_perturb_samples=100,
    agg_wrapper=agg_segmentation_wrapper,
    perturb_radius=0.2
)

print(metrics)

{'infidelity': np.float64(19.025664772666417), 'sensitivity': 0.4467165797449239}


{'infidelity': 109.40544798374177, 'sensitivity': np.float64(0.009333162949272234)}
