In [1]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import copy
import torch.nn.functional as F
from einops import rearrange
import os
os.chdir('project/learning_causal_discovery/')
from run.shallowmind.api.infer import prepare_inference
from pytorch_grad_cam import GradCAM


plt.style.use('bmh')

In [2]:
root_dir = 'work_dir/'

In [3]:
# helper function
def normalization(x, eps=1e-6):
        x = abs(x)
        # rescale operations to ensure gradients lie between 0 and 1
        flatin = x.reshape((x.size(0),-1))
        temp, _ = flatin.min(1, keepdim=True)
        x = x - temp.unsqueeze(1)

        flatin = x.reshape((x.size(0),-1))
        temp, _ = flatin.max(1, keepdim=True)
        x = x / (temp.unsqueeze(1) + eps)
        return x

def get_input_gradient(model, objective, x):
    model.zero_grad()
    # gradients w.r.t. input
    input_gradients = torch.autograd.grad(outputs=objective, inputs=x)[0]
    return input_gradients

def get_gradient_saliency(model, sample):
    x, y = copy.deepcopy(sample)
    x = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in x.items()}
    x['seq'].requires_grad_()
    model.eval()
    prob = model(x)
    objective = -1. * F.nll_loss(prob, y.cuda().flatten(), reduction='sum')
    input_gradient = get_input_gradient(model, objective, x['seq'])
    # Input-gradient * image
    # input_gradient = input_gradient * x['seq']
    input_gradient = rearrange(input_gradient, 'b l c -> b c l')
    gradient = normalization(input_gradient).cpu()
    gradient = rearrange(gradient, 'b c l -> b l c')
    return {'seq':x['seq'], 'confidence': prob.softmax(dim=-1)[0, 1], 'gt': y, 'gradient': gradient.squeeze().reshape(-1, 2)}

class GradCAM1d(GradCAM):
    def get_cam_weights(self, input_tensor, target_layer, target_category, activations, grads):
        return np.mean(grads, axis=1)

    def scale_cam_image(self, cam, target_size=None):
        result = []
        for img in cam:
            img = img - np.min(img)
            img = img / (1e-7 + np.max(img))
            if target_size is not None:
                img = np.interp(np.linspace(0, target_size, target_size), np.linspace(0, target_size, len(img)), img)
            result.append(img)
        result = np.float32(result)

        return result

    def get_target_length(self, input_tensor):
        length = input_tensor.size(-2)
        return length

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      targets,
                      activations,
                      grads,
                      eigen_smooth=False):

        weights = self.get_cam_weights(input_tensor,
                                       target_layer,
                                       targets,
                                       activations,
                                       grads)
        weighted_activations = weights[:, :, None] * np.transpose(activations, (0, 2, 1))
        cam = np.transpose(weighted_activations, (0, 2, 1)).sum(axis=-1)
        return cam

    def compute_cam_per_layer(
            self,
            input_tensor,
            targets,
            eigen_smooth):
        activations_list = [a.cpu().data.numpy()
                            for a in self.activations_and_grads.activations]
        grads_list = [g.cpu().data.numpy()
                      for g in self.activations_and_grads.gradients]
        target_size = self.get_target_length(input_tensor)

        cam_per_target_layer = []
        # Loop over the saliency image from every layer
        for i in range(len(self.target_layers)):
            target_layer = self.target_layers[i]
            layer_activations = None
            layer_grads = None
            if i < len(activations_list):
                layer_activations = activations_list[i]
            if i < len(grads_list):
                layer_grads = grads_list[i]

            cam = self.get_cam_image(input_tensor,
                                     target_layer,
                                     targets,
                                     layer_activations,
                                     layer_grads,
                                     eigen_smooth)
            cam = np.maximum(cam, 0)
            scaled = self.scale_cam_image(cam, target_size)
            cam_per_target_layer.append(scaled[:, None, :])

        return cam_per_target_layer

