In [None]:
import os


ROOT = "/fs01/home/afallah/odyssey/slurm"
os.chdir(ROOT)
# from models.cehr_bert.tokenizer import ConceptTokenizer
import glob
import json
from itertools import chain
from random import randint
from typing import Sequence, Union

import pandas as pd
import torch
from tokenizers import Tokenizer, models, pre_tokenizers

# from keras.preprocessing.text import Tokenizer
from transformers import PreTrainedTokenizerFast


%matplotlib inline

DATA_ROOT = f"{ROOT}/data"
DATA_PATH = f"{DATA_ROOT}/patient_sequences.parquet"
special_tokens = (
    ["[UNK]", "[PAD]", "[CLS]", "[REG]", "[MASK]", "[VS]", "[VE]"]
    + [f"[W_{i}]" for i in range(0, 4)]
    + [f"[M_{i}]" for i in range(0, 13)]
    + ["[LT]"]
)

In [None]:
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    batch_size = 64
    num_workers = 2
    vocab_size = None
    embedding_size = 128
    time_embeddings_size = 16
    max_len = 512
    device = torch.device("cuda")

In [None]:
# # Load data
# data = pd.read_parquet(DATA_PATH)
# data.rename(columns={'event_tokens': 'event_tokens_untruncated', 'event_tokens_updated': 'event_tokens'}, inplace=True)
# data

In [None]:
### Define random dataset using the actual vocabulary ###
lab_dict = [
    "50995_4",
    "51429_0",
    "51009_4",
    "50908_4",
    "51564_0",
    "52301_3",
    "51044_3",
    "50841_4",
    "52178_0",
    "51512_0",
]
procedure_dict = [
    "00641601310",
    "00069419068",
    "00078037742",
    "63323027205",
    "00078059620",
    "00007550040",
    "51079017220",
    "68084061221",
]
diagnosis_dict = [
    "02H64JZ",
    "7906",
    "8222",
    "0JC00ZZ",
    "0YHM0YZ",
    "9652",
    "03743D6",
    "100",
    "0B998ZZ",
    "8127",
]
time_dict = ["[W_1]", "[W_3]", "[M_2]", "[M_5]", "[LT]"]


def generate_random_events(num_visits, event_dict, time_tokens):
    patient = []
    length_visits = [randint(1, 6) for _ in range(num_visits)]

    for i in range(num_visits):
        patient.append("[VS]")
        length_visit = length_visits[i]
        random_events = [
            event_dict[randint(0, len(event_dict) - 1)] for _ in range(length_visit)
        ]
        patient += random_events
        patient.append("[VE]")

        if i < num_visits - 1:
            patient.append(time_tokens[i])

        patient.append("[REG]")

    return patient


def generate_random_patient(lab_dict, procedure_dict, diagnosis_dict, time_dict):
    num_visits = randint(1, 5)
    time_tokens = [time_dict[randint(0, len(time_dict) - 1)] for _ in range(num_visits)]

    random_lab = generate_random_events(num_visits, lab_dict, time_tokens)
    random_procedure = generate_random_events(num_visits, procedure_dict, time_tokens)
    random_diagnosis = ["[CLS]"] + generate_random_events(
        num_visits,
        diagnosis_dict,
        time_tokens,
    )

    prior_vs_diagnosis = [
        diagnosis_dict[randint(0, len(diagnosis_dict) - 1)]
        for _ in range(randint(0, 5))
    ]
    random_diagnosis = [random_diagnosis[0]] + prior_vs_diagnosis + random_diagnosis[1:]

    return {
        "diagnosis": random_diagnosis,
        "procedure": random_procedure,
        "lab": random_lab,
    }


def generate_random_dataset(
    lab_dict,
    procedure_dict,
    diagnosis_dict,
    time_dict,
    num_patients=10,
):
    patients = []

    for i in range(num_patients):
        patient = generate_random_patient(
            lab_dict,
            procedure_dict,
            diagnosis_dict,
            time_dict,
        )
        patients.append(patient)

    return patients


