In [51]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "../data/"
#os.getcwd()

In [52]:
pids = pickle.load(open('pids.pkl', 'rb'))
vids = pickle.load(open('vids.pkl', 'rb'))
seqs = pickle.load(open('seqs.pkl', 'rb'))
types = pickle.load(open('types.pkl', 'rb'))
rtypes = pickle.load(open('rtypes.pkl', 'rb'))

In [53]:
len(types)

1279

In [33]:
class RNN(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, hid_dim=128):
        super(RNN, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first=True)
        self.fc = nn.Linear(hid_dim, vocab_size, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, codes):
        emb_list = []
        for code in codes:
            emb = self.embeddings(torch.tensor(code))
            emb_mean = emb.mean(dim=0).unsqueeze(dim=0)
            emb_list.append(emb_mean)
        emb_seq = torch.cat(emb_list, dim=0).unsqueeze(dim=0)
        output, _ = self.rnn(emb_seq)
        result = F.relu(output)
        result = self.fc(result)
        return result[:, -1, :]

In [34]:
import random
random.seed(seed)

training_index = random.sample(range(len(seqs)), k=int(len(seqs)*0.8))
test_index = list(set(range(len(seqs))) - set(training_index))

In [35]:
vocab_med = 1279

In [36]:
model = RNN(vocab_size=vocab_med)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [37]:
def dataFormatter(patient_list):

    x = []
    for codes in patient_list:
        x.append(codes)
    target = np.zeros((1, vocab_med))
    target[0, x[-1]] = 1
    return x[:-1], torch.FloatTensor(target)

In [38]:
from sklearn.metrics import precision_recall_fscore_support


def eval_model(model, data_index):
    model.eval()
    p, r, f = np.array([]), np.array([]), np.array([])
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                x, y = dataFormatter(patient[:idx+1])
                y_hat = model(x)
                y_hat = (y_hat > 0.5).int()
                y = y.squeeze()
                y_hat = y_hat.squeeze()
                new_p, new_r, new_f, _ = precision_recall_fscore_support(y, y_hat, average='binary', zero_division=1)
                p, r, f = np.append(p, new_p), np.append(r, new_r), np.append(f, new_f)

    return np.mean(p), np.mean(r), np.mean(f)

In [39]:
from sklearn.metrics import accuracy_score

def train(training_index, val_index):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for p_index in training_index:
            loss = 0
            patient = seqs[p_index]
            for idx, visit in enumerate(patient):
                if idx > 0:
                    x, target = dataFormatter(patient[:idx+1])
                    multi_target = np.full((1, vocab_med), -1)
                    for idx, item in enumerate(visit):
                        multi_target[0][idx] = item
                    multi_target = torch.LongTensor(multi_target)
                    pred = model(x)

                    loss += F.binary_cross_entropy_with_logits(pred, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss = loss.item()
        train_loss = train_loss / len(training_index)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval_model(model, val_index)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'
              .format(epoch+1, p, r, f))

In [40]:
def eval_model_top_k(model, data_index, k):
    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    scores = torch.LongTensor()
    model.eval()
    for p_index in data_index:
        patient = seqs[p_index]
        for idx, visit in enumerate(patient):
            if idx > 0:
                x, y = dataFormatter(patient[:idx+1])
                y_hat = model(x)
                y_hat = (y_hat > 0.5).int()
                scores = torch.cat((scores, torch.mul(y_hat, y).sum(dim=1)), 0)
                y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
                y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    top_k = torch.topk(scores, k).indices
    y_true = torch.reshape(y_true[top_k, :], (-1,))
    y_pred = torch.reshape(y_pred[top_k, :], (-1,))
    result = torch.mul(y_true, y_pred).sum(0) / y_true.sum(dim=0)
    
    return result

In [41]:
n_epochs = 20
train(training_index, training_index)

Epoch: 1 	 Training Loss: 0.000038
Epoch: 1 	 Validation p: 0.77, r:0.20, f: 0.30
Epoch: 2 	 Training Loss: 0.000038
Epoch: 2 	 Validation p: 0.77, r:0.25, f: 0.35
Epoch: 3 	 Training Loss: 0.000038
Epoch: 3 	 Validation p: 0.77, r:0.27, f: 0.37
Epoch: 4 	 Training Loss: 0.000038
Epoch: 4 	 Validation p: 0.78, r:0.28, f: 0.39
Epoch: 5 	 Training Loss: 0.000038
Epoch: 5 	 Validation p: 0.79, r:0.29, f: 0.40
Epoch: 6 	 Training Loss: 0.000037
Epoch: 6 	 Validation p: 0.79, r:0.30, f: 0.41
Epoch: 7 	 Training Loss: 0.000037
Epoch: 7 	 Validation p: 0.80, r:0.31, f: 0.42
Epoch: 8 	 Training Loss: 0.000037
Epoch: 8 	 Validation p: 0.80, r:0.32, f: 0.43
Epoch: 9 	 Training Loss: 0.000037
Epoch: 9 	 Validation p: 0.81, r:0.33, f: 0.44
Epoch: 10 	 Training Loss: 0.000036
Epoch: 10 	 Validation p: 0.81, r:0.34, f: 0.45
Epoch: 11 	 Training Loss: 0.000037
Epoch: 11 	 Validation p: 0.81, r:0.35, f: 0.47
Epoch: 12 	 Training Loss: 0.000036
Epoch: 12 	 Validation p: 0.82, r:0.36, f: 0.48
Epoch: 13 

In [42]:
p, r, f = eval_model(model, test_index)
print('Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(p, r, f))

Validation p: 0.65, r:0.33, f: 0.40


In [48]:
%%time
for k in [5, 10, 15, 20]:
    acc_k = eval_model_top_k(model, test_index, k)
    print('Validation Accuracy@' + str(k) + ': {:.2f}'.format(acc_k))

Validation Accuracy@5: 0.55
Validation Accuracy@10: 0.53
Validation Accuracy@15: 0.53
Validation Accuracy@20: 0.55
Wall time: 1min 6s


In [49]:
acc_k = eval_model_top_k(model, test_index, 3)
print('Validation acc_k: {:.2f}'.format(acc_k))

Validation acc_k: 0.59
