In [983]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import os
import pandas as pd
from collections import defaultdict

from load_data import DATA_DIR, TRANSFORMS, P53_CLASS_NAMES, \
    convert_presence_probs_to_status_probs
from resnet import ResNetModel, ResNetModelDoubleBinary
# from resnet_patch import ResNetModelDoubleBinary as ResNetModelDoubleBinaryPatch
from pl_clam import CLAM_MB, CLAM_db

P53_CLASS_CODES = ["WT", "OE", "NM", "DC"]

BOLERO_DIR = os.path.join(DATA_DIR, '..', 'BOLERO')
PATHXL_DIR = os.path.join(DATA_DIR, '..', 'p53_consensus_study')

BASE_DIR = {
    'test': DATA_DIR,
    'bolero': BOLERO_DIR,
    'pathxl': PATHXL_DIR
}

RESULTS_DIR = os.path.join(DATA_DIR, '..', '..', 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)
VIS_DIR = os.path.join(DATA_DIR, '..', '..', 'visualizations')
MODELS_DIR = os.path.join(DATA_DIR, '..', '..', 'models')

bag_latent_paths = {
    "test":   os.path.join(BASE_DIR["test"], "bag_latents_gs256_retccl__backup.pt"),
    "bolero": os.path.join(BASE_DIR["bolero"], "bag_latents_gs256_retccl.pt"),
    "pathxl": os.path.join(BASE_DIR["pathxl"], "bag_latents_gs256_retccl.pt"),
}

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device: {}".format(device))

color_dict = {
    "r": '#DA4C4C', # Red
    "o": '#E57439', # Orange
    "y": '#EDB732', # Yellow
    "g": '#479A5F', # Green
    "lb": '#5BC5DB', # Light blue
    "b": '#5387DD', # Blue
    "p": '#7D54B2', # Purple
    "pi": '#E87B9F', # Pink
#  '#229487', # Dark green/Turquoise
#  '#C565C7', # Lilac
    "r_p": '#E89393', # Pale red
    "o_p": '#EFAB88', # Pale orange
    "y_p": '#F4D384', # Pale yellow
    "g_p": '#90C29F', # Pale green
    "lb_p":'#9CDCE9', # Pale light blue
    "b_p": '#98B7EA', # Pale blue
    "p_p": "#B198D0", # Pale purple
}
colors = list(color_dict.values())

Device: cuda


# Load Models

In [219]:
# fb: full biopsy, db: double binary, gs: grid spacing
model_kwargs = {
    "CLAM":     {"model_class": CLAM_MB, "gs": 256},
    "CLAM_db":  {"model_class": CLAM_db, "gs": 256},
    "CLAM_m":   {"model_class": CLAM_MB, "gs": 256},
    "CLAM_db_m":{"model_class": CLAM_db, "gs": 256},
}
for name in ["fb_db", "fb"]:
    for spacing in [2, 4, 8, 16, 32, 64, 128, 256]:
        if "db" not in name and spacing == 2: # Skipped this one because it's too slow
            continue
        model_name = f"{name}_spacing{spacing}"
        model_kwargs[model_name] = {"spacing": spacing}

for model_name in model_kwargs:
    if "fb_db" in model_name:
        model_kwargs[model_name]["model_class"] = ResNetModelDoubleBinary
    elif "fb" in model_name:
        model_kwargs[model_name]["model_class"] = ResNetModel
    checkpoint_dir = os.path.join(MODELS_DIR, model_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    model_kwargs[model_name]["checkpoint_paths"] = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]

def load_model(model_class, checkpoint_path):
    if "CLAM" in model_class.__name__: # For some reason pl can't load these models with load_from_checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # model = model_class(**checkpoint["model_kwargs"])
        model = model_class()
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model = model_class.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.to(device)
    return model

In [220]:
# Print how many checkpoints we have for each model
for model_name in model_kwargs:
    print(f"{model_name}: {len(model_kwargs[model_name]['checkpoint_paths'])} checkpoints")

