In [1]:
import os; ROOT = '/fs01/home/afallah/odyssey/slurm'; os.chdir(ROOT)
import sys
import scipy, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler, StandardScaler, MaxAbsScaler
from sklearn.model_selection import train_test_split, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score
from sklearn.metrics import f1_score, roc_curve, auc, precision_recall_curve, roc_auc_score, average_precision_score
from scipy.sparse import csr_matrix, hstack, vstack, save_npz, load_npz

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import relu, leaky_relu, sigmoid
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from torch.nn.utils.rnn import pack_padded_sequence

from models.cehr_bert.data import PretrainDataset, FinetuneDataset
from models.cehr_bert.model import BertPretrain
# from models.cehr_bert.tokenizer import ConceptTokenizer
from models.cehr_bert.embeddings import Embeddings

import glob, json, random, glob
from random import randint
from typing import Sequence, Union
# from keras.preprocessing.text import Tokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers
from itertools import chain

from tqdm import tqdm
%matplotlib inline

DATA_ROOT = f'{ROOT}/data'
DATA_PATH = f'{DATA_ROOT}/patient_sequences.parquet'
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[REG]", "[MASK]", "[VS]", "[VE]"] +\
                 [f'[W_{i}]' for i in range(0, 4)] + [f'[M_{i}]' for i in range(0, 13)] + ['[LT]']

2024-02-01 12:28:43.625583: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-01 12:28:43.680182: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-01 12:28:43.680232: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-01 12:28:43.682207: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-01 12:28:43.693093: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-01 12:28:43.694537: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [2]:
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    batch_size = 64
    num_workers = 2
    vocab_size = None
    embedding_size = 128
    time_embeddings_size = 16
    max_len = 512
    device = torch.device('cuda')

In [3]:
# # Load data
# data = pd.read_parquet(DATA_PATH)
# data.rename(columns={'event_tokens': 'event_tokens_untruncated', 'event_tokens_updated': 'event_tokens'}, inplace=True)
# data

In [4]:
### Define random dataset using the actual vocabulary ###
lab_dict = ["50995_4", "51429_0", "51009_4", "50908_4", "51564_0", "52301_3", "51044_3", "50841_4", "52178_0", "51512_0"]
procedure_dict = ["00641601310", "00069419068", "00078037742", "63323027205", "00078059620", "00007550040", "51079017220", "68084061221"]
diagnosis_dict = ["02H64JZ", "7906", "8222", "0JC00ZZ", "0YHM0YZ", "9652", "03743D6", "100", "0B998ZZ", "8127"]
time_dict = ['[W_1]', '[W_3]', '[M_2]', '[M_5]', '[LT]']


def generate_random_events(num_visits, event_dict, time_tokens):
    patient = []
    length_visits = [randint(1, 6) for _ in range(num_visits)]

    for i in range(num_visits):
        patient.append('[VS]')
        length_visit = length_visits[i]
        random_events = [event_dict[randint(0, len(event_dict)-1)] for _ in range(length_visit)]
        patient += random_events
        patient.append('[VE]')

        if i < num_visits - 1:
            patient.append(time_tokens[i])

        patient.append('[REG]')

    return patient


def generate_random_patient(lab_dict, procedure_dict, diagnosis_dict, time_dict):
    num_visits = randint(1, 5)
    time_tokens = [time_dict[randint(0, len(time_dict)-1)] for _ in range(num_visits)]

    random_lab = generate_random_events(num_visits, lab_dict, time_tokens)
    random_procedure = generate_random_events(num_visits, procedure_dict, time_tokens)
    random_diagnosis = ['[CLS]'] + generate_random_events(num_visits, diagnosis_dict, time_tokens)

    prior_vs_diagnosis = [diagnosis_dict[randint(0, len(diagnosis_dict)-1)] for _ in range(randint(0, 5))]
    random_diagnosis = [random_diagnosis[0]] + prior_vs_diagnosis + random_diagnosis[1:]

    return {"diagnosis":random_diagnosis, "procedure":random_procedure, "lab":random_lab}


def generate_random_dataset(lab_dict, procedure_dict, diagnosis_dict, time_dict, num_patients=10):
    patients = []

    for i in range(num_patients):
        patient = generate_random_patient(lab_dict, procedure_dict, diagnosis_dict, time_dict)
        patients.append(patient)

    return patients


