Necessary imports for the whole notebook

In [None]:
#%pip install torch 
#%pip install numpy pandas
#%pip install matplotlib tqdm
#%pip install opencv-python scikit-image
%ls

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt
import cv2

import numpy as np
import pandas as pd
from scipy.special import expit

from tqdm.notebook import tqdm, trange

#from numba import jit

## Load the configurable components

First open the config file with all the necessary configurable components described.
Contains mosty data generators as other componentes still need to be rewritten.

In [None]:
import json

with open('prunning_teacher_student/prunning_ts_config.json') as ts_config_file:
    ts_config = json.load(ts_config_file)

load the basic data generators 

In [None]:
from rationai.utils.config import build_from_config, parse_configs_recursively


datagen_bool = build_from_config(parse_configs_recursively("datagen_bool", cfg_store=ts_config["named_configs"]))
generators_dict_bool = datagen_bool.build_from_template()

train_generator_bool = generators_dict_bool['train_gen']
valid_generator_bool = generators_dict_bool['valid_gen']

#datagen_ts = build_from_config(parse_configs_recursively("datagen_ts", cfg_store=ts_config["named_configs"]))
#generators_dict_ts = datagen_ts.build_from_template()

#train_generator_ts = generators_dict_ts['train_gen']
#valid_generator_ts = generators_dict_ts['valid_gen']


In [None]:
batch_size = 1
train_generator_bool.set_batch_size(batch_size)
valid_generator_bool.set_batch_size(batch_size)
# train_generator_ts.set_batch_size(batch_size)
# valid_generator_ts.set_batch_size(batch_size)

In [None]:
from rationai.utils.config import build_from_config, parse_configs_recursively

datagen_seq = build_from_config(parse_configs_recursively("datagen_seq", cfg_store=ts_config["named_configs"]))
generators_dict_seq = datagen_seq.build_from_template()
seq_generator = generators_dict_seq['just_gen_seq']
seq_generator.set_batch_size(1)
for i in range(4):
    seq_generator.sampler.next()
seq_generator.on_epoch_end()
print(len(seq_generator))  # 874
# indexes around 59 contain the critical segment
special_input_, special_output_ = seq_generator[59]
special_img = special_input_.squeeze().permute([1,2,0]).numpy()


## Modules setup

Let's import a GradCam module, defined in a separate file (for simplicity). The module should be able to be used with an arbitrary network topology that contains convolutional layers.
The idea of gradcam is that it looks into a model at the particular layer and weights the featrue map activations of the particular layer by the gradients flowing back from the following layers. Averaged together these weighted activation maps create a visual representation of what the network pays attention to.



In [None]:
from prunning_gradcam.grad_cam import GradCam

To make it easier, the GradCam module is written in such a way that it takes two modules on initialisation, which represent the two parts of the already dissected model. The model dissection has to be performed manually,
though I have provided a method for dissecting any torch.nn.Sequential model on an arbitrary layer index.


In [None]:
from prunning_gradcam.models import SequentialVGG16
from prunning_gradcam.grad_cam_tools import _load_params, dissect_sequential_model

import logging as log

# initialize the VGG model
vgg = SequentialVGG16()

# load weights from a file
_load_params(vgg, source_state_dict='transplanted-model.chkpt')

# dissect model into two parts
first_part, remaining_part = dissect_sequential_model(vgg, 29)
remaining_part_without_sigmoid, _ = dissect_sequential_model(remaining_part, -1)

grad_cam = GradCam(first_part, remaining_part)


## Math

In the original paper, authors calculate average weighted activation in a particular layer. The neuron importance factor $\alpha_k^c$ is obtained as an average gradient of the class activation $y^c$ (before the last activation function is applied) and each activation map $A^k \in \mathbb{R}^{u\times v} $ is weighted by it.

$$
\alpha_k^c = \frac{1}{Z}\sum_{i=0}^{u}\sum_{j=0}^{v} \frac{\partial \xi^c}{\partial A^k_{ij}}
$$
$$
L^c_{Grad-CAM}=ReLU(\sum_{k}\alpha^c_kA^k)
$$

$$
\frac{\partial \xi^c_i}{\partial A_{ki}} = \frac{\partial \xi^c_i}{\partial y^k_i}  \frac{\partial y^k_i}{\partial A_{ki}}
$$