CLAM: 5 checkpoints
CLAM_db: 5 checkpoints
CLAM_m: 5 checkpoints
CLAM_db_m: 5 checkpoints
fb_db_spacing2: 5 checkpoints
fb_db_spacing4: 5 checkpoints
fb_db_spacing8: 4 checkpoints
fb_db_spacing16: 5 checkpoints
fb_db_spacing32: 5 checkpoints
fb_db_spacing64: 5 checkpoints
fb_db_spacing128: 5 checkpoints
fb_db_spacing256: 5 checkpoints
fb_spacing4: 5 checkpoints
fb_spacing8: 5 checkpoints
fb_spacing16: 5 checkpoints
fb_spacing32: 5 checkpoints
fb_spacing64: 5 checkpoints
fb_spacing128: 5 checkpoints
fb_spacing256: 5 checkpoints


# Inference functions

In [221]:
def load_img(img_name, data_name):
    img_path = os.path.join(BASE_DIR[data_name], 'biopsies', f"{img_name}.png")
    img = plt.imread(img_path) # (H, W, C) float32
    return img
    
def load_patch_latents(img_name, data_name, bag_latents):
    if data_name == "test":
        return bag_latents[img_name].squeeze(1) # (N, 2048)
    elif data_name == "bolero":
        slide_name, biopsy_name = tuple(img_name.split("_"))
        slide_latents = bag_latents[int(slide_name)] # (n_biopsies, N, 2048)
        return slide_latents[int(biopsy_name)+1] # (N, 2048)
    elif data_name == "pathxl":
        return bag_latents[img_name].squeeze(1) # (N, 2048)


def call_constant_size(model, img_name, data_name, size, **kwargs):
    img = load_img(img_name, data_name)
    img = torch.nn.functional.interpolate(torch.tensor(img).permute(2, 0, 1).unsqueeze(0), size=(size, size), mode='bilinear')
    img = TRANSFORMS['normalize'](img)
    with torch.no_grad():
        return model(img.to(device)).cpu().detach().numpy()

def call_constant_spacing(model, img_name, data_name, spacing, **kwargs):
    img = load_img(img_name, data_name)
    h = img.shape[0] // spacing
    w = img.shape[1] // spacing
    img = torch.nn.functional.interpolate(torch.tensor(img).permute(2, 0, 1).unsqueeze(0), size=(h, w), mode='bilinear')
    img = TRANSFORMS['normalize'](img)
    with torch.no_grad():
        return model(img.to(device)).cpu().detach().numpy()

def call_CLAM(model, img_name, data_name, bag_latents, **kwargs):
    patch_latents = load_patch_latents(img_name, data_name, bag_latents).unsqueeze(0)
    with torch.no_grad():
        logits, Y_prob, Y_hat, A_raw, results_dict = model(patch_latents.to(device))
        return Y_prob.cpu().detach().numpy(), A_raw.cpu().detach().numpy()
    

for model_name in model_kwargs:
    if "size" in model_kwargs[model_name]:
        model_kwargs[model_name]["call"] = call_constant_size
    elif "spacing" in model_kwargs[model_name]:
        model_kwargs[model_name]["call"] = call_constant_spacing
    elif "CLAM" in model_name:
        model_kwargs[model_name]["call"] = call_CLAM
    else:
        raise ValueError(f"Unknown call type for {model_name}")

# Eval functions

In [1121]:
def get_img_names(data_name='test'):
    if data_name.startswith("test+pathxl"):
        return get_img_names("test") + get_img_names("pathxl")
    if data_name == "test":
        return pd.read_csv(os.path.join(BASE_DIR[data_name], 'test.csv'))["id"].tolist()
    return [img_path.split('.')[0] for img_path in os.listdir(os.path.join(BASE_DIR[data_name], 'biopsies'))]

def get_results(model, mdl_kwargs, data_name='test', bag_latents=None):
    img_names = get_img_names(data_name)
    results = {}
    for img_name in tqdm(img_names):
        results[img_name] = mdl_kwargs["call"](model, img_name, data_name, bag_latents=bag_latents, **mdl_kwargs)
    return results

