Necessary imports for the whole notebook

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

import matplotlib.pyplot as plt
import cv2

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm, trange

from numba import jit

Load the datagenerators

In [None]:
import json
from rationai.datagens.generators import BaseGeneratorPytorch
from rationai.utils.config import build_from_config, parse_configs_recursively


with open('prunning_teacher_student/prunning_ts_config.json') as ts_config_file:
    ts_config = json.load(ts_config_file)


datagen_bool = build_from_config(parse_configs_recursively("datagen_bool", cfg_store=ts_config["named_configs"]))
generators_dict_bool = datagen_bool.build_from_template()

train_generator_bool = generators_dict_bool['train_gen']
valid_generator_bool = generators_dict_bool['valid_gen']

datagen_ts = build_from_config(parse_configs_recursively("datagen_ts", cfg_store=ts_config["named_configs"]))
generators_dict_ts = datagen_ts.build_from_template()

train_generator_ts = generators_dict_ts['train_gen']
valid_generator_ts = generators_dict_ts['valid_gen']


In [None]:
batch_size = 1
train_generator_bool.set_batch_size(batch_size)
valid_generator_bool.set_batch_size(batch_size)
train_generator_ts.set_batch_size(batch_size)
valid_generator_ts.set_batch_size(batch_size)

## Modules setup

Define a GradCam module. The module should be able to be used with an arbitrary network topology that includes convolutional layers.
The idea of gradcam is that it looks into a model at the particular layer and weights the gradients flowing back from the following layers with the activations of the particular layer.


In [None]:
import torch.nn as nn
from collections import OrderedDict


class GradCam(nn.Module):
    """
    This Module performs grad cam computation for a model consisting of two Modules applied in order.
    To use it on an arbitrary model, the model has to be manually dissected first
    The dissection for a novolutional networs is usually performes in such a way that the convolutional layers are represented
    by the first module and the dense layers are left in the second Module
    The dissection can be arbitrary for as long as the output of the first model has a shape of [N, C, W, H],
    meaning it has to be a batch of size N, each element containing C feature maps of shape W x H (2D matrices)
    so that the spatial information can be transformed into a heatmap overlay for the input picture.
    """
    def __init__(self, model_before_cam: nn.Module, model_after_cam: nn.Module, *args, **kwargs):
        """Initialize the GradCam for a dissected model, split into two parts.
        The parts are expected to split at the layer of interest and together have to make the full model.

        Args:
            model_before_cam (nn.Module): Part of the inspected model containing up to the convolutional layer we want to inspect 
            model_after_cam (nn.Module): Remaining layers of the inspected model
        """
        super().__init__(*args, **kwargs)
        self.model_before_cam = model_before_cam
        self.model_after_cam = model_after_cam
        
        self.gradients = None
        self.activations = None
    
    # hook for stroing the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        # remember the activation
        self.activations = self.model_before_cam(x)
        
        # register the backward hook (we dont need the reference to it)
        _ = self.activations.register_hook(self.activations_hook)
        
        # apply the remaining layers
        return self.model_after_cam(self.activations)
        
    # getter fro the stored gradients
    def get_activations_gradient(self):
        return self.gradients
    
    # getter for the stored activations
    def get_activations(self, x):
        return self.activations


To make it easier, the GradCam module is written in such a way that it takes two modules on initialisation, which represent the two parts of the already dissected model. The model dissection has to be performed manually,
though I have provided a method for dissecting any torch.nn.Sequential model on an arbitrary layer index.


In [None]:
from prunning_gradcam.models import SequentialVGG16
from prunning_gradcam.grad_cam_tools import _load_params, dissect_sequential_model

import logging as log

# initialize the VGG model
vgg = SequentialVGG16()

# load weights from a file
_load_params(vgg, source_state_dict_path='transplanted-model.chkpt')

# dissect model into two parts
first_part, remaining_part = dissect_sequential_model(vgg, 30)

grad_cam = GradCam(first_part, remaining_part)


## Visualisation tools

For the purpose of visualizing the GradCAM outputs we have to have a way of plotting the heatmaps, overlaying them over the original images or somehow visualizing the activation regions.

For that purpose I have coded a function that can overlay several heatmaps over a single image in one step, next to each other in a grid. 
The last column in the grid shows all the heatmaps over each other clipped.

In [None]:
from typing import List, Tuple

from skimage.morphology import dilation, disk



def scaleshift_to_unit_range(img: np.ndarray):
    img = img - img.min()
    img_max = img.max()
    if img_max != 0:
        img *= 1/img_max
    return img


def resize_and_color(bitmap: np.ndarray, dims:Tuple[int]) -> np.ndarray:
    overlay = cv2.resize(bitmap, dims)
    overlay = np.uint8(255 * overlay)
    overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB)
    overlay = cv2.LUT(overlay, _COLORMAP_JETFROMBLACK)
    return overlay