$$
\frac{\partial }{\partial }
$$

Since our VGG16 uses global max pooling (GMP), the importance coefficient $\alpha_k^c$ will result in the same ranking up to a constant factor $k=uv$ for average pooling and for maxpooling, allowing us to also calculate

$$
\alpha_k^c = max_{ij}\frac{\partial y^c}{\partial A^k_{ij}}
$$

which is a little faster.



## Visualisation tools

For the purpose of visualizing the GradCAM outputs we have to have a way of plotting the heatmaps, overlaying them over the original images or somehow visualizing the activation regions.

I have coded a function that can overlay several heatmaps over a single image in one step, next to each other in a grid. 
The last column in the grid shows all the heatmaps over each other clipped.

The code is split into multiple helper functions for modularity.

In [None]:
from typing import List, Tuple

from skimage.morphology import dilation, disk


def _COLORMAP_JETFROMBLACK():
    """Precomputes a lookup table for a custom colormap

    Returns:
        numpy.ndarray: A numpy array of shape (256,1,3) or something like that.
    """
    lut = np.asarray([[np.ones(3)*i for _ in range(1)] for i in range(256)], dtype=np.uint8)
    lut = cv2.applyColorMap(lut, cv2.COLORMAP_JET)
    #lut = cv2.cvtColor(lut, cv2.COLOR_BGR2RGB)
    beginning = lut[:64, 0].astype(np.float64)
    beginning *= np.stack([np.arange(64) * (1/64) for _ in range(3)], axis=1)
    lut[:64, 0] = beginning.astype(np.uint8)
    plt.imshow(np.swapaxes(np.repeat(lut, 10, axis=1), 0, 1))
    plt.title("JETFROMBLACK")
    plt.show()
    return lut

def _COLORMAP_TWILIGHTSYMETRIC():
    """Precomputes a lookup table for a custom colormap

    Returns:
        numpy.ndarray: A numpy array of shape (256,1,3) or something like that.
    """
    lut = np.asarray([[np.ones(3)*i for _ in range(1)] for i in range(256)], dtype=np.uint8)
    lut = cv2.applyColorMap(lut, cv2.COLORMAP_TWILIGHT)
    #lut = cv2.cvtColor(lut, cv2.COLOR_BGR2RGB)
    beginning = lut[:64, 0].astype(np.float64)
    beginning *= np.stack([np.ones(64), np.arange(64) * (1/64), np.arange(64) * (1/64)], axis=1)
    lut[:64, 0] = beginning.astype(np.uint8)

    middle = lut[96:160, 0].astype(np.float64)
    middle *= np.stack([np.sqrt(np.absolute(np.linspace(-1, 1, 64))) for _ in range(3)], axis=1)
    lut[96:160, 0] = middle.astype(np.uint8)

    #print(lut[127:130])
    plt.imshow(np.swapaxes(np.repeat(lut, 10, axis=1), 0, 1))
    plt.title("TWILIGHTSYMETRIC")
    plt.show()
    return lut

# return a colormap containing gradient from green to red with zeros in the middle
def _COLORMAP_GRADIENT_BIPOLAR_GREEN_YELLOW():
    lut_red = np.asarray([[np.asarray([0, 0, i]) for _ in range(1)] for i in range(0, 256, 2)], dtype=np.uint8)
    lut_yellow = np.asarray([[np.asarray([0, i, 0]) for _ in range(1)] for i in range(0, 256, 2)], dtype=np.uint8)
    lut = np.concatenate((lut_red[::-1], lut_yellow), axis=0)
    plt.imshow(np.swapaxes(np.repeat(lut, 10, axis=1), 0, 1))
    plt.title("GRADIENT_BIPOLAR_GREEN_RED")
    plt.show()
    return lut



# precomputed stuff like color LUTs and so on
_COLORMAP_JETFROMBLACK = _COLORMAP_JETFROMBLACK()
_COLORMAP_TWILIGHTSYMETRIC = _COLORMAP_TWILIGHTSYMETRIC()
_COLORMAP_GRADIENT_BIPOLAR_GREEN_RED = _COLORMAP_GRADIENT_BIPOLAR_GREEN_YELLOW()


