In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('..')

In [3]:
import argparse
import logging
import os
from collections import OrderedDict
from glob import glob

import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
import plotly.express as px
import numpy as np
import torch
import gzip

import utils

In [4]:
import re
import torch
from tqdm import tqdm
import utils
import yaml
from glob import glob
from functools import partial

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from dataloader.data import MIMICDataset, get_tables
from dataloader.labels import get_labels
from dataloader.utils import BinnedEvent, get_vocab
from utils import prepare_batch, load_class, load_model, load_config

In [6]:
DEVICE = 'cuda:0'
RUN_ID = '2xdwyub7'
LISTFILE_ROOT = '/home/ashankar/mimic3_data/data/'
LISTFILE = 'test_listfile'
THRES = 0.92

params = load_config(RUN_ID)
# params['vocab_file'] = 'embeddings/sentences.mimic3.hourly.random.binned.train.counts'
params['patient_modelcls'], params['modelcls']

  c = yaml.load(f)


('models.PatientChannelRNNMaxPoolEncoder', 'models.MultitaskFinetune')

In [7]:
joint_vocab = get_vocab(**params)

tables = get_tables(load=True,
                    event_class=BinnedEvent,
                    vocab=joint_vocab,
                    **params)

labels = get_labels(DEVICE)

train_set = MIMICDataset(datalist_file='train_listfile.csv', mode='EVAL',
                         tables=tables, labels=labels,
                         limit=None,
                         numericalize=True)
val_set = MIMICDataset(datalist_file='val_listfile.csv', mode='EVAL',
                       tables=tables, labels=labels,
                       limit=128,
                       numericalize=True)

In [8]:
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=params['batch_size'], 
                                           collate_fn=partial(utils.pad_batch,
                                                              tables=tables,
                                                              labels=labels,
                                                              limit=48),
                                           shuffle=True,
                                           num_workers=0,
                                           pin_memory=False,
                                           drop_last=True)

val_loader = torch.utils.data.DataLoader(val_set,
                                         batch_size=params['batch_size'], 
                                         collate_fn=partial(utils.pad_batch,
                                                            tables=tables,
                                                            labels=labels,
                                                            limit=None),
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True)

In [9]:
model = load_model(params, joint_vocab, tables, DEVICE)

# Collect patient embeddings

In [10]:
model.timestep_encoder.event_encoder.encoder.weight.device

device(type='cuda', index=0)

In [11]:
tok_acts = []
pheno_labels = []
ihm_labels = []
los_labels = []
decomp_labels = []
pheno_preds = []
decomp_preds = []
ihm_preds = []
los_preds = []
patients = []

for i, batch in tqdm(enumerate(val_loader)):

    model.eval()
    with torch.no_grad():
        x, y_trues, extras = prepare_batch(batch, DEVICE)
        # skip masked
        if y_trues['in_hospital_mortality'][0,0] == 0.: continue
        
        y_preds, extra = model(*x.values())
        patients.append(extra['patient'].detach().cpu().numpy())

        pheno_labels.append(y_trues['phenotyping'].detach().cpu().numpy())
        ihm_labels.append(y_trues['in_hospital_mortality'].detach().cpu().numpy())
        los_labels.append(y_trues['length_of_stay_classification'].detach().cpu().numpy())
        decomp_labels.append(y_trues['decompensation'].detach().cpu().numpy())
        decomp_preds.append(y_trues['decompensation'].detach().cpu().numpy())
        pheno_preds.append(y_preds['phenotyping'].detach().cpu().numpy())
        ihm_preds.append(y_preds['in_hospital_mortality'].detach().cpu().numpy())
        los_preds.append(y_preds['length_of_stay_classification'].detach().cpu().numpy())
        
        output = {}
        output['y_pred'] = y_preds
        output['y_true'] = y_trues


128it [00:10, 12.02it/s]


In [None]:
ihm_labels = np.concatenate(ihm_labels, 0)
los_labels = np.concatenate([l[:,1,47] for l in los_labels if l.shape[2] > 47], 0)
decomp_labels = np.concatenate(decomp_labels, -1)
decomp_preds = np.concatenate(decomp_preds, -1)
ihm_preds = np.concatenate(ihm_preds, 0)
los_preds = np.concatenate([l[:,47].argmax(-1) for l in los_preds if l.shape[1] > 47], 0)

In [None]:
decomp_labels.shape

In [None]:
decomp_preds.shape

In [None]:
pats = np.concatenate(patients, 1)[0]

In [None]:
ihm = np.concatenate([p[:, 47] for p in patients])

In [None]:
((decomp_labels[0, 0] == 1.) & ((decomp_labels[0, 1]) == (decomp_preds[0, 1] > 0.9973))).shape

In [None]:
pats.shape

In [None]:
ihm_pos = ihm[(ihm_labels[:, 0] == 1.) & ((ihm_labels[:, 1]) == (ihm_preds[:, 0] > .92))]
decomp_pos = pats[(decomp_labels[0, 0] == 1.) & ((decomp_labels[0, 1]) == (decomp_preds[0, 1] > 0.9973))]
los_pos = ihm[los_labels == los_preds]