def superimpose(image: np.ndarray, overlay: np.ndarray, strategy:str='ceil'):
    if strategy=='ceil':
        res = np.minimum(image + overlay*0.375, 255)
    elif strategy=='sub':
        res = np.maximum(image - overlay*0.5, 0)
    elif strategy=='linear_combination':
        res = image*0.5 + overlay*0.5
    elif strategy=='outline':
        footprint = disk(3)
        bool_map = overlay.sum(axis=2) > 0
        dilated = dilation(bool_map, footprint)
        outline = dilated & ~bool_map
        res = image.copy()
        # fill the outline with specific color
        res[outline] = [0, 0, 0]
    else:
        raise NotImplementedError(f'There is no strategy with name {strategy}!')

    return res.astype(np.uint8)


def _COLORMAP_JETFROMBLACK():
    lut = np.asarray([[np.ones(3)*i for _ in range(1)] for i in range(256)], dtype=np.uint8)
    lut = cv2.applyColorMap(lut, cv2.COLORMAP_JET)
    #lut = cv2.cvtColor(lut, cv2.COLOR_BGR2RGB)
    beginning = lut[:64, 0].astype(np.float64)
    beginning *= np.stack([np.arange(64) * (1/64) for _ in range(3)], axis=1)
    lut[:64, 0] = beginning.astype(np.uint8)
    return lut


def plot_bitmap_overlays(bitmaps: List[np.ndarray], base_image: np.ndarray, save_file:str=None, strategy='outline', heatmaps_titles=None):
    """
    Creates a grid plot from a list of overlays and a base image.
    Expects the overlay arrays in the list to be of shape [W1, H1] and the image of shape [3, W2, H2],
    where W1, H1, W2, H2 are widths and heights, not required to be of same size. Overlays are stretched to the image size.
    
    There are several rows plotted, leftmost column contain teh base images, rightmost column contains all overlays.
    Images in in columns in between represent each overlay separately

    Args:
        bitmaps (List[np.ndarray]): The bitmaps that are going to be transformed into overlays
        base_image (np.ndarray): The underlying image
        save_file (str, optional): If there is a string present, the figure is saved into a file. Defaults to None.
    """
    
    fig, ax = plt.subplots(nrows=2, ncols=len(bitmaps)+2, sharex=True, sharey=True, figsize=(5*(len(bitmaps)+2), 5*2))
   
    # mark the row indices for flexibility
    overlay_row = 0
    superimposed_row = 1

    # in the first column, just print the original images. The upper is clipped, the lower is rescaled, both to [0..1]
    ax[overlay_row][0].imshow(np.maximum(base_image, 0))
    ax[overlay_row][0].title.set_text('Input image clipped to [0..1]')
    base_image = scaleshift_to_unit_range(base_image)
    ax[superimposed_row][0].imshow(base_image)
    ax[superimposed_row][0].title.set_text('Input image rescaled to [0..1]')

    base_image_255 = (base_image*255)#.astype(np.uint8)
    
    # for each bitmaps get the overlays and show them
    for column, bitmap in enumerate(iterable=bitmaps, start=1):
        overlay = resize_and_color(bitmap, base_image.shape[:2])
        #print('OVERLAY MAXMIN', overlay.dtype, overlay.max(), overlay.min())
        ax[overlay_row][column].imshow(overlay)  # show the separate overlay on the first row
        if heatmaps_titles is not None:
            ax[overlay_row][column].title.set_text(heatmaps_titles[column-1])

        # impose the overlay on top of the base image
        ax[superimposed_row][column].imshow(superimpose(base_image_255, overlay, strategy))  # show the superimposed image on the second row
        
    # get the total averaged overlay from all the bitmaps
    overlay = resize_and_color(np.minimum(sum(bitmaps), 1.0), base_image.shape[:2])
    ax[overlay_row][-1].imshow(overlay)  # show the separate overlay on the first row
    ax[overlay_row][-1].title.set_text('All overlays stacked and clipped to [0..1]')

    # impose the overlay on top of the base image
    ax[superimposed_row][-1].imshow(superimpose(base_image_255, overlay, strategy))  # show the superimposed image on the second row

    if save_file is not None:
        fig.savefig(save_file, bbox_inches='tight')
    else:
        fig.show()
    plt.close(fig)

# precomputed stuff like color LUTs and so on
_COLORMAP_JETFROMBLACK = _COLORMAP_JETFROMBLACK()


In [None]:
from typing import Union


# average value counter
class AVG:
    """This counter should work for scalars and for torch.Tensors (maybe even for numpy.ndarrays)
    """
    sum_: Union[float, torch.Tensor]
    count: Union[float, torch.Tensor]
    value: Union[float, torch.Tensor]
    def __init__(self, sum_=None, count=None):
        self.sum_ = sum_
        self.count = count
        self.value = None
    
    def record(self, value, weight=1):
        if self.sum_ is None:
            self.sum_ = value
            self.count = weight
        else:
            self.sum_ += value
            self.count += weight
        self.value = None

    def __call__(self):
        if self.value is None:
            self.value = self.sum_ / self.count
        return self.value
        
