In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import os
import json
import pandas as pd
import torchvision

from scipy.stats import rankdata
from collections import defaultdict

from load_data import DATA_DIR, TRANSFORMS, P53_CLASS_NAMES, \
    convert_presence_probs_to_status_probs
from resnet import ResNetModel, ResNetModelDoubleBinary
# from resnet_patch import ResNetModelDoubleBinary as ResNetModelDoubleBinaryPatch
from pl_clam import CLAM_MB, CLAM_db

P53_CLASS_CODES = ["WT", "OE", "NM", "DC"]

BOLERO_DIR = os.path.join(DATA_DIR, '..', 'BOLERO')
# PATHXL_DIR = os.path.join(DATA_DIR, '..', 'p53_consensus_study')

BASE_DIR = {
    'test': DATA_DIR,
    'bolero': BOLERO_DIR,
    # 'pathxl': PATHXL_DIR
}

RESULTS_DIR = os.path.join(DATA_DIR, '..', '..', 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)
VIS_DIR = os.path.join(DATA_DIR, '..', '..', 'visualizations')
MODELS_DIR = os.path.join(DATA_DIR, '..', '..', 'models')

bag_latent_paths = {
    "test":   os.path.join(BASE_DIR["test"], "bag_latents_gs256_retccl.pt"),
    "bolero": os.path.join(BASE_DIR["bolero"], "bag_latents_gs256_retccl.pt"),
    # "pathxl": os.path.join(BASE_DIR["pathxl"], "bag_latents_gs256_retccl.pt"),
}

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device: {}".format(device))

color_dict = {
    "r": '#DA4C4C', # Red
    "o": '#E57439', # Orange
    "y": '#EDB732', # Yellow
    "g": '#479A5F', # Green
    "lb": '#5BC5DB', # Light blue
    "b": '#5387DD', # Blue
    "p": '#7D54B2', # Purple
    "pi": '#E87B9F', # Pink
#  '#229487', # Dark green/Turquoise
#  '#C565C7', # Lilac
    "r_p": '#E89393', # Pale red
    "o_p": '#EFAB88', # Pale orange
    "y_p": '#F4D384', # Pale yellow
    "g_p": '#90C29F', # Pale green
    "lb_p":'#9CDCE9', # Pale light blue
    "b_p": '#98B7EA', # Pale blue
    "p_p": "#B198D0", # Pale purple
}
colors = list(color_dict.values())

# Load Models

In [2]:
# fb: full biopsy, db: double binary, gs: grid spacing
model_kwargs = {
    "CLAM":     {"model_class": CLAM_MB, "gs": 256},
    "CLAM_db":  {"model_class": CLAM_db, "gs": 256},
    "CLAM_m":   {"model_class": CLAM_MB, "gs": 256},
    "CLAM_db_m":{"model_class": CLAM_db, "gs": 256},
}
for name in ["fb_db", "fb"]:
    for spacing in [2, 4, 
                    # 8, 16, 32, 64, 128, 256
                    ]:
        if "db" not in name and spacing == 2: # Skipped this one because it's too slow
            continue
        model_name = f"{name}_spacing{spacing}"
        model_kwargs[model_name] = {"spacing": spacing}

for model_name in model_kwargs:
    if "fb_db" in model_name:
        model_kwargs[model_name]["model_class"] = ResNetModelDoubleBinary
    elif "fb" in model_name:
        model_kwargs[model_name]["model_class"] = ResNetModel
    checkpoint_dir = os.path.join(MODELS_DIR, model_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    model_kwargs[model_name]["checkpoint_paths"] = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]

def load_model(model_class, checkpoint_path):
    if "CLAM" in model_class.__name__: # For some reason pl can't load these models with load_from_checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # model = model_class(**checkpoint["model_kwargs"])
        model = model_class()
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model = model_class.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.to(device)
    return model

In [None]:
# Print how many checkpoints we have for each model
for model_name in model_kwargs:
    print(f"{model_name}: {len(model_kwargs[model_name]['checkpoint_paths'])} checkpoints")

In [4]:
best_checkpoint_dict = {
    "CLAM": 4,
    "CLAM_db": 0,
    "CLAM_m": 3,
    "CLAM_db_m": 2,
    "fb_spacing4": 4,
    "fb_db_spacing4": 3,
}

# Inference functions

In [5]:
def load_img(img_name, data_name):
    """
    Returns (H, W, C) float32 tensor
    """
    img_path = os.path.join(BASE_DIR[data_name], 'biopsies', f"{img_name}.png")
    img = plt.imread(img_path) # (H, W, C) float32
    return img
    
def load_patch_latents(img_name, data_name, bag_latents):
    """
    Returns (N, 2048) tensor
    """
    if data_name == "test":
        return bag_latents[img_name].squeeze(1) # (N, 2048)
    elif data_name == "bolero":
        slide_name, biopsy_name = tuple(img_name.split("_"))
        slide_latents = bag_latents[int(slide_name)] # (n_biopsies, N, 2048)
        return slide_latents[int(biopsy_name)+1] # (N, 2048)
    elif data_name == "pathxl":
        return bag_latents[img_name].squeeze(1) # (N, 2048)
    
def get_img_const_spacing(img_name, data_name, spacing):
    """
    Returns (1, C, H, W) tensor with constant spacing (spacing x spacing pixels per grid cell)
    """
    img = load_img(img_name, data_name)
    h = img.shape[0] // spacing
    w = img.shape[1] // spacing
    img = torch.nn.functional.interpolate(torch.tensor(img).permute(2, 0, 1).unsqueeze(0), size=(h, w), mode='bilinear')
    img = TRANSFORMS['normalize'](img)
    return img

def call_constant_spacing(model, img_name, data_name, spacing, **kwargs):
    img = get_img_const_spacing(img_name, data_name, spacing)
    with torch.no_grad():
        return model(img.to(device)).cpu().detach().numpy()

def call_CLAM(model, img_name, data_name, bag_latents, **kwargs):
    patch_latents = load_patch_latents(img_name, data_name, bag_latents).unsqueeze(0)
    with torch.no_grad():
        logits, Y_prob, Y_hat, A_raw, results_dict = model(patch_latents.to(device))
        return Y_prob.cpu().detach().numpy(), A_raw.cpu().detach().numpy()
    

for model_name in model_kwargs:
    if "spacing" in model_kwargs[model_name]:
        model_kwargs[model_name]["call"] = call_constant_spacing
    elif "CLAM" in model_name:
        model_kwargs[model_name]["call"] = call_CLAM
    else:
        raise ValueError(f"Unknown call type for {model_name}")

# Visualization functions