def scaleshift(img: np.ndarray, shift:float=None, scale:float=None):
    """Function to scale and shift real values in a tensor.
    By default scaleshifts by min and max to obtain values in the range [0..1]
    Creates a copy of the scaled tensor and returns it.

    Args:
        img (np.ndarray): Scaled tensor
        shift (float, optional): This value is added to each element of the tensor prior to scaling. Defaults to None.
        scale (float, optional): All values are multiplied by this value after shifting. Defaults to None.

    Returns:
        _type_: New tensor with shifted and scaled values.
    """
    if shift is None:
        shift = -img.min()
    if scale is None:
        scale = 1/img_max

    # create a new picture tensor copy by not using inplace addition
    img = img + shift  
    img_max = img.max()
    if img_max != 0:
        img *= scale
    return img


def resize_and_color(bitmap: np.ndarray, dims:Tuple[int], colormap_lut: np.ndarray=_COLORMAP_JETFROMBLACK) -> np.ndarray:
    overlay = cv2.resize(bitmap, dims)
    overlay = np.uint8(255 * overlay)
    overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB)
    overlay = cv2.LUT(overlay, colormap_lut)
    return overlay


def superimpose(image: np.ndarray, overlay: np.ndarray, strategy:str='ceil'):
    if strategy=='ceil':
        res = np.minimum(image + overlay*0.375, 255)
    elif strategy=='sub':
        res = np.maximum(image - overlay*0.5, 0)
    elif strategy=='linear_combination':
        res = image*0.5 + overlay*0.5
    elif strategy=='outline':
        footprint = disk(3)
        bool_map = overlay.sum(axis=2) > 0
        dilated = dilation(bool_map, footprint)
        outline = dilated & ~bool_map
        res = image.copy()
        # fill the outline with specific color
        res[outline] = [0, 0, 0]
    elif strategy=='black_as_alpha':
        # use custom alpha blending to combine the two images
        alpha = np.divide(np.max(overlay, axis=2, keepdims=True), 255.)
        res = image*(1-alpha) + overlay*alpha
    else:
        raise NotImplementedError(f'There is no strategy with name {strategy}!')

    return res.astype(np.uint8)




