In [21]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)


In [22]:
seqs = pickle.load(open('seqs.pkl', 'rb'))

In [23]:
class AlphaAttention(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        alpha = torch.softmax(self.a_att(g), dim=1)
        return alpha

In [24]:
class BetaAttention(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        beta = torch.tanh(self.b_att(h))
        return beta

In [25]:
class RETAIN(nn.Module):
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.att_a = AlphaAttention(embedding_dim)
        self.att_b = BetaAttention(embedding_dim)
        self.fc = nn.Linear(embedding_dim, num_codes)
        self.sigmod = nn.Sigmoid()

    def forward(self, x):
        rev_x = list(reversed(x))
        emb_list = []
        for code in rev_x:
            emb = self.embedding(torch.tensor(code))
            emb_mean = emb.mean(dim=0).unsqueeze(dim=0)
            emb_list.append(emb_mean)
        rev_x = torch.cat(emb_list, dim=0).unsqueeze(dim=0)

        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        alpha = self.att_a(g)
        beta = self.att_b(h)
        weights = torch.mul(alpha, beta)
        c = torch.mul(weights, rev_x).sum(dim=1)
        logits = self.fc(c)
        return self.sigmod(logits)

In [26]:
import random
random.seed(seed)

training_index = random.sample(range(len(seqs)), k=int(len(seqs)*0.8))
test_index = list(set(range(len(seqs))) - set(training_index))

In [27]:
vocab_med = 1279

In [28]:
model = RETAIN(num_codes=vocab_med)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [29]:
def dataFormatter(patient_list):
    x = []
    for codes in patient_list:
        x.append(codes)
    target = np.zeros((1, vocab_med))
    target[0, x[-1]] = 1
    return x[:-1], torch.FloatTensor(target)

In [30]:
# from sklearn.metrics import precision_recall_fscore_support


# def eval_model(model, data_index):
#     model.eval()
#     y_pred = torch.LongTensor()
#     y_true = torch.LongTensor()
    
#     model.eval()
#     for p_index in data_index:
#         patient = seqs[p_index]
#         for idx, visit in enumerate(patient):
#             if idx > 0:
#                 x, y = dataFormatter(patient[:idx+1])
#                 y_hat = model(x)
#                 y_hat = (y_hat > 0.5).int()
#                 y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
#                 y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
#     p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='samples')

#     return p, r, f

In [31]:
from sklearn.metrics import precision_recall_fscore_support


def eval_model(model, data_index):
    model.eval()
    p, r, f = np.array([]), np.array([]), np.array([])
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                x, y = dataFormatter(patient[:idx+1])
                y_hat = model(x)
                y_hat = (y_hat > 0.5).int()
                y = y.squeeze()
                y_hat = y_hat.squeeze()
                new_p, new_r, new_f, _ = precision_recall_fscore_support(y, y_hat, average='binary', zero_division=1)
                p, r, f = np.append(p, new_p), np.append(r, new_r), np.append(f, new_f)

    return np.mean(p), np.mean(r), np.mean(f)

In [32]:
from sklearn.metrics import accuracy_score

def train(training_index, val_index):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for p_index in training_index:
            loss = 0
            patient = seqs[p_index]
            for idx, visit in enumerate(patient):
                if idx > 0:
                    x, target = dataFormatter(patient[:idx+1])
                    multi_target = np.full((1, vocab_med), -1)
                    for idx, item in enumerate(visit):
                        multi_target[0][idx] = item
                    multi_target = torch.LongTensor(multi_target)
                    pred = model(x)
                    loss += F.binary_cross_entropy_with_logits(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss = loss.item()
        train_loss = train_loss / len(training_index)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval_model(model, val_index)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'
              .format(epoch+1, p, r, f))

In [33]:
def eval_model_top_k(model, data_index, k):
    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    scores = torch.LongTensor()
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                x, y = dataFormatter(patient[:idx+1])
                y_hat = model(x)
                y_hat = (y_hat > 0.5).int()
                scores = torch.cat((scores, torch.mul(y_hat, y).sum(dim=1)), 0)
                y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
                y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    top_k = torch.topk(scores, k).indices
    y_true = torch.reshape(y_true[top_k, :], (-1,))
    y_pred = torch.reshape(y_pred[top_k, :], (-1,))
    result = torch.mul(y_true, y_pred).sum(0) / y_true.sum(dim=0)
    
    return result

In [34]:
n_epochs = 20
train(training_index, training_index)

Epoch: 1 	 Training Loss: 0.000345
Epoch: 1 	 Validation p: 0.73, r:0.18, f: 0.27
Epoch: 2 	 Training Loss: 0.000345
Epoch: 2 	 Validation p: 0.76, r:0.18, f: 0.28
Epoch: 3 	 Training Loss: 0.000345
Epoch: 3 	 Validation p: 0.77, r:0.18, f: 0.28
Epoch: 4 	 Training Loss: 0.000345
Epoch: 4 	 Validation p: 0.78, r:0.19, f: 0.29
Epoch: 5 	 Training Loss: 0.000345
Epoch: 5 	 Validation p: 0.78, r:0.20, f: 0.30
Epoch: 6 	 Training Loss: 0.000345
Epoch: 6 	 Validation p: 0.79, r:0.20, f: 0.30
Epoch: 7 	 Training Loss: 0.000345
Epoch: 7 	 Validation p: 0.79, r:0.21, f: 0.31
Epoch: 8 	 Training Loss: 0.000345
Epoch: 8 	 Validation p: 0.79, r:0.21, f: 0.32
Epoch: 9 	 Training Loss: 0.000345
Epoch: 9 	 Validation p: 0.79, r:0.22, f: 0.32
Epoch: 10 	 Training Loss: 0.000345
Epoch: 10 	 Validation p: 0.80, r:0.22, f: 0.32
Epoch: 11 	 Training Loss: 0.000345
Epoch: 11 	 Validation p: 0.80, r:0.23, f: 0.33
Epoch: 12 	 Training Loss: 0.000345
Epoch: 12 	 Validation p: 0.80, r:0.23, f: 0.33
Epoch: 13 

In [35]:
p, r, f = eval_model(model, test_index)
print('Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(p, r, f))

Validation p: 0.74, r:0.22, f: 0.32


In [38]:
%%time
for k in [5, 10, 15, 20]:
    acc_k = eval_model_top_k(model, test_index, k)
    print('Validation Accuracy@' + str(k) + ': {:.2f}'.format(acc_k))

Validation Accuracy@5: 0.36
Validation Accuracy@10: 0.32
Validation Accuracy@15: 0.32
Validation Accuracy@20: 0.32
Wall time: 1min 8s


In [39]:
acc_k = eval_model_top_k(model, test_index, 3)
print('Validation acc_k: {:.2f}'.format(acc_k))

Validation acc_k: 0.34