In [6]:
def occlusion(model, img, patch_size=128, step_size=64, mode="inclusion", print_progress=True):
    # Make sure the image is divisible by the patch size
    h, w = img.shape[-2], img.shape[-1]
    h_padded, w_padded = int(np.ceil(h/patch_size))*patch_size, int(np.ceil(w/patch_size))*patch_size
    pad_h, pad_w = h_padded - h, w_padded - w

    pixels = img.squeeze().cpu()
    img = torch.nn.functional.pad(img, (pad_w, 0, pad_h, 0)) # the order is left, right, top, bottom
    h_padded, w_padded = img.shape[-2], img.shape[-1]

    if mode == "inclusion":
        # dummy_img = torch.zeros_like(img).squeeze()
        dummy_img = torch.zeros(3, 1024, 1024)
        # # Fill with mean color of the image
        # dummy_img[0] = pixels[0].mean()
        # dummy_img[1] = pixels[1].mean()
        # dummy_img[2] = pixels[2].mean()

        # Place patches in the middle of the dummy image
        patches_changed = dummy_img.unsqueeze(0).clone() # Shape: (1, 3, img_h, img_w)
        h_steps = h_padded // step_size
        w_steps = w_padded // step_size
        patches_changed = patches_changed.repeat(h_steps, w_steps, 1, 1, 1)
        for i in range(0, h_padded, step_size):
            for j in range(0, w_padded, step_size):
                # Place patch in the middle of patches_changed
                h_middle = dummy_img.shape[-2] // 2 - patch_size // 2
                w_middle = dummy_img.shape[-1] // 2 - patch_size // 2
                patch = img[:, :, i:i+patch_size, j:j+patch_size]
                ph, pw = patch.shape[-2], patch.shape[-1]
                patches_changed[i // step_size, j // step_size, :, h_middle:h_middle+ph, w_middle:w_middle+pw] = img[:, :, i:i+patch_size, j:j+patch_size]
        # Reshape the tensor to (B, 3, img_size, img_size)
        patches_changed = patches_changed.view(-1, 3, dummy_img.shape[-2], dummy_img.shape[-1])
    elif mode == "patch":
        # Just have patches by themselves
        h_steps = h_padded // step_size
        w_steps = w_padded // step_size
        patches_changed = torch.zeros(h_steps, w_steps, 3, patch_size, patch_size)
        for i in range(0, h_padded, step_size):
            for j in range(0, w_padded, step_size):
                # Place patch in the middle of patches_changed
                patch = img[:, :, i:i+patch_size, j:j+patch_size]
                ph, pw = patch.shape[-2], patch.shape[-1]
                h_middle = patch_size // 2 - ph // 2
                w_middle = patch_size // 2 - pw // 2
                patches_changed[i // step_size, j // step_size, :, h_middle:h_middle+ph, w_middle:w_middle+pw] = patch
        # Reshape the tensor to (B, 3, img_size, img_size)
        patches_changed = patches_changed.view(-1, 3, patch_size, patch_size)
    elif mode == "occlusion":
        with torch.no_grad():
            full_prediction = model(img.to(device)).cpu().detach().squeeze()
            if len(full_prediction) == 4:
                full_prediction = torch.nn.functional.softmax(full_prediction, dim=0)
                full_prediction = full_prediction[[1,2]]
        # Copy the image to the dummy image except for the patch
        h_steps = h_padded // step_size
        w_steps = w_padded // step_size
        patches_changed = img.repeat(h_steps, w_steps, 1, 1, 1)
        for i in range(0, h_padded, step_size):
            for j in range(0, w_padded, step_size):
                patches_changed[i // step_size, j // step_size, :, i:i+patch_size, j:j+patch_size] = 0
        # Reshape the tensor to (B, 3, img_size, img_size)
        patches_changed = patches_changed.view(-1, 3, h_padded, w_padded)

    # Get the model output for each image
    diff = torch.zeros(patches_changed.shape[0], 2)
    iterator = range(patches_changed.shape[0])
    if print_progress:
        iterator = tqdm(iterator)
    for i in iterator: # necessary for CUDA memory
        current_img = patches_changed[i].unsqueeze(0).to(device)
        with torch.no_grad():
            preds = model(current_img).cpu().detach().squeeze()
            if len(preds) == 4:
                preds = torch.nn.functional.softmax(preds, dim=0)
                preds = preds[[1,2]]
            diff[i] = preds
    if mode != "occlusion":
        # Threshold the difference at the minimum value
        diff[:, 0] = (diff[:, 0] - diff[:, 0].min())
        diff[:, 1] = (diff[:, 1] - diff[:, 1].min())
    else:
        # # Subtract the full prediction
        diff[:, 0] = full_prediction[0] - diff[:, 0]
        diff[:, 1] = full_prediction[1] - diff[:, 1]
        # Center at 0.5
        diff[:, 0] = diff[:, 0]*2 + 0.5
        diff[:, 1] = diff[:, 1]*2 + 0.5

    # Make heatmap as grid of patch outputs
    overexpression_heatmap = np.zeros((h_steps, w_steps))
    nullmutation_heatmap = np.zeros((h_steps, w_steps))

    # Make a mask to keep track of the number of patches that overlap in each pixel
    mask = np.zeros((h_steps, w_steps))

    steps_per_patch = patch_size // step_size
    for i in range(diff.shape[0]):
        row = i // w_steps
        col = i % w_steps
        overexpression_heatmap[row:row+steps_per_patch, col:col+steps_per_patch] += diff[i, 0].item()
        nullmutation_heatmap  [row:row+steps_per_patch, col:col+steps_per_patch] += diff[i, 1].item()
        mask                  [row:row+steps_per_patch, col:col+steps_per_patch] += 1

    # Crop to the original image size
    pad_steps_h = pad_h // step_size
    pad_steps_w = pad_w // step_size
    overexpression_heatmap = overexpression_heatmap[pad_steps_h:, pad_steps_w:]
    nullmutation_heatmap = nullmutation_heatmap[pad_steps_h:, pad_steps_w:]
    mask = mask[pad_steps_h:, pad_steps_w:]

    # Normalize the img by dividing by the mask
    overexpression_heatmap /= mask
    nullmutation_heatmap /= mask

    return overexpression_heatmap, nullmutation_heatmap

def vis_constant_spacing(model, img_name, data_name, spacing, **kwargs):
    img = get_img_const_spacing(img_name, data_name, spacing) # (1, 3, H, W)
    return list(occlusion(model, img, **kwargs))


def load_attention_maps(model_name, data_name):
    if data_name.startswith("test+pathxl"):
        results = defaultdict(lambda: defaultdict(dict))
        for data_name in ["test", "pathxl"]:
            results_data = load_attention_maps(model_name, data_name)
            for key in results_data:
                results[key].update(results_data[key])
        return results

    results_dir = os.path.join(RESULTS_DIR, model_name)
    results = defaultdict(lambda: defaultdict(dict))
    best_checkpoint_idx = best_checkpoint_dict[model_name]
    checkpoint_name = [f for f in os.listdir(results_dir) if f.startswith(data_name)][best_checkpoint_idx]
    results_path = os.path.join(results_dir, checkpoint_name)
    result_content = torch.load(results_path)
    if "CLAM_db" in model_name:
        for img_name, (presence_probs, A_raw) in result_content.items():
            results["presence_probs"][img_name] = presence_probs
            results["status_probs"][img_name] = convert_presence_probs_to_status_probs(torch.tensor(presence_probs)).numpy()
            results["A_raw"][img_name] = A_raw
    elif "CLAM" in model_name:
        for img_name, (status_probs, A_raw) in result_content.items():
            results["status_probs"][img_name] = status_probs
            results["A_raw"][img_name] = A_raw
    else:
        raise ValueError(f"Unsupported model type {model_name}")
    return results

def vis_CLAM(model, img_name, data_name, attention_maps, non_empty_patch_indices, biopsy_dims):
    prediction = np.round(attention_maps["status_probs"][img_name],2)
    pred_argmax = np.argmax(prediction)
    print("Prediction:", prediction)
    # Rank the class predictions, like prediction=[0.1, 0.6, 0.2, 0.1] -> [1, 2, 0, 3]
    pred_rank = torch.argsort(-torch.tensor(prediction))[0]
    multipliers = [1,1,1,1]
    for i in range(len(pred_rank)):
        branch = pred_rank[i]
        multipliers[branch] = 1/(i+1)
    multipliers = multipliers[1:] # Skip the WT class

    attention_map = attention_maps["A_raw"][img_name] # (K, N)
    K = attention_map.shape[0]
    non_empty_indices = non_empty_patch_indices[img_name] # (N,)
    img_h, img_w = biopsy_dims[str(img_name)]

    patch_size = 256
    patch_rows = max(round(img_h / patch_size), 1)
    patch_cols = max(round(img_w / patch_size), 1)
    # The patch attention is a 1D array and corresponds to the non-empty indices
    all_patch_attention = torch.zeros(2, patch_rows * patch_cols)
    if K == 4:
        # for i in range(3):
        #     all_patch_attention[i].fill_(attention_map[1:][i].min())
        # all_patch_attention[:, non_empty_indices] = torch.tensor(attention_map)[1:] # Skip the WT class

        all_patch_attention[0].fill_(attention_map[pred_argmax].min())
        all_patch_attention[0, non_empty_indices] = torch.tensor(attention_map)[pred_argmax]
        all_patch_attention = all_patch_attention[:1]

        # all_patch_attention.fill_(attention_map.min())
        # all_patch_attention[:, non_empty_indices] = torch.tensor(attention_map)[1:3]

        # all_patch_attention = torch.zeros(4, patch_rows * patch_cols)
        # all_patch_attention.fill_(attention_map.min())
        # all_patch_attention[:, non_empty_indices] = torch.tensor(attention_map)
    elif K == 2:
        all_patch_attention.fill_(attention_map.min())
        all_patch_attention[:, non_empty_indices] = torch.tensor(attention_map)
        # Normalize the attention values to [0, 1]
        # all_patch_attention = (all_patch_attention - all_patch_attention.min()) / (all_patch_attention.max() - all_patch_attention.min())
        # Convert attention scores to percentiles
        # all_patch_attention[pred_argmax] = torch.tensor(rankdata(all_patch_attention[pred_argmax], method='ordinal') / len(all_patch_attention[pred_argmax]))

    for i in range(len(all_patch_attention)):
        # all_patch_attention[i] = (all_patch_attention[i] - all_patch_attention[i].min()) / (all_patch_attention[i].max() - all_patch_attention[i].min()) * multipliers[i]
        all_patch_attention[i] = (all_patch_attention[i] - all_patch_attention[i].min()) / (all_patch_attention.max() - all_patch_attention[i].min())

    heatmaps = [a.numpy() for a in all_patch_attention.reshape(-1, patch_rows, patch_cols)]
    return heatmaps


def load_mask(img_name, data_name):
    mask_path = os.path.join(BASE_DIR[data_name], 'masks', f"{img_name}.png")
    if not os.path.exists(mask_path):
        return None
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    return mask

def plot_heatmap(heatmaps, img_name, data_name, downsample=0.25, in_one=False):
    img = load_img(img_name, data_name) # (H, W, C) float32
    img = cv2.resize(img, (0,0), fx=downsample, fy=downsample)  
    for i, heatmap in enumerate(heatmaps):
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # For upsampling, this uses linear interpolation
        heatmaps[i] = heatmap
    mask = load_mask(img_name, data_name)
    if mask is not None:
        mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    heatmap_colors = ['Oranges', 'Blues'] if len(heatmaps) == 2 else ['Oranges', 'Blues', 'Purples']
    if not in_one:
        # Plot the img and then the heatmaps overlayed one by one to the side
        fig, ax = plt.subplots(1, len(heatmaps)+1, figsize=(20, 10))
        ax[0].imshow(img)
        ax[0].axis('off')
        masks = [None, mask] if len(heatmaps) == 2 else [None, None, mask, mask]
        for i, heatmap in enumerate(heatmaps):
            show_img = img.copy()
            if masks[i] is not None:
                contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(show_img, contours[0], -1, (0, 0, 0), 2)
            ax[i+1].imshow(show_img)
            ax[i+1].imshow(heatmap, alpha=0.3, cmap='jet', vmin=0, vmax=1)
            ax[i+1].axis('off')
    else:
        # Plot the img and then the heatmaps overlayed on top of each other
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        # Add the mask
        if mask is not None:
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img, contours[0], -1, (0, 0, 0), 2)
            ax.imshow(img, alpha=0.3)
        ax.imshow(img)
        ax.axis('off')
        for i, h in enumerate(heatmaps):
            heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
            ax.imshow(h, alpha=0.3, cmap=heatmap_colors[i], vmin=0, vmax=1)
    plt.show()

for model_name in model_kwargs:
    if "spacing" in model_kwargs[model_name]:
        model_kwargs[model_name]["vis"] = vis_constant_spacing
    elif "CLAM" in model_name:
        model_kwargs[model_name]["vis"] = vis_CLAM
    else:
        raise ValueError(f"Unknown call type for {model_name}")

In [None]:


model_name = "fb_db_spacing4"
model = load_model(model_kwargs[model_name]["model_class"], 
                   model_kwargs[model_name]["checkpoint_paths"][best_checkpoint_dict[model_name]])

# img_name = 34
img_name = 1133
data_name = "test"
spacing = 4
# heatmaps = vis_constant_spacing(model, img_name, data_name, spacing, patch_size=128, step_size=64, 
#                      mode="inclusion")
# plot_heatmap(heatmaps, img_name, data_name, downsample=0.25, in_one=False)
# heatmaps = vis_constant_spacing(model, img_name, data_name, spacing, patch_size=128, step_size=64, 
#                      mode="patch")
# plot_heatmap(heatmaps, img_name, data_name, downsample=0.25, in_one=False)

heatmaps = {}
# for img_name in [34, 1133, 192, 1416, 7]:
for img_name in ["0_0", "5_0"]:
    heatmaps[img_name] = vis_constant_spacing(model, img_name, data_name, spacing, patch_size=128, step_size=64, 
                        mode="occlusion")

In [None]:
for img_name, hms in heatmaps.items():
    oe_heatmap, nm_heatmap = tuple(hms)
    new_heatmaps = [oe_heatmap.copy(), nm_heatmap.copy()]
    for i, heatmap in enumerate(new_heatmaps):
        print("min, max")
        print(heatmap.min(), heatmap.max())
        # median = np.median(heatmap)
        # heatmap = heatmap - median
        # max_val = heatmap.max()
        # if max_val != 0:
        #     heatmap = heatmap / np.abs(heatmap).max() * 0.5

        # print(heatmap.min(), heatmap.max())
        # heatmap *= 2
        # heatmap = heatmap + 0.5
        new_heatmaps[i] = heatmap
    # display(heatmaps)
    plot_heatmap(new_heatmaps, img_name, data_name, downsample=0.25, in_one=False)
    # vis_constant_spacing(model, img_name, data_name, spacing, patch_size=512, step_size=256, 
    #                      mode="occlusion")

In [None]:

non_empty_patch_indices = torch.load(os.path.join(BASE_DIR[data_name], "non_empty_patch_indices_gs256.pt"))
with open(os.path.join(BASE_DIR[data_name], "biopsy_dims.json"), "r") as f:
    biopsy_dims = json.load(f)

for model_name in ["CLAM", "CLAM_m", "CLAM_db", "CLAM_db_m"]:
    print(model_name)
    attention_maps = load_attention_maps(model_name, data_name)
    heatmaps = vis_CLAM(None, img_name, None, attention_maps, non_empty_patch_indices, biopsy_dims)

    plot_heatmap(heatmaps, img_name, data_name, downsample=0.25, in_one=False)

In [None]:
# Load CLAM_m model and call it on the image
model_name = "CLAM_m"
model = load_model(model_kwargs[model_name]["model_class"], 
                   model_kwargs[model_name]["checkpoint_paths"][best_checkpoint_dict[model_name]])
img_name = 1133
data_name = "test"
bag_latents = torch.load(bag_latent_paths[data_name])
Y_prob, A_raw = call_CLAM(model, img_name, data_name, bag_latents)
heatmaps = vis_CLAM(None, img_name, None, {"status_probs":{img_name:Y_prob}, "A_raw":{img_name:A_raw}}, non_empty_patch_indices, biopsy_dims)

plot_heatmap(heatmaps, img_name, data_name, downsample=0.25, in_one=False)

In [11]:
def get_last_conv_layer(model):
    for name, layer in reversed(list(model.named_modules())):
        if isinstance(layer, torch.nn.Conv2d):
            return name, layer
    raise ValueError("No convolutional layer found in the model")

class GradCam:
    def __init__(self, model, target_layer, target_layer_name):
        self.model = model
        self.target_layer = target_layer
        self.target_layer_name = target_layer_name
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hook_feature_map()
        self.hook_feature_grad()

    def hook_feature_map(self):
        def hook_fn(module, input, output):
            self.feature_map = output
        self.target_layer.register_forward_hook(hook_fn)

    def hook_feature_grad(self):
        def hook_fn(module, grad_input, grad_output):
            self.feature_grad = grad_output[0]
        self.target_layer.register_backward_hook(hook_fn)

    def forward(self, x):
        return self.model(x)

    def backward(self, output, target_class):
        self.model.zero_grad()
        one_hot_output = torch.zeros((1, output.size()[-1]), dtype=torch.float32, device=output.device)
        one_hot_output[0][target_class] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def __call__(self, x):
        cams = []
        output = self.forward(x)
        n_outputs = output.size()[-1]
        for target_class in range(n_outputs):
            self.backward(output, target_class)
            if self.feature_grad is None or self.feature_map is None:
                raise ValueError("Feature gradients or feature maps are not set. Check hooks.")
            
            weights = torch.mean(self.feature_grad, dim=(2, 3)).squeeze()
            cam = torch.tensordot(weights, self.feature_map.squeeze(0), dims=([0], [0]))
            # cam = torch.nn.functional.relu(cam)
            
            # Avoid NaNs in normalization
            if torch.isnan(cam).any():
                raise ValueError("CAM contains NaN values before normalization.")
            cam = cam.detach().cpu().numpy()

            cam_min, cam_max = cam.min(), cam.max()
            if cam_max == cam_min:
                # print("CAM has uniform values. Check model and target class.")
                pass
            else:
                cam /= (cam_max - cam_min)
            cam += 0.5
            cams.append(cam)
        return output, cams

def get_grad_cam(model, img_name, data_name):
    img = get_img_const_spacing(img_name, data_name, 4)
    last_conv_layer_name, last_conv_layer = get_last_conv_layer(model)
    grad_cam = GradCam(model, last_conv_layer, last_conv_layer_name)
    output, cams = grad_cam(img.to(device))
    cams = cams if len(cams) == 2 else cams[1:3]
    return output, cams

In [None]:
model_name = "fb_spacing4"
model = load_model(model_kwargs[model_name]["model_class"], 
                   model_kwargs[model_name]["checkpoint_paths"][best_checkpoint_dict[model_name]])

display(model.model)

# Eval functions

In [12]:
def get_img_names(data_name='test'):
    if data_name.startswith("test+pathxl"):
        return get_img_names("test") + get_img_names("pathxl")
    if data_name == "test":
        return pd.read_csv(os.path.join(BASE_DIR[data_name], 'test.csv'))["id"].tolist()
    return [img_path.split('.')[0] for img_path in os.listdir(os.path.join(BASE_DIR[data_name], 'biopsies'))]

def get_results(model, mdl_kwargs, data_name='test', bag_latents=None):
    img_names = get_img_names(data_name)
    results = {}
    for img_name in tqdm(img_names):
        results[img_name] = mdl_kwargs["call"](model, img_name, data_name, bag_latents=bag_latents, **mdl_kwargs)
    return results

def save_results(results, model_name, data_name, checkpoint_name):
    results_dir = os.path.join(VIS_DIR, model_name)
    os.makedirs(results_dir, exist_ok=True)
    results_path = os.path.join(results_dir, f"{data_name}_{checkpoint_name}.pt")
    torch.save(results, results_path)

In [None]:
# # Run visualizations for fb models
# model_heatmaps = { # data_name: {model_name: {img_name: {heatmap_type: heatmaps}}}
#     "test": {
#         "fb_spacing4": defaultdict(dict),
#         "fb_db_spacing4": defaultdict(dict),
#     },
# }
vis_save_path = os.path.join(VIS_DIR, "fb_models.pt")
model_heatmaps = torch.load(vis_save_path)

data_name = "test"
for model_name in ["fb_spacing4", "fb_db_spacing4"]:
    print(model_name)
    best_checkpoint_idx = best_checkpoint_dict[model_name]
    model = load_model(model_kwargs[model_name]["model_class"], model_kwargs[model_name]["checkpoint_paths"][best_checkpoint_idx])
    img_names = get_img_names(data_name)
    for img_name in tqdm(img_names):
        for heatmap_type in ["inclusion", "patch", "occlusion", "gradcam"]:
            # Skip if we already have the results
            if img_name in model_heatmaps[data_name][model_name] and heatmap_type in model_heatmaps[data_name][model_name][img_name]:
                continue

            if heatmap_type == "gradcam":
                output, heatmaps = get_grad_cam(model, img_name, data_name)
            else:
                heatmaps = vis_constant_spacing(model, img_name, data_name, model_kwargs[model_name]["spacing"], 
                                                        patch_size=128, step_size=64, mode=heatmap_type, print_progress=False)
            model_heatmaps[data_name][model_name][img_name][heatmap_type] = heatmaps

vis_save_path = os.path.join(VIS_DIR, "fb_models.pt")
torch.save(model_heatmaps, vis_save_path)

# Get Metrics
Load results and labels

In [14]:
"""
fb_db   has a dict like idx: shape (1,2) with the two mutation probabilities
fb      has a dict like idx: shape (1,4) with the four class probabilities
CLAM_db has a dict like idx: tuple of: (
            shape (1,2) with the two mutation probabilities,
            shape (2, n_patches) Attention map 
    )
CLAM    has a dict like idx: tuple of: (
            shape (1,4) with the four class probabilities,
            shape (2, n_patches) Attention map 
        )
"""
model_type_result_keys = {
    "fb_db": ["presence_probs"],
    "fb": ["status_probs"],
    "CLAM_db": ["presence_probs", "A_raw"],
    "CLAM": ["status_probs", "A_raw"],
}
def get_result_keys(model_name):
    for key in model_type_result_keys: # The order is important
        if key in model_name:
            return model_type_result_keys[key]

def load_results(model_name, data_name):
    if data_name.startswith("test+pathxl"):
        results = defaultdict(lambda: defaultdict(dict))
        for data_name in ["test", "pathxl"]:
            results_data = load_results(model_name, data_name)
            for key in results_data:
                results[key].update(results_data[key])
        return results

    results_dir = os.path.join(RESULTS_DIR, model_name)
    results = defaultdict(lambda: defaultdict(dict))
    for i, checkpoint_name in enumerate([f for f in os.listdir(results_dir) if f.startswith(data_name)]):
        results_path = os.path.join(results_dir, checkpoint_name)
        result_content = torch.load(results_path)
        checkpoint_name = checkpoint_name.replace(f"{data_name}_", "").replace(".pt", "")
        if "CLAM_db" in model_name:
            for img_name, (presence_probs, A_raw) in result_content.items():
                results["presence_probs"][img_name][checkpoint_name] = presence_probs
                results["status_probs"][img_name][checkpoint_name] = convert_presence_probs_to_status_probs(torch.tensor(presence_probs)).numpy()
                results["A_raw"][img_name][checkpoint_name] = A_raw
        elif "CLAM" in model_name:
            for img_name, (status_probs, A_raw) in result_content.items():
                results["status_probs"][img_name][checkpoint_name] = status_probs
                results["A_raw"][img_name][checkpoint_name] = A_raw
        elif "fb_db" in model_name:
            for img_name, presence_probs in result_content.items():
                results["presence_probs"][img_name][checkpoint_name] = presence_probs
                results["status_probs"][img_name][checkpoint_name] = convert_presence_probs_to_status_probs(torch.tensor(presence_probs)).numpy()
        elif "fb" in model_name:
            for img_name, status_probs in result_content.items():
                # results["status_probs"][img_name][checkpoint_name] = status_probs
                results["status_probs"][img_name][checkpoint_name] = torch.nn.functional.softmax(torch.tensor(status_probs), dim=1).numpy()
        else:
            raise ValueError(f"Unsupported model type {model_name}")
    return results

def get_labels(data_name):
    if data_name == "test":
        return pd.read_csv(os.path.join(BASE_DIR[data_name], 'test.csv')).set_index("id").to_dict(orient='dict')['label']
    elif data_name == "bolero":
        labels = pd.read_csv(os.path.join(BASE_DIR[data_name], 'P53_BOLERO_T.csv'))
        labels = labels.sort_values(by="Case ID")
        labels = labels.reset_index(drop=True)
        # Map GS to {1:0, 2:1, 3:2, 4:4} where 4 is unknown
        labels["GS"] = labels["GS"].map({1:0, 2:1, 3:2, 4:4})
        # Only keep GS column
        labels = labels[["GS"]].to_dict(orient='dict')["GS"]
        return labels
    elif data_name.startswith("pathxl"):
        labels = pd.read_csv(os.path.join(BASE_DIR["pathxl"], 'labels.csv'))
        # idx is id column and biopsy_nr column separated by _
        labels["idx"] = labels["id"].astype(str) + "_" + labels["biopsy_nr"].astype(str)
        labels = labels.set_index("idx")
        # Sort by id primarily and biopsy_nr secondarily
        labels = labels.sort_values(by=["id", "biopsy_nr"])
        # Map label
        mapping = {"WT":0, "Overexpression":1, "Null":2, "Double clones":3}
        labels["label"] = labels["label"].map(mapping)
        if data_name == "pathxl": # Filter out any concordance % < 75
            labels = labels[labels["concordance %"] >= 75]
        elif data_name == "pathxl-100":
            labels = labels[labels["concordance %"] == 100]
        labels = labels[["label"]].to_dict(orient='dict')["label"]
        return labels
    elif data_name == "test+pathxl":
        labels = get_labels("test")
        labels.update(get_labels("pathxl"))
        return labels
    elif data_name == "test+pathxl-100":
        labels = get_labels("test")
        labels.update(get_labels("pathxl-100"))
        return labels
    
# load_results("fb_spacing4", "test")["status_probs"]
# get_labels("test+pathxl-100")

# Visualizations

## Full Biopsy

In [None]:
# load biopsy dims
with open(os.path.join(BASE_DIR["test"], "biopsy_dims.json"), "r") as f:
    biopsy_dims = json.load(f)

img_names = get_img_names("test")
print(img_names[:5])

# Get largest biopsies in test set
biopsy_dims = {str(k): biopsy_dims[str(k)] for k in img_names}
sorted_biopsies = sorted(biopsy_dims.items(), key=lambda x: min(x[1]), reverse=True)
sorted_biopsies[:10]
display(sorted_biopsies[:10])

# Show little thumbnails of these biopsies downsampled by 8x
n = 4
fig, ax = plt.subplots(n, 5, figsize=(20, 5*n))
labels = get_labels("test")
for i, (img_name, dims) in enumerate(sorted_biopsies[:n*5]):
    img = get_img_const_spacing(img_name, "test", 8)
    ax[i//5, i%5].imshow(img.squeeze().permute(1, 2, 0))
    ax[i//5, i%5].axis('off')
    label = labels[img_name]
    ax[i//5, i%5].set_title(img_name + f"\n({label})")

In [18]:
model_name_mapping = {
    "fb_spacing4": "Full-Biopsy Multiclass",
    "fb_db_spacing4": "Full-Biopsy Double-Binary",
    "CLAM": "CLAM",
    "CLAM_db": "CLAM Double-Binary",
    "CLAM_m": "CLAM +DC",
    "CLAM_db_m": "CLAM Double-Binary +DC",
}
model_name_mapping = {
    "fb_spacing4": "FB",
    "fb_db_spacing4": "FBdb",
    "CLAM": "CLAM",
    "CLAM_db": "CLAMdb",
    "CLAM_m": "CLAM+DC",
    "CLAM_db_m": "CLAMdb+DC",
}

In [None]:
cuts_dict = {
    # # 1416: "right",
    # 1133: "top",
    # # 240: "bottom",
    # 192: "right",
    # # 2: "left",
    # 7: "right",
    # 34: "left",
    "168_3": "left",
    "195_5": "top",
    "265_2": "top",
    "492_1": "bottom",
}

indices = [k for k in cuts_dict]
cuts = [cuts_dict[k] for k in cuts_dict]
data_name = "test"
labels = get_labels(data_name)

def cut_img(img, cut):
    h, w = img.shape[0], img.shape[1]
    if cut == "top":
        img = img[-w:]
    elif cut == "bottom":
        img = img[:w]
    elif cut == "left":
        img = img[:, -h:]
    elif cut == "right":
        img = img[:, :h]
    return img

# Make plot to compare the different heatmap methods for fb and fb_db
model_names = ["fb_spacing4", "fb_db_spacing4"]

heatmap_types = ["gradcam", "occlusion", "inclusion", "patch"]
heatmap_type_mapping = {
    "gradcam": "Grad-CAM",
    "occlusion": "occlusion",
    "inclusion": "inclusion",
    "patch": "patch prediction",
}
n_cols = len(heatmap_types) + 1 # +1 for the original image
n_rows = len(indices)
# Make separate plots for each model
for model_name in model_names[:1]:
    print(model_name)
    results = load_results(model_name, data_name)
    best_checkpoint_idx = best_checkpoint_dict[model_name]
    best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"][indices[0]]][best_checkpoint_idx]
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, n_rows*4.5))
    for i, img_idx in enumerate(indices):
        img_name = img_idx
        img = load_img(img_name, data_name)
        # Resize to be 0.25 of the original size
        img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

        # Make image square by cropping the longer side on the cuts[j] side
        h, w = img.shape[0], img.shape[1]
        img = cut_img(img, cuts[i])

        # Open mask and make contour
        img_with_mask = img.copy()
        mask = load_mask(img_name, data_name)
        if mask is not None:
            mask = cv2.resize(mask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_mask, contours[0], -1, (0, 0, 0), 2)
        img_with_OE_mask = img.copy()
        mask = load_mask(f"{img_name}OE", data_name)
        if mask is not None:
            mask = cv2.resize(mask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_OE_mask, contours[0], -1, (0, 0, 0), 2)

        show_img = img.copy()
        # Place small text of the image index in the top left corner
        # Align text top and left
        ax[i, 0].imshow(show_img)
        ax[i, 0].text(5, 5, ['a','b','c','d'][i], fontsize=20, color='black', verticalalignment='top', horizontalalignment='left')
        ax[i, 0].axis('off')

        label = P53_CLASS_CODES[labels[img_name]]
        prediction = P53_CLASS_CODES[np.argmax(results["status_probs"][img_name][best_checkpoint_name])]
        title = f"Label {label}\nPrediction {prediction}"
        ax[i, 0].set_title(title, 
                        #    align text right
                            loc='right',
                           )

        # Set two y ticks for the second column, one for overexpression and one for nullmutation
        # Calculate the tick locations to match up with the grid of heatmaps
        tick_locs = [y * 2*h for y in [0.25, 0.75]]
        ax[i, 1].set_yticks(tick_locs)
        ax[i, 1].set_yticklabels(["OE heatmap", "NM heatmap"], rotation=90, verticalalignment='center', horizontalalignment='right')
        for j, heatmap_type in enumerate(heatmap_types):
            heatmaps = model_heatmaps[data_name][model_name][img_name][heatmap_type] # (2, H, W) torch float32

            if heatmap_type == "occlusion":
                heatmaps = [(hm - 0.5)/2 + 0.5 for hm in heatmaps]

            # Resize the heatmaps to the original image size
            heatmaps = [cv2.resize(hm, (w, h)) for hm in heatmaps]
            # Cut the heatmaps to the same size as the image
            heatmaps = [cut_img(hm, cuts[i]) for hm in heatmaps]

            # Make 1x2 grid of heatmaps with make_grid
            img_underlay = [torch.tensor(img_with_OE_mask).permute(2, 0, 1),
                            torch.tensor(img_with_mask).permute(2, 0, 1)]
            img_underlay = torchvision.utils.make_grid(img_underlay, nrow=1, normalize=True).permute(1, 2, 0).numpy()
            heatmap = torchvision.utils.make_grid(torch.tensor(heatmaps).unsqueeze(1), nrow=1, normalize=False).permute(1, 2, 0).numpy()[:,:,0]

            ax[i, j+1].imshow(img_underlay)
            ax[i, j+1].imshow(heatmap, cmap='jet', vmin=0, vmax=1, alpha=0.3)
            ax[i, j+1].set_title(heatmap_type_mapping[heatmap_type])
            ax[i, j+1].set_xticks([])
            if j > 0:
                ax[i, j+1].set_yticks([])
    plt.tight_layout()
    # Set suptitle
    # fig.suptitle(model_name_mapping[model_name], fontsize=16, y=1.02)
    plt.show()

In [None]:
# Convert HEX color_dict to a list of RGB [0,1] values
print(color_dict)
color_dict_RGB = {k: tuple([int(h[i:i+2], 16)/255 for i in (1, 3, 5)]) for k, h in color_dict.items()}
print(color_dict_RGB)

In [None]:
# Load biopsy dims
with open(os.path.join(BASE_DIR["test"], "biopsy_dims.json"), "r") as f:
    biopsy_dims = json.load(f)

# Load non-empty patch indices
non_empty_patch_indices = torch.load(os.path.join(BASE_DIR["test"], "non_empty_patch_indices_gs256.pt"))

# Make plot to compare the different heatmap methods for fb and fb_db
# model_names = list(reversed(["CLAM", "CLAM_db", "CLAM_m", "CLAM_db_m", "fb_spacing4", "fb_db_spacing4"]))
model_names = ["fb_spacing4", "fb_db_spacing4", "CLAM", "CLAM_db", "CLAM_m", "CLAM_db_m"]

heatmap_type = "inclusion"
n_cols = len(model_names) + 1 # +1 for the original image
n_rows = len(indices)
fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, n_rows*4.5))
for j, model_name in enumerate(model_names[:]):
    print(model_name)
    results = load_results(model_name, data_name)
    if "CLAM" in model_name:
        attention_maps = load_attention_maps(model_name, data_name)
    best_checkpoint_idx = best_checkpoint_dict[model_name]
    best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"][indices[0]]][best_checkpoint_idx]
    for i, img_idx in enumerate(indices):
        img_name = img_idx
        img = load_img(img_name, data_name)
        # Resize to be 0.25 of the original size
        img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

        # Make image square by cropping the longer side on the cuts[j] side
        h, w = img.shape[0], img.shape[1]
        img = cut_img(img, cuts[i])

        # Open mask and make contour
        img_with_mask = img.copy()
        NMmask = load_mask(img_name, data_name)
        if NMmask is not None:
            mask = cv2.resize(NMmask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_mask, contours[0], -1, (0, 0, 0), 2)
        img_with_OE_mask = img.copy()
        OEmask = load_mask(f"{img_name}OE", data_name)
        if OEmask is not None:
            mask = cv2.resize(OEmask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_OE_mask, contours[0], -1, (0, 0, 0), 2)
        img_with_both_masks = img.copy()
        if NMmask is not None and OEmask is not None:
            mask = cv2.resize(NMmask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_both_masks, contours[0], -1, color_dict_RGB["b"], 3)
            mask = cv2.resize(OEmask, (w,h))
            mask = cut_img(mask, cuts[i])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_both_masks, contours[0], -1, color_dict_RGB["o"], 3)

        blanco = np.ones_like(img)

        show_img = img.copy()
        # Place small text of the image index in the top left corner
        # Align text top and left
        ax[i, 0].imshow(show_img)
        ax[i, 0].text(5, 5, ['a','b','c','d'][i], fontsize=20, color='black', verticalalignment='top', horizontalalignment='left')
        ax[i, 0].axis('off')
        
        label = P53_CLASS_CODES[labels[img_name]]
        pred = np.argmax(results["status_probs"][img_name][best_checkpoint_name])
        prediction = P53_CLASS_CODES[pred]
        title = f"Label {label}"
        ax[i, 0].set_title(title, 
                        #    align text right
                            loc='right',
                           )

        # Set two y ticks for the second column, one for overexpression and one for nullmutation
        # Calculate the tick locations to match up with the grid of heatmaps
        tick_locs = [y * 2*h for y in [0.25, 0.75]]
        ax[i, -1].set_yticks(tick_locs)
        ax[i, -1].set_yticklabels(["OE heatmap", "NM heatmap"], rotation=90, verticalalignment='center', horizontalalignment='left')
        # Place these on the right side of the image
        ax[i, -1].yaxis.tick_right()

        if model_name in ["fb_spacing4", "fb_db_spacing4"]:
            heatmaps = model_heatmaps[data_name][model_name][img_name][heatmap_type] # (2, H, W) torch float32
        else:
            heatmaps = vis_CLAM(None, img_name, None, attention_maps, non_empty_patch_indices, biopsy_dims)
            # if len(heatmaps) != 2:
            #     heatmaps = [heatmaps[pred]]

        # Resize the heatmaps to the original image size
        heatmaps = [cv2.resize(hm, (w, h)) for hm in heatmaps]
        # Cut the heatmaps to the same size as the image
        heatmaps = [cut_img(hm, cuts[i]) for hm in heatmaps]

        if len(heatmaps) == 2:
            # Make 1x2 grid of heatmaps with make_grid
            img_underlay = [torch.tensor(img_with_OE_mask).permute(2, 0, 1),
                            torch.tensor(img_with_mask).permute(2, 0, 1)]
            img_underlay = torchvision.utils.make_grid(img_underlay, nrow=1, normalize=True).permute(1, 2, 0).numpy()
            heatmap = torchvision.utils.make_grid(torch.tensor(heatmaps).unsqueeze(1), nrow=1, normalize=False).permute(1, 2, 0).numpy()[:,:,0]
        elif len(heatmaps) == 1:
            if prediction == "OE":
                img_underlay = [torch.tensor(img_with_OE_mask).permute(2, 0, 1),
                                torch.tensor(blanco).permute(2, 0, 1)]
                heatmaps = [heatmaps[0], np.zeros_like(heatmaps[0])]
            elif prediction == "NM":
                img_underlay = [torch.tensor(blanco).permute(2, 0, 1),
                                torch.tensor(img_with_mask).permute(2, 0, 1)]
                heatmaps = [np.zeros_like(heatmaps[0]), heatmaps[0]]
            elif prediction == "DC":
                img_underlay = [torch.tensor(blanco).permute(2, 0, 1),
                                torch.tensor(img_with_both_masks).permute(2, 0, 1),
                                torch.tensor(blanco).permute(2, 0, 1)]
                heatmaps = [np.zeros_like(heatmaps[0]), heatmaps[0], np.zeros_like(heatmaps[0])]
            img_underlay = torchvision.utils.make_grid(img_underlay, nrow=1, normalize=True).permute(1, 2, 0).numpy()
            heatmap = torchvision.utils.make_grid(torch.tensor(heatmaps).unsqueeze(1), nrow=1, normalize=False).permute(1, 2, 0).numpy()[:,:,0]

            if prediction == "DC":
                # Cut h//2 from the top and bottom
                img_underlay = img_underlay[h//2:-h//2]
                heatmap = heatmap[h//2:-h//2]

        ax[i, j+1].imshow(img_underlay)
        ax[i, j+1].imshow(heatmap, cmap='jet', vmin=0, vmax=1, alpha=0.3)
        title = f"{model_name_mapping[model_name]}\nPrediction {prediction}"
        ax[i, j+1].set_title(title)
        ax[i, j+1].set_xticks([])
        if j < 5:
            ax[i, j+1].set_yticks([])
plt.tight_layout()
# Set suptitle
# fig.suptitle("heatmap comparison between models", fontsize=16, y=1.02)

# Decrease the space between the subplots
plt.subplots_adjust(hspace=-0.4, wspace=0.05)

plt.show()

# Visualize the hard cases

most importantly:
- Why is FB better at NM than CLAM? Receptive field? Feature extractor? Also test the CLAM model with ResNet18 encoder
- Why is CLAM suddenly better at NM and OE in BOLERO? (is it just better at generalizing to new data or is it actually better at this data than LANS, and FB worse at it?)
    - See CLAM heatmaps of the 8 failed NM from LANS and the 4 NM from BOLERO
    - See FB heatmaps of those same cases

also:
- Why do the FBdb DC still go wrong? Difference in NM in DC, maybe less of it or looks different?
- Why is CLAMdb+DC not as good as CLAM+DC? (is it the process-of-elimination thing? Could test this looking at the predictions of CLAM+DC and see if it's indeed 0 if not the class with DC being constant)

Tested CLAM with ResNet18_tuned, and also with retccl_tuned and retccl_tuned_hybrid

The ResNet18_tuned performed better on the NM in LANS and identical to CLAM further. The ones it underclassifies look like the CLAM models with NM a bit better (and similar tiny trouble with OE)
On BOLERO however, it looks almost exactly like the FB model, likely because it uses the same feature extractor.

The retccl_tuned performed even better on NM but worse on OE. looks more like the FB than the CLAM models on LANS, even though it didn't inherit anything from the FB models (except finetuning the feature extractor on their labels).
On BOLERO, it also looks very much like the FB models.

The retccl_tuned_hybrid was similar in every way to the retccl_tuned, but with a slightly better OE performance on LANS.

This suggests that finetuning is what helps achieve better NM performance on LANS, but also decreases performance on BOLERO due to overfitting to the training domain. Spacing and patch size don't seem to play a role in this.

In [None]:
# Get img_names for the NM biopsies from the pathxl data that were classified as WT by CLAM
# Load results
model_name = "CLAM"
data_name = "pathxl"
results = load_results(model_name, data_name)
best_checkpoint_idx = best_checkpoint_dict[model_name]
best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"]['0_1']][best_checkpoint_idx]
# Load labels
labels = get_labels(data_name)
# Get img_names
img_names = get_img_names(data_name)
# Get intersection of label keys and img_names
img_names = list(set(labels.keys()).intersection(set(img_names)))
# Get the NM biopsies
img_names = [img_name for img_name in img_names if labels[img_name] == 2]
# Get the WT predictions from CLAM
img_names = [img_name for img_name in img_names if np.argmax(results["status_probs"][img_name][best_checkpoint_name]) == 0]

img_names


# Visualize NM and DC of test set

In [None]:
# Visualize all 15 NM and 15 DC with the FB inclusion heatmap
# model_name = "fb_spacing4"
model_name = "CLAM_db_m"
data_name = "test"
heatmap_type = "inclusion"
results = load_results(model_name, data_name)
best_checkpoint_idx = best_checkpoint_dict[model_name]
best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"][indices[0]]][best_checkpoint_idx]

# First all 15 NM
indices = [img_name for img_name in get_img_names(data_name) if labels[img_name] == 2]
# Then all 15 DC
indices += [img_name for img_name in get_img_names(data_name) if labels[img_name] == 3]

fig, ax = plt.subplots(10, 3, figsize=(12, 10*4.5))
for i, img_idx in enumerate(indices):
    ax_row = i // 3
    ax_col = i % 3

    img_name = img_idx
    img = load_img(img_name, data_name)
    # Resize to be 0.25 of the original size
    img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

    # Make image square by cropping the longer side on the cuts[j] side
    h, w = img.shape[0], img.shape[1]
    # cut = cuts[i]
    cut = "right"
    img = cut_img(img, cut)

    # Open mask and make contour
    img_with_mask = img.copy()
    NMmask = load_mask(img_name, data_name)
    if NMmask is not None:
        mask = cv2.resize(NMmask, (w,h))
        mask = cut_img(mask, cut)
        contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(img_with_mask, contours[0], -1, (0, 0, 0), 2)
    img_with_OE_mask = img.copy()
    OEmask = load_mask(f"{img_name}OE", data_name)
    if OEmask is not None:
        mask = cv2.resize(OEmask, (w,h))
        mask = cut_img(mask, cut)
        contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(img_with_OE_mask, contours[0], -1, (0, 0, 0), 2)

    # show_img = img.copy()
    # Place small text of the image index in the top left corner
    # Align text top and left
    # ax[i, 0].imshow(show_img)
    # ax[i, 0].text(5, 5, ['a','b','c','d'][i], fontsize=20, color='black', verticalalignment='top', horizontalalignment='left')
    # ax[i, 0].axis('off')
    
    if model_name in ["fb_spacing4", "fb_db_spacing4"]:
        heatmaps = model_heatmaps[data_name][model_name][img_name][heatmap_type] # (2, H, W) torch float32
    elif "CLAM" in model_name:
        attention_maps = load_attention_maps(model_name, data_name)
        heatmaps = vis_CLAM(None, img_name, None, attention_maps, non_empty_patch_indices, biopsy_dims)

    # Resize the heatmaps to the original image size
    heatmaps = [cv2.resize(hm, (w, h)) for hm in heatmaps]
    # Cut the heatmaps to the same size as the image
    heatmaps = [cut_img(hm, cut) for hm in heatmaps]

    # Make 1x2 grid of heatmaps with make_grid
    img_underlay = [torch.tensor(img_with_OE_mask).permute(2, 0, 1),
                    torch.tensor(img_with_mask).permute(2, 0, 1)]
    img_underlay = torchvision.utils.make_grid(img_underlay, nrow=2, normalize=True).permute(1, 2, 0).numpy()
    heatmap = torchvision.utils.make_grid(torch.tensor(heatmaps).unsqueeze(1), nrow=2, normalize=False).permute(1, 2, 0).numpy()[:,:,0]

    ax[ax_row, ax_col].imshow(img_underlay)
    ax[ax_row, ax_col].imshow(heatmap, cmap='jet', vmin=0, vmax=1, alpha=0.3)

    # Patch aggregate prediction
    # Take max for both OE and NM
    scores = [hm.max() for hm in heatmaps]
    
    label = P53_CLASS_CODES[labels[img_name]]
    pred = results["status_probs"][img_name][best_checkpoint_name][0]
    title = f"Label {label}\nPred {[round(p,2) for p in pred]}\nScores {[round(s,2) for s in scores]}"
    ax[ax_row, ax_col].set_title(title, 
                    #    align text right
                        loc='right',
                          )

    ax[ax_row, ax_col].set_xticks([])
    ax[ax_row, ax_col].set_yticks([])
plt.tight_layout()
# Set suptitle for the model
fig.suptitle(model_name_mapping[model_name], fontsize=16, y=.88)

# Decrease the space between the subplots
plt.subplots_adjust(hspace=-0.8, wspace=0.1)

# Save the figure
plt.savefig(os.path.join(VIS_DIR, f"{model_name}_inclusion_NM_DC_heatmap.png"))

### GradCAM

In [None]:
# Visualize gradient-based class activation maps (Grad-CAM)
#
# Grad-CAM is a technique to visualize the regions of the image that are important for the model's
# prediction. It does this by computing the gradients of the output class with respect to the
# feature maps of the last convolutional layer of the model. The gradients are then used to compute
# a weighted sum of the feature maps, where the weights are the gradients. The resulting heatmap
# is then overlaid on the original image to visualize the important regions.
#
# The code below is adapted from the PyTorch Grad-CAM tutorial:
# https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#grad-cam
def get_last_conv_layer(model):
    for name, layer in reversed(list(model.named_modules())):
        if isinstance(layer, torch.nn.Conv2d):
            return name, layer
    raise ValueError("No convolutional layer found in the model")

class GradCam:
    def __init__(self, model, target_layer, target_layer_name):
        self.model = model
        self.target_layer = target_layer
        self.target_layer_name = target_layer_name
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hook_feature_map()
        self.hook_feature_grad()

    def hook_feature_map(self):
        def hook_fn(module, input, output):
            self.feature_map = output
        self.target_layer.register_forward_hook(hook_fn)

    def hook_feature_grad(self):
        def hook_fn(module, grad_input, grad_output):
            self.feature_grad = grad_output[0]
        self.target_layer.register_backward_hook(hook_fn)

    def forward(self, x):
        return self.model(x)

    def backward(self, output, target_class):
        self.model.zero_grad()
        one_hot_output = torch.zeros((1, output.size()[-1]), dtype=torch.float32, device=output.device)
        one_hot_output[0][target_class] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def __call__(self, x):
        cams = []
        output = self.forward(x)
        print(output)
        n_outputs = output.size()[-1]
        for target_class in range(n_outputs):
            self.backward(output, target_class)
            if self.feature_grad is None or self.feature_map is None:
                raise ValueError("Feature gradients or feature maps are not set. Check hooks.")
            if self.feature_grad.sum() == 0:
                print("No gradient for target class", target_class)
            if self.feature_map.sum() == 0:
                print("No feature map for target class", target_class)

            if n_outputs == 2:
                weights = self.model.model.fc[0].weight[target_class]
            else:
                weights = self.model.model.fc.weight[target_class]
            
            weights = torch.mean(self.feature_grad, dim=(2, 3)).squeeze()
            cam = torch.tensordot(weights, self.feature_map.squeeze(0), dims=([0], [0]))
            # cam = torch.nn.functional.relu(cam)
            
            # Avoid NaNs in normalization
            if torch.isnan(cam).any():
                raise ValueError("CAM contains NaN values before normalization.")
            cam = cam.detach().cpu().numpy()
            
            cam_min, cam_max = cam.min(), cam.max()
            print(cam_min, cam_max)
            if cam_max == cam_min:
                # print("CAM has uniform values. Check model and target class.")
                pass
            else:
                # cam /= (cam_max - cam_min)
                cam /= np.abs(cam).max()
                pass
            cam += 0.5
            cams.append(cam)
        return output, cams

def get_grad_cam(model, img_name, data_name):
    img = get_img_const_spacing(img_name, data_name, 4)
    last_conv_layer_name, last_conv_layer = get_last_conv_layer(model)
    grad_cam = GradCam(model, last_conv_layer, last_conv_layer_name)
    output, cams = grad_cam(img.to(device))
    cams = cams if len(cams) == 2 else cams[1:3]
    return output, cams


# Load the model
model_name = "fb_db_spacing4"
checkpoint_path = model_kwargs[model_name]["checkpoint_paths"][best_checkpoint_dict[model_name]]
model = load_model(model_kwargs[model_name]["model_class"], checkpoint_path)
model.eval()

img_name = 7
data_name = "test"
output, cams = get_grad_cam(model, img_name, data_name)

plot_heatmap(cams, img_name, data_name, downsample=0.25, in_one=False)

## Receptive Field

In [None]:
# Visualize how large ResNet18s receptive field is for different image sizes
from resnet import ResNetModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

# Create a dummy input image
img_size = 2048
img = torch.zeros(1, 3, img_size, img_size)
img[:, :, img_size//2, img_size//2] = 1
img[:, :, img_size//4, img_size//4] = 1
# img[:, :, img_size//4:3*img_size//4, img_size//4:3*img_size//4] = 1

plt.imshow(img.squeeze().numpy().transpose(1, 2, 0))
plt.show()

# Load the model
model = models.resnet18()
# Set all parameters to 1
# for param in model.parameters():
#     param.data.fill_(1)
# Cut off before the pooling layer
# model = nn.Sequential(*list(list(model.children())[0].children())[:-2])
model = nn.Sequential(*list(model.children())[:-2])
# display(model)
model.eval()

# Get the output of the model
with torch.no_grad():
    output = model(img)

# Plot the output
output = output.squeeze().numpy().transpose(1, 2, 0).max(axis=2)
# display(output)
plt.imshow(output, cmap='gray', vmin=0, vmax=1)
plt.show()

# Demo

In [None]:
# Find biopsies with highest predicted probability for WT, OE, NM and DC respectively, specifically for the CLAM_db_m model
model_name = "CLAM_db_m"
data_name = "test"
results = load_results(model_name, data_name)
labels = get_labels(data_name)
best_checkpoint_idx = best_checkpoint_dict[model_name]
best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"][indices[0]]][best_checkpoint_idx]

attention_maps = load_attention_maps(model_name, data_name)
non_empty_patch_indices = torch.load(os.path.join(BASE_DIR[data_name], "non_empty_patch_indices_gs256.pt"))
with open(os.path.join(BASE_DIR[data_name], "biopsy_dims.json"), "r") as f:
    biopsy_dims = json.load(f)

# Get the indices of the biopsies with the highest predicted probability for each class
top_indices = {}
indices = [k for k in results["status_probs"]]
for class_idx in range(4):
    top_indices[class_idx] = sorted(indices, key=lambda x: results["status_probs"][x][best_checkpoint_name][0][class_idx], reverse=True)[:5]
    print(top_indices[class_idx])

# Visualize the biopsies with the highest predicted probability for each class
n_cols = 5
n_rows = 4
fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, n_rows*4.5))
for i, class_idx in enumerate(top_indices):
    for j, img_idx in enumerate(top_indices[class_idx]):
        img_name = img_idx
        if "_" in str(img_name):
            data_name = "pathxl"
        else:
            data_name = "test"
        img = load_img(img_name, data_name)
        # Resize to be 0.25 of the original size
        img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

        # Make image square by cropping the longer side on the cuts[j] side
        h, w = img.shape[0], img.shape[1]
        # img = cut_img(img, cuts[j])

        # Open mask and make contour
        img_with_mask = img.copy()
        mask = load_mask(img_name, "test")
        if mask is not None:
            mask = cv2.resize(mask, (w,h))
            # mask = cut_img(mask, cuts[j])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_mask, contours[0], -1, (0, 0, 0), 2)
        img_with_OE_mask = img.copy()
        mask = load_mask(f"{img_name}OE", "test")
        if mask is not None:
            mask = cv2.resize(mask, (w,h))
            # mask = cut_img(mask, cuts[j])
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img_with_OE_mask, contours[0], -1, (0, 0, 0), 2)

        show_img = img.copy()
        # Place small text of the image index in the top left corner
        # Align text top and left

        # Load heatmaps
        heatmaps = vis_CLAM(None, img_name, None, attention_maps, non_empty_patch_indices, biopsy_dims)

        # Resize the heatmaps to the original image size
        heatmaps = [cv2.resize(hm, (w, h)) for hm in heatmaps]

        # Make 1x2 grid of heatmaps with make_grid
        img_underlay = [torch.tensor(img_with_OE_mask).permute(2, 0, 1),
                        torch.tensor(img_with_mask).permute(2, 0, 1)]
        img_underlay = torchvision.utils.make_grid(img_underlay, nrow=1, normalize=True).permute(1, 2, 0).numpy()
        heatmap = torchvision.utils.make_grid(torch.tensor(heatmaps).unsqueeze(1), nrow=1, normalize=False).permute(1, 2, 0).numpy()[:,:,0]

        ax[i, j].imshow(img_underlay)
        ax[i, j].imshow(heatmap, cmap='jet', vmin=0, vmax=1, alpha=0.3)
        # ax[i, j].imshow(show_img)

        ax[i, j].axis('off')

        label = P53_CLASS_CODES[labels[img_name]]
        prediction = P53_CLASS_CODES[np.argmax(results["status_probs"][img_name][best_checkpoint_name])]
        title = f"({img_name}) Label {label}\nPrediction {prediction} {results['status_probs'][img_name][best_checkpoint_name][0][class_idx]:.2f}"
        ax[i, j].set_title(title, 
                        #    align text right
                            loc='right',
                            )
plt.tight_layout()

In [None]:
indices = [192, 34, 0]
cuts = ["right", "left", "left"]

# Plot a 3x3 plot with each column for a different biopsy
# The first row is the original image with mask contours, the second row is the heatmap for the OE class and the third row is the heatmap for the NM class
n_cols = 3
n_rows = len(indices)
fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, n_rows*4.5))
for i, img_idx in enumerate(indices):
    img_name = img_idx
    img = load_img(img_name, "test")
    # Resize to be 0.25 of the original size
    img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

    # Make image square by cropping the longer side on the cuts[i] side
    h, w = img.shape[0], img.shape[1]
    img = cut_img(img, cuts[i])

    # Open mask and make contour
    img_with_NM_mask = img.copy()
    img_with_OE_mask = img.copy()
    img_with_both_masks = img.copy()
    mask = load_mask(img_name, "test")
    OEmask = load_mask(f"{img_name}OE", "test")
    if mask is not None:
        mask = cv2.resize(mask, (w,h))
        mask = cut_img(mask, cuts[i])
        contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(img_with_NM_mask, contours[0], -1, (0,0,0), 1)
        cv2.drawContours(img_with_both_masks, contours[0], -1, color_dict_RGB["b"], 2)
    if OEmask is not None:
        mask = cv2.resize(OEmask, (w,h))
        mask = cut_img(mask, cuts[i])
        contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(img_with_OE_mask, contours[0], -1, (0,0,0), 1)
        cv2.drawContours(img_with_both_masks, contours[0], -1, color_dict_RGB["o"], 2)

    # Place small text of the image index in the top left corner
    # Align text top and left
    ax[i, 0].imshow(img_with_both_masks)
    ax[i, 0].text(5, 5, ['a','b','c','d'][i], fontsize=20, color='black', verticalalignment='top', horizontalalignment='left')
    ax[i, 0].axis('off')

    label = P53_CLASS_CODES[labels[img_name]]
    prediction = P53_CLASS_CODES[np.argmax(results["status_probs"][img_name][best_checkpoint_name])]
    title = f"Label {label}"
    # ax[0, j].set_title(title)

    # Load heatmaps
    heatmaps = vis_CLAM(None, img_name, None, attention_maps, non_empty_patch_indices, biopsy_dims)

    # Resize the heatmaps to the original image size
    heatmaps = [cv2.resize(hm, (w, h)) for hm in heatmaps]

    # Cut the heatmaps to the same size as the image
    heatmaps = [cut_img(hm, cuts[i]) for hm in heatmaps]

    # Plot the img with the heatmaps overlaid
    ax[i, 1].imshow(img_with_OE_mask)
    ax[i, 1].imshow(heatmaps[0], cmap='jet', vmin=0, vmax=1, alpha=0.3)
    # ax[i, 1].set_title("OE heatmap")
    ax[i, 1].axis('off')

    ax[i, 2].imshow(img_with_NM_mask)
    ax[i, 2].imshow(heatmaps[1], cmap='jet', vmin=0, vmax=1, alpha=0.3)
    # ax[i, 2].set_title("NM heatmap")
    ax[i, 2].axis('off')
