In [None]:
import numpy as np
import pandas as pd
import string
import random
from tqdm import tqdm
from typing import *

import torch

import pytorch_lightning as pl

from loguru import logger

In [None]:
import matplotlib
import matplotlib.pyplot as plt

In [None]:
from src.experiments.experiments import get_experiment

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
experiment = get_experiment('resnet')

In [None]:
trainer = pl.Trainer(gpus=1, precision=16)

experiment = experiment.load_from_checkpoint(
    '../.data/resnet50_k_v.ckpt',
    train_data_path='../.data/medical_data_train',
    val_data_path='../.data/medical_data_val',
    learning_rate= 0.003,
    batch_size= 64,
    sequence_length= 16384,
    
)

## training


In [None]:
from grad_cam import *
import os.path as osp
import matplotlib.cm as cm
import cv2
from dataclasses import dataclass
from src.datasets.medical_data import PadData
def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
    gcam = gcam.cpu().numpy()
    cmap = cm.jet_r(gcam)
    if paper_cmap:
        alpha = gcam[..., None]
        gcam = alpha * cmap + (1 - alpha) * raw_image
    else:
        gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
    
    #cv2.imwrite(filename, np.uint8(gcam))
    
    return np.uint8(gcam), cmap

def colorize(words, color_array):
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
    return colored_string

def get_colors(inp, colormap, vmin=None, vmax=None):
    norm = plt.Normalize(vmin, vmax)
    return colormap(norm(inp))

@dataclass
class Sample:
    target_layer: str
    words: list
    sensitivity: torch.Tensor
    
        
if __name__ == '__main__':
    
    results: List[Sample] = []

    # Synset words
    classes = [0, 1]

    # Model
    model = experiment.model


    # The four residual layers
    target_layers = ["features.6.0.residual_block", "features.7.0.residual_block", "features.8.0.residual_block","features.9.0.residual_block"]
    target_class = 0

    output_dir = '.'

    test_dataloader = torch.utils.data.DataLoader(
        experiment.data_module.val_dataset,
        batch_size=experiment.batch_size,
        collate_fn=PadData(
            pad_to_length=experiment.data_module.sequence_length,
            pad_val=experiment.data_module.pad_token_id,
        ),
        num_workers=2,
        pin_memory=True,
        shuffle=True,
    )
    
    gcam = GradCAM(model=model)
    
    i = 0
    for _, batch in tqdm(enumerate(test_dataloader)):
        model.zero_grad()
        
        x, y, y_hat, loss = experiment.step(batch, "eval")
        gcam.device = x.device
        y_hat = torch.softmax(y_hat, dim=1)
        
        gcam.logits = y_hat
        gcam.image_shape = x.shape
        
        confidences, predicted = y_hat.max(1)
        
        mask = (y == 1) & (predicted == 1)

        # only look at correct predictions
        x = x[mask]
        predicted = predicted[mask]
        y = y[mask]


        ids_ = torch.LongTensor([[target_class]] * len(x))
        gcam.backward(ids=ids_)


        for target_layer in target_layers:
#             print("Generating Grad-CAM @{}".format(target_layer))

            # Grad-CAM
            regions = gcam.generate(target_layer=target_layer).squeeze(0).squeeze(0)


            for j, sample in enumerate(x):            
                words = experiment.model.tokeniser.convert_ids_to_tokens(x[j, :].tolist())

                # cmap
                cm = plt.get_cmap('plasma', lut=8)

                color_array = [matplotlib.colors.to_hex(color) for color in get_colors(regions[j, :], cm, 0, 1)]

                s = colorize(words, color_array)

#                     # or simply save in an html file and open in browser
#                     #os.makedirs(f'out/sample{j}', exist_ok=True)
#                     with open(f"../.data/out/sample{i}-actual:{y[j]}-pred:{predicted[j]}_at_{confidences[j]*100:0.2f}%_confident.html", 'w') as f:
#                         f.write(s)

                i += 1

                results.append(Sample(
                    target_layer=target_layer,
                    words = words,
                    sensitivity = regions[j, :],
                ))
        print(len(results))

