In [None]:
"""
File: Bi-LSTM.ipynb
Code to train and evaluate a bi-directional LSTM model on MIMIC-IV FHIR dataset.
"""

def Project():
    """
    __Objectives__
    >>> 0. Import data and tokenizer
    >>> 1. Train the tokenizer on all sequences of the dataset
    >>> 2. Tokenize different sequences and join them together
    >>> 3. Prepare actual labels for one, six, twelve month death after discharge
    >>> 4. Define the model architecture for bidrectional LSTM
    >>> 5. Train Bi-LSTM model and evaluate on test dataset
    >>> 6. Compare performance across new tasks to XGBoost

    __Questions__
    0.

    __Extra__
    Careful with tokenizing all sequences as it could tricky!

    """
    return ProjectObjectives.__doc__

In [None]:
ROOT = 'E:\Vector Institute\odyssey'
import os; os.chdir(ROOT)
import scipy, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, StandardScaler, MaxAbsScaler
from sklearn.model_selection import train_test_split, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score
from sklearn.metrics import f1_score, roc_curve, auc, precision_recall_curve, roc_auc_score, average_precision_score
from scipy.sparse import csr_matrix, hstack, vstack, save_npz, load_npz

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import relu, leaky_relu
from torch.utils.data import Dataset, TensorDataset, DataLoader

from models.cehr_bert.data import PretrainDataset, FinetuneDataset
from models.cehr_bert.model import BertPretrain
from models.cehr_bert.tokenizer import ConceptTokenizer
from models.cehr_bert.embeddings import Embeddings

from tqdm import tqdm
%matplotlib inline

DATA_ROOT = f'{ROOT}/data'
DATA_PATH = f'{DATA_ROOT}/patient_sequences.parquet'
SAMPLE_DATA_PATH = f'{DATA_ROOT}/CEHR-BERT_sample_patient_sequence.parquet'
FREQ_DF_PATH = f'{DATA_ROOT}/patient_feature_freq.csv'
FREQ_MATRIX_PATH = f'{DATA_ROOT}/patient_freq_matrix.npz'

In [23]:
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    max_len = 500
    batch_size = 16
    num_workers = 2
    vocab_size = None
    embedding_size = 128
    time_embeddings_size = 16
    max_seq_length = 512
    device = torch.device('cuda')

def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # pl.seed_everything(seed)

seed_all(config.seed)
torch.cuda.get_device_name(torch.cuda.current_device())

'NVIDIA GeForce GTX 1650 Ti'

In [24]:
# Load data
data = pd.read_parquet(DATA_PATH)
data.rename(columns={'event_tokens': 'event_tokens_untruncated', 'event_tokens_updated': 'event_tokens'}, inplace=True)
data