def plot_bitmap_overlays(bitmaps: List[np.ndarray], base_image: np.ndarray, save_file:str=None, colormap_lut: np.ndarray=_COLORMAP_GRADIENT_BIPOLAR_GREEN_RED, strategy='black_as_alpha', heatmaps_titles=None):
    """
    Creates a grid plot from a list of overlays and a base image.
    Expects the overlay arrays in the list to be of shape [W1, H1] and the image of shape [3, W2, H2],
    where W1, H1, W2, H2 are widths and heights, not required to be of same size. Overlays are stretched to the image size.
    
    There are several rows plotted, leftmost column contain teh base images, rightmost column contains all overlays.
    Images in in columns in between represent each overlay separately

    Args:
        bitmaps (List[np.ndarray]): The bitmaps that are going to be transformed into overlays
        base_image (np.ndarray): The underlying image
        save_file (str, optional): If there is a string present, the figure is saved into a file. Defaults to None.
    """
    
    fig, ax = plt.subplots(nrows=2, ncols=len(bitmaps)+1, sharex=True, sharey=True, figsize=(10*(len(bitmaps)+1), 10*2))
   
    # mark the row indices for flexibility
    overlay_row = 0
    superimposed_row = 1

    # in the first column, just print the original images. The upper is clipped, the lower is rescaled, both to [0..1]
    ax[overlay_row,0].imshow(base_image)
    ax[overlay_row,0].title.set_text('Base image')
    #base_image = scaleshift_to_unit_range(base_image)
    colorscale = np.swapaxes(np.repeat(colormap_lut, 10, axis=1), 0, 1)
    #ax[superimposed_row][0].imshow(cv2.resize(colorscale, base_image.shape[:2]))
    #ax[superimposed_row][0].title.set_text('Heatmap overlay colorscale')

    # extract the ahred tick Grouper objcts from the axes and remove the first column axes
    axLB = ax[superimposed_row,0]
    axLB.get_shared_x_axes().remove(axLB)
    axLB.get_shared_y_axes().remove(axLB)
    #axLB.clear()

    # Create and assign new ticker
    xticker = matplotlib.axis.Ticker()
    yticker = matplotlib.axis.Ticker()
    axLB.xaxis.major = xticker
    axLB.yaxis.major = yticker

    # The new ticker needs new locator and formatters
    xloc = matplotlib.ticker.AutoLocator()
    xfmt = matplotlib.ticker.StrMethodFormatter('{x:,.2f}')

    yloc = matplotlib.ticker.AutoLocator()
    yfmt = matplotlib.ticker.ScalarFormatter()

    # Assign the locators and formatters to the axes
    axLB.xaxis.set_major_locator(xloc)
    axLB.xaxis.set_major_formatter(xfmt)
    axLB.yaxis.set_major_locator(yloc)
    axLB.yaxis.set_major_formatter(yfmt)

    axLB.imshow(cv2.resize(colorscale, base_image.shape[:2]))
    axLB.title.set_text('Heatmap overlay colorscale')

    # set the new x-axis tick labels to the logit values of the colormap
    axLB.set_xticks(np.linspace(0, 512, 10, endpoint=False, dtype=np.int32))
    axLB.set_xticklabels(expit(np.linspace(-1, 1, 10, endpoint=False)))

    #set yticks empty
    axLB.set_yticks([])

    
    base_image_255 = (base_image*255)
    
    # for each bitmaps get the overlays and plot them
    for column, bitmap in enumerate(iterable=bitmaps, start=1):
        overlay = resize_and_color(bitmap, base_image.shape[:2], colormap_lut=colormap_lut)
        ax[overlay_row][column].imshow(overlay)  # show the separate overlay on the first row
        if heatmaps_titles is not None:
            ax[overlay_row][column].title.set_text(heatmaps_titles[column-1])

        # impose the overlay on top of the base image
        ax[superimposed_row][column].imshow(superimpose(base_image_255, overlay, strategy))  # show the superimposed image on the second row
        
    # # get the total averaged overlay from all the bitmaps
    # overlay = resize_and_color(np.minimum(sum(bitmaps), 1.0), base_image.shape[:2], colormap_lut=colormap_lut)
    # ax[overlay_row][-1].imshow(overlay)  # show the separate overlay on the first row
    # ax[overlay_row][-1].title.set_text('All overlays stacked and clipped to [0..1]')

    # # impose the overlay on top of the base image
    # ax[superimposed_row][-1].imshow(superimpose(base_image_255, overlay, strategy))  # show the superimposed image on the second row

    if save_file is not None:
        fig.savefig(save_file, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()
    




In [None]:
from typing import Union


# average value counter
class AVG:
    """This counter should work for scalars and for torch.Tensors (maybe even for numpy.ndarrays)
    """
    sum_: Union[float, torch.Tensor]
    count: Union[float, torch.Tensor]
    value: Union[float, torch.Tensor]
    def __init__(self, sum_=None, count=None):
        self.sum_ = sum_
        self.count = count
        self.value = None
    
    def record(self, value, weight=1):
        if self.sum_ is None:
            self.sum_ = value
            self.count = weight
        else:
            self.sum_ += value
            self.count += weight
        self.value = None

    def __call__(self):
        if self.value is None:
            self.value = self.sum_ / self.count
        return self.value
        
device_ = 0
# set the evaluation mode
grad_cam.eval()
grad_cam.cuda(device_)

epochs_ = 1
max_examples = 3

images_list = []
activations_list = []
gradients_list = []
predictions_list = []


generator = seq_generator
for epoch_ in range(epochs_):
    print('Generator has', len(generator), 'examples in this epoch.')
    for batch_idx in tqdm(range(min(len(generator), max_examples))):

        input_, target = generator[batch_idx + 58]
        input_ = input_.type(torch.float)

        # transform to a 3 channel image shape expected by pyplot
        img = (input_.squeeze().permute([1,2,0]).numpy() + 1.) / 2.
        images_list.append(img)
        
        input_ = input_.cuda(device_)
        pred = grad_cam(input_)

        # check the model decision
        is_cancer = pred > 0.5
        is_cancer = is_cancer.cpu()
        
        if is_cancer == target: 
            _t = 'T'
        else:
            _t = 'F'
        
        _t += 'P' if is_cancer else 'N'
        predictions_list.append(_t)
        
        # get the gradient of the output with respect to the parameters of the model
        pred.backward()


        with torch.no_grad():
            # pull the gradients out of the model
            gradients = grad_cam.get_activations_gradient().cpu().squeeze(0)
            gradients_list.append(gradients)

            # get the activations of the last convolutional layer
            activations = grad_cam.get_activations().cpu().squeeze(0)
            activations_list.append(activations)

            # # pool the gradients across the feature maps and batch
            # pooled_gradients = torch.mean(gradients, dim=[2, 3])
            # #pooled_gradients = torch.amax(gradients, dim=(0, 2, 3))
            
            
            # # get the sorted indices of the gradients flowing back
            # sorted_gradient_indices = torch.argsort(pooled_gradients, dim=-1, descending=True).numpy()
            
            # # weight the channels by corresponding gradients through broadcasting multiplication
            # weighted_activations = (activations * pooled_gradients.unsqueeze(1).unsqueeze(2)).numpy()
            # activations = activations.numpy()

            # # select some interesting weighted activations
            # selection = [*sorted_gradient_indices[:10], *sorted_gradient_indices[-10:]]
            # heatmaps = activations[selection, :, :]

            # # relu on top of the heatmap, expression (2) in https://arxiv.org/pdf/1610.02391.pdf
            # #heatmaps = np.maximum(heatmaps, 0)
            
            # # row_mins = heatmaps.min(axis=(1, 2), keepdims=True)

            # # # normalize all at once
            # # row_maxs = heatmaps.max(axis=(1, 2), keepdims=True)
            
            # # nonzero_idx = (row_maxs != 0).squeeze()
            # # heatmaps[nonzero_idx] /= row_maxs[nonzero_idx]  # can't avoid division here

            # # heatmaps_titles = [f'Scaled by {1/row_maxs[i].item():.2f}, rank {pooled_gradients[selection[i]]:.4E}' if row_maxs[i] > 0 else f'No positive act., rank {pooled_gradients[selection[i]]:.4E}' for i in range(len(selection))]
            
            # print(heatmaps.shape)

            # heatmaps_titles = [f'Transformed with logit, rank {pooled_gradients[selection[i]]:.4E}' for i in range(len(selection))]
            # heatmaps = expit(heatmaps)
            
            # # heatmaps = np.append(heatmaps, np.max(np.sum(weighted_activations[selection]), 0) * (1/weighted_activations[selection].max()))
            # # heatmaps_titles.append('Original GradCAM')

            
            
            # with plt.ioff():
            #     plot_bitmap_overlays(heatmaps, img, 
            #         save_file=f'grad_cam_fmwise_special_blaa_{batch_idx}.jpg',
            #         colormap_lut=_COLORMAP_GRADIENT_BIPOLAR_GREEN_RED,
            #         strategy='black_as_alpha',
            #         heatmaps_titles=heatmaps_titles)
            #     #plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_special_outl_{batch_idx}.jpg', strategy='outline', heatmaps_titles=heatmaps_titles)
                
                
            #     #plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_ceil_{batch_idx}.jpg', strategy='ceil', heatmaps_titles=heatmaps_titles)
            #     # plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_subt_{batch_idx}.png', strategy='sub')
                

In [None]:
with torch.no_grad():
    for i in trange(len(activations_list)):
        img = images_list[i]
        activations = activations_list[i]
        gradients = gradients_list[i]
        pred_checked = predictions_list[i]

        print(activations.shape, gradients.shape, pred_checked)
        
        # pool the gradients across the feature maps and batch
        pooled_gradients = torch.mean(gradients, dim=[1, 2])
        #pooled_gradients = torch.amax(gradients, dim=(0, 2, 3))
        
        
        # get the sorted indices of the gradients flowing back
        sorted_gradient_indices = torch.argsort(pooled_gradients, dim=-1, descending=True).numpy()
        
        
        # weight the channels by corresponding gradients through broadcasting multiplication
        weighted_activations = (activations * pooled_gradients.unsqueeze(1).unsqueeze(2)).numpy()
        activations = activations.numpy()

        # select some interesting weighted activations
        selection = [*sorted_gradient_indices[:10], *sorted_gradient_indices[-10:]]
        heatmaps = activations[selection, :, :]

        # relu on top of the heatmap, expression (2) in https://arxiv.org/pdf/1610.02391.pdf
        #heatmaps = np.maximum(heatmaps, 0)
        
        # row_mins = heatmaps.min(axis=(1, 2), keepdims=True)

        # # normalize all at once
        # row_maxs = heatmaps.max(axis=(1, 2), keepdims=True)
        
        # nonzero_idx = (row_maxs != 0).squeeze()
        # heatmaps[nonzero_idx] /= row_maxs[nonzero_idx]  # can't avoid division here

        # heatmaps_titles = [f'Scaled by {1/row_maxs[i].item():.2f}, rank {pooled_gradients[selection[i]]:.4E}' if row_maxs[i] > 0 else f'No positive act., rank {pooled_gradients[selection[i]]:.4E}' for i in range(len(selection))]
        
        print(heatmaps.shape)

        heatmaps_titles = [f'Transformed with logit, rank {pooled_gradients[selection[i]]:.4E}' for i in range(len(selection))]
        heatmaps = expit(heatmaps)
        
        # append original gradcam (weighted sum of activations)
        relued_gradcam = np.maximum(np.sum(weighted_activations[selection], axis=0, keepdims=True), 0.)
        normalized_gradcam = relued_gradcam * 1/relued_gradcam.max()
        heatmaps = np.append(heatmaps, normalized_gradcam, axis=0)
        
        heatmaps_titles.append('Original GradCAM')

        
        
        with plt.ioff():
            plot_bitmap_overlays(heatmaps, img, 
                save_file=f'grad_cam_fmwise_special_blaa_{i}.jpg',
                colormap_lut=_COLORMAP_GRADIENT_BIPOLAR_GREEN_RED,
                strategy='black_as_alpha',
                heatmaps_titles=heatmaps_titles)
            #plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_special_outl_{batch_idx}.jpg', strategy='outline', heatmaps_titles=heatmaps_titles)
            
            
            #plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_ceil_{batch_idx}.jpg', strategy='ceil', heatmaps_titles=heatmaps_titles)
            # plot_bitmap_overlays(heatmaps, img, f'grad_cam_fmwise_subt_{batch_idx}.png', strategy='sub')
            




In [None]:
from prunning_gradcam.models import GMaxPool2d

from prunning_gradcam.grad_cam_tools import _load_params, dissect_sequential_model

inp = torch.as_tensor([[
    [
        [0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0],
        [0, 1, 1, 1, 0],
        [0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0],
    ]
]], dtype=torch.float32)
inp2 = torch.as_tensor([[
    [
        [0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],
    ]
]], dtype=torch.float32)
print(inp.size())
state_dict_ = {
    '0.weight': torch.as_tensor([[
        [[1, 0, 1],
         [0, 0, 0],
         [1, 0, 1]]
    ],[
        [[0, 0, 0],
         [1, 1, 1],
         [0, 0, 0]]
    ],[
        [[0, 1, 0],
         [0, 1, 0],
         [0, 1, 0]]
    ]], dtype=torch.float32),
    '0.bias': torch.as_tensor([ -3, -2, -2], dtype=torch.float32),
    # '3.weight': torch.as_tensor([[
    #     [[1, 1],
    #      [1, 1]],
    #     [[1, 1],
    #      [1, 1]],
    #     [[1, 1],
    #      [1, 1]]
    # ],[
    #     [[0, 0],
    #      [1, 1]],
    #     [[0, 0],
    #      [1, 1]],
    #     [[0, 0],
    #      [1, 1]]
    # ],[
    #     [[1, 1],
    #      [0, 0]],
    #     [[1, 1],
    #      [0, 0]],
    #     [[1, 1],
    #      [0, 0]]
    # ]], dtype=torch.float32),
    # '3.bias' : torch.as_tensor([0, 0, 0]),
    '3.weight': torch.as_tensor([[1, 1,1]], dtype=torch.float32),
    '3.bias': torch.as_tensor([1], dtype=torch.float32)
}

test_model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    # nn.MaxPool2d(kernel_size=2, stride=2),
    # nn.Conv2d(in_channels=3, out_channels=3, kernel_size=2, padding=0),
    # nn.ReLU(inplace=True),
    GMaxPool2d(),
    nn.Linear(in_features=3, out_features=1),
    nn.Sigmoid()
)

