In [None]:
import os


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)
# from models.cehr_bert.tokenizer import ConceptTokenizer
import glob
import json
from itertools import chain

import pandas as pd
import torch
from tokenizers import Tokenizer, models, pre_tokenizers

# from keras.preprocessing.text import Tokenizer
from transformers import PreTrainedTokenizerFast


%matplotlib inline

DATA_ROOT = f"{ROOT}/data/slurm_data/2048/one_month"
DATA_PATH = f"{DATA_ROOT}/pretrain.parquet"
TOKENIZER_PATH = f"{DATA_ROOT}/tokenizer.json"
special_tokens = (
    ["[PAD]", "[UNK]", "[CLS]", "[MASK]", "[VS]", "[VE]"]
    + [f"[W_{i}]" for i in range(0, 4)]
    + [f"[M_{i}]" for i in range(0, 13)]
    + ["[LT]"]
)

# To be added [REG] token

In [None]:
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    batch_size = 64
    num_workers = 2
    vocab_size = None
    embedding_size = 128
    time_embeddings_size = 16
    max_len = 512
    device = torch.device("cuda")

In [None]:
# Load data
patients = pd.read_parquet(DATA_PATH)
patients

In [None]:
# Create dictionary of all possible medical concepts
vocab_dict = {}

vocab_json_files = glob.glob(os.path.join(config.data_dir, "*_vocab.json"))
for file in vocab_json_files:
    vocab = json.load(open(file, "r"))

    vocab_type = file.split("/")[-1].split(".")[0]
    vocab_dict[vocab_type] = vocab

combined_vocab = list(chain.from_iterable(list(vocab_dict.values())))

In [None]:
# Create the tokenizer dictionary
combined_vocab = special_tokens + combined_vocab
tokenizer_vocab = {token: i for i, token in enumerate(combined_vocab)}

# Create the tokenizer object
tokenizer_object = Tokenizer(
    models.WordPiece(
        vocab=tokenizer_vocab,
        unk_token="[UNK]",
        max_input_chars_per_word=1000,
    ),
)
tokenizer_object.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

In [None]:
# Test Examples
example = " ".join(patients.iloc[5]["diagnosis"] + ["[UNK] [PAD] [PAD] [PAD] [PAD]"])
example2 = " ".join(patients.iloc[1]["lab"] + ["[UNK] [PAD] [PAD]"])
encoding = tokenizer_object.decode([0, 1, 2])
print(encoding)

In [None]:
# Save tokenizer
tokenizer_object.save(path=TOKENIZER_PATH)

In [None]:
# Create tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=TOKENIZER_PATH,
    bos_token="[VS]",
    eos_token="[VE]",
    unk_token="[UNK]",
    # sep_token="[SEP]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    mask_token="[MASK]",
)

tokenizer(
    [example, example2],
    return_attention_mask=True,
    return_token_type_ids=True,
    truncation=True,
    padding=True,
    max_length=2048,
    return_tensors="pt",
)