# Assume these are already truncated
patients = generate_random_dataset(lab_dict, procedure_dict, diagnosis_dict, time_dict)
patients = pd.DataFrame(patients)
patients

Unnamed: 0,diagnosis,procedure,lab
0,"[[CLS], 02H64JZ, 8222, [VS], 100, 0JC00ZZ, 100...","[[VS], 51079017220, 00069419068, 63323027205, ...","[[VS], 51044_3, 51429_0, 52301_3, 52178_0, 509..."
1,"[[CLS], 8222, 8127, [VS], 8127, [VE], [LT], [R...","[[VS], 00078059620, 00641601310, [VE], [LT], [...","[[VS], 51044_3, 50908_4, 50908_4, 50841_4, [VE..."
2,"[[CLS], 0JC00ZZ, 0YHM0YZ, [VS], 8222, 7906, [V...","[[VS], 00078059620, 68084061221, [VE], [W_3], ...","[[VS], 51009_4, 51564_0, 50995_4, 50995_4, [VE..."
3,"[[CLS], 0JC00ZZ, 8127, [VS], 9652, 0JC00ZZ, 81...","[[VS], 63323027205, 00641601310, 00069419068, ...","[[VS], 51009_4, 52301_3, 51564_0, 51044_3, [VE..."
4,"[[CLS], 7906, 7906, [VS], 0YHM0YZ, 100, 0YHM0Y...","[[VS], 68084061221, 51079017220, 00641601310, ...","[[VS], 52301_3, 50841_4, 50908_4, 51009_4, 510..."
5,"[[CLS], 7906, 0JC00ZZ, 7906, 100, [VS], 7906, ...","[[VS], 00078059620, [VE], [M_2], [REG], [VS], ...","[[VS], 52301_3, 51512_0, [VE], [M_2], [REG], [..."
6,"[[CLS], [VS], 0B998ZZ, 02H64JZ, 0JC00ZZ, [VE],...","[[VS], 00078059620, 68084061221, 51079017220, ...","[[VS], 51044_3, 50841_4, 51512_0, 51429_0, 509..."
7,"[[CLS], 9652, 9652, 0YHM0YZ, 0JC00ZZ, [VS], 79...","[[VS], 00078037742, 68084061221, 00007550040, ...","[[VS], 51429_0, 51429_0, 51512_0, 51044_3, [VE..."
8,"[[CLS], 100, [VS], 0JC00ZZ, 9652, [VE], [M_2],...","[[VS], 00641601310, [VE], [M_2], [REG], [VS], ...","[[VS], 50995_4, 50841_4, 51564_0, 51429_0, 508..."
9,"[[CLS], 8127, [VS], 7906, [VE], [W_1], [REG], ...","[[VS], 00069419068, 00078037742, 00078059620, ...","[[VS], 50908_4, 50908_4, 51009_4, 52301_3, 509..."


In [23]:
# Vertical Alignment For Single Example
patient = patients.iloc[9]

procedure_ref = []
next_iter_procedure_ref = []
new_procedure = [[]]

d = 0; p = 0
while patient['diagnosis'][d] != '[VS]':
    new_procedure[-1].append('[PAD]')
    d += 1


while d < len(patient['diagnosis']):

    if patient['procedure'][p] == '[VE]' and patient['diagnosis'][d] != '[VE]':
        new_procedure[-1].append('[PAD]')
        d += 1
        continue

    elif patient['procedure'][p] != '[VE]' and patient['diagnosis'][d] == '[VE]':
        vs_index = len(new_procedure[-1]) - new_procedure[-1][::-1].index('[VS]') - 1
        procedure_ref.append((p, vs_index))

        while patient['procedure'][p] != '[VE]':
            p += 1

    new_procedure[-1].append(patient['procedure'][p])
    d += 1; p += 1