def save_results(results, model_name, data_name, checkpoint_name):
    results_dir = os.path.join(VIS_DIR, model_name)
    os.makedirs(results_dir, exist_ok=True)
    results_path = os.path.join(results_dir, f"{data_name}_{checkpoint_name}.pt")
    torch.save(results, results_path)

In [502]:
for data_name in ["test", "bolero", "pathxl"]:
    print(f"\nRUNNING {data_name.upper()}")
    bag_latents = torch.load(bag_latent_paths[data_name], map_location=device)
    for model_name, mdl_kwargs in model_kwargs.items():
        for i, checkpoint_path in enumerate(mdl_kwargs["checkpoint_paths"][:]):
            checkpoint_name = os.path.basename(checkpoint_path).replace(".ckpt", "")
            results_path = os.path.join(VIS_DIR, model_name, f"{data_name}_{checkpoint_name}.pt")

            if os.path.exists(results_path):
                print(f"Already done {model_name}")
                continue
            print(f"Running {model_name} {i}")
            model = load_model(mdl_kwargs["model_class"], checkpoint_path)
            results = get_results(model, mdl_kwargs, data_name=data_name, bag_latents=bag_latents)
            save_results(results, model_name, data_name=data_name, checkpoint_name=checkpoint_name)


RUNNING TEST
Already done CLAM
Already done CLAM
Already done CLAM
Already done CLAM
Already done CLAM
Already done CLAM_db
Already done CLAM_db
Already done CLAM_db
Already done CLAM_db
Already done CLAM_db
Already done CLAM_m
Already done CLAM_m
Already done CLAM_m
Already done CLAM_m
Already done CLAM_m
Already done CLAM_db_m
Already done CLAM_db_m
Already done CLAM_db_m
Already done CLAM_db_m
Already done CLAM_db_m
Already done fb_db_spacing2
Already done fb_db_spacing2
Already done fb_db_spacing2
Already done fb_db_spacing2
Already done fb_db_spacing2
Already done fb_db_spacing4
Already done fb_db_spacing4
Already done fb_db_spacing4
Already done fb_db_spacing4
Already done fb_db_spacing4
Already done fb_db_spacing8
Already done fb_db_spacing8
Already done fb_db_spacing8
Already done fb_db_spacing8
Already done fb_db_spacing16
Already done fb_db_spacing16
Already done fb_db_spacing16
Already done fb_db_spacing16
Already done fb_db_spacing16
Already done fb_db_spacing32
Already do

# Get Metrics
Load results and labels

In [1125]:
"""
fb_db   has a dict like idx: shape (1,2) with the two mutation probabilities
fb      has a dict like idx: shape (1,4) with the four class probabilities
CLAM_db has a dict like idx: tuple of: (
            shape (1,2) with the two mutation probabilities,
            shape (2, n_patches) Attention map 
    )
CLAM    has a dict like idx: tuple of: (
            shape (1,4) with the four class probabilities,
            shape (2, n_patches) Attention map 
        )
"""
model_type_result_keys = {
    "fb_db": ["presence_probs"],
    "fb": ["status_probs"],
    "CLAM_db": ["presence_probs", "A_raw"],
    "CLAM": ["status_probs", "A_raw"],
}
def get_result_keys(model_name):
    for key in model_type_result_keys: # The order is important
        if key in model_name:
            return model_type_result_keys[key]