# Assume these are already truncated
patients = generate_random_dataset(lab_dict, procedure_dict, diagnosis_dict, time_dict)
patients = pd.DataFrame(patients)
patients

In [None]:
# Vertical Alignment For Single Example
patient = patients.iloc[9]

procedure_ref = []
next_iter_procedure_ref = []
new_procedure = [[]]

d = 0
p = 0
while patient["diagnosis"][d] != "[VS]":
    new_procedure[-1].append("[PAD]")
    d += 1


while d < len(patient["diagnosis"]):
    if patient["procedure"][p] == "[VE]" and patient["diagnosis"][d] != "[VE]":
        new_procedure[-1].append("[PAD]")
        d += 1
        continue

    elif patient["procedure"][p] != "[VE]" and patient["diagnosis"][d] == "[VE]":
        vs_index = len(new_procedure[-1]) - new_procedure[-1][::-1].index("[VS]") - 1
        procedure_ref.append((p, vs_index))

        while patient["procedure"][p] != "[VE]":
            p += 1

    new_procedure[-1].append(patient["procedure"][p])
    d += 1
    p += 1


while procedure_ref or next_iter_procedure_ref:
    if next_iter_procedure_ref:
        procedure_ref = next_iter_procedure_ref.copy()
        next_iter_procedure_ref = []

    new_procedure.append([])

    for i, (ref, vs_index) in enumerate(procedure_ref):
        if len(new_procedure[-1]) == 0:
            n = 0
            while n <= vs_index:
                current_token = new_procedure[0][n]
                if current_token in special_tokens:
                    new_procedure[-1].append(current_token)
                else:
                    new_procedure[-1].append("[PAD]")
                n += 1

        n = vs_index + 1
        p = ref

        while patient["procedure"][p] != "[VE]" and new_procedure[0][n] != "[VE]":
            new_procedure[-1].append(patient["procedure"][p])
            n += 1
            p += 1

        if patient["procedure"][p] != "[VE]" and new_procedure[0][n] == "[VE]":
            next_iter_procedure_ref.append((p, vs_index))

        procedure_ref.remove((ref, vs_index))

        if len(procedure_ref) == 0:
            next_vs_index = len(new_procedure[0]) - 1
        else:
            next_vs_index = procedure_ref[i][1]

        while n <= next_vs_index:
            current_token = new_procedure[0][n]
            if current_token in special_tokens:
                new_procedure[-1].append(current_token)
            else:
                new_procedure[-1].append("[PAD]")
            n += 1


print(
    f"Diagnosis:\n{patient['diagnosis']} \n\nOld Procedure:\n{patient['procedure']} \n\nNew Procedure:\n{new_procedure}\n",
)
print(f"Len check: {len(new_procedure[0]) == len(patient['diagnosis'])}")

In [None]:
vocab_dict = {}

vocab_json_files = glob.glob(os.path.join(config.data_dir, "*_vocab.json"))
for file in vocab_json_files:
    vocab = json.load(open(file, "r"))

    vocab_type = file.split("/")[-1].split(".")[0]
    vocab_dict[vocab_type] = vocab

combined_vocab = list(chain.from_iterable(list(vocab_dict.values())))

In [None]:
combined_vocab = special_tokens + combined_vocab
tokenizer_vocab = {token: i for i, token in enumerate(combined_vocab)}

In [None]:
tokenizer_object = Tokenizer(
    models.WordPiece(
        vocab=tokenizer_vocab,
        unk_token="[UNK]",
        max_input_chars_per_word=1000,
    ),
)
tokenizer_object.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

example = " ".join(patients.iloc[5]["diagnosis"] + ["[UNK] [PAD] [PAD] [PAD] [PAD]"])
example2 = " ".join(patients.iloc[1]["lab"] + ["[UNK] [PAD] [PAD]"])
encoding = tokenizer_object.decode([0, 1, 2])
print(encoding)

