Warning! This notebook requires pandas 1.3 in order to style the outputs

In [None]:
! conda install -yn bdh-project pandas==1.3.0

In [1]:
import pickle
import os
import sys

import torch
from torch.utils.data import DataLoader
from gensim.models import Word2Vec
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

sys.path.append('../src/models')
from train_model import load_dataset, document_collate_function

  from ipykernel import kernelapp as app


In [2]:
# this must be >=1.3.0!!
pd.__version__

'1.3.0'

In [3]:
# load everything

PROCESSED_DATA_PATH = '../data/processed/dev_50.p'
MODEL_OUTPUT_PATH = '../models/HLANModel-interpretable-epoch9.pth'
EMBEDDING_MODEL_PATH = '../src/features/saved_embedding_models/word_embeddings.model'
CODEMAP_PATH = '../data/processed/code_map.p'
BATCH_SIZE = 32

dev_dataset = load_dataset(PROCESSED_DATA_PATH)
dev_loader = DataLoader(
    dataset=dev_dataset, batch_size=BATCH_SIZE,
    shuffle=False, collate_fn=document_collate_function,
    num_workers=0
)

model = torch.load(MODEL_OUTPUT_PATH)
model.eval()

embedding_model = Word2Vec.load(EMBEDDING_MODEL_PATH)
embedding_to_word_map = {}
for word in embedding_model.wv.index2word:
    embedding_to_word_map[tuple(embedding_model.wv[word])] = word

def get_word_from_embedding(embedding):
    if isinstance(embedding, torch.Tensor):
        embedding = tuple(embedding.detach().numpy())
    elif isinstance(embedding, np.ndarray):
        embedding = tuple(embedding)
    return embedding_to_word_map.get(embedding)

with open(CODEMAP_PATH, 'rb') as f:
    codemap = pickle.load(f)
reverse_codemap = {v: k for k, v in codemap.items()}

INFO - 2022-11-30 08:14:30,308 - train_model.py - Loading data from file ../data/processed/dev_50.p...
INFO - 2022-11-30 08:14:34,606 - utils.py - loading Word2Vec object from ../src/features/saved_embedding_models/word_embeddings.model
INFO - 2022-11-30 08:14:34,998 - utils.py - loading wv recursively from ../src/features/saved_embedding_models/word_embeddings.model.wv.* with mmap=None
INFO - 2022-11-30 08:14:34,999 - utils.py - loading vectors from ../src/features/saved_embedding_models/word_embeddings.model.wv.vectors.npy with mmap=None
INFO - 2022-11-30 08:14:35,033 - utils.py - setting ignored attribute vectors_norm to None
INFO - 2022-11-30 08:14:35,034 - utils.py - loading vocabulary recursively from ../src/features/saved_embedding_models/word_embeddings.model.vocabulary.* with mmap=None
INFO - 2022-11-30 08:14:35,034 - utils.py - loading trainables recursively from ../src/features/saved_embedding_models/word_embeddings.model.trainables.* with mmap=None
INFO - 2022-11-30 08:14:3

In [31]:
# run for the first batch

samples = next(iter(dev_loader))
labels = samples[1] # (batch_size, num_labels)
outputs = torch.nn.Sigmoid()(model(samples[0]))
word_attention = model.p_attention_word
sentence_attention = model.p_attention_sent
num_sentences = sentence_attention.shape[2]
outputs

tensor([[8.6095e-01, 9.6447e-02, 8.2078e-03,  ..., 4.3510e-02, 9.3843e-03,
         2.0846e-04],
        [8.9579e-02, 6.7908e-02, 9.6941e-01,  ..., 1.8098e-01, 1.1136e-02,
         7.8409e-04],
        [4.9423e-02, 1.8387e-01, 9.8905e-03,  ..., 9.4176e-03, 1.7204e-02,
         1.4745e-05],
        ...,
        [8.8130e-01, 5.8519e-01, 4.9931e-02,  ..., 8.7787e-02, 2.0624e-02,
         1.6198e-04],
        [1.2631e-01, 1.7652e-01, 1.1699e-02,  ..., 1.8942e-01, 1.3930e-02,
         1.5242e-04],
        [2.1006e-01, 1.7027e-02, 3.2877e-02,  ..., 1.2051e-01, 3.3459e-02,
         2.1134e-03]], grad_fn=<SigmoidBackward0>)