def load_results(model_name, data_name):
    if data_name.startswith("test+pathxl"):
        results = defaultdict(lambda: defaultdict(dict))
        for data_name in ["test", "pathxl"]:
            results_data = load_results(model_name, data_name)
            for key in results_data:
                results[key].update(results_data[key])
        return results

    results_dir = os.path.join(RESULTS_DIR, model_name)
    results = defaultdict(lambda: defaultdict(dict))
    for i, checkpoint_name in enumerate([f for f in os.listdir(results_dir) if f.startswith(data_name)]):
        results_path = os.path.join(results_dir, checkpoint_name)
        result_content = torch.load(results_path)
        checkpoint_name = checkpoint_name.replace(f"{data_name}_", "").replace(".pt", "")
        if "CLAM_db" in model_name:
            for img_name, (presence_probs, A_raw) in result_content.items():
                results["presence_probs"][img_name][checkpoint_name] = presence_probs
                results["status_probs"][img_name][checkpoint_name] = convert_presence_probs_to_status_probs(torch.tensor(presence_probs)).numpy()
                results["A_raw"][img_name][checkpoint_name] = A_raw
        elif "CLAM" in model_name:
            for img_name, (status_probs, A_raw) in result_content.items():
                results["status_probs"][img_name][checkpoint_name] = status_probs
                results["A_raw"][img_name][checkpoint_name] = A_raw
        elif "fb_db" in model_name:
            for img_name, presence_probs in result_content.items():
                results["presence_probs"][img_name][checkpoint_name] = presence_probs
                results["status_probs"][img_name][checkpoint_name] = convert_presence_probs_to_status_probs(torch.tensor(presence_probs)).numpy()
        elif "fb" in model_name:
            for img_name, status_probs in result_content.items():
                # results["status_probs"][img_name][checkpoint_name] = status_probs
                results["status_probs"][img_name][checkpoint_name] = torch.nn.functional.softmax(torch.tensor(status_probs), dim=1).numpy()
        else:
            raise ValueError(f"Unsupported model type {model_name}")
    return results

def get_labels(data_name):
    if data_name == "test":
        return pd.read_csv(os.path.join(BASE_DIR[data_name], 'test.csv')).set_index("id").to_dict(orient='dict')['label']
    elif data_name == "bolero":
        labels = pd.read_csv(os.path.join(BASE_DIR[data_name], 'P53_BOLERO_T.csv'))
        labels = labels.sort_values(by="Case ID")
        labels = labels.reset_index(drop=True)
        # Map GS to {1:0, 2:1, 3:2, 4:4} where 4 is unknown
        labels["GS"] = labels["GS"].map({1:0, 2:1, 3:2, 4:4})
        # Only keep GS column
        labels = labels[["GS"]].to_dict(orient='dict')["GS"]
        return labels
    elif data_name.startswith("pathxl"):
        labels = pd.read_csv(os.path.join(BASE_DIR["pathxl"], 'labels.csv'))
        # idx is id column and biopsy_nr column separated by _
        labels["idx"] = labels["id"].astype(str) + "_" + labels["biopsy_nr"].astype(str)
        labels = labels.set_index("idx")
        # Sort by id primarily and biopsy_nr secondarily
        labels = labels.sort_values(by=["id", "biopsy_nr"])
        # Map label
        mapping = {"WT":0, "Overexpression":1, "Null":2, "Double clones":3}
        labels["label"] = labels["label"].map(mapping)
        if data_name == "pathxl": # Filter out any concordance % < 75
            labels = labels[labels["concordance %"] >= 75]
        elif data_name == "pathxl-100":
            labels = labels[labels["concordance %"] == 100]
        labels = labels[["label"]].to_dict(orient='dict')["label"]
        return labels
    elif data_name == "test+pathxl":
        labels = get_labels("test")
        labels.update(get_labels("pathxl"))
        return labels
    elif data_name == "test+pathxl-100":
        labels = get_labels("test")
        labels.update(get_labels("pathxl-100"))
        return labels
    
# load_results("fb_spacing4", "test")["status_probs"]
# get_labels("test+pathxl-100")

# Visualizations

## CLAM

In [None]:
import json
with open("../CLAM/results/test_s1/clam_outputs.json", "r") as f:
    clam_outputs = json.load(f)

non_empty_patch_indices_by_biopsy = torch.load(os.path.join(DATA_DIR, "non_empty_patch_indices_gs256_relaxed.pt"))

test_idx = 19
output = clam_outputs[test_idx]
label = output['labels']
pred = output['preds']
class_probs = output['class_probs']
patch_attention = output['patch_attention']