# set the evaluation mode
grad_cam.eval()
grad_cam.cuda(1)

epochs_ = 1
max_examples = 3

accuracy_avg = AVG()
gradients_avg = AVG()


generator = train_generator_bool
for epoch_ in range(epochs_):
    print('Generator has', len(generator), 'examples in this epoch.')
    for batch_idx in tqdm(range(min(len(generator), max_examples))):

        input_, target = generator[batch_idx + 200]
        input_ = input_.type(torch.float)

        # transform to a 3 channel image shape expected by pyplot
        img = input_.squeeze().permute([1,2,0]).numpy()
        
        input_ = input_.cuda(1)
        pred = grad_cam(input_)

        # check the model decision
        is_cancer = pred > 0.5
        is_cancer = is_cancer.cpu()
        
        if is_cancer == target:
            accuracy_avg.record(1)
            print("Correctly predicted:", is_cancer)
        else:
            accuracy_avg.record(0)
            print("Wrongly predicted:", is_cancer)
        
        
        # get the gradient of the output with respect to the parameters of the model
        pred.backward()


        with torch.no_grad():
            # pull the gradients out of the model
            gradients = grad_cam.get_activations_gradient().detach().cpu()

            # pool the gradients across the feature maps and batch
            #pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
            pooled_gradients = torch.amax(gradients, dim=(0, 2, 3))
            
            
            #get the sorted indices of the gradients flowing back
            sorted_gradient_indices = torch.argsort(pooled_gradients, dim=-1, descending=True).numpy()
            

            # get the activations of the last convolutional layer
            activations = grad_cam.get_activations(input_).detach().cpu().squeeze(0)
            
            # weight the channels by corresponding gradients through broadcasting multiplication
            weighted_activations = (activations * pooled_gradients.unsqueeze(1).unsqueeze(2)).numpy()
            
            # select some interesting weighted activations
            selection = sorted_gradient_indices[:20]
            heatmaps = weighted_activations[selection, :, :]

            # relu on top of the heatmap, expression (2) in https://arxiv.org/pdf/1610.02391.pdf
            heatmaps = np.maximum(heatmaps, 0)

            # normalize
            row_sums = heatmaps.max(axis=(1, 2), keepdims=True)
            nonzero_idx = (row_sums != 0).squeeze()
            heatmaps[nonzero_idx] /= row_sums[nonzero_idx]

            heatmaps_titles = [f'Scaled by {1/factor}' if factor > 0 else 'Empty' for factor in row_sums]
            
            with plt.ioff():
                plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_linc_{batch_idx}.jpg', strategy='linear_combination', heatmaps_titles=heatmaps_titles)
                plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_outl_{batch_idx}.jpg', strategy='outline', heatmaps_titles=heatmaps_titles)
                plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_ceil_{batch_idx}.jpg', strategy='ceil', heatmaps_titles=heatmaps_titles)
                # plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_subt_{batch_idx}.png', strategy='sub')
                

    print(accuracy_avg())
    

# poslat jednotlivé feature mapy s největším rankem vedle sebe
# vytvořil překryv jednotlivých vrstev pro kontrolu

In [None]:
import seaborn as sns

grads_for_boxplots = torch.stack(grad_list_fro_boxplot, 0)
df = pd.DataFrame(grads_for_boxplots.numpy())
sorted_index = df.median().sort_values().index

f = plt.figure()
f.set_figwidth(128)
f.set_figheight(10)

sns.boxplot(data=df)
plt.ylabel("Average gradient size", size=18)
axes = plt.gca()
axes.yaxis.grid()
plt.savefig("Gradients_per_example_sorted_30.jpg")

In [None]:
import numpy as np

nums = np.uint8(np.asarray([50, -25, -7, 50, 50])) - np.asarray([0.6863, 0.2846, 1.987,0,0]) 
nums.dtype

In [None]:
ten = torch.Tensor([
    [[[2, 1, 2],
     [3, 1, 2],
     [4, 1, 2]],
     [[5, 1, 2],
     [6, 1, 2],
     [7, 1, 2]]]
    ])
print(ten.size())
#pls = torch.mean(ten, dim=[0, 2, 3])
pls = torch.amax(ten, (0,2,3))
pls

In [None]:
import pandas as pd

def load_h5_store_pandas(file_path: str):
    store = pd.HDFStore(file_path, mode="r")
    return store

  
hdfs = load_h5_store_pandas('/mnt/data/home/bajger/NN_pruning/histopat/experiment_output/transfer_learning/predictions.h5')
hdfs2 = load_h5_store_pandas('/mnt/data/home/bajger/NN_pruning/histopat/datasets/hdfs_output/hdfs_output.h5')
i = 0


In [None]:
for table_name in hdfs.keys():
    print(hdfs2[table_name])

