In [1]:
ROOT = "/fs01/home/afallah/odyssey/odyssey"

from typing import Any, Tuple

import torch
import pickle
import os; os.chdir(ROOT)
import pandas as pd
from torch.utils.data import DataLoader

from models.big_bird_cehr.model import BigBirdPretrain, BigBirdFinetune
from models.big_bird_cehr.data import PretrainDataset, FinetuneDataset
from models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer


DATA_ROOT = f"{ROOT}/data/slurm_data/2048/one_month"
DATA_PATH = f"{DATA_ROOT}/fine_test.parquet"
NEW_DATA_PATH = f"{ROOT}/data/bigbird_data/patient_sequences_2048_labeled.parquet"

In [2]:
data = pd.read_parquet("/h/afallah/odyssey/odyssey/data/bigbird_data/patient_sequences_2048_labeled.parquet")
patient_ids = pickle.load(open('/h/afallah/odyssey/odyssey/data/bigbird_data/dataset_2048_mortality_1month.pkl', 'rb'))
pre_data = data.loc[data['patient_id'].isin(patient_ids['test'])]
pre_data.rename(columns={'label_mortality_1month': 'label'}, inplace=True)
# Train Tokenizer
tokenizer = HuggingFaceConceptTokenizer(data_dir="/h/afallah/odyssey/odyssey/data/vocab")
tokenizer.fit_on_vocab()

# Load datasets
finetune_dataset = FinetuneDataset(
    data=pre_data,
    tokenizer=tokenizer,
    max_len=2048,
)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  pre_data.rename(columns={'label_mortality_1month': 'label'}, inplace=True)


In [3]:
train_loader = DataLoader(
        finetune_dataset,
        batch_size=2,
    )

In [4]:
finetune_dataset[0]

{'concept_ids': tensor([    5,     3, 10580,  ...,     0,     0,     0]),
 'type_ids': tensor([1, 2, 7,  ..., 0, 0, 0]),
 'ages': tensor([ 0, 76, 76,  ...,  0,  0,  0]),
 'time_stamps': tensor([   0, 5670, 5670,  ...,    0,    0,    0]),
 'visit_orders': tensor([   0,    2,    2,  ..., 2049, 2049, 2049]),
 'visit_segments': tensor([0, 2, 2,  ..., 0, 0, 0]),
 'labels': tensor(0),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0])}

In [5]:
args = pd.DataFrame({'batch_size':[2], 'gpus':[1], 'max_epochs':[5]})

In [6]:
pretrained_model = BigBirdPretrain(
        args=args,
        dataset_len=len(finetune_dataset),
        vocab_size=tokenizer.get_vocab_size(),
        padding_idx=tokenizer.get_pad_token_id(),
    )
pretrained_model.load_state_dict(torch.load('/h/afallah/odyssey/odyssey/checkpoints/bigbird_pretraining_a100/best.ckpt')["state_dict"])

  self.warmup = int(0.1 * grad_steps)
  self.decay = int(0.9 * grad_steps)


<All keys matched successfully>

In [7]:
model = BigBirdFinetune(
        args=args,
        dataset_len=len(finetune_dataset),
        pretrained_model=pretrained_model,
    )

  self.warmup = int(0.1 * grad_steps)
  self.decay = int(0.9 * grad_steps)


In [8]:
batch = next(iter(train_loader))
at = batch.pop('attention_mask')
labels = batch.pop('labels')
batch

{'concept_ids': tensor([[    5,     3, 10580,  ...,     0,     0,     0],
         [    5,     3, 14013,  ...,     0,     0,     0]]),
 'type_ids': tensor([[1, 2, 7,  ..., 0, 0, 0],
         [1, 2, 7,  ..., 0, 0, 0]]),
 'ages': tensor([[ 0, 76, 76,  ...,  0,  0,  0],
         [ 0, 47, 47,  ...,  0,  0,  0]]),
 'time_stamps': tensor([[   0, 5670, 5670,  ...,    0,    0,    0],
         [   0, 5857, 5857,  ...,    0,    0,    0]]),
 'visit_orders': tensor([[   0,    2,    2,  ..., 2049, 2049, 2049],
         [   0,    2,    2,  ..., 2049, 2049, 2049]]),
 'visit_segments': tensor([[0, 2, 2,  ..., 0, 0, 0],
         [0, 2, 2,  ..., 0, 0, 0]])}