idx = test_dataset.labels[test_idx][0]
img = plt.imread(os.path.join(DATA_DIR, "biopsies", f"{idx}.png")) # shape (h, w, 3)
patch_size = 256
# _, non_empty_indices = process_image(torch.tensor(img).permute(2,0,1), patch_size)
non_empty_indices = non_empty_patch_indices_by_biopsy[idx]
patch_rows = max(round(img.shape[0] / patch_size), 1)
patch_cols = max(round(img.shape[1] / patch_size), 1)
# The patch attention is a 1D array and corresponds to the non-empty indices
all_patch_attention = torch.zeros(patch_rows * patch_cols)
all_patch_attention[non_empty_indices] = torch.tensor(patch_attention)
all_patch_attention = all_patch_attention.reshape(patch_rows, patch_cols)
heatmap = cv2.resize(all_patch_attention.numpy(), (img.shape[1], img.shape[0]))
print(heatmap.min(), heatmap.max())
# heatmap = torch.nn.functional.sigmoid(torch.tensor(heatmap-heatmap.min())/10).numpy()

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img)
ax[0].set_title(f"Label: {P53_CLASS_NAMES[label]}, Pred: {P53_CLASS_NAMES[pred]}")

if label == 2:
    mask = plt.imread(os.path.join(DATA_DIR, "masks", f"{idx}.png"))
    # Draw black contour on the img
    contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(img, contours, -1, (0, 0, 0), 5)

ax[1].imshow(img)
ax[1].imshow(heatmap, alpha=0.3, cmap='jet'
            #  , vmin=5, vmax=10
             )
ax[1].set_title(f"Label: {P53_CLASS_NAMES[label]}, Pred: {P53_CLASS_NAMES[pred]}")

plt.show()

## Full-Biopsy

### GradCAM for non-db

In [None]:
# Visualize gradient-based class activation maps (Grad-CAM)
#
# Grad-CAM is a technique to visualize the regions of the image that are important for the model's
# prediction. It does this by computing the gradients of the output class with respect to the
# feature maps of the last convolutional layer of the model. The gradients are then used to compute
# a weighted sum of the feature maps, where the weights are the gradients. The resulting heatmap
# is then overlaid on the original image to visualize the important regions.
#
# The code below is adapted from the PyTorch Grad-CAM tutorial:
# https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#grad-cam
def get_last_conv_layer(model):
    for name, layer in reversed(list(model.named_modules())):
        if isinstance(layer, torch.nn.Conv2d):
            return name, layer
    raise ValueError("No convolutional layer found in the model")

class GradCam:
    def __init__(self, model, target_layer, target_layer_name):
        self.model = model
        self.target_layer = target_layer
        self.target_layer_name = target_layer_name
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hook_feature_map()
        self.hook_feature_grad()

    def hook_feature_map(self):
        def hook_fn(module, input, output):
            self.feature_map = output
        self.target_layer.register_forward_hook(hook_fn)

    def hook_feature_grad(self):
        def hook_fn(module, grad_input, grad_output):
            self.feature_grad = grad_output[0]
        self.target_layer.register_backward_hook(hook_fn)

    def forward(self, x):
        return self.model(x)

    def backward(self, output, target_class):
        self.model.zero_grad()
        one_hot_output = torch.zeros((1, output.size()[-1]), dtype=torch.float32, device=output.device)
        one_hot_output[0][target_class] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def __call__(self, x, target_class, img_size):
        output = self.forward(x)
        self.backward(output, target_class)
        if self.feature_grad is None or self.feature_map is None:
            raise ValueError("Feature gradients or feature maps are not set. Check hooks.")
        
        weights = torch.mean(self.feature_grad, dim=(2, 3)).squeeze()
        cam = torch.tensordot(weights, self.feature_map.squeeze(0), dims=([0], [0]))
        cam = torch.nn.functional.relu(cam)
        
        # Avoid NaNs in normalization
        if torch.isnan(cam).any():
            raise ValueError("CAM contains NaN values before normalization.")
        
        cam = torch.nn.functional.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(img_size, img_size), mode='bilinear', align_corners=False)
        cam = cam.squeeze(0).squeeze(0)
        
        cam_min, cam_max = cam.min(), cam.max()
        cam = (cam - cam_min)
        if cam_max == cam_min:
            print("CAM has uniform values. Check model and target class.")
        else:
            cam /= (cam_max - cam_min)
        
        return cam