In [None]:
acts = np.mean(ihm_pos, 0)

In [None]:
ihm_acts = np.mean([p[0,47] for p in ihm_pos], 0)
decomp_acts = np.mean([p[0, 4:].mean(0) for p in decomp_pos],0)
los_pos = np.mean([p[0, 4:].mean(0) for p in los_pos],0)
phen_acts = np.mean([p[0,-1] for p in patients],0)

In [None]:
ihm_acts.shape, decomp_acts.shape, phen_acts.shape

In [None]:
def sum_pos(x):
    return x[x > 0.].sum()

def sum_neg(x):
    return x[x < 0.].sum()

In [None]:
table_pooling = sum([[f'{i} over time, {j} over events']*50 for i in ['max-pool', 'avg-pool', 'sum-pool'] for j in ['max-pool', 'avg-pool', 'sum-pool']], [])
table_dem = ['dem'] * 20

In [None]:
def get_source_dist(acts, y_axis='Mean activation'):
    dfs = []
    for tabl, ind, l in zip(['CHARTEVENTS', 'LABEVENTS', 'OUTPUTEVENTS', 'PRESCRIPTIONS', 'INPUTEVENTS', 'dem'], [0, 450, 900, 1350, 1800, 1820], [450, 450, 450, 450, 450, 20]):
    #     breakpoint()
        df = pd.DataFrame(acts[ind:ind+l],
                         columns=[y_axis])
        df['Source'] = tabl
        if tabl != 'dem':
            df['time-pool'] = sum([[p]*150 for p in ['max', 'avg', 'sum']], [])
            df['Event pooling'] = sum([[p]*50 for p in ['max', 'avg', 'sum']*3], [])
            df_pos = df.groupby(['time-pool', 'Event pooling', 'Source'])[y_axis].agg(sum_pos)
            df_pos = df_pos.reset_index()
            df_pos['Event pooling'] = sum([[f'${p}^+$'] for p in ['max', 'avg', 'sum']*3], [])

            df_neg = df.groupby(['time-pool', 'Event pooling', 'Source'])[y_axis].agg(sum_neg)
            df_neg = df_neg.reset_index()
            df_neg['Event pooling'] = sum([[f'${p}^-$'] for p in ['max', 'avg', 'sum']*3], [])
        else:
            df['time-pool'] = 'dem'
            df_pos = df.groupby(['Source'])[y_axis].agg(sum_pos)
            df_pos = df_pos.reset_index()
            df_pos['time-pool'] = 'dem'
            df_pos['Event pooling'] = 'dem'

            df_neg = df.groupby(['Source'])[y_axis].agg(sum_neg)
            df_neg = df_neg.reset_index()
            df_neg['time-pool'] = 'dem'
            df_neg['Event pooling'] = 'dem'
            
            

        df = pd.concat([df_pos, df_neg], 0)
        dfs.append(df)
    
    return pd.concat(dfs, 0, ignore_index=True)

import plotly.graph_objects as go

def plot_source_dist(df, y_axis='Mean activation'):
    fig = px.bar(df, x='time-pool', y=y_axis, color='Event pooling', width=800, height=400, barmode='relative', facet_col='Source')

#     fig = px.histogram(df, x='pooling', y=y_axis, color='pooling', width=1100, height=400,
#                        histfunc='sum', 
#                        facet_col='table')
#     fig.update_traces(nbinsx=450//bin_size,  overwrite=True)
#     fig.update_traces(width=0.7,  overwrite=True)
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.layout.legend.orientation = 'h'
    fig.layout.legend.yanchor = 'top'
    fig.layout.legend.xanchor = 'center'
    fig.layout.legend.x = 0.5
    fig.layout.legend.y = -0.3
    fig.update_yaxes(col=1, title_text=f'Mean contribution')
    fig.update_yaxes(overwrite=True,
                     matches=None,
                     row=1, 
                     col=6)
    
    fig.update_xaxes(overwrite=True,
                     matches=None)
    fig.update_xaxes(overwrite=True,
#                      ticklabels=['avg', 'max', 'max'],
#                      range=[0, 3],
                    nticks=3,
                    dtick=1)
    fig.update_xaxes(overwrite=True,
    #                  tick0=75, dtick=75, 
                     showgrid=True,
                     title_text='',
                     showticklabels=True)
    fig.update_xaxes(overwrite=True,
                     showticklabels=False,
                     row=1,
                     col=6)
    
    fig.update_xaxes(col=3, title_text=f'Time pooling')
    return fig

In [None]:
df = get_source_dist(decomp_acts, 'Mean activation')
df.to_csv('exp_act.csv')
fig = plot_source_dist(df)
fig.update_layout(title_text='Patient activation')
fig.show()

In [None]:
weights = model.predictor.decision_mlps['in_hospital_mortality'].weight.detach().cpu().numpy()[0]
bias = model.predictor.decision_mlps['in_hospital_mortality'].bias.detach().cpu().numpy()[0]

df = get_source_dist(ihm_acts * weights)
df.to_csv('exp_ihm.csv')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text='In-Hospital Mortality')
fig.update_layout(showlegend=True)
fig.update_yaxes(range=[-5,5], overwrite=True)
fig.write_image('notebooks/figures/ihm-legend.pdf')
fig

