In [None]:
import os
os.chdir('..')

In [None]:
import re
import torch
import utils
import yaml
from glob import glob

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from dataloader.data import MIMICDataset, get_tables
from dataloader.labels import get_labels
from dataloader.utils import BinnedEvent, get_vocab
from utils import prepare_batch, load_class, load_model, load_params

In [None]:
DEVICE = 'cpu'
data_path = 'data/multitask'

In [None]:
params = load_params('26yyp9cl')
params['patient_modelcls'], params['modelcls']

In [None]:
params
params['batch_size'] = 2
params['vocab_file'] = 'embeddings/sentences.mimic3.hourly.random.binned.train.counts'
# 
# params['normalize'] = False

In [None]:
joint_vocab = get_vocab(**params)
tables = get_tables(['CHARTEVENTS', 'LABEVENTS', 'OUTPUTEVENTS', 'dem'],
                    load=True,
                    event_class=BinnedEvent,
                    vocab=joint_vocab)

labels = get_labels(DEVICE)

train_set = MIMICDataset(data_path, 'train', datalist_file='train_listfile.csv', mode='TRAIN',
                         tables=tables, labels=labels,
                         limit=None,
                         use_cache=False,
                         numericalize=True,
                         )

In [None]:
model = load_model(params, joint_vocab, tables, DEVICE)

In [None]:
from functools import partial
from samplers import AgeSubjectRandomSampler
sampler = AgeSubjectRandomSampler(train_set)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=2,
                                           collate_fn=partial(utils.min_batch,
                                                                  tables=tables,
                                                                  labels=labels,
                                                                  limit=720),
                                           sampler=sampler,
                                           num_workers=0, pin_memory=True, drop_last=True)

In [None]:
for sample in train_loader:
    break

In [None]:
x, y, extra = prepare_batch(sample, DEVICE)
predictions, outputs = model(*x)
patient, timesteps = outputs['patient'], outputs['timesteps']

patient_timesteps = patient  # N, L, C

prediction_step = 1

In [None]:
_patient_timesteps = patient_timesteps[:, :-prediction_step]
_timesteps = timesteps[:, prediction_step:]

pat_prediction = model.step_predictors_pat[prediction_step-1](_patient_timesteps).detach()

ts_prediction = model.step_predictors_ts[prediction_step-1](_timesteps).detach()

In [None]:
pat_prediction.shape, ts_prediction.shape

In [None]:
for t in _timesteps:
    plt.figure(figsize=(10,5), dpi=150)
    sns.heatmap(t.detach().T, cmap='Greys')
    plt.show()

In [None]:
step = -1
plt.figure(figsize=(15, 1))
_prediction = pat_prediction[0,[step]]
sns.heatmap(_prediction, cmap='Greys')

In [None]:
predictions['contrastive'].shape#[:, step, prediction_step]

In [None]:
# for timestep in ts_prediction[:,[step]].detach():
plt.figure(figsize=(5, 3), dpi=150)
pos_ts = ts_prediction[0, [step]]
neg_ts = ts_prediction[np.random.randint(1,ts_prediction.size(0)), [step]]

timestep_match = np.concatenate([pos_ts*_prediction, pos_ts, _prediction, neg_ts, neg_ts*_prediction], 0)
sns.heatmap(timestep_match, cmap='Greys')
timestep_match = (pos_ts @ _prediction.T).item()

plt.annotate(f"pos sim={timestep_match:.4f}", xy=(pos_ts.size(1) * .5, .5), # (1,0.5) of the an1's bbox
              xytext=(30,0), textcoords="offset points",
              va="center", ha="left",
              bbox=dict(boxstyle="round4", fc="w"))

timestep_match = (neg_ts @ _prediction.T).item()
plt.annotate(f"neg sim={timestep_match:.4f}", xy=(pos_ts.size(1) * .5, 4.5), # (1,0.5) of the an1's bbox
              xytext=(30,0), textcoords="offset points",
              va="center", ha="left",
              bbox=dict(boxstyle="round4", fc="w"))

plt.show()

In [None]:
pos_ts @ neg_ts.T

In [None]:
torch.nn.functional.cosine_similarity(pos_ts, _prediction)

In [None]:
torch.nn.functional.cosine_similarity(neg_ts, _prediction)

In [None]:
(pos_ts @ _prediction.T).item(), (neg_ts @ _prediction.T).item()