def get_grad_cam(model, img, target_class, img_size):
    model.eval()
    last_conv_layer_name, last_conv_layer = get_last_conv_layer(model)
    grad_cam = GradCam(model, last_conv_layer, last_conv_layer_name)
    return grad_cam(img, target_class, img_size)

    
# Visualize the Grad-CAM heatmap for the predicted class
# Doubleclone idx: 56, 46, 53, 54
idx = np.random.randint(len(test_dataset))
idx = 19
print("Image index: ", idx)
img, label = test_dataset[idx]
img_size = img.size(-1)
show_img = img.permute(1, 2, 0)
img = img.unsqueeze(0).to(device)
outputs = model(img)
_, predicted = torch.max(outputs, 1)
print("Predicted: ", predicted.item())
grad_cam_class1 = get_grad_cam(model, img, 1, img_size).cpu().detach().numpy().squeeze()
grad_cam_class2 = get_grad_cam(model, img, 2, img_size).cpu().detach().numpy().squeeze()

# Normalize the img
show_img = (show_img - show_img.min()) / (show_img.max() - show_img.min())

# Plot the original image and the Grad-CAM heatmap, and the original image with the heatmap overlayed
fig, ax = plt.subplots(1, 3, figsize=(30, 30))
ax[0].imshow(show_img)
ax[0].set_title("Original image\nLabel: {}\nPredicted: {}".format(P53_CLASS_NAMES[label], P53_CLASS_NAMES[predicted.item()]))
ax[1].imshow(show_img)
ax[1].imshow(grad_cam_class1, vmin=0, vmax=1, alpha=0.5, cmap='jet')
ax[1].set_title("Grad-CAM heatmap\nOverexpression")
ax[2].imshow(show_img)
ax[2].imshow(grad_cam_class2, vmin=0, vmax=1, alpha=0.5, cmap='jet')
ax[2].set_title("Grad-CAM heatmap\nNull mutation")
ax[0].axis("off")
ax[1].axis("off")
ax[2].axis("off")
plt.show()

### GradCAM for db

In [None]:
# Visualize Class Activation Mapping (CAM) For Double Binary Model
#
# Keep in mind that in this case, we have a separate heatmap for the first and second binary classifier head
# The first heatmap will show the regions of the image that are important for the Overexpression prediction
# The second heatmap will show the regions of the image that are important for the Nullmutation prediction

# Visualize the CAM heatmap for the predicted class
idx = np.random.randint(len(test_dataset))
# idx = 19
print("Image index: ", idx)
img, label = test_dataset[idx]
img_size = img.size(-1)
show_img = img.permute(1, 2, 0)
img = img.unsqueeze(0).to(device)
outputs = model_db(img).cpu().detach().numpy().squeeze()
prediction = int(outputs[0] > 0.5) + 2 * int(outputs[1] > 0.5)
# In this case, the output is a tensor of shape (1, 2) with the independent predictions for each binary classifier
# Therefore we don't apply the argmax function to get the predicted class
print("Predicted: ", P53_CLASS_NAMES[prediction])
grad_cam_class1 = get_grad_cam(model_db, img, 0, img_size).cpu().detach().numpy().squeeze()
grad_cam_class2 = get_grad_cam(model_db, img, 1, img_size).cpu().detach().numpy().squeeze()
cam_class1 = get_cam(model_db, img, 0).cpu().detach().numpy().squeeze()
cam_class2 = get_cam(model_db, img, 1).cpu().detach().numpy().squeeze()

