In [1]:
"""
File: Bi-LSTM.ipynb
Code to train and evaluate a bi-directional LSTM model on MIMIC-IV FHIR dataset.
"""

def Project():
    """
    __Objectives__
    0. Import data and tokenizer
    1. Train the tokenizer on all sequences of the dataset
    2. Tokenize different sequences and join them together
    >>> 3. Prepare actual labels for one, six, twelve month death after discharge
    >>> 4. Define the model architecture for bidrectional LSTM
    >>> 5. Train Bi-LSTM model and evaluate on test dataset
    >>> 6. Compare performance across new tasks to XGBoost

    __Questions__
    0.

    __Extra__
    Careful with tokenizing all sequences as it could tricky!

    """
    return ProjectObjectives.__doc__

In [2]:
import os; ROOT = 'E:\Vector Institute\odyssey'; os.chdir(ROOT)
import scipy, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, StandardScaler, MaxAbsScaler
from sklearn.model_selection import train_test_split, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score
from sklearn.metrics import f1_score, roc_curve, auc, precision_recall_curve, roc_auc_score, average_precision_score
from scipy.sparse import csr_matrix, hstack, vstack, save_npz, load_npz

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import relu, leaky_relu, sigmoid
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR

from models.cehr_bert.data import PretrainDataset, FinetuneDataset
from models.cehr_bert.model import BertPretrain
from models.cehr_bert.tokenizer import ConceptTokenizer
from models.cehr_bert.embeddings import Embeddings

from tqdm import tqdm
%matplotlib inline

DATA_ROOT = f'{ROOT}/data'
DATA_PATH = f'{DATA_ROOT}/patient_sequences.parquet'
SAMPLE_DATA_PATH = f'{DATA_ROOT}/CEHR-BERT_sample_patient_sequence.parquet'
FREQ_DF_PATH = f'{DATA_ROOT}/patient_feature_freq.csv'
FREQ_MATRIX_PATH = f'{DATA_ROOT}/patient_freq_matrix.npz'



In [3]:
# save parameters and configurations
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    max_len = 500
    batch_size = 8
    num_workers = 2
    vocab_size = None
    embedding_size = 128
    time_embeddings_size = 16
    max_seq_length = 512
    device = torch.device('cuda')

def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # pl.seed_everything(seed)

seed_all(config.seed)
torch.cuda.get_device_name(torch.cuda.current_device())

'NVIDIA GeForce GTX 1650 Ti'

In [4]:
# Load data
data = pd.read_parquet(DATA_PATH)
data.rename(columns={'event_tokens': 'event_tokens_untruncated', 'event_tokens_updated': 'event_tokens'}, inplace=True)
data