In [None]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_object,
    bos_token="[VS]",
    eos_token="[VE]",
    unk_token="[UNK]",
    # sep_token="[SEP]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    mask_token="[MASK]",
)

tokenizer(
    [example, example2],
    return_attention_mask=True,
    return_token_type_ids=True,
    truncation=False,
    padding=True,
    max_length=2048,
    return_tensors="pt",
)

In [None]:
class ConceptTokenizer:
    """Tokenizer for event concepts."""

    def __init__(
        self,
        pad_token: str = "[PAD]",
        mask_token: str = "[MASK]",
        start_token: str = "[VS]",
        end_token: str = "[VE]",
        oov_token="-1",
        data_dir: str = "data_files",
    ):
        self.tokenizer = Tokenizer(oov_token=oov_token, filters="", lower=False)
        self.mask_token = mask_token
        self.pad_token = pad_token
        self.special_tokens = (
            [pad_token, mask_token, start_token, end_token]
            + [f"W_{i}" for i in range(0, 4)]
            + [f"M_{i}" for i in range(0, 13)]
            + ["LT"]
        )
        self.data_dir = data_dir

    def fit_on_vocab(self) -> None:
        """Fit the tokenizer on the vocabulary."""
        vocab_json_files = glob.glob(os.path.join(self.data_dir, "*_vocab.json"))
        for file in vocab_json_files:
            vocab = json.load(open(file, "r"))
            self.tokenizer.fit_on_texts(vocab)
        self.tokenizer.fit_on_texts(self.special_tokens)

    def encode(
        self,
        concept_sequences: Union[str, Sequence[str]],
        is_generator: bool = False,
    ) -> Union[int, Sequence[int]]:
        """Encode the concept sequences into token ids."""
        return (
            self.tokenizer.texts_to_sequences_generator(concept_sequences)
            if is_generator
            else self.tokenizer.texts_to_sequences(concept_sequences)
        )

    def decode(
        self,
        concept_sequence_token_ids: Union[int, Sequence[int]],
    ) -> Sequence[str]:
        """Decode the concept sequence token ids into concepts."""
        return self.tokenizer.sequences_to_texts(concept_sequence_token_ids)

    def get_all_token_indexes(self) -> set:
        all_keys = set(self.tokenizer.index_word.keys())

        if self.tokenizer.oov_token is not None:
            all_keys.remove(self.tokenizer.word_index[self.tokenizer.oov_token])

        if self.special_tokens is not None:
            excluded = set(
                [
                    self.tokenizer.word_index[special_token]
                    for special_token in self.special_tokens
                ],
            )
            all_keys = all_keys - excluded
        return all_keys

    def get_first_token_index(self) -> int:
        return min(self.get_all_token_indexes())

    def get_last_token_index(self) -> int:
        return max(self.get_all_token_indexes())

    def get_vocab_size(self) -> int:
        # + 1 because oov_token takes the index 0
        return len(self.tokenizer.index_word) + 1

    def get_pad_token_id(self):
        pad_token_id = self.encode(self.pad_token)
        while isinstance(pad_token_id, list):
            pad_token_id = pad_token_id[0]
        return pad_token_id

    def get_mask_token_id(self):
        mask_token_id = self.encode(self.mask_token)
        while isinstance(mask_token_id, list):
            mask_token_id = mask_token_id[0]
        return mask_token_id

    def get_special_token_ids(self):
        special_ids = self.encode(self.special_tokens)
        flat_special_ids = [item[0] for item in special_ids]
        return flat_special_ids

In [None]:
tokenizer = ConceptTokenizer(data_dir=config.data_dir)
tokenizer.fit_on_vocab()
config.vocab_size = tokenizer.get_vocab_size()

In [None]:
tokenizer.decode([[0, 1, 2, 3, 4, 5]])

In [None]:
tokenizer.encode(["[PAD] [UNKoieri] [CLS] [VS]"])

In [None]:
tokenizer.encode(patients.iloc[5])