# Normalize the img
show_img = (show_img - show_img.min()) / (show_img.max() - show_img.min())

# Plot the original image and the grad-CAM and CAM heatmap
fig, ax = plt.subplots(2, 3, figsize=(15, 10))
ax[0, 0].imshow(show_img)
ax[0, 0].set_title("Original image\nLabel: {}\nPredicted: {}".format(P53_CLASS_NAMES[label], P53_CLASS_NAMES[prediction]))
ax[0, 1].imshow(show_img)
ax[0, 1].imshow(grad_cam_class1, vmin=0, vmax=1, alpha=0.5, cmap='jet')
ax[0, 1].set_title(f"Grad-CAM heatmap\nOverexpression {outputs[0]:.2f}")
ax[0, 2].imshow(show_img)
ax[0, 2].imshow(grad_cam_class2, vmin=0, vmax=1, alpha=0.5, cmap='jet')
ax[0, 2].set_title(f"Grad-CAM heatmap\nNullmutation {outputs[1]:.2f}")
ax[1, 1].imshow(show_img)
ax[1, 1].imshow(cam_class1, cmap='jet', vmin=0, vmax=1, alpha=0.5)
ax[1, 1].set_title(f"CAM heatmap\nOverexpression {outputs[0]:.2f}")
ax[1, 2].imshow(show_img)
ax[1, 2].imshow(cam_class2, cmap='jet', vmin=0, vmax=1, alpha=0.5)
ax[1, 2].set_title(f"CAM heatmap\nNullmutation {outputs[1]:.2f}")
plt.show()

### Occlusion sensitivity

### Inclusion sensitivity