plt.tight_layout()

# Make subplots close to each other
plt.subplots_adjust(hspace=-0.25, wspace=0.05)

ax[0, 0].set_title("biopsy with annotations")
ax[0, 1].set_title("overexpression heatmap")
ax[0, 2].set_title("null-mutation heatmap")

ax[0, 0].legend(handles=[
    plt.Rectangle((0,0),1,1,fc=color_dict["o"], linewidth=1, label="overexpression"),
    plt.Rectangle((0,0),1,1,fc=color_dict["b"], linewidth=1, label="null-mutation"),
], loc='lower left')

plt.show()

# Example patches

In [None]:
# Find heatmaps with highest predicted probability for OE and NM respectively
model_name = "fb_db_spacing4"
data_name = "test"
results = load_results(model_name, data_name)
labels = get_labels(data_name)
best_checkpoint_idx = best_checkpoint_dict[model_name]
best_checkpoint_name = [checkpoint_name for checkpoint_name in results["status_probs"][indices[0]]][best_checkpoint_idx]

# Get the indices of the heatmaps with the highest max for each class
top_indices = {}
indices = [k for k in results["status_probs"]]
heatmap_type = "inclusion"
heatmaps = {img_name: model_heatmaps[data_name][model_name][img_name][heatmap_type] for img_name in indices}
for class_idx in range(2):
    top_indices[class_idx] = sorted(indices, key=lambda x: max([hm.max() for hm in heatmaps[x][class_idx]]), reverse=True)[:5]
    print(top_indices[class_idx])

