In [None]:
import os
os.chdir('..')

In [None]:
import re
import torch
import utils
import yaml
from glob import glob

import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd

from dataloader.data import MIMICDataset, get_tables, JointTabularFeature
from dataloader.labels import get_labels
from dataloader.utils import BinnedEvent, get_vocab
from utils import prepare_batch, load_class, load_model, load_config

In [None]:
DEVICE = 'cuda:0'

In [None]:
params = load_config('28zyblao')

In [None]:
params['wandb_id']

In [None]:
params
# params['min_word_count'] = 10000
params['batch_size'] = 1
# params['vocab_file'] = 'embeddings/sentences.mimic3.hourly.random.binned.train.counts'

In [None]:
params['joint_tables']

In [None]:
joint_vocab = get_vocab(**params)

tables = get_tables(vocab=joint_vocab,
                    load=True,
                    event_class=BinnedEvent,
                    **params)

labels = get_labels(DEVICE)

val_set = MIMICDataset(datalist_file='val_listfile.csv', mode='TRAIN',
                       tables=tables, labels=labels,
                       limit=None,
                       numericalize=True,
                       )

In [None]:
model = utils.load_model(params, joint_vocab, tables, DEVICE)
loaded_epoch = re.findall(r'checkpoint_(\d+)_', params['model_path'])

# Timesteps

In [None]:
from functools import partial
val_loader = torch.utils.data.DataLoader(val_set, batch_size=params['batch_size'],
                                         collate_fn=partial(utils.min_batch,
                                                            tables=tables,
                                                            labels=labels,
                                                            limit=720,
                                                            event_limit=300),
                                         shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

In [None]:
patient_embeddings = []
timestep_embeddings = []
filenames = []
targets = []

for batch in val_loader:
    x, y_true, extra = prepare_batch(batch, DEVICE)
    
    if 1 not in y_true['decompensation'][0,1]:
        continue

    preds, outputs = model(*x.values())
    output = {"y_pred": preds,
              "y_true": y_true}
    
    patient_embeddings.append(outputs['patient'][0].detach().cpu().numpy())
    timestep_embeddings.append(outputs['timesteps'][0].detach().cpu().numpy())
    targets.append(np.concatenate([preds['decompensation'].detach().cpu().numpy(), 
                                  y_true['decompensation'][:,1].detach().cpu().numpy()], 
                                 0))
    filenames.append(extra['filename'])
    losses = {}
#     for label in labels.values():
#         losses[label.task] = label.loss(output)
    if len(timestep_embeddings) == 1: break

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
N_sections = 3
d_section = timestep_embeddings[0].shape[1] // N_sections

fig, axes = plt.subplots(1, N_sections+1, figsize=(8, 4), gridspec_kw={'width_ratios':[0.1, 1, 1, 1]})

death = targets[0].T[:,[0]]
death[targets[0].T[:,[1]] == 1] *= -1

# rasterized for removing the lines in pdf export
sns.heatmap(death, ax=axes[0], cmap='bwr', vmax=1, vmin=-1, cbar=False, rasterized=True)
axes[0].set_yticks(range(24, len(death), 24))
axes[0].set_yticklabels(range(1, len(death)//24))
axes[0].set_xticks([])
axes[0].set_title('$p_d$')

for (i, table) in enumerate(['CHARTEVENTS', 'LABEVENTS', 'OUTPUTEVENTS']):
    sns.heatmap(timestep_embeddings[0][:,i*d_section:(i+1)*d_section], 
                ax=axes[i+1], 
                cbar=False,
                rasterized=True)
    axes[i+1].set_yticks(range(24, len(death), 24))
    axes[i+1].set_yticklabels([])
    axes[i+1].set_xticks([])
    axes[i+1].set_title(table)

plt.tight_layout()
plt.savefig('notebooks/decomp-timesteps.pdf')

In [None]:
from captum.attr import IntegratedGradients, TokenReferenceBase, visualization

dataiterator = mimic_dataset(val_loader)

TASK = 'decompensation'

def to_batch(x, y_trues, extras, limit=None):
    x = list(x.values())
    batch =  Batch(
        inputs=(x[0], x[1][:limit], x[2][:limit], x[3][:limit]),
        labels=(y_trues[TASK],),
        additional_args=(x[3])
    )
    return batch, (x, y_trues, extras)

def task_forward(*inputs):
    preds, _ = model.forward(*inputs)
    return preds[TASK]

ig = IntegratedGradients(task_forward)

attribution_dfs = []
def forward_with_sigmoid(inputs):
    out = model(*inputs)
    return out

for i, (batch, (x, y_trues, extras)) in tqdm(enumerate(dataiterator)):
    # skip masked
    print(batch.labels[0])
    if batch.labels[0][1] == 0.: continue
        
    model.zero_grad()
    
    inputs = tuple([batch.inputs[0]] + [input_text_transform(input) for input in batch.inputs[1:]])
    out, insight = model(*inputs)
    pat_repr = insight['patient'].detach().cpu()
    pred = torch.sigmoid(out[TASK]).item()

    baselines = [input*0. for input in batch.inputs]

    try:
        # generate reference for each sample
        attr = ig.attribute(inputs=inputs,
                            n_steps=1)
    except:
        continue
    
    print([a.shape for a in attr])
    print([x.shape for x in batch.inputs])
    if (pred > THRES) and batch.labels[0][0,1].item():
        df = create_attribution_df(attr[1:], batch.inputs[1:])
        attribution_dfs.append((extras['filename'], df))
        if len(attribution_dfs) > 5: break
    most_attributions.append(extract_most_attr(extras['filename'][0], batch, 
                                               attr, 10, pred, batch.labels[0][0,1].item()))
    if 1 in y_trues['decompensation'][0,1]:
        break