In [1]:
import os
from typing import Any, Dict, Optional, Tuple, Union

import glob
import json
import os
from itertools import chain
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import pandas as pd
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import (
    BatchEncoding,
    PreTrainedTokenizerFast,
    AutoTokenizer,
    MambaConfig,
    MambaModel,
    MambaForCausalLM,
    MambaPreTrainedModel,
)

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from torch import nn, optim
from torch.cuda.amp import autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput

# from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel

ROOT = "/h/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.models.embeddings import *
from odyssey.data.dataset import PretrainDataset, PretrainDatasetDecoder
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.model_utils import (
    get_run_id,
    load_config,
    load_pretrain_data,
    load_finetune_data,
)
from odyssey.utils.utils import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""
New Training:
1. Num parameters
2. Epochs
3. Overfitting
    4. Emebeddings
    5. Label balance
    6. Dataset
"""

'\nNew Training:\n1. Num parameters\n2. Epochs\n3. Overfitting\n    4. Emebeddings\n    5. Label balance\n    6. Dataset\n'

In [2]:
model = torch.load("checkpoints/mamba_pretrain_with_embeddings/best.ckpt")

In [2]:
class args:
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences_2048_multi.parquet"
    id_file = "dataset_2048_multi.pkl"
    vocab_dir = "odyssey/data/vocab"
    max_len = 2048
    mask_prob = 0.15
    tasks = ["mortality_1month", "los_1week", "c0", "c1", "c2"]
    balance_guide = {
        "mortality_1month": 0.5,
        "los_1week": 0.5,
        "c0": 0.5,
        "c1": 0.5,
        "c2": 0.5,
    }

In [3]:
# Setup tokenizer
tokenizer = ConceptTokenizer(data_dir=args.vocab_dir)
tokenizer.fit_on_vocab()


# Setup data
pre_data = load_pretrain_data(
    args.data_dir,
    f"patient_sequences/{args.sequence_file}",
    f"patient_id_dict/{args.id_file}",
)
train_dataset = PretrainDatasetDecoder(
    data=pre_data,
    tokenizer=tokenizer,
    max_len=args.max_len,
)


# _, fine_test = load_finetune_data(
#     args.data_dir, args.sequence_file, args.id_file, "few_shot", "all"
# )
# test_dataset = PretrainDatasetDecoder(
#     data=fine_test,
#     tokenizer=tokenizer,
#     max_len=args.max_len,
# )

In [4]:
config = MambaConfig(
    vocab_size=tokenizer.get_vocab_size(),
    hidden_size=768,
    state_size=16,
    num_hidden_layers=32,
    max_seq_length=2048,
    pad_token_id=tokenizer.get_pad_token_id(),
    bos_token_id=tokenizer.token_to_id("[CLS]"),
    eos_token_id=tokenizer.get_pad_token_id(),
)

# embeddings = MambaEmbeddingsForCEHR(
#     config=config
# )

model = MambaForCausalLM(config=config)
# model.backbone.embeddings = embeddings
model.to(device)
model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(20600, 768)
    (layers): ModuleList(
      (0-31): 32 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=20600, bias=False)
)

In [None]:
# Load pretrained model
checkpoint = torch.load("checkpoints/mamba_pretrain/best.ckpt", map_location=device)
state_dict = checkpoint["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model

In [None]:
train_loader = DataLoader(
    decoder_dataset,  # test_dataset, #train_dataset
    batch_size=3,
    shuffle=False,
)

sample = decoder_dataset[2323]  # test_dataset[8765] #train_dataset[0]
task = sample.pop("task")
sample = {key: tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}
sample["task"] = task

# sample = next(iter(train_loader))
# sample = {key:tensor.to(device) for key, tensor in sample.items()}

sample

In [None]:
input_ids = sample["concept_ids"].squeeze().tolist()
input_ids = input_ids[: input_ids.index(0)]
print(tokenizer.decode(input_ids))

In [None]:
tokenizer.decode(input_ids[-10:])

In [None]:
output = model.generate(
    torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),
    max_new_tokens=10,
)

tokenizer.decode(output.squeeze().tolist()[-10:])

In [None]:
# import torch
# from transformers import AutoTokenizer, MambaForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

# inputs = tokenizer(["Hello, my dog is cute", "NO", "Go to Sumeru"], padding=True, return_tensors="pt")
# outputs = model(inputs['input_ids'], labels=inputs["input_ids"])
# loss = outputs.loss
# logits = outputs.logits

In [None]:
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
# inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# model.backbone.embeddings.cache_input(
#     token_type_ids_batch = sample['type_ids'],
#     position_ids_batch = None,
#     inputs_embeds = None,
#     time_stamps = sample['time_stamps'],
#     ages = sample['ages'],
#     visit_orders = sample['visit_orders'],
#     visit_segments = sample['visit_segments']
# )

# outputs = model(
#     input_ids=sample["concept_ids"], labels=sample["concept_ids"], return_dict=True
# )

# loss = outputs.loss
# logits = outputs.logits

In [None]:
# model = model.backbone
outputs = model(input_ids=sample["concept_ids"], return_dict=True)

last_hidden_states = outputs.last_hidden_state
last_hidden_states.shape

In [None]:
config.hidden_act

In [None]:
classifier = torch.nn.Linear(config.hidden_size, 2, bias=False).to(device)
logits = classifier(last_hidden_states)
logits

In [None]:
sample["concept_ids"].squeeze()[204]

In [None]:
sequence_lengths = torch.eq(sample["concept_ids"], 0).int().argmax(-1) - 1
sequence_lengths

In [None]:
pooled_logits = logits[torch.arange(1, device=device), sequence_lengths]
pooled_logits

In [None]:
pooled_last_hidden_states = last_hidden_states[
    torch.arange(1, device=device), sequence_lengths
]
classifier(pooled_last_hidden_states)

In [None]:
import copy

config_copy = copy.deepcopy(config)
config_copy.classifier_dropout = 0.1
head = MambaClassificationHead(config_copy).to(device)
head(pooled_last_hidden_states)

In [None]:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, 2), torch.tensor([0]).to(device).view(-1))
loss

In [None]:
outputs

In [None]:
inputs["input_ids"].shape

In [None]:
last_hidden_states[:, 0, :].shape

In [None]:
from odyssey.models.cehr_mamba.model import MambaPretrain
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss

In [None]:
"""
1. Emebeddings -> Not now
2. Padding order -> Done automatically