Unnamed: 0,patient_id,num_visits,label,death_after_start,death_after_end,length,token_length,new_start,event_tokens_untruncated,event_tokens,age_tokens,time_tokens,visit_tokens,position_tokens
0,be7990af-3829-5df0-b552-c397a71d46fe,3,0,,,217,225,,"[VS, 4443, 00338004304, 00006473900, 000935211...","[VS, 4443, 00338004304, 00006473900, 000935211...","[66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 6...","[8000, 8000, 8000, 8000, 8000, 8000, 8000, 800...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,877d281b-676b-53ab-9911-1e4677989f6f,1,0,,,18,20,,"[VS, 741, 00182864389, 00904585461, 0070345020...","[VS, 741, 00182864389, 00904585461, 0070345020...","[37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 3...","[5085, 5085, 5085, 5085, 5085, 5085, 5085, 508...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,65ae1ba2-dede-53a4-80be-3d0666b27e87,1,0,,,40,42,,"[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 2...","[8787, 8787, 8787, 8787, 8787, 8787, 8787, 878...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,aa1446f6-dbc4-5734-9645-a1e01a7ba6f0,1,0,,,18,20,,"[VS, 0689, 33332000801, 00056017075, 655970103...","[VS, 0689, 33332000801, 00056017075, 655970103...","[77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 7...","[4853, 4853, 4853, 4853, 4853, 4853, 4853, 485...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,b3c303cc-df8c-5789-80f0-83f1c319b813,1,1,22.0,17.0,81,83,,"[VS, 7935, 00338067104, 00054855324, 009045165...","[VS, 7935, 00338067104, 00054855324, 009045165...","[62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 6...","[6037, 6037, 6037, 6037, 6037, 6037, 6037, 603...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90273,88ae054e-0173-5049-b067-a67bad1aeee9,1,0,,,36,38,,"[VS, 7931, 00075062041, 00023050601, 005363381...","[VS, 7931, 00075062041, 00023050601, 005363381...","[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 3...","[8788, 8788, 8788, 8788, 8788, 8788, 8788, 878...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
90274,3b6ec88d-59a8-5833-8977-48e8b58211b1,1,0,,,17,19,,"[VS, 00338067104, 51079045620, 66553000401, 00...","[VS, 00338067104, 51079045620, 66553000401, 00...","[68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 6...","[5353, 5353, 5353, 5353, 5353, 5353, 5353, 535...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
90275,b883470b-664e-5f0e-b38c-717cd5b07b84,1,1,3.0,0.0,152,154,,"[VS, 5503, 00338004904, 00006494300, 001828447...","[VS, 5503, 00338004904, 00006494300, 001828447...","[81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 8...","[7450, 7450, 7450, 7450, 7450, 7450, 7450, 745...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
90277,c946654b-2765-5dc1-8cd4-9865d3c84d30,2,0,,,46,51,,"[VS, 51079088120, 51079088120, 68084025401, 00...","[VS, 51079088120, 51079088120, 68084025401, 00...","[45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 4...","[4850, 4850, 4850, 4850, 4850, 4850, 4850, 485...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [5]:
# Define custom labels, here death in 12 M
data['label'] = ((data['death_after_end'] > 0) & (data['death_after_end'] < 365)).astype(int)

In [6]:
# Fit tokenizer on .json vocab files
tokenizer = ConceptTokenizer(data_dir=config.data_dir)
tokenizer.fit_on_vocab()
config.vocab_size = tokenizer.get_vocab_size()
tokenizer

<models.cehr_bert.tokenizer.ConceptTokenizer at 0x1c337e71f10>

In [61]:
# Get training and test datasets

train_data, test_data = train_test_split(
            data[:100],
            test_size=config.test_size,
            random_state=config.seed,
            stratify=data['label'][:100]
)

train_dataset = FinetuneDataset(
        data=train_data,
        tokenizer=tokenizer,
        max_len=config.max_len,
)

test_dataset = FinetuneDataset(
        data=test_data,
        tokenizer=tokenizer,
        max_len=config.max_len,
)

train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
)

test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
)

Unnamed: 0,patient_id,num_visits,label,death_after_start,death_after_end,length,token_length,new_start,event_tokens_untruncated,event_tokens,age_tokens,time_tokens,visit_tokens,position_tokens
0,be7990af-3829-5df0-b552-c397a71d46fe,3,0,,,217,225,,"[VS, 4443, 00338004304, 00006473900, 000935211...","[VS, 4443, 00338004304, 00006473900, 000935211...","[66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 6...","[8000, 8000, 8000, 8000, 8000, 8000, 8000, 800...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,877d281b-676b-53ab-9911-1e4677989f6f,1,0,,,18,20,,"[VS, 741, 00182864389, 00904585461, 0070345020...","[VS, 741, 00182864389, 00904585461, 0070345020...","[37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 3...","[5085, 5085, 5085, 5085, 5085, 5085, 5085, 508...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,65ae1ba2-dede-53a4-80be-3d0666b27e87,1,0,,,40,42,,"[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 2...","[8787, 8787, 8787, 8787, 8787, 8787, 8787, 878...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,aa1446f6-dbc4-5734-9645-a1e01a7ba6f0,1,0,,,18,20,,"[VS, 0689, 33332000801, 00056017075, 655970103...","[VS, 0689, 33332000801, 00056017075, 655970103...","[77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 7...","[4853, 4853, 4853, 4853, 4853, 4853, 4853, 485...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,b3c303cc-df8c-5789-80f0-83f1c319b813,1,1,22.0,17.0,81,83,,"[VS, 7935, 00338067104, 00054855324, 009045165...","[VS, 7935, 00338067104, 00054855324, 009045165...","[62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 6...","[6037, 6037, 6037, 6037, 6037, 6037, 6037, 603...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
101,79740e0e-6da1-5483-bf25-225354e249cc,3,0,,,1198,1206,695.0,"[VS, 0DBN0ZZ, 0D1M0Z4, 0D1M0Z4, 0DN80ZZ, 00338...","[VS, 51237_3, 51274_3, 51275_4, 50868_1, 50882...","[78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 7...","[7066, 7066, 7066, 7066, 7066, 7066, 7066, 706...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
102,e7cc1bc6-9e86-54b2-9d00-e23eae8d4459,1,0,,,53,55,,"[VS, 00781305714, 00406055262, 33332001101, 51...","[VS, 00781305714, 00406055262, 33332001101, 51...","[56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 5...","[8704, 8704, 8704, 8704, 8704, 8704, 8704, 870...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
103,c1c54a3d-43e7-5e63-841a-c349708a548d,3,0,,,245,253,,"[VS, 00904224461, 00182844789, 00074568113, 00...","[VS, 00904224461, 00182844789, 00074568113, 00...","[83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 8...","[7519, 7519, 7519, 7519, 7519, 7519, 7519, 751...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
104,1a52e3fa-4453-5cc7-b50d-c9276032b6bd,1,0,,,598,600,89.0,"[VS, 009U3ZX, 00904198861, 49999002812, 003784...","[VS, 50960_0, 50931_0, 50912_4, 50902_0, 50893...","[43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 4...","[7153, 7153, 7153, 7153, 7153, 7153, 7153, 715...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [55]:
# Define model architecture

class BiLSTMModel(nn.Module):

    def __init__(self, embedding_dim, hidden_size, num_layers, output_size, dropout_rate=0.5):
        super(BiLSTMModel, self).__init__()

        self.embeddings = Embeddings(
            vocab_size=config.vocab_size,
            embedding_size=config.embedding_size,
            time_embedding_size=config.time_embeddings_size,
            max_len=config.max_seq_length)

        self.lstm = nn.LSTM(input_size=embedding_dim,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True,
                            dropout=dropout_rate)

        self.batch_norm = nn.BatchNorm1d(hidden_size * 2)
        self.dropout = nn.Dropout(dropout_rate)
        self.linear1 = nn.Linear(hidden_size * 2, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)


    def forward(self, inputs):
        x = self.embeddings(*inputs)
        lstm_out, _ = self.lstm(x)
        output = lstm_out[:, -1, :]
        output = self.batch_norm(output)
        output = self.dropout(output)
        output = relu(self.linear1(output))
        output = self.linear2(output)
        return output


    @staticmethod
    def get_inputs_labels(batch):
        labels = batch['labels'].view(-1, 1).to(config.device)
        inputs = batch['concept_ids'].to(config.device),\
                 batch['time_stamps'].to(config.device),\
                 batch['ages'].to(config.device),\
                 batch['visit_orders'].to(config.device),\
                 batch['visit_segments'].to(config.device)

        return inputs, labels.float()

In [56]:
# Set hyperparameters for Bi-LSTM model adn training loop
input_size = config.embedding_size              # embedding_dim
hidden_size = config.embedding_size             # output hidden size
num_layers = 5                                  # number of LSTM layers
output_size = 1                                 # Binary classification, so output size is 1
dropout_rate = 0.5                              # Dropout rate for regularization

epochs = 10
learning_rate = 0.002

In [60]:
# Training Loop
model = BiLSTMModel(input_size, hidden_size, num_layers, output_size, dropout_rate).to(config.device)
loss_fcn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ExponentialLR(optimizer, gamma=0.8, verbose=True)


for epoch in range(epochs):
    train_total_loss = 0; train_accuracy = 0; test_accuracy = 0


    model.train()
    for batch_no, batch in tqdm(enumerate(train_loader),
                                total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', unit=' batch'):

        inputs, labels = model.get_inputs_labels(batch)
        outputs = model(inputs)
        loss = loss_fcn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_total_loss += loss.item()
        # print(f'Batch Loss: {loss.item():.4f}', end='\r')


    # model.eval()
    # with torch.no_grad():
    #     for batch_no, batch in tqdm(enumerate(train_loader),
    #                                 total=len(train_loader), desc=f'Train Evaluation {epoch+1}/{epochs}', unit=' batch'):
    #
    #         inputs, labels = model.get_inputs_labels(batch)
    #         outputs = model(inputs)
    #         predictions = torch.round(sigmoid(outputs))
    #         train_accuracy += balanced_accuracy_score(labels, predictions)
    #
    #
    #     for batch_no, batch in tqdm(enumerate(test_loader),
    #                                 total=len(test_loader), desc=f'Test Evaluation {epoch+1}/{epochs}', unit=' batch'):
    #
    #         inputs, labels = model.get_inputs_labels(batch)
    #         outputs = model(inputs)
    #         predictions = torch.round(sigmoid(outputs))
    #         test_accuracy += balanced_accuracy_score(labels, predictions)


    print(f'Average Train Loss: {train_total_loss/len(train_loader):.5f}  |  Last Batch Train Loss: {loss.item()}  |  '
          f'Train Accuracy: {train_accuracy/len(train_loader)}  |  Test Accuracy: {test_accuracy/len(test_loader)}')
    scheduler.step()

({'concept_ids': tensor([    1,  5116,  5975,  4757,  6787,  4757,  2980,  5975,  1743,  2566,
            542,  2909,  2329,   634,  1753,  1204,   433,   188,  1625,   722,
            896,   135,   379,  1711,  1571,  2786,  1640,   742,   605,  2882,
              1, 20586,     1,  6269,  4726,  3688,  5428,  4450,  3494,  5975,
           3639,  5715,  6889,  3345,  5369,  5255,  4450,  5979,  3529,  4944,
           2980,  4651,  1352,  2369,  1069,    56,   722,   865,   896,   390,
           1261,  1711,  1571,  1990,  1749,   891,   605,   196,   866,  2174,
           1064,  1804,  1976,   105,   245,  2666,  1905,   188,   134,  2912,
           2592,   379,  1711,  1571,  1749,  1442,   891,  1131,  1905,   188,
            685,  2230,  1198,  1064,  1804,  2132,   105,   245,  1204,  2268,
           1428,  1601,  1348,   896,  1716,   379,  1711,   480,  1749,  2234,
           1725,   605,   196,  2230,  2174,  1064,  1804,  2132,   105,   245,
           1204,     1, 2

In [13]:
# jupyter
# reevaluate xgboost with new label and send it github

AttributeError: 'DatasetWithTokenLength' object has no attribute 'sort'