def get_cam_plots(target_idxs, mi, dl, axes=None, shift=False, default_shift='backward', save_dir=None):
    target_layers = [mi.model.backbone.trm.layers[3].norm1]
    cam = GradCAM1d(model=mi, target_layers=target_layers, use_cuda=True)

    # targets = [ClassifierOutputTarget(1)]

    # grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    # print(grayscale_cam.shape)
    if axes is None:
        fig, axes = plt.subplots(round(len(target_idxs)/3), 3, figsize=(24, round(len(target_idxs)/3)*4))

    cls = ['#DE1334', '#6752FF', '#0B2735']
    for i, ax in enumerate(axes.flatten()):
        if shift:
            if isinstance(dl.dataset[idx][0]['seq'], np.ndarray):
                input_tensor = torch.tensor(dl.dataset[i][0]['seq']).unsqueeze(0).cuda()
            else:
                input_tensor = dl.dataset[target_idxs[i]][0]['seq'].clone().unsqueeze(0).cuda()
            if default_shift == 'forward':
                input_tensor[:, :, 1] = torch.cat((input_tensor[0, 0, 1].reshape(1, 1, 1).tile(1, 200, 1), input_tensor[:, :-200, 1].unsqueeze(-1)), dim=1).squeeze()
            elif default_shift == 'backward':
                input_tensor[:, :, 1] = torch.cat((input_tensor[:, 200:, 1].unsqueeze(-1), input_tensor[0, -1, 1].reshape(1, 1, 1).tile(1, 200, 1)), dim=1).squeeze()
            else:
                raise ValueError('default_shift must be either forward or backward')
        else:
            input_tensor =  torch.tensor(dl.dataset[target_idxs[i]][0]['seq']).unsqueeze(0).cuda()
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        ax.plot(dl.dataset[target_idxs[i]][0]['seq'][::100, 0], c=cls[0], linestyle='-', label='cause')
        ax.plot(dl.dataset[target_idxs[i]][0]['seq'][::100, 1], c=cls[1], linestyle='--', label='result')
        # if i < len(target_idxs)//2:
        #     ax.plot(dl.dataset[target_idxs[i]][0]['seq'][:, 1], c=cls[1], label='result')
        # else:
        #     ax.plot(dl.dataset[target_idxs[i]][0]['seq'][:, 1], c=cls[2], label='result')
        ax.plot(grayscale_cam.squeeze()[::100], c=cls[2], label='grad_cam')
        confidence = mi(input_tensor).softmax(dim=-1)[:, 1].detach().cpu().numpy()[0]
        label = int(target_idxs[i] in pos_samples)
        if shift:
            ax.set_title(f'Original Label: {label} | Current Label: {int(label==0)} | Prediction: {int(confidence>0.5)}', fontdict={'family': 'Serif'})
        else:
            ax.set_title(f'Label: {label} | Prediction: {int(confidence>0.5)}', fontdict={'family': 'Serif'})
        ax.set_xticks([])
        ax.set_yticks([])
        ax.legend(loc='lower right', prop={'family': 'Serif'}, labelcolor='black')

    if save_dir is not None:
        plt.savefig(f"{save_dir}.svg")


### Gradient Saliency

In [None]:
# get result from deep learning method
cfgs = [
        'patch_transformer_128_50ep_cosine_adamw_1e-3lr_256bs_0.05wd_fold_seed_42/patch_transformer_128_50ep_cosine_adamw_1e-3lr_256bs_0.05wd.py'
        ]
ckpts = [
         'patch_transformer_128_50ep_cosine_adamw_1e-3lr_256bs_0.05wd_fold_seed_42/ckpts/exp_name=patch_transformer_128_50ep_cosine_adamw_1e-3lr_256bs_0.05wd-cfg=patch_transformer_128_50ep_cosine_adamw_1e-3lr_256bs_0.05wd-bs=256-val_auroc=0.9380.ckpt'
         ]
cfgs, ckpts = [root_dir + cfg for cfg in cfgs], [root_dir + ckpt for ckpt in ckpts]

di, mi = prepare_inference(cfgs[0], ckpts[0])
mi = mi.cuda()
di.setup()
dl = di.test_dataloader()

cfgs = [
        'patch_transformer_128_0.5noise_fold_seed_42/patch_transformer_128_w0.5noise.py'
        ]
ckpts = [
         'patch_transformer_128_0.5noise_fold_seed_42/ckpts/exp_name=patch_transformer_128_0.5noise-cfg=patch_transformer_128_w0.5noise-bs=256-val_auroc=0.9493.ckpt'
         ]
cfgs, ckpts = [root_dir + cfg for cfg in cfgs], [root_dir + ckpt for ckpt in ckpts]

di_noise, mi_noise = prepare_inference(cfgs[0], ckpts[0])
mi_noise = mi_noise.cuda()
di_noise.setup()
dl_noise = di_noise.test_dataloader()

In [6]:
# get postive and negative samples
df = pd.read_csv(di.hparams.data.test.split)
pos_samples = np.where(df['label']==1)[0]

### GradCAM

In [None]:
fig, axes = plt.subplots(6, 3, figsize=(24, 24))
target_idxs = pos_samples[0:45:5].tolist()
get_cam_plots(target_idxs, mi, dl, axes.flatten()[:9])
get_cam_plots(target_idxs, mi_noise, dl_noise, axes.flatten()[9:18])
fig.savefig('/home/charon/project/nmos_inference/figures/Figure 5. CAM/cam_supplement.svg')

In [None]:
get_cam_plots(target_idxs, mi, dl, shift=True, save_dir='/home/charon/project/nmos_inference/figures/Figure 5. CAM/reverse')