---
Finetuning Approach:
    1. Replace the first and last REG token with the class token
2. Use the last hiddent state of the last token for class prediction
3. Ourselves!

4. Dataset refactoring (inheritance, what to return, etc)
"""

In [None]:
import random
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
from torch.utils.data import Dataset

from odyssey.data.tokenizer import ConceptTokenizer, truncate_and_pad

TASK_INDEX = 1
LABEL_INDEX = 2
CUTOFF_INDEX = 3


# Load FinetuneDatasetDecoder for debugging


decoder_dataset = FinetuneDatasetDecoder(
    data=fine_test,
    tokenizer=tokenizer,
    max_len=args.max_len,
    tasks=args.tasks,
    balance_guide=args.balance_guide,
)
decoder_dataset[12112]

In [7]:
from odyssey.models.embeddings import *


embeddings = MambaEmbeddingsForCEHR(
    config=config,
    type_vocab_size=9,
    max_num_visits=512,
    time_embeddings_size=32,
    visit_order_size=3,
    hidden_dropout_prob=0.1,
)

In [8]:
batch = train_dataset[51020]
batch = {key: tensor.unsqueeze(0) for key, tensor in batch.items()}
batch

{'concept_ids': tensor([[    5,     3, 18896,  ...,  1712,     4,     6]]),
 'type_ids': tensor([[1, 2, 6,  ..., 5, 3, 8]]),
 'ages': tensor([[ 0, 77, 77,  ..., 78, 78, 78]]),
 'time_stamps': tensor([[   0, 8928, 8928,  ..., 8981, 8981, 8981]]),
 'visit_orders': tensor([[0, 1, 1,  ..., 8, 8, 8]]),
 'visit_segments': tensor([[0, 2, 2,  ..., 1, 1, 1]]),
 'labels': tensor([[    5,     3, 18896,  ...,  1712,     4,     6]])}

In [9]:
print(set(batch["visit_orders"][0].tolist()))

{0, 1, 2, 3, 4, 5, 6, 7, 8}


In [10]:
inputs = (
    batch["concept_ids"],
    batch["type_ids"],
    batch["time_stamps"],
    batch["ages"],
    batch["visit_orders"],
    batch["visit_segments"],
)
labels = batch["labels"]

concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = inputs
inputs_embeds = embeddings(
    input_ids=concept_ids,
    token_type_ids_batch=type_ids,
    time_stamps=time_stamps,
    ages=ages,
    visit_orders=visit_orders,
    visit_segments=visit_segments,
)
inputs_embeds

tensor([[[ 0.0452, -0.9307,  0.3723,  ..., -1.1524,  0.5854, -0.3397],
         [ 0.0355, -0.5227, -3.0730,  ...,  0.0355, -0.9060, -0.9247],
         [ 0.5338, -0.9962, -1.5450,  ...,  1.6476, -1.0616, -1.5152],
         ...,
         [ 0.1515,  1.6252, -0.2081,  ..., -1.0206,  0.8621,  1.3194],
         [-0.0812,  2.3611, -0.1516,  ..., -1.0140, -0.0978,  1.6653],
         [ 0.4165,  0.8611, -0.5180,  ...,  0.3821,  1.1638,  1.4207]]],
       grad_fn=<NativeLayerNormBackward0>)

In [13]:
outputs = model(
    inputs_embeds=inputs_embeds.to(device),
    labels=labels.to(device),
    output_hidden_states=False,
    return_dict=True,
)
outputs

MambaCausalLMOutput(loss=tensor(14.1312, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[-1.7062, -5.0347,  2.0794,  ...,  2.3091,  1.0115, -1.9545],
         [-1.0205, -2.8787, -4.8018,  ..., -3.2100, -5.3467, -1.1486],
         [ 1.2367, -4.0578, -3.7514,  ..., -0.0644,  0.9085,  0.9692],
         ...,
         [ 2.3067,  3.5723,  1.9051,  ..., -0.0123, -2.5649, -0.4133],
         [ 4.0669,  4.1643,  3.6506,  ...,  2.8866, -5.4374, -0.8073],
         [ 1.8762,  5.5222, -0.6316,  ...,  0.1687, -7.1170, -4.6202]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>), cache_params=None, hidden_states=None)