# Also get indices of the heatmaps with the lowest min for both classes at the same time
bottom_indices = []
bottom_indices = sorted(indices, key=lambda x: min([hm.min() for hm in heatmaps[x][0]]), reverse=False)[:5]

top_indices = {
    0: [32, 1031, 222],
    1: [1416, 240, 1489],
}

# Visualize the heatmaps with the highest max for each class
n_cols = 5
n_rows = 3
fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, n_rows*4.5))
for i, class_idx in enumerate(top_indices):
    for j, img_idx in enumerate(top_indices[class_idx]):
        img_name = img_idx
        img = load_img(img_name, data_name)
        # Resize to be 0.25 of the original size

        # Make image square by cropping the longer side on the cuts[j] side
        h, w = img.shape[0], img.shape[1]
        # img = cut_img(img, cuts[j])

        # Open mask
        if class_idx == 1:
            mask = load_mask(img_name, data_name)
            mask = cv2.resize(mask, (w,h)) / 255
        else:
            mask = np.ones((h, w))

        # Take the 256x256 patch of the heatmap with the highest max
        heatmap = model_heatmaps[data_name][model_name][img_name][heatmap_type][class_idx] # (H, W) torch float32
        # Resize the heatmap to the original image size
        heatmap = cv2.resize(heatmap, (w, h))
        # Take AND of the heatmap and the mask
        heatmap = heatmap * mask
        # Get the indices of the max value in the heatmap
        max_idx = np.unravel_index(heatmap.argmax(), heatmap.shape)
        # Get the 256x256 patch around the max value, on the original image
        img = img[max_idx[0]-128:max_idx[0]+128, max_idx[1]-128:max_idx[1]+128]
        hm = heatmap[max_idx[0]-128:max_idx[0]+128, max_idx[1]-128:max_idx[1]+128]
        
        ax[i, j].imshow(img)
        # ax[i, j].imshow(hm, cmap='jet', vmin=0, vmax=1, alpha=0.3)

        ax[i, j].axis('off')

        label = P53_CLASS_CODES[labels[img_name]]
        prediction = P53_CLASS_CODES[np.argmax(results["status_probs"][img_name][best_checkpoint_name])]
        title = f"({img_name}) Label {label}\nPrediction {prediction} {hm.max():.2f}"
        ax[i, j].set_title(title, 
                        #    align text right
                            loc='right',
                            )
plt.tight_layout()
