# WSOL - CAM
---

In [1]:
%load_ext autoreload
%autoreload 2
import sys
if '..' not in sys.path: sys.path.append("..")

In [2]:
%reset -f
%matplotlib inline

# Imports

In [3]:
from src.display.plot import imshow
from src.files.preprocess import to_grid
from src.display.cmap import parula_map
from src.utils.decorator import HiddenPrints
from concurrent.futures import ThreadPoolExecutor
from src.files.preprocess import image2axial
from src.cam import CAM
from src.agent import Agent,ThreadSafeReloadedModel
from src.files import preprocess


import seaborn as sns
import sys,os
import warnings
import nibabel as nib
import numpy as np
import itertools
import torchcam
import torch
import matplotlib.pyplot as plt

# Variables

In [4]:
FILE_NAMES = ['../data/SPM_categorised/AIH/AD/AD_ADNI_2489.nii', '../data/SPM_categorised/AIH/MCI/MCI_ADNI_1389.nii', '../data/SPM_categorised/AIH/CN/CN_ADNI_0442.nii']

MODEL_CHECKPOINT = "/var/metrics/codetests/logs/tb/final/resnet18_brew2/20210524192415/version_0/checkpoints/epoch=52-step=2808.ckpt"

NUM_WORKERS=4
NUM_TRIALS = 10
OBSERVED_CLASSES = [0,1,2]

TARGET_LAYER = "model.layer4"

THREADSAFE_CAM_EXTRACTOR = ThreadSafeReloadedModel(MODEL_CHECKPOINT, cam_type=torchcam.cams.GradCAMpp)
SMOOTHGRADCAMPP_KWARGS = {"std":1}

Global seed set to 420


# Evaluation

## Average CAM per Class

In [5]:
def trial_classes(extractor, images, classes=[0,1,2],num_trials=10, grid_kwgs={}):
    #extractor = self.get_cam(model, self.CAM_TYPE)
    iteration = 1
    for img_index, image in enumerate(images):
        masks = []

        for i in range(num_trials):
            print(f"Running trial: {iteration:2.0f}/{num_trials*len(images):2.0f}", end="\r")

            cam_extractor = extractor()

            class_scores, class_idx = cam_extractor.class_score(image)

            masks.append(torch.stack([
                preprocess.to_grid(preprocess.preprocess_image(cam_extractor.activations(idx, class_scores)), **grid_kwgs)
                for idx in classes
            ]))
            iteration += 1
    return cam_extractor.average_image(masks)

## Grad-CAM++ vs Smooth Grad-CAM++

Testing with Grad-CAM++

In [8]:
fig = CAM.plot(
    images=CAM.repeat_stack(CAM.preprocess(FILE_NAMES[0]), repeat=3),
    masks=trial_classes(
        ThreadSafeReloadedModel(MODEL_CHECKPOINT, cam_type=torchcam.cams.GradCAMpp), [CAM.preprocess(FILE_NAMES[0])], 
        grid_kwgs={'max_num_slices':16, 'nrow':4}
    ),
    labels=[0,1,2], 
    class_label="AD",
    architecture="/var/metrics/codetests/logs/tb/final/base2/resnet18_brew2/20210524224031/version_2/checkpoints/epoch=40-step=2172.ckpt".split("/")[8]
);

FileNotFoundError: No such file or no access: '../data/SPM_categorised/AIH/AD/AD_ADNI_2489.nii'

Testing with Smooth Grad-CAM++

In [None]:
fig = CAM.plot(
    images=CAM.repeat_stack(CAM.preprocess(FILE_NAMES[0]), repeat=3),
    masks=trial_classes(
        ThreadSafeReloadedModel(MODEL_CHECKPOINT, cam_type=torchcam.cams.SmoothGradCAMpp, cam_kwargs=SMOOTHGRADCAMPP_KWARGS), [CAM.preprocess(FILE_NAMES[0])], 
        grid_kwgs={'max_num_slices':16, 'nrow':4}
    ),
    labels=[0,1,2], 
    class_label="AD",
    architecture="/var/metrics/codetests/logs/tb/final/base2/resnet18_brew2/20210524224031/version_2/checkpoints/epoch=40-step=2172.ckpt".split("/")[8]
);

## With CN, MCI and AD scan

In [None]:
fig = CAM.plot(
    images=CAM.repeat_stack(CAM.preprocess(FILE_NAMES[0]), repeat=3),
    masks=trial_classes(
        THREADSAFE_CAM_EXTRACTOR, [CAM.preprocess(FILE_NAMES[0])],
        grid_kwgs={'max_num_slices':16, 'nrow':4}
    ),
    labels=[0,1,2], 
    class_label="AD",
    architecture="/var/metrics/codetests/logs/tb/final/base2/resnet18_brew2/20210524224031/version_2/checkpoints/epoch=40-step=2172.ckpt".split("/")[8]
);

fig = CAM.plot(
    images=CAM.repeat_stack(CAM.preprocess(FILE_NAMES[1]), repeat=3),
    masks=trial_classes(
        THREADSAFE_CAM_EXTRACTOR, [CAM.preprocess(FILE_NAMES[1])],
        grid_kwgs={'max_num_slices':16, 'nrow':4}
    ),
    labels=[0,1,2], 
    class_label="MCI",
    architecture="/var/metrics/codetests/logs/tb/final/base2/resnet18_brew2/20210524224031/version_2/checkpoints/epoch=40-step=2172.ckpt".split("/")[8]
);

fig = CAM.plot(
    images=CAM.repeat_stack(CAM.preprocess(FILE_NAMES[1]), repeat=3),
    masks=trial_classes(
        THREADSAFE_CAM_EXTRACTOR, [CAM.preprocess(FILE_NAMES[2])],
        grid_kwgs={'max_num_slices':16, 'nrow':4}
    ),
    labels=[0,1,2], 
    class_label="CN",
    architecture="/var/metrics/codetests/logs/tb/final/base2/resnet18_brew2/20210524224031/version_2/checkpoints/epoch=40-step=2172.ckpt".split("/")[8]
);

## Evaluate CAM distribution

In [None]:
def trial_average(extractor, images:list=[], num_trials:int=10):
        iteration = 1
        masks = []
        for img_index, image in enumerate(images):
            for i in range(num_trials):
                print(f"Running trial: {iteration:2.0f}/{num_trials*len(images):2.0f}", end="\r")
            
                cam_extractor = extractor()

                class_scores, class_idx = cam_extractor.class_score(image)
                masks.append(cam_extractor.activations(class_idx, class_scores))
                iteration += 1

        return preprocess.preprocess_image(CAM.average_image(masks))

In [None]:
IMAGES = [image[1].squeeze(0) for image in THREADSAFE_CAM_EXTRACTOR.get_validation_images(observe_classes=OBSERVED_CLASSES)]
averaged_mask = trial_average(
    THREADSAFE_CAM_EXTRACTOR, IMAGES
)

fig = CAM.plot(
    masks=torch.stack([preprocess.to_grid(averaged_mask,max_num_slices=16, nrow=4)])
);

In [None]:
plt.figure()
sns.kdeplot(np.hstack(np.hstack(averaged_mask)))
plt.xlabel("Intensity")
plt.ylabel("Density")
plt.title("Density vs intensity plot on CAM")

In [None]:
# Show density of each image
[sns.kdeplot(np.hstack(image)) for image in averaged_mask];
plt.xlabel("Intensity")
plt.ylabel("Density")
plt.title("Density vs intensity plot on CAM")