path_ = 'temp_test_state_dict'
torch.save(state_dict_, path_)

_load_params(test_model, path_)

first, after = dissect_sequential_model(test_model, 1)
after_no_sig, sigm = dissect_sequential_model(after, -1)
 
gmodel = GradCam(first, after_no_sig)
pred = gmodel(inp)
print('PRED', pred)
img = inp.squeeze(0).permute([1,2,0]).numpy()
        
pred.backward()
with torch.no_grad():
    # pull the gradients out of the model
    gradients = gmodel.get_activations_gradient().detach().cpu()
    print("GRADS", gradients)

    # pool the gradients across the feature maps and batch
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    #pooled_gradients = torch.amax(gradients, dim=(0, 2, 3))
    
    
    #get the sorted indices of the gradients flowing back
    sorted_gradient_indices = torch.argsort(pooled_gradients, dim=-1, descending=True).numpy()
    

    # get the activations of the last convolutional layer
    activations = gmodel.get_activations().detach().cpu().squeeze(0)
    print('ACTS', activations)
    
    # weight the channels by corresponding gradients through broadcasting multiplication
    weighted_activations = activations.numpy() # (activations * pooled_gradients.unsqueeze(1).unsqueeze(2)).numpy()
    
    # select some interesting weighted activations
    selection = sorted_gradient_indices[:20]
    heatmaps = weighted_activations[selection, :, :]

    # relu on top of the heatmap, expression (2) in https://arxiv.org/pdf/1610.02391.pdf
    heatmaps = np.maximum(heatmaps, 0)

    # normalize
    row_sums = heatmaps.max(axis=(1, 2), keepdims=True)
    nonzero_idx = (row_sums != 0).squeeze()
    heatmaps[nonzero_idx] /= row_sums[nonzero_idx]

    heatmaps_titles = [f'Scaled by {1/factor}' if factor > 0 else 'Empty' for factor in row_sums]
    
    with plt.ioff():
        plot_bitmap_overlays(heatmaps, img, save_file="grad_cam_test.jpg", strategy='linear_combination', heatmaps_titles=heatmaps_titles)
                