while procedure_ref or next_iter_procedure_ref:

    if next_iter_procedure_ref:
        procedure_ref = next_iter_procedure_ref.copy()
        next_iter_procedure_ref = []

    new_procedure.append([])

    for i, (ref, vs_index) in enumerate(procedure_ref):

        if len(new_procedure[-1]) == 0:
            n = 0
            while n <= vs_index:
                current_token = new_procedure[0][n]
                if current_token in special_tokens:
                    new_procedure[-1].append(current_token)
                else:
                    new_procedure[-1].append('[PAD]')
                n += 1

        n = vs_index + 1
        p = ref

        while patient['procedure'][p] != '[VE]' and new_procedure[0][n] != '[VE]':
            new_procedure[-1].append(patient['procedure'][p])
            n += 1; p += 1

        if patient['procedure'][p] != '[VE]' and new_procedure[0][n] == '[VE]':
            next_iter_procedure_ref.append((p, vs_index))

        procedure_ref.remove((ref, vs_index))

        if len(procedure_ref) == 0:
            next_vs_index = len(new_procedure[0]) - 1
        else:
            next_vs_index = procedure_ref[i][1]

        while n <= next_vs_index:
            current_token = new_procedure[0][n]
            if current_token in special_tokens:
                new_procedure[-1].append(current_token)
            else:
                new_procedure[-1].append('[PAD]')
            n += 1



print(f"Diagnosis:\n{patient['diagnosis']} \n\nOld Procedure:\n{patient['procedure']} \n\nNew Procedure:\n{new_procedure}\n")
print(f"Len check: {len(new_procedure[0]) == len(patient['diagnosis'])}")

Diagnosis:
['[CLS]', '0JC00ZZ', '[VS]', '100', '03743D6', '9652', '0JC00ZZ', '0B998ZZ', '7906', '[VE]', '[M_5]', '[REG]', '[VS]', '0JC00ZZ', '8222', '[VE]', '[LT]', '[REG]', '[VS]', '100', '0YHM0YZ', '03743D6', '02H64JZ', '03743D6', '[VE]', '[M_5]', '[REG]', '[VS]', '0YHM0YZ', '7906', '0YHM0YZ', '[VE]', '[M_2]', '[REG]', '[VS]', '8222', '7906', '8127', '[VE]', '[REG]'] 

Old Procedure:
['[VS]', '00641601310', '00007550040', '63323027205', '[VE]', '[M_5]', '[REG]', '[VS]', '00078037742', '[VE]', '[LT]', '[REG]', '[VS]', '63323027205', '68084061221', '00007550040', '63323027205', '00078037742', '00078037742', '[VE]', '[M_5]', '[REG]', '[VS]', '63323027205', '00069419068', '51079017220', '[VE]', '[M_2]', '[REG]', '[VS]', '00069419068', '63323027205', '00078037742', '63323027205', '51079017220', '63323027205', '[VE]', '[REG]'] 