In [None]:
def calc_heatmap(idx):
    img = plt.imread(os.path.join(DATA_DIR, 'biopsies', f'{idx}.png')) # Shape: (H, W, 3)
    img_size = 1024
    img = torch.tensor(img).permute(2, 0, 1).to(device) # Shape: (3, H, W)
    # Resize image
    img = torch.nn.functional.interpolate(img.unsqueeze(0), size=(img_size, img_size), mode='bilinear', align_corners=False)
    img = test_transform(img) # Shape: (1, 3, H, W)

    # Do the same as above but this time with overlap between the patches, to make the heatmap smoother
    # So for example with patch size 256, but step size 64, we will have 16 times more patches
    patch_size = 128
    step_size = 64

    # Pad the image to accommodate the step size when sliding the patches
    pad = (patch_size - step_size)
    img = torch.nn.functional.pad(img, (pad, 0, pad, 0)) # the order is left, right, top, bottom
    WT_img = torch.zeros(3, img_size, img_size)
    img_size = img.shape[-1]
    # Fill with mean color of the image
    pixels = img.squeeze().cpu()
    WT_img[0] = pixels[0].mean()
    WT_img[1] = pixels[1].mean()
    WT_img[2] = pixels[2].mean()
    WT_img = torch.nn.functional.pad(WT_img, (pad, 0, pad, 0))

    patches_added = WT_img.unsqueeze(0).clone() # Shape: (1, 3, img_size, img_size)
    steps = img_size // step_size
    middle = img_size // 2 - patch_size // 2
    patches_added = patches_added.repeat(steps, steps, 1, 1, 1)
    for i in range(0, img_size, step_size):
        for j in range(0, img_size, step_size):
            # Place patch in the middle of patches_added
            patch = img[:, :, i:i+patch_size, j:j+patch_size]
            w, h = patch.shape[-2], patch.shape[-1] # This is necessary because the patch can be smaller than patch_size at the edges
            patches_added[i // step_size, j // step_size, :, middle:middle+w, middle:middle+h] = img[:, :, i:i+patch_size, j:j+patch_size]

    # Reshape the tensor to (B, 3, img_size, img_size)
    patches_added = patches_added.view(-1, 3, img_size, img_size)

    # Crop to the original image size
    patches_added = patches_added[:, :, pad:, pad:]

    # Get the model output for each image
    diff = torch.zeros(patches_added.shape[0], 2)
    # for i in tqdm(range(patches_added.shape[0]), desc="Generating Heatmap"): # necessary for CUDA memory
    for i in range(patches_added.shape[0]): # necessary for CUDA memory
        current_img = patches_added[i].unsqueeze(0).to(device)
        with torch.no_grad():
            diff[i] = model_db(current_img).cpu().detach().squeeze()
    # Normalize the difference
    diff[:, 0] = (diff[:, 0] - diff[:, 0].min())
    diff[:, 1] = (diff[:, 1] - diff[:, 1].min())

    # Make heatmap as grid of patch outputs
    overexpression_heatmap = np.zeros((steps, steps))
    nullmutation_heatmap = np.zeros((steps, steps))

    # Make a mask to keep track of the number of patches that overlap in each pixel
    mask = np.zeros((steps, steps))

    steps_per_patch = patch_size // step_size
    for i in range(diff.shape[0]):
        row = i // steps
        col = i % steps
        overexpression_heatmap[row:row+steps_per_patch, col:col+steps_per_patch] += diff[i, 0].item()
        nullmutation_heatmap[row:row+steps_per_patch, col:col+steps_per_patch] += diff[i, 1].item()
        mask[row:row+steps_per_patch, col:col+steps_per_patch] += 1

    # Normalize the img by dividing by the mask
    overexpression_heatmap /= mask
    nullmutation_heatmap /= mask

    # Crop to the original image size
    pad_steps = pad // step_size
    overexpression_heatmap = overexpression_heatmap[pad_steps:, pad_steps:]
    nullmutation_heatmap = nullmutation_heatmap[pad_steps:, pad_steps:]

    return overexpression_heatmap, nullmutation_heatmap


def plot_heatmap(img, overexpression_heatmap, nullmutation_heatmap, mask=None):

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img)
    ax[0].set_title("Original image")
    # Overlay the heatmap on the image
    ax[1].imshow(img)
    # Rescale the heatmap to the image size
    overexpression_heatmap = cv2.resize(overexpression_heatmap, (img.shape[1], img.shape[0]), 
                                        # interpolation=cv2.INTER_NEAREST
                                        )
    ax[1].imshow(overexpression_heatmap, alpha=0.3, cmap='jet', vmin=0, 
                 vmax=1
                 )
    ax[1].set_title("Overexpression heatmap")

    # If mask is provided, draw its contours on the null mutation heatmap
    if mask is not None:
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(img, contours, -1, (0,0,0), 5)

    ax[2].imshow(img)
    nullmutation_heatmap = cv2.resize(nullmutation_heatmap, (img.shape[1], img.shape[0]), 
                                    #   interpolation=cv2.INTER_NEAREST
                                      )
    ax[2].imshow(nullmutation_heatmap, alpha=0.3, cmap='jet', vmin=0, 
                 vmax=1
                 )
    ax[2].set_title("Null mutation heatmap")

    plt.show()

## Receptive Field

In [None]:
# Visualize how large ResNet18s receptive field is for different image sizes
from resnet import ResNetModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

# Create a dummy input image
img_size = 1024
img = torch.zeros(1, 3, img_size, img_size)
img[:, :, img_size//2, img_size//2] = 1
img[:, :, img_size//4, img_size//4] = 1
# img[:, :, img_size//4:3*img_size//4, img_size//4:3*img_size//4] = 1

plt.imshow(img.squeeze().numpy().transpose(1, 2, 0))
plt.show()

# Load the model
model = models.resnet18()
# Set all parameters to 1
# for param in model.parameters():
#     param.data.fill_(1)
# Cut off before the pooling layer
# model = nn.Sequential(*list(list(model.children())[0].children())[:-2])
model = nn.Sequential(*list(model.children())[:-2])
# display(model)
model.eval()

# Get the output of the model
with torch.no_grad():
    output = model(img)

# Plot the output
output = output.squeeze().numpy().transpose(1, 2, 0).max(axis=2)
# display(output)
plt.imshow(output)
plt.show()