In [32]:
# first, pick a document to visualize by setting index below
# anywhere in the range (0, BATCH_SIZE)
index = 0
positive_label_indexes = labels[index, :].numpy().nonzero()[0]


positive_label_indexes

array([ 0,  3,  4,  9, 29])

In [40]:
# finally, pick a label using the index of the array
# printed above. anywhere in the range (0, len(positive_label_indexes))
label_index = 3

input_words = samples[0][0][index].detach().numpy()

# sentence weights for this label
positive_label = positive_label_indexes[label_index]
code = reverse_codemap[positive_label]
y_pred = outputs[index][positive_label]
sentence_weights = sentence_attention[positive_label, index, :].detach().numpy()

# word weights for this label
start_idx = index * num_sentences
end_idx = start_idx + num_sentences
word_weights = word_attention[index, start_idx:end_idx, :].detach().numpy()

# word weights, weighted by sentence weights
weighted_word_weights = (word_weights.T * sentence_weights).T

viz_table = []
colors = []
for sentence_idx in range(sentence_weights.shape[0]):
    sent_weight = sentence_weights[sentence_idx]
    row = [str(sent_weight)]
    color_row = [sent_weight]
    for word_idx in range(word_weights.shape[1]):
        word = get_word_from_embedding(input_words[sentence_idx][word_idx])
        if word is None:
            word = ''
        color_row.append(weighted_word_weights[sentence_idx][word_idx])
        row.append(word)
    viz_table.append(row)
    colors.append(color_row)
    
def format_display(styler):
    styler.set_caption(f"Attention Weights for code {code} - Predicted score {y_pred:.2%}")
    styler.background_gradient(
         axis=None, cmap="YlGnBu", gmap=colors
    )
    return styler
    
df = pd.DataFrame(viz_table)
df.style.pipe(format_display)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25
0,0.0021717846,admission,date,discharge,date,date,of,birth,sex,f,service,surgery,allergies,patient,recorded,as,having,no,known,allergies,to,drugs,attending,first,name3,lf
1,0.099645376,chief,complaint,60f,on,coumadin,was,found,slightly,drowsy,tonight,then,fell,down,stairs,paramedic,found,her,unconscious,and,she,was,intubated,w,o,any
2,0.0032811803,medication,head,ct,shows,multiple,iph,transferred,to,hospital1,for,further,eval,major,surgical,or,invasive,procedure,none,past,medical,history,her,medical,history,is
3,0.01101851,significant,for,hypertension,osteoarthritis,involving,bilateral,knee,joints,with,a,dependence,on,cane,for,ambulation,chronic,back,pain,she,also,has,a,history,of,a
4,0.005356998,right,lung,cancer,requiring,right,lobectomy,in,no,metastasis,was,known,and,she,has,since,recovered,well,and,is,considered,cured,social,history,unknown,family
5,0.8566249,history,nc,physical,exam,physical,exam,intubated,non,sedated,received,no,paralytic,medication,no,eye,opening,pupil,rt,mm,lt,mm,both,non,reactive,corneal
6,0.008655145,bilat,extends,both,ue,to,stim,min,withdrawal,triple,flexion,both,le,upgoing,toes,bilat,brief,hospital,course,ct,scan,revealed,very,severe,iph,given
7,0.0027144165,her,poor,prognosis,with,fixed,pupils,and,posturing,patient,was,made,cmo,by,family,she,expired,shortly,after,arrival,to,hospital,medications,on,admission,unknown
8,0.0048937094,discharge,medications,expired,discharge,disposition,expired,discharge,diagnosis,iph,discharge,condition,expired,discharge,instructions,none,followup,instructions,none,first,name11,name,pattern1,last,name,namepattern4
9,0.0012678097,md,md,number,completed,by,,,,,,,,,,,,,,,,,,,,