In [None]:
weights = model.predictor.decision_mlps['in_hospital_mortality'].weight.detach().cpu().numpy()[0]
df = get_source_dist(ihm_acts * weights)
df.to_csv('exp_ihm.csv')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text='In-Hospital Mortality')
fig.update_layout(showlegend=False)
fig.update_yaxes(range=[-5,5], overwrite=True)
fig.write_image('notebooks/figures/ihm.pdf')
fig

In [None]:
weights = model.predictor.decision_mlps['decompensation'].weight.detach().cpu().numpy()[0]
bias = model.predictor.decision_mlps['decompensation'].bias.detach().cpu().numpy()[0]

df = get_source_dist(decomp_acts * weights, 'Mean activation')
# df.to_csv('exp_decomp.csv')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text='Decompensation')
fig.update_layout(showlegend=False)
fig.update_yaxes(range=[-5,5], overwrite=True)
fig.write_image('notebooks/figures/decomp.pdf')
fig

In [None]:
weights = model.predictor.decision_mlps['length_of_stay_regression'].weight.detach().cpu().numpy()[0]
print(weights.shape)
df = get_source_dist(decomp_acts * weights, 'Mean activation')
df.to_csv('exp_los.csv')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text='Length-of-stay regression')
fig.update_layout(showlegend=False)
fig.update_yaxes(range=[-5,5], overwrite=True)
fig.write_image('notebooks/figures/los.pdf')
fig

In [None]:
p_index = 8
weights = model.predictor.decision_mlps['phenotyping'].weight.detach().cpu().numpy()[p_index]# + model.decision_mlps.decompensation.bias.detach().cpu().numpy()[0]
df = get_source_dist(decomp_acts * weights, 'Mean activation')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text=f'Phenotyping: {labels["phenotyping"].classes[p_index]}')
fig.update_layout(showlegend=False)
fig.update_yaxes(range=[-.11,.11], overwrite=True)
fig.write_image('notebooks/figures/pheno8.pdf')
fig

In [None]:
p_index = 0
weights = model.predictor.decision_mlps['phenotyping'].weight.detach().cpu().numpy()[p_index]# + model.decision_mlps.decompensation.bias.detach().cpu().numpy()[0]
df = get_source_dist(decomp_acts * weights, 'Mean activation')
fig = plot_source_dist(df, 'Mean activation')
fig.update_layout(title_text=f'Phenotyping: {labels["phenotyping"].classes[p_index]}')
fig.update_layout(showlegend=False)
fig.update_yaxes(range=[-.11,.11], overwrite=True)
fig.write_image('notebooks/figures/pheno0.pdf')
fig

# Save activated features

In [None]:
act_paths = glob(f"wandb/run-*{RUN_ID}/insight*/*.tsv.gz")
act_paths

In [None]:
import pandas as pd
import re

dfs = []
for path in act_paths:
#     if 'output' not in path: continue
#     print(path)
    with gzip.open(path, 'rt') as f:
        df = pd.read_csv(f, sep='\t')
        m = re.match('.*/(.*?[A-Z]+).*_activations.*', path)
        df['table'] = m[1]
        dfs.append(df)

df = pd.concat(dfs)

In [None]:
class args:
    prediction = glob(f'wandb/*{RUN_ID}/{LISTFILE}_predictions/in_hospital_mortality*.csv')[-1]
    test_listfile = f'{LISTFILE_ROOT}/in-hospital-mortality/{LISTFILE}.csv'
pred_df = pd.read_csv(args.prediction, index_col=False)

In [None]:
df = df[~df['token'].isin(['<pad>'])]
df = df.merge(pred_df, on='stay', how='right')
assert not df.y_true.isna().any()

In [None]:
print(len(df))
df = df[(df.prediction > 0.95) == df.y_true]
print(len(df))

In [None]:
df['feature'] = df['token'].apply(lambda x: x.split('_')[0] if '_' in x else '')

In [None]:
sorted_df = df[['dim', 'feature', 'token', 'activation', 'table']].groupby(['feature', 'token', 'dim', 'table']).agg(['mean', 'sum', 'count', 'std'])

In [None]:
pd.options.display.max_rows = 100
pd.options.display.min_rows = 50

In [None]:
sorted_df = sorted_df.sort_values(('activation', 'sum'), ascending=False).reset_index()

In [None]:
sorted_df[:25].set_index(['table', 'dim', 'token']).sort_index()#.loc['events']

In [None]:
filtered_df = df[df.token.isin(sorted_df.iloc[:10].token)]

In [None]:
#px.box(filtered_df, y='activation', x='token', color='table', width=1000, height=400)

In [None]:
tokens = sorted_df[(sorted_df[('activation', 'count')]>1) & (sorted_df[('activation', 'sum')]>1)].reset_index()['token'].unique()
len(tokens)

In [None]:
# from datasets.utils import feature_string
# import pandas as pd

# with open('our_features_train', 'wt') as f:
#     for feature in tokens:
#         f.write(feature_string(feature) + '\n')