In [None]:
len(results)

In [None]:
def parse_words(words, sensitivites):
    k_v_store = []
    
    key = ''
    value = ''
    acc = 0
    
    for rune, sensitivity in zip(results[0].words, sensitivites):
        acc += sensitivity
        if len(rune) > 5 and value == '':
            key = rune

        elif len(rune) <= 5:
            value += rune

        elif len(rune) > 5 and value:
            k_v_store.append((rune, value, acc.item() / (len(value)+1)))
            key = ''
            value = ''
            acc = 0

        else:
            k_v_store.append((key, value, acc.item() / (len(value)+1)))
            key = ''
            value = ''
            acc = 0
            
    return k_v_store

In [None]:
most_k_v: Dict[Tuple[str, str], List[float]] = {}
num_entries = 0
for result in tqdm(results):
    sensitivities = sorted(parse_words(result.words, result.sensitivity), key= lambda tup: tup[-1], reverse=True)
    
    for (k, v, s) in sensitivities:
        key = tuple((k, v))
        most_k_v[key] = most_k_v.get(key, []) 
        
        most_k_v[key].append(s)
        num_entries += 1

In [None]:
results_most_k_v = {}
for k, v in most_k_v.items():
    results_most_k_v[k] = np.array(most_k_v[k]).sum()

In [None]:
for thing in sorted(results_most_k_v.items(), key= lambda tup: tup[-1], reverse=True)[:30]:
    print(f"({thing[0][0]}) = \"{thing[0][1]}\" \t\twas present {(thing[1]/num_entries)*100:0.2f}%")

In [None]:
"""
(('script_item_active.med_active_ingr', '1'), 1552.4205774287411)
(('observation_active.observation_value', 'O2SAT'), 1101.7255629818114)
(('encounter_reason_active.reason', 'RESULTS DISCUSSED'), 1037.018368516322)
(('script_item_active.dose', '20mg'), 1018.9674180358202)
(('script_item_active.strength', '3'), 1013.4701462443155)
(('script_item_active.repeats', '25'), 1009.5255802461905)
(('script_item_active.quantity', 'PIROXICAM'), 999.4156962497607)
(('script_item_active.med_name', 'PIROXICAM'), 983.5153352558156)
(('script_item_active.frequency', '1'), 976.3854911817947)
(('encounter_reason_active.reason', 'ELEVATED PSA'), 939.2458728121164)
(('observation_active.observation_value', 'PULSE'), 930.0018636126786)
(('script_item_active.repeats', '1'), 791.5672112685506)
(('observation_active.observation_name', '98'), 791.0192584013294)
(('script_item_active.strength', '0'), 726.3520064711704)
(('script_item_active.dose', '0.1%'), 619.3899866717061)
(('script_item_active.quantity', 'ADVANTAN'), 610.9945505812882)
(('script_item_active.med_name', 'METHYLPREDNISOLONE ACEPONATE'), 594.2070541277367)
(('script_item_active.frequency', 'Topical'), 572.4618293443946)
(('script_item_active.strength', '2'), 521.2934443520304)
(('script_item_active.med_active_ingr', '8'), 440.6237507160537)
(('script_item_active.repeats', '20'), 343.97818737079285)
(('script_item_active.quantity', 'TRAMADOL'), 341.27871390079196)
(('encounter_reason_active.reason', 'ADVICE AND LISTENING'), 340.3012212191011)
(('script_item_active.frequency', '50-100mg'), 336.9993179654589)
(('script_item_active.med_name', 'TRAMADOL HYDROCHLORIDE'), 336.2271525825771)
(('observation_active.observation_value', 'TEMP'), 319.9377679076771)
(('encounter_reason_active.reason', 'DERMATITIS'), 280.79228743576834)
(('encounter_reason_active.reason', 'COUNSELLING'), 203.64822078745178)
(('encounter_reason_active.reason', 'WOUND REVIEW'), 192.2942505436592)
(('immunisation_active.vaccine_name', 'FLUQUADRI'), 182.15081978951784)
​

"""