New Procedure:
[['[PAD]', '[PAD]', '[VS]', '00641601310', '00007550040', '63323027205', '[PAD]', '[PAD]', '[PAD]', '[VE]', '[M_5]', '[REG]', '[VS]', '00078037742',

In [5]:
vocab_dict = {}

vocab_json_files = glob.glob(os.path.join(config.data_dir, '*_vocab.json'))
for file in vocab_json_files:
    vocab = json.load(open(file, 'r'))

    vocab_type = file.split("/")[-1].split(".")[0]
    vocab_dict[vocab_type] = vocab

combined_vocab = list(chain.from_iterable(list(vocab_dict.values())))

In [6]:
combined_vocab = special_tokens + combined_vocab
tokenizer_vocab = {token: i for i, token in enumerate(combined_vocab)}

In [7]:
tokenizer_object = Tokenizer(models.WordPiece(vocab=tokenizer_vocab, unk_token="[UNK]", max_input_chars_per_word=1000))
tokenizer_object.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

example = " ".join(patients.iloc[5]['diagnosis'] + ["[UNK] [PAD] [PAD] [PAD] [PAD]"])
example2 = " ".join(patients.iloc[1]['lab'] + ["[UNK] [PAD] [PAD]"])
encoding = tokenizer_object.decode([0, 1, 2])
print(encoding)

[UNK] [PAD] [CLS]


In [15]:
tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer_object,
        bos_token="[VS]",
        eos_token="[VE]",
        unk_token="[UNK]",
        # sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
    )

tokenizer(
    [example, example2],
    return_attention_mask=True,
    return_token_type_ids=True,
    truncation=False,
    padding=True,
    max_length=2048,
    return_tensors="pt",
)



{'input_ids': tensor([[   2, 3064, 3066, 3064, 3070,    5, 3064, 3066, 3072,    6,   13,    3,
            5, 3070, 3072, 3067, 3070,    6,   24,    3,    5, 3064, 3068, 3072,
            6,    3,    0,    1,    1,    1,    1],
        [   5,   31,   28,   28,   32,    6,   24,    3,    5,   34,    6,    3,
            0,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1]])}

In [151]:
class ConceptTokenizer:
    """Tokenizer for event concepts."""
    def __init__(
        self,
        pad_token: str = "[PAD]",
        mask_token: str = "[MASK]",
        start_token: str = "[VS]",
        end_token: str = "[VE]",
        oov_token='-1',
        data_dir: str = 'data_files'
    ):
        self.tokenizer = Tokenizer(oov_token=oov_token, filters='', lower=False)
        self.mask_token = mask_token
        self.pad_token = pad_token
        self.special_tokens = [pad_token, mask_token, start_token, end_token] + \
            [f'W_{i}' for i in range(0, 4)] + [f'M_{i}' for i in range(0, 13)] + ['LT']
        self.data_dir = data_dir

    def fit_on_vocab(self) -> None:
        """Fit the tokenizer on the vocabulary."""
        vocab_json_files = glob.glob(os.path.join(self.data_dir, '*_vocab.json'))
        for file in vocab_json_files:
            vocab = json.load(open(file, 'r'))
            self.tokenizer.fit_on_texts(vocab)
        self.tokenizer.fit_on_texts(self.special_tokens)

    def encode(
        self,
        concept_sequences: Union[str, Sequence[str]],
        is_generator: bool=False
    ) -> Union[int, Sequence[int]]:
        """Encode the concept sequences into token ids."""
        return self.tokenizer.texts_to_sequences_generator(
            concept_sequences) if is_generator else self.tokenizer.texts_to_sequences(
            concept_sequences)

    def decode(
        self, concept_sequence_token_ids: Union[int, Sequence[int]]
    ) -> Sequence[str]:
        """Decode the concept sequence token ids into concepts."""
        return self.tokenizer.sequences_to_texts(concept_sequence_token_ids)

    def get_all_token_indexes(self) -> set:
        all_keys = set(self.tokenizer.index_word.keys())

        if self.tokenizer.oov_token is not None:
            all_keys.remove(self.tokenizer.word_index[self.tokenizer.oov_token])

        if self.special_tokens is not None:
            excluded = set(
                [self.tokenizer.word_index[special_token] for special_token in self.special_tokens])
            all_keys = all_keys - excluded
        return all_keys

    def get_first_token_index(self) -> int:
        return min(self.get_all_token_indexes())

    def get_last_token_index(self) -> int:
        return max(self.get_all_token_indexes())

    def get_vocab_size(self) -> int:
        # + 1 because oov_token takes the index 0
        return len(self.tokenizer.index_word) + 1

    def get_pad_token_id(self):
        pad_token_id = self.encode(self.pad_token)
        while isinstance(pad_token_id, list):
            pad_token_id = pad_token_id[0]
        return pad_token_id

    def get_mask_token_id(self):
        mask_token_id = self.encode(self.mask_token)
        while isinstance(mask_token_id, list):
            mask_token_id = mask_token_id[0]
        return mask_token_id

    def get_special_token_ids(self):
        special_ids = self.encode(self.special_tokens)
        flat_special_ids = [item[0] for item in special_ids]
        return flat_special_ids

In [152]:
tokenizer = ConceptTokenizer(data_dir=config.data_dir)
tokenizer.fit_on_vocab()
config.vocab_size = tokenizer.get_vocab_size()

In [153]:
tokenizer.decode([[0, 1, 2, 3, 4, 5]])

['-1 -1 50995_4 51429_0 51009_4 50908_4']

In [154]:
tokenizer.encode(["[PAD] [UNKoieri] [CLS] [VS]"])

[[20569, 1, 1, 20571]]

In [67]:
tokenizer.encode(patients.iloc[5])

[[4, 3071, 3062, 3069, 3065, 3064, 3069, 5, 23, 4, 3070, 3066, 3063, 3068, 5],
 [4, 15727, 5, 23, 4, 15731, 15731, 15729, 15731, 15731, 5],
 [4, 32, 5, 23, 4, 31, 5]]