#print(inp, output)

In [None]:
import seaborn as sns

grads_for_boxplots = torch.stack(grad_list_fro_boxplot, 0)
df = pd.DataFrame(grads_for_boxplots.numpy())
sorted_index = df.median().sort_values().index

f = plt.figure()
f.set_figwidth(128)
f.set_figheight(10)

sns.boxplot(data=df)
plt.ylabel("Average gradient size", size=18)
axes = plt.gca()
axes.yaxis.grid()
plt.savefig("Gradients_per_example_sorted_30.jpg")

In [None]:
pooled_gradients = torch.as_tensor([0, 5, 2, 3, 4])
#pooled_gradients = torch.amax(gradients, dim=(0, 2, 3))


#get the sorted indices of the gradients flowing back
sorted_gradient_indices = torch.argsort(pooled_gradients, dim=-1, descending=True).numpy()

print(pooled_gradients[sorted_gradient_indices])

In [None]:
ten = torch.Tensor([
    [[[2, 1, 2],
     [3, 1, 2],
     [4, 1, 2]],
     [[5, 1, 2],
     [6, 1, 2],
     [7, 1, 2]]]
    ])
print(ten.size())
#pls = torch.mean(ten, dim=[0, 2, 3])
pls = torch.amax(ten, (0,2,3))
pls

In [None]:
import pandas as pd

def load_h5_store_pandas(file_path: str):
    store = pd.HDFStore(file_path, mode="r")
    return store

  
#hdfs = load_h5_store_pandas('/mnt/data/home/bajger/NN_pruning/histopat/experiment_output/transfer_learning/predictions.h5')
hdfs2 = load_h5_store_pandas('/mnt/data/home/bajger/NN_pruning/histopat/datasets/hdfs_output/hdfs_output.h5')
#hdfs3 = load_h5_store_pandas('/mnt/data/crc_ml/data/processed/Prostata/level1/r512px/c256px/t512px/no_overlap/datasets/1602530237.h5')



In [None]:
#print(hdfs3)
print(len(hdfs2))
index_ = 0
for table_name in hdfs2.keys():
    print(f'{index_}({len(hdfs2[table_name])})', end=' ')
    
    if table_name == '/test/TP-2019_2824-01-1':
        print(len(hdfs2[table_name]), end=' ')
        print(f'The dataset is on the {index_}th position.')
        break
    index_ += 1 #len(hdfs2[table_name])
    #print(hdfs3[table_name])