Unnamed: 0,patient_id,num_visits,label,death_after_start,death_after_end,length,token_length,new_start,event_tokens_untruncated,event_tokens,age_tokens,time_tokens,visit_tokens,position_tokens
0,be7990af-3829-5df0-b552-c397a71d46fe,3,0,,,217,225,,"[VS, 4443, 00338004304, 00006473900, 000935211...","[VS, 4443, 00338004304, 00006473900, 000935211...","[66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 6...","[8000, 8000, 8000, 8000, 8000, 8000, 8000, 800...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,877d281b-676b-53ab-9911-1e4677989f6f,1,0,,,18,20,,"[VS, 741, 00182864389, 00904585461, 0070345020...","[VS, 741, 00182864389, 00904585461, 0070345020...","[37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 3...","[5085, 5085, 5085, 5085, 5085, 5085, 5085, 508...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,65ae1ba2-dede-53a4-80be-3d0666b27e87,1,0,,,40,42,,"[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[VS, 51248_2, 51736_2, 51244_3, 51222_4, 51737...","[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 2...","[8787, 8787, 8787, 8787, 8787, 8787, 8787, 878...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,aa1446f6-dbc4-5734-9645-a1e01a7ba6f0,1,0,,,18,20,,"[VS, 0689, 33332000801, 00056017075, 655970103...","[VS, 0689, 33332000801, 00056017075, 655970103...","[77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 7...","[4853, 4853, 4853, 4853, 4853, 4853, 4853, 485...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,b3c303cc-df8c-5789-80f0-83f1c319b813,1,1,22.0,17.0,81,83,,"[VS, 7935, 00338067104, 00054855324, 009045165...","[VS, 7935, 00338067104, 00054855324, 009045165...","[62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 6...","[6037, 6037, 6037, 6037, 6037, 6037, 6037, 603...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90273,88ae054e-0173-5049-b067-a67bad1aeee9,1,0,,,36,38,,"[VS, 7931, 00075062041, 00023050601, 005363381...","[VS, 7931, 00075062041, 00023050601, 005363381...","[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 3...","[8788, 8788, 8788, 8788, 8788, 8788, 8788, 878...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
90274,3b6ec88d-59a8-5833-8977-48e8b58211b1,1,0,,,17,19,,"[VS, 00338067104, 51079045620, 66553000401, 00...","[VS, 00338067104, 51079045620, 66553000401, 00...","[68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 6...","[5353, 5353, 5353, 5353, 5353, 5353, 5353, 535...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
90275,b883470b-664e-5f0e-b38c-717cd5b07b84,1,0,3.0,0.0,152,154,,"[VS, 5503, 00338004904, 00006494300, 001828447...","[VS, 5503, 00338004904, 00006494300, 001828447...","[81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 8...","[7450, 7450, 7450, 7450, 7450, 7450, 7450, 745...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
90277,c946654b-2765-5dc1-8cd4-9865d3c84d30,2,0,,,46,51,,"[VS, 51079088120, 51079088120, 68084025401, 00...","[VS, 51079088120, 51079088120, 68084025401, 00...","[45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 4...","[4850, 4850, 4850, 4850, 4850, 4850, 4850, 485...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [25]:
# data['position_tokens'] = data['position_tokens'].apply(lambda x: np.zeros_like(x))
data['label'] = ((data['death_after_end'] > 0) & (data['death_after_end'] < 365)).astype(int)

In [26]:
tokenizer = ConceptTokenizer(data_dir=config.data_dir)
tokenizer.fit_on_vocab()
config.vocab_size = tokenizer.get_vocab_size()
tokenizer

<models.cehr_bert.tokenizer.ConceptTokenizer at 0x26cf3fdf650>

In [29]:
train_data, test_data = train_test_split(
            data,
            test_size=config.test_size,
            random_state=config.seed,
            stratify=data['label']
        )

train_dataset = FinetuneDataset(
        data=train_data,
        tokenizer=tokenizer,
        max_len=config.max_len,
)

test_dataset = FinetuneDataset(
        data=test_data,
        tokenizer=tokenizer,
        max_len=config.max_len,
)

train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
)

test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
)

In [30]:
class BiLSTMModel(nn.Module):

    def __init__(self, embedding_dim, hidden_size, num_layers, output_size, dropout_rate=0.5):
        super(BiLSTMModel, self).__init__()

        self.embeddings = Embeddings(
            vocab_size=config.vocab_size,
            embedding_size=config.embedding_size,
            time_embedding_size=config.time_embeddings_size,
            max_len=config.max_seq_length)

        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.batch_norm = nn.BatchNorm1d(hidden_size * 2)
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(hidden_size * 2, output_size)  # Multiply by 2 for bidirectional


    def forward(self, inputs):
        x = self.embeddings(*inputs)
        lstm_out, _ = self.lstm(x)
        output = lstm_out[:, -1, :]
        output = self.batch_norm(output)
        output = self.dropout(output)
        output = self.linear(output)
        return output

In [31]:
# Set hyperparameters for Bi-LSTM model
input_size = config.embedding_size              # embedding_dim
hidden_size = config.embedding_size // 2
num_layers = 5                                  # Number of LSTM layers
output_size = 1                                 # Binary classification, so output size is 1
dropout_rate = 0.5                              # Dropout rate for regularization

num_epochs = 10
learning_rate = 0.005

In [32]:
model = BiLSTMModel(input_size, hidden_size, num_layers, output_size, dropout_rate).to(config.device)
loss_fcn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(num_epochs):
    train_total_loss = 0.0

    for batch_no, batch in tqdm(enumerate(train_loader),
                                           total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}', unit=' batch'):

        labels = batch['labels'].view(-1, 1).to(config.device)
        inputs = batch['concept_ids'].to(config.device),\
                 batch['time_stamps'].to(config.device),\
                 batch['ages'].to(config.device),\
                 batch['visit_orders'].to(config.device),\
                 batch['visit_segments'].to(config.device)

        outputs = model(inputs)
        loss = loss_fcn(outputs, labels.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_total_loss += loss.item()
        # print(f'Batch Loss: {loss.item():.4f}', end='\r')

    print(f'Epoch {epoch+1}/{num_epochs}  |  Average Train Loss: {train_total_loss/len(train_loader):.5f}\n')

Epoch 1/10: 100%|██████████| 8684/8684 [09:59<00:00, 14.49 batch/s]


Epoch 1/10  |  Average Train Loss: 0.30950



Epoch 2/10: 100%|██████████| 8684/8684 [12:10<00:00, 11.90 batch/s]


Epoch 2/10  |  Average Train Loss: 0.27158



KeyboardInterrupt: 

In [None]:
# Test check in loop
# jupyter + cluster setup
# reevaluate xgboost with new label and send it github