In [1]:
import os
import sys
import pandas as pd
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [2]:
DATA_PATH = "./Data/mimic/"
pids = pickle.load(open(os.path.join(DATA_PATH, 'pids.pkl'), 'rb'))
vids = pickle.load(open(os.path.join(DATA_PATH, 'vids.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH, 'seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH, 'types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH, 'rtypes.pkl'), 'rb'))

In [3]:
rtypes_list = pd.Series(rtypes)
diag_codes = rtypes_list[rtypes_list.str.startswith('DIAG_')]
diag_codes.reset_index(inplace=True, drop=True)
drug_codes = rtypes_list[rtypes_list.str.startswith('DRUG_')]
drug_codes.reset_index(inplace=True, drop=True)

In [4]:
diag_codes.head()

0    DIAG_423
1    DIAG_511
2    DIAG_785
3    DIAG_458
4    DIAG_311
dtype: object

In [5]:
drugs = pd.Series(list(range(len(drug_codes))), index=drug_codes.values).to_dict()
diags = pd.Series(list(range(len(diag_codes))), index=diag_codes.values).to_dict()

In [6]:
def eventlabel2eventid(seqs):
    new_seqs, diag_seqs, drug_seqs = [], [], []
    for patient in seqs:
        events, diag_visits, drug_visits = [], [], []
        for visit in patient:
            diag_ids, drug_ids = [], []
            for event_label in visit:
                event_id = rtypes[event_label]
                if event_id in diags:
                    diag_ids.append(diags[event_id])
                else:
                    drug_ids.append(drugs[event_id])
            diag_visits.append(diag_ids)
            drug_visits.append(drug_ids)
            vis_events = [diag_ids, drug_ids]
            events.append(vis_events)
        diag_seqs.append(diag_visits)
        drug_seqs.append(drug_visits)
        new_seqs.append(events)
    return new_seqs, diag_seqs, drug_seqs


new_seqs, diag_seqs, drug_seqs = eventlabel2eventid(seqs)
len(drug_seqs[0])

2

In [7]:
C_d = len(diags)
C_t = len(drugs)

In [8]:
C_d

855

In [9]:
C_t

424

In [10]:
len(new_seqs[9])

2

In [11]:
class Embeddings(nn.Module):
    def __init__(self, code_dim, emb_dim=256):
        super().__init__()
        self.linear = nn.Linear(code_dim, emb_dim)
        self.relu = nn.ReLU()

    def forward(self, visit):
        return self.relu(self.linear(visit))

In [12]:
class BRNN(torch.nn.Module):
    def __init__(self, emb_dim, hid_dim=128):
        super().__init__()
        self.hid_dim = hid_dim
        self.brnn = nn.GRU(emb_dim, hid_dim, bidirectional=True, batch_first=True)

    def forward(self, emb_visits):
        output, _ = self.brnn(emb_visits)
        result = output[:, :, : self.hid_dim] + output[:, :, self.hid_dim:]
        return result

In [13]:
class Attention(nn.Module):
    def __init__(self, hid_dim=128):
        super().__init__()
        self.att = nn.Linear(hid_dim, 1)

    def forward(self, com):
        return F.softmax(self.att(com), dim=1)

In [14]:
class COAM(nn.Module):
    def __init__(self, diag_dim=C_d, drug_dim=C_t, emb_dim=256, hid_dim=128):
        super().__init__()
        self.emb_d = Embeddings(diag_dim, emb_dim)
        self.emb_t = Embeddings(drug_dim, emb_dim)
        self.brnn_d = BRNN(emb_dim, hid_dim)
        self.brnn_t = BRNN(emb_dim, hid_dim)
        self.com = nn.Linear(2 * hid_dim, hid_dim)
        self.att_a = Attention(hid_dim)
        self.att_b = Attention(hid_dim)
        self.p = nn.Linear(2 * hid_dim, hid_dim)
        self.output = nn.Linear(hid_dim, diag_dim + drug_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, diag_vis, drug_vis):
        d_visits, t_visits = [], []
        for i in range(len(diag_vis)):
            d_emb = self.emb_d(diag_vis[i])
            t_emb = self.emb_t(drug_vis[i])
            d_visits.append(d_emb)
            t_visits.append(t_emb)
        h = self.brnn_d(torch.cat(d_visits).unsqueeze(dim=0))
        g = self.brnn_t(torch.cat(t_visits).unsqueeze(dim=0))
        com = self.com(torch.cat((h, g), 2))
        alpha = self.att_a(com)
        beta = self.att_b(com)
        h_tilde = torch.mul(beta, h).sum(dim=1)
        g_tilde = torch.mul(alpha, g).sum(dim=1)
        tilde_cat = torch.cat((h_tilde, g_tilde), 1)
        p = self.p(tilde_cat)
        result = self.output(p)
        result = self.sigmoid(result)
        return result

In [15]:
class COAMa(nn.Module):
    def __init__(self, diag_dim=C_d, drug_dim=C_t, emb_dim=256, hid_dim=128):
        super().__init__()
        self.emb_d = Embeddings(diag_dim, emb_dim)
        self.emb_t = Embeddings(drug_dim, emb_dim)
        self.brnn_d = BRNN(emb_dim, hid_dim)
        self.brnn_t = BRNN(emb_dim, hid_dim)
        self.com = nn.Linear(2 * hid_dim, hid_dim)
        self.att_a = Attention(hid_dim)
        self.att_b = Attention(hid_dim)
        self.p = nn.Linear(2 * hid_dim, hid_dim)
        self.output = nn.Linear(hid_dim, diag_dim + drug_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, diag_vis, drug_vis):
        d_visits, t_visits = [], []
        for i in range(len(diag_vis)):
            d_emb = self.emb_d(diag_vis[i])
            t_emb = self.emb_t(drug_vis[i])
            d_visits.append(d_emb)
            t_visits.append(t_emb)
        h = self.brnn_d(torch.cat(d_visits).unsqueeze(dim=0))
        g = self.brnn_t(torch.cat(t_visits).unsqueeze(dim=0))
        alpha = self.att_a(h)
        beta = self.att_b(g)
        h_tilde = torch.mul(beta, h).sum(dim=1)
        g_tilde = torch.mul(alpha, g).sum(dim=1)
        tilde_cat = torch.cat((h_tilde, g_tilde), 1)
        p = self.p(tilde_cat)
        result = self.output(p)
        result = self.sigmoid(result)
        return result

In [16]:
class COAMb(nn.Module):
    def __init__(self, diag_dim=C_d, drug_dim=C_t, emb_dim=256, hid_dim=128):
        super().__init__()
        self.emb_d = Embeddings(diag_dim, emb_dim)
        self.emb_t = Embeddings(drug_dim, emb_dim)
        self.brnn_d = BRNN(emb_dim, hid_dim)
        self.brnn_t = BRNN(emb_dim, hid_dim)
        self.com = nn.Linear(2 * hid_dim, hid_dim)
        self.att = Attention(hid_dim)
        self.p = nn.Linear(2 * hid_dim, hid_dim)
        self.output = nn.Linear(hid_dim, diag_dim + drug_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, diag_vis, drug_vis):
        d_visits, t_visits = [], []
        for i in range(len(diag_vis)):
            d_emb = self.emb_d(diag_vis[i])
            t_emb = self.emb_t(drug_vis[i])
            d_visits.append(d_emb)
            t_visits.append(t_emb)
        h = self.brnn_d(torch.cat(d_visits).unsqueeze(dim=0))
        g = self.brnn_t(torch.cat(t_visits).unsqueeze(dim=0))
        com = self.com(torch.cat((h, g), 2))
        att = self.att(com)
        h_tilde = torch.mul(att, h).sum(dim=1)
        g_tilde = torch.mul(att, g).sum(dim=1)
        tilde_cat = torch.cat((h_tilde, g_tilde), 1)
        p = self.p(tilde_cat)
        result = self.output(p)
        result = self.sigmoid(result)
        return result

In [17]:
import random
random.seed(seed)

training_index = random.sample(range(len(new_seqs)), k=int(len(new_seqs)*0.8))
test_index = list(set(range(len(new_seqs))) - set(training_index))

In [18]:
len(training_index)

6001

In [19]:
len(test_index)

1501

In [20]:
def dataFormatter(patient_list):
    diag_vis, drug_vis = [], []
    for visit in patient_list:
        diag_codes = torch.zeros((1, C_d))
        diag_codes[0, visit[0]] = 1
        diag_vis.append(diag_codes)
        drug_codes = torch.zeros((1, C_t))
        drug_codes[0, visit[1]] = 1
        drug_vis.append(drug_codes)
    target = torch.cat((diag_vis[-1], drug_vis[-1]), 1)

    return diag_vis[:-1], drug_vis[:-1], torch.FloatTensor(target)

In [21]:
from sklearn.metrics import precision_recall_fscore_support


def eval_model(model, seqs, data_index):
    model.eval()
    p, r, f = np.array([]), np.array([]), np.array([])
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                diag_vis, drug_vis, y = dataFormatter(patient[:idx+1])
                y_hat = model(diag_vis, drug_vis)
                y_hat = (y_hat > 0.5).int()
                y = y.squeeze()
                y_hat = y_hat.squeeze()
                new_p, new_r, new_f, _ = precision_recall_fscore_support(y, y_hat, average='binary', zero_division=0)
                p, r, f = np.append(p, new_p), np.append(r, new_r), np.append(f, new_f)
    
    return np.mean(p), np.mean(r), np.mean(f)

In [41]:
def eval_model_top_k(model, seqs, data_index, k):
    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    scores = torch.LongTensor()
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                diag_vis, drug_vis, y = dataFormatter(patient[:idx+1])
                y_hat = model(diag_vis, drug_vis)
                y_hat = (y_hat > 0.5).int()
                scores = torch.cat((scores, torch.mul(y_hat, y).sum(dim=1)), 0)
                y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
                y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    top_k = torch.topk(scores, k).indices
    y_true = torch.reshape(y_true[top_k, :], (-1,))
    y_pred = torch.reshape(y_pred[top_k, :], (-1,))
    result = torch.mul(y_true, y_pred).sum(0) / y_true.sum(dim=0)
    
    return result

In [22]:
from sklearn.metrics import accuracy_score

def train(model, optimizer, training_index, seqs, val_index):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for p_index in training_index:
#             print('p_index is:', p_index)
            loss = 0
            patient = seqs[p_index]
            for idx, visit in enumerate(patient):
                if idx > 0:
                    diag_vis, drug_vis, target = dataFormatter(patient[:idx+1])
                    pred = model(diag_vis, drug_vis)
                    loss += F.binary_cross_entropy_with_logits(pred, target)
                    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(training_index)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval_model(model, seqs, val_index)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'
              .format(epoch+1, p, r, f))

In [23]:
coam = COAM()
optimizer = torch.optim.Adam(coam.parameters(), lr=1e-3)

In [24]:
n_epochs = 20
train(coam, optimizer, training_index, new_seqs, training_index)

Epoch: 1 	 Training Loss: 1.142964
Epoch: 1 	 Validation p: 0.70, r:0.10, f: 0.17
Epoch: 2 	 Training Loss: 1.141942
Epoch: 2 	 Validation p: 0.68, r:0.14, f: 0.23
Epoch: 3 	 Training Loss: 1.141666
Epoch: 3 	 Validation p: 0.71, r:0.14, f: 0.23
Epoch: 4 	 Training Loss: 1.141600
Epoch: 4 	 Validation p: 0.69, r:0.15, f: 0.24
Epoch: 5 	 Training Loss: 1.141556
Epoch: 5 	 Validation p: 0.71, r:0.15, f: 0.24
Epoch: 6 	 Training Loss: 1.141522
Epoch: 6 	 Validation p: 0.72, r:0.14, f: 0.23
Epoch: 7 	 Training Loss: 1.141506
Epoch: 7 	 Validation p: 0.71, r:0.14, f: 0.23
Epoch: 8 	 Training Loss: 1.141467
Epoch: 8 	 Validation p: 0.68, r:0.16, f: 0.26
Epoch: 9 	 Training Loss: 1.141393
Epoch: 9 	 Validation p: 0.68, r:0.18, f: 0.27
Epoch: 10 	 Training Loss: 1.141360
Epoch: 10 	 Validation p: 0.66, r:0.18, f: 0.27
Epoch: 11 	 Training Loss: 1.141315
Epoch: 11 	 Validation p: 0.66, r:0.17, f: 0.27
Epoch: 12 	 Training Loss: 1.141404
Epoch: 12 	 Validation p: 0.66, r:0.17, f: 0.26
Epoch: 13 

In [25]:
%%time
p, r, f = eval_model(coam, new_seqs, test_index)
print('Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(p, r, f))

Validation p: 0.67, r:0.17, f: 0.26
CPU times: user 7.34 s, sys: 21.2 s, total: 28.6 s
Wall time: 4.02 s


In [48]:
%%time
for k in [5, 10, 15, 20]:
    acc_k = eval_model_top_k(coam, new_seqs, test_index, k)
    print('Validation Accuracy@' + str(k) + ': {:.2f}'.format(acc_k))

Validation Accuracy@5: 0.26
Validation Accuracy@10: 0.22
Validation Accuracy@15: 0.22
Validation Accuracy@20: 0.23
CPU times: user 1min 21s, sys: 2min 14s, total: 3min 35s
Wall time: 29.2 s


In [47]:
p = eval_model_top_k(coam, new_seqs, test_index, 10)
print('Validation p: {:.2f}'.format(p))

Validation p: 0.22


In [26]:
coama = COAMa()
optimizera = torch.optim.Adam(coama.parameters(), lr=1e-3)

In [27]:
%%time
n_epochs = 20
train(coama, optimizera, training_index, new_seqs, training_index)

Epoch: 1 	 Training Loss: 1.143495
Epoch: 1 	 Validation p: 0.72, r:0.06, f: 0.10
Epoch: 2 	 Training Loss: 1.142636
Epoch: 2 	 Validation p: 0.72, r:0.06, f: 0.10
Epoch: 3 	 Training Loss: 1.142405
Epoch: 3 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 4 	 Training Loss: 1.142299
Epoch: 4 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 5 	 Training Loss: 1.142299
Epoch: 5 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 6 	 Training Loss: 1.142299
Epoch: 6 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 7 	 Training Loss: 1.142299
Epoch: 7 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 8 	 Training Loss: 1.142299
Epoch: 8 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 9 	 Training Loss: 1.142299
Epoch: 9 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 10 	 Training Loss: 1.142299
Epoch: 10 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 11 	 Training Loss: 1.142299
Epoch: 11 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 12 	 Training Loss: 1.142299
Epoch: 12 	 Validation p: 0.73, r:0.09, f: 0.15
Epoch: 13 

In [28]:
%%time
p, r, f = eval_model(coama, new_seqs, test_index)
print('Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(p, r, f))

Validation p: 0.72, r:0.08, f: 0.15
CPU times: user 7 s, sys: 21.3 s, total: 28.3 s
Wall time: 3.86 s


In [49]:
%%time
for k in [5, 10, 15, 20]:
    acc_k = eval_model_top_k(coama, new_seqs, test_index, k)
    print('Validation Accuracy@' + str(k) + ': {:.2f}'.format(acc_k))

Validation Accuracy@5: 0.09
Validation Accuracy@10: 0.09
Validation Accuracy@15: 0.08
Validation Accuracy@20: 0.09
CPU times: user 1min 20s, sys: 2min 15s, total: 3min 36s
Wall time: 29 s


In [29]:
coamb = COAMb()
optimizerb = torch.optim.Adam(coamb.parameters(), lr=1e-3)

In [30]:
%%time
n_epochs = 20
train(coamb, optimizerb, training_index, new_seqs, training_index)

Epoch: 1 	 Training Loss: 1.143332
Epoch: 1 	 Validation p: 0.72, r:0.06, f: 0.10
Epoch: 2 	 Training Loss: 1.142651
Epoch: 2 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 3 	 Training Loss: 1.142749
Epoch: 3 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 4 	 Training Loss: 1.142749
Epoch: 4 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 5 	 Training Loss: 1.142749
Epoch: 5 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 6 	 Training Loss: 1.142749
Epoch: 6 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 7 	 Training Loss: 1.142749
Epoch: 7 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 8 	 Training Loss: 1.142749
Epoch: 8 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 9 	 Training Loss: 1.142749
Epoch: 9 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 10 	 Training Loss: 1.142749
Epoch: 10 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 11 	 Training Loss: 1.142749
Epoch: 11 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 12 	 Training Loss: 1.142749
Epoch: 12 	 Validation p: 0.67, r:0.09, f: 0.15
Epoch: 13 

In [31]:
%%time
p, r, f = eval_model(coamb, new_seqs, test_index)
print('Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(p, r, f))

Validation p: 0.65, r:0.09, f: 0.15
CPU times: user 7.18 s, sys: 20.7 s, total: 27.9 s
Wall time: 3.92 s


In [50]:
%%time
for k in [5, 10, 15, 20]:
    acc_k = eval_model_top_k(coamb, new_seqs, test_index, k)
    print('Validation Accuracy@' + str(k) + ': {:.2f}'.format(acc_k))

Validation Accuracy@5: 0.07
Validation Accuracy@10: 0.09
Validation Accuracy@15: 0.09
Validation Accuracy@20: 0.09
CPU times: user 1min 21s, sys: 2min 9s, total: 3min 30s
Wall time: 29.8 s