In [9]:
pretrained_model(inputs=tuple(batch.values()), attention_mask=at, labels=None).logits.shape

torch.Size([2, 2048, 20592])

In [15]:
model(inputs=tuple(batch.values()), attention_mask=at, labels=labels, return_dict=True)

SequenceClassifierOutput(loss=tensor(0.7985, grad_fn=<NllLossBackward0>), logits=tensor([[-0.0548,  0.3878],
        [-0.0134, -0.0844]], grad_fn=<AddmmBackward0>), hidden_states=(tensor([[[-6.7883e-01, -2.1554e-01, -1.0080e+00,  ..., -4.0627e-02,
          -1.9321e-02, -2.8757e+00],
         [-3.4303e-01,  1.1852e+00, -1.3026e-01,  ...,  1.3585e+00,
          -7.3872e-01,  2.7645e-01],
         [-3.5166e-02,  1.1023e+00,  8.1303e-02,  ...,  2.3114e-02,
           5.4960e-01,  2.8621e-01],
         ...,
         [-1.1241e+00, -3.8402e-01, -1.1593e+00,  ..., -1.2536e-03,
           1.1522e-02, -3.2608e+00],
         [-5.1603e-01,  4.8416e-01, -9.5265e-01,  ..., -2.5094e+00,
           1.4269e-01, -3.4499e+00],
         [-1.8325e-01, -1.2168e-01,  3.0352e-01,  ..., -2.6697e+00,
           9.0409e-01, -2.5874e+00]],

        [[-6.7366e-01, -2.1113e-01, -1.0023e+00,  ..., -3.6522e-02,
           1.4872e-02,  2.7641e-02],
         [ 7.2113e-01, -4.1033e-02, -3.6217e-01,  ...,  6.0504e-01,
 

In [3]:
def mask_tokens(self, sequence: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """ Mask the tokens in the sequence using vectorized operations."""
    mask_token_id = self.tokenizer.get_mask_token_id()

    masked_sequence = sequence.clone()

    # Ignore [PAD], [UNK], [MASK] tokens
    prob_matrix = torch.full(masked_sequence.shape, self.mask_prob)
    prob_matrix[torch.where(masked_sequence <= mask_token_id)] = 0
    selected = torch.bernoulli(prob_matrix).bool()

    # 80% of the time, replace masked input tokens with respective mask tokens
    replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected
    masked_sequence[replaced] = mask_token_id

    # 10% of the time, we replace masked input tokens with random vector.
    randomized = torch.bernoulli(torch.full(selected.shape, 0.1)).bool() & selected & ~replaced
    random_idx = torch.randint(low=self.tokenizer.get_first_token_index(),
                               high=self.tokenizer.get_last_token_index(),
                               size=prob_matrix.shape, dtype=torch.long)
    masked_sequence[randomized] = random_idx[randomized]

    labels = torch.where(selected, sequence, -100)

    return masked_sequence, labels

In [14]:
len(set(train_dataset[0]['type_ids'].tolist()))

9

In [None]:
patients = pd.read_parquet(NEW_DATA_PATH)
patients

In [3]:
tokenizer = HuggingFaceConceptTokenizer(data_dir=DATA_ROOT)
tokenizer.fit_on_vocab()

train_dataset = PretrainDataset(
    data=patients,
    tokenizer=tokenizer,
    max_len=2048,
    mask_prob=0.15,
)

In [6]:
e1 = "[CLS] [VS] 00054853516 00245008201 00338004904 00008084199 00045152510 00006003121"
e2 = "[CLS] [VS] 00054853516 00245008201"

In [38]:
tokenizer(patients["event_tokens_2048"].iloc[0])

{'input_ids': tensor([[    5,     0,     0,  ...,     0,     0,     0],
        [    3,     0,     0,  ...,     0,     0,     0],
        [12809,     0,     0,  ...,     0,     0,     0],
        ...,
        [ 1352,     0,     0,  ...,     0,     0,     0],
        [    4,     0,     0,  ...,     0,     0,     0],
        [    6,     0,     0,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]])}

In [None]:
patients = patients[patients["event_tokens_2048"].notnull()]

tokenizer = ConceptTokenizer(data_dir=DATA_ROOT)
tokenizer.fit_on_vocab()

train_dataset = PretrainDataset(
    data=patients,
    tokenizer=tokenizer,
    max_len=2048,
    mask_prob=1,
)

In [None]:
tokenizer.decode([[3]])

In [None]:
patients.iloc[1]

In [None]:
train_dataset[0]

In [None]:
print(train_dataset[0]["attention_mask"])

In [None]:
tokenizer.get_mask_token_id()

In [None]:
print(list(train_dataset[110]["concept_ids"]).count(20569))

In [None]:
print(len(train_dataset[110]["concept_ids"]))

In [None]:
tokenizer.get_pad_token_id()

In [None]:
tokenizer.encode(["[PAD]"])

In [None]:
patients.iloc[0]["event_tokens_2048"]

In [None]:
ROOT = "/fs01/home/afallah/odyssey/odyssey"

import os


os.chdir(ROOT)
import numpy as np
import pandas as pd
from tqdm import tqdm

from models.big_bird_cehr.data import PretrainDataset
from models.big_bird_cehr.tokenizer import ConceptTokenizer


DATA_ROOT = f"{ROOT}/data/slurm_data/2048/one_month"
DATA_PATH = f"{DATA_ROOT}/fine_test.parquet"
patients = pd.read_parquet(DATA_PATH)
patients
# Find the unique set of all possible tokens, including special tokens
unique_event_tokens = set()

for patient_event_tokens in tqdm(
        patients["event_tokens_2048"].values, desc="Loading Tokens", unit=" Patients",
):
    for event_token in patient_event_tokens:
        unique_event_tokens.add(event_token)

unique_event_tokens = list(unique_event_tokens)
unique_event_tokens.sort(reverse=True)

print(
    f"Complete list of unique event tokens\nLength: {len(unique_event_tokens)}\nHead: {unique_event_tokens[:30]}...",
)
special_tokens = [
    "[CLS]",
    "[PAD]",
    # "[VS]",
    "[VE]",
    "[W_0]",
    "[W_1]",
    "[W_2]",
    "[W_3]",
    *[f"[M_{i}]" for i in range(0, 13)],
    "[LT]",
]

feature_event_tokens = [token for token in unique_event_tokens if token not in special_tokens]

print(len(feature_event_tokens), feature_event_tokens[:20])
patients_event_tokens = patients["event_tokens_2048"]
len_vocab = len(feature_event_tokens)
token2id = {token: i for i, token in enumerate(feature_event_tokens)}
token_correlations = np.zeros(shape=(len_vocab, len_vocab))
token_frequencies = []

for curr_token in tqdm(feature_event_tokens, desc="Analyzing... ", unit=" Tokens"):
    curr_token_id = token2id[curr_token]
    token_freq = 0

    for _, patient in enumerate(patients_event_tokens):

        vs_id = np.where(patient == "[VS]")[0]
        ve_id = np.where(patient == "[VE]")[0]

        for vs, ve in zip(vs_id, ve_id):
            curr_visit = patient[vs:ve]

            if curr_token not in curr_visit:
                continue

            token_freq += 1
            for visit_token in curr_visit:
                token_correlations[curr_token_id][token2id[visit_token]] += 1

    token_frequencies.append(token_freq)
patients = patients[patients["event_tokens_2048"].notnull()]

tokenizer = ConceptTokenizer(data_dir=DATA_ROOT)
tokenizer.fit_on_vocab()

train_dataset = PretrainDataset(
    data=patients,
    tokenizer=tokenizer,
    max_len=2048,
    mask_prob=1,
)