In [4]:
import os
from typing import Any, Dict, Optional, Tuple, Union

import glob
import json
import os
from itertools import chain
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import pandas as pd
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import (
    BatchEncoding,
    PreTrainedTokenizerFast,
    AutoTokenizer,
    MambaConfig,
    MambaModel,
    MambaForCausalLM,
)

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from torch import nn, optim
from torch.cuda.amp import autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput

# from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel

ROOT = "/h/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.models.embeddings import *
from odyssey.data.dataset import PretrainDataset, PretrainDatasetDecoder
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.model_utils import (
    get_run_id,
    load_config,
    load_pretrain_data,
    load_finetune_data,
)
from odyssey.utils.utils import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class args:
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences_2048.parquet"
    id_file = "dataset_2048_multi.pkl"
    vocab_dir = "odyssey/data/vocab"
    max_len = 2048
    mask_prob = 0.15

In [6]:
# Setup tokenizer
tokenizer = ConceptTokenizer(data_dir=args.vocab_dir)
tokenizer.fit_on_vocab()


# Setup data
# pre_data = load_pretrain_data(
#         args.data_dir,
#         f'patient_sequences/{args.sequence_file}',
#         f'patient_id_dict/{args.id_file}',
# )
# train_dataset = PretrainDatasetDecoder(
#         data=pre_data,
#         tokenizer=tokenizer,
#         max_len=args.max_len,
# )


_, fine_test = load_finetune_data(
    args.data_dir, args.sequence_file, args.id_file, "few_shot", "all"
)
test_dataset = PretrainDatasetDecoder(
    data=fine_test,
    tokenizer=tokenizer,
    max_len=args.max_len,
)

In [9]:
config = MambaConfig(
    vocab_size=tokenizer.get_vocab_size(),
    hidden_size=768,
    state_size=16,
    num_hidden_layers=32,
    max_seq_length=2048,
    pad_token_id=tokenizer.get_pad_token_id(),
    bos_token_id=tokenizer.token_to_id("[CLS]"),
    eos_token_id=tokenizer.get_pad_token_id(),
)

# embeddings = MambaEmbeddingsForCEHR(
#     config=config
# )

model = MambaForCausalLM(config=config)
# model.backbone.embeddings = embeddings
model.to(device)

model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(20600, 768)
    (layers): ModuleList(
      (0-31): 32 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=20600, bias=False)
)

In [10]:
# Load pretrained model
checkpoint = torch.load("checkpoints/mamba_pretrain/best.ckpt", map_location=device)
state_dict = checkpoint["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(20600, 768)
    (layers): ModuleList(
      (0-31): 32 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=20600, bias=False)
)

In [60]:
train_loader = DataLoader(
    test_dataset,  # train_dataset
    batch_size=3,
    shuffle=False,
)

sample = test_dataset[97]  # train_dataset[0]
sample = {key: tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}

# sample = next(iter(train_loader))
# sample = {key:tensor.to(device) for key, tensor in sample.items()}

sample

{'concept_ids': tensor([[   5,    3, 3637,  ...,    0,    0,    0]], device='cuda:0'),
 'labels': tensor([[   5,    3, 3637,  ...,    0,    0,    0]], device='cuda:0')}

In [61]:
input_ids = sample["concept_ids"].squeeze().tolist()
input_ids = input_ids[: input_ids.index(0)]
print(tokenizer.decode(input_ids))

[CLS] [VS] 0FC78ZZ 0FC98ZZ 49281041688 00338011704 00641607825 51484_2 51491_3 51498_2 52010_3 52009_4 52008_0 52007_2 52004_4 52005_3 52006_1 51250_0 51265_3 51277_0 51279_4 51301_3 50861_4 50863_3 50868_3 50878_3 50882_1 50883_4 50884_3 50885_4 50893_4 50902_2 50912_2 50931_2 50960_3 50970_2 50971_2 50983_3 51006_0 51221_4 51222_4 51248_3 50912_1 50931_2 50971_2 50868_2 50983_3 51006_0 50878_2 50882_3 51237_2 50861_4 51274_1 50863_3 50885_4 50902_1 [VE] [REG]


In [62]:
tokenizer.decode(input_ids[-10:])

'50878_2 50882_3 51237_2 50861_4 51274_1 50863_3 50885_4 50902_1 [VE] [REG]'

In [65]:
output = model.generate(
    torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),
    max_new_tokens=10,
)

tokenizer.decode(output.squeeze().tolist()[-10:])

'50882_1 50885_3 50902_3 51221_4 51222_4 51248_1 51250_0 51265_3 51277_0 51279_4'

In [None]:
# import torch
# from transformers import AutoTokenizer, MambaForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

# inputs = tokenizer(["Hello, my dog is cute", "NO", "Go to Sumeru"], padding=True, return_tensors="pt")
# outputs = model(inputs['input_ids'], labels=inputs["input_ids"])
# loss = outputs.loss
# logits = outputs.logits

In [None]:
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
# inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# model.backbone.embeddings.cache_input(
#     token_type_ids_batch = sample['type_ids'],
#     position_ids_batch = None,
#     inputs_embeds = None,
#     time_stamps = sample['time_stamps'],
#     ages = sample['ages'],
#     visit_orders = sample['visit_orders'],
#     visit_segments = sample['visit_segments']
# )

outputs = model(
    input_ids=sample["concept_ids"], labels=sample["concept_ids"], return_dict=True
)

loss = outputs.loss
logits = outputs.logits

In [None]:
class MambaPretrain(pl.LightningModule):
    """Mamba model for pretraining."""

    def __init__(
        self,
        vocab_size: int,
        embedding_size: int = 768,
        state_size: int = 16,
        num_hidden_layers: int = 32,
        expand: int = 2,
        conv_kernel: int = 4,
        learning_rate: float = 5e-5,
        dropout_prob: float = 0.1,
        padding_idx: int = 0,
        cls_idx: int = 5,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.state_size = state_size
        self.num_hidden_layers = num_hidden_layers
        self.expand = expand
        self.conv_kernel = conv_kernel
        self.learning_rate = learning_rate
        self.dropout_prob = dropout_prob
        self.padding_idx = padding_idx
        self.cls_idx = cls_idx

        self.config = MambaConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.embedding_size,
            state_size=self.state_size,
            num_hidden_layers=self.num_hidden_layers,
            expand=self.expand,
            conv_kernel=self.conv_kernel,
            pad_token_id=self.padding_idx,
            bos_token_id=self.cls_idx,
            eos_token_id=self.padding_idx,
        )

        self.model = MambaForCausalLM(config=config)

    def forward(
        self,
        input_ids: torch.Tensor,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor, ...], MambaOutput]:
        """Forward pass for the model."""

        return self.model(
            input_ids=input_ids,
            labels=input_ids,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

    def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
        """Train model on training dataset."""
        concept_ids = batch["concept_ids"]

        # Ensure use of mixed precision
        with autocast():
            loss = self(
                concept_ids,
                return_dict=True,
            ).loss

        (current_lr,) = self.lr_schedulers().get_last_lr()
        self.log_dict(
            dictionary={"train_loss": loss, "lr": current_lr},
            on_step=True,
            prog_bar=True,
        )

        return loss

    def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
        """Evaluate model on validation dataset."""
        concept_ids = batch["concept_ids"]

        # Ensure use of mixed precision
        with autocast():
            loss = self(
                concept_ids,
                return_dict=True,
            ).loss

        (current_lr,) = self.lr_schedulers().get_last_lr()
        self.log_dict(
            dictionary={"val_loss": loss, "lr": current_lr},
            on_step=True,
            prog_bar=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(
        self,
    ) -> Tuple[list[Any], list[dict[str, SequentialLR | str]]]:
        """Configure optimizers and learning rate scheduler."""
        optimizer = optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
        )

        n_steps = self.trainer.estimated_stepping_batches
        n_warmup_steps = int(0.1 * n_steps)
        n_decay_steps = int(0.9 * n_steps)

        warmup = LinearLR(
            optimizer,
            start_factor=0.01,
            end_factor=1.0,
            total_iters=n_warmup_steps,
        )
        decay = LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=0.01,
            total_iters=n_decay_steps,
        )
        scheduler = SequentialLR(
            optimizer=optimizer,
            schedulers=[warmup, decay],
            milestones=[n_warmup_steps],
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

In [None]:
"""
1. Emebeddings -> Not now
2. Padding order -> Done automatically
"""

In [None]:
class MambaEmbeddingsForCEHR(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
    def __init__(
        self,
        config: MambaConfig,
        max_position_embeddings: int = 2048,
        type_vocab_size: int = 8,
        time_embeddings_size: int = 16,
        visit_order_size: int = 3,
        layer_norm_eps: float = 1e-12,
        hidden_dropout_prob: float = 0.1,
    ) -> None:
        """Initiate wrapper class for embeddings used in BigBird CEHR classes."""
        super().__init__()
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.layer_norm_eps = layer_norm_eps
        self.hidden_dropout_prob = hidden_dropout_prob
        self.hidden_size = config.hidden_size

        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=config.pad_token_id,
        )
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            config.hidden_size,
        )
        self.token_type_embeddings = nn.Embedding(
            self.type_vocab_size,
            config.hidden_size,
        )
        self.visit_order_embeddings = nn.Embedding(
            self.max_position_embeddings,
            config.hidden_size,
        )
        self.time_embeddings = TimeEmbeddingLayer(
            embedding_size=time_embeddings_size,
            is_time_delta=True,
        )
        self.age_embeddings = TimeEmbeddingLayer(
            embedding_size=time_embeddings_size,
        )
        self.visit_segment_embeddings = VisitEmbedding(
            visit_order_size=visit_order_size,
            embedding_size=config.hidden_size,
        )
        self.scale_back_concat_layer = nn.Linear(
            config.hidden_size + 2 * time_embeddings_size,
            config.hidden_size,
        )

        self.time_stamps: Optional[torch.Tensor] = None
        self.ages: Optional[torch.Tensor] = None
        self.visit_orders: Optional[torch.Tensor] = None
        self.visit_segments: Optional[torch.Tensor] = None

        # self.LayerNorm is not snake-cased to stick with TensorFlow model
        # variable name and be able to load any TensorFlow checkpoint file.
        self.tanh = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory.
        self.position_embedding_type = getattr(
            config,
            "position_embedding_type",
            "absolute",
        )
        self.register_buffer(
            "position_ids",
            torch.arange(self.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.register_buffer(
            "token_type_ids",
            torch.zeros(self.position_ids.size(), dtype=torch.long),
            persistent=False,
        )
        # End copy

    def cache_input(
        self,
        token_type_ids_batch: Optional[torch.Tensor] = None,
        position_ids_batch: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        time_stamps: Optional[torch.Tensor] = None,
        ages: Optional[torch.Tensor] = None,
        visit_orders: Optional[torch.Tensor] = None,
        visit_segments: Optional[torch.Tensor] = None,
    ) -> None:
        """Cache values for time_stamps, ages, visit_orders & visit_segments.

        These values will be used by the forward pass to change the final embedding.

        Parameters
        ----------
        token_type_ids_batch : torch.Tensor
            The token type IDs of the input data.
        position_ids_batch : torch.Tensor
            The position IDs of the input data.
        inputs_embeds : torch.Tensor
            The embeddings of the input data.
        time_stamps : torch.Tensor
            Time stamps of the input data.
        ages : torch.Tensor
            Ages of the input data.
        visit_orders : torch.Tensor
            Visit orders of the input data.
        visit_segments : torch.Tensor
            Visit segments of the input data.
        """
        self.token_type_ids_batch = token_type_ids_batch
        self.position_ids_batch = position_ids_batch
        self.inputs_embeds = inputs_embeds
        self.time_stamps = time_stamps
        self.ages = ages
        self.visit_orders = visit_orders
        self.visit_segments = visit_segments

    def clear_cache(self) -> None:
        """Delete the tensors cached by cache_input method."""
        del (
            self.token_type_ids_batch,
            self.position_ids_batch,
            self.inputs_embeds,
            self.time_stamps,
            self.ages,
            self.visit_orders,
            self.visit_segments,
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        past_key_values_length: int = 0,
    ) -> Any:
        """Return the final embeddings of concept ids using input and cached values."""
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = self.inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if self.position_ids_batch is None:
            self.position_ids_batch = self.position_ids[
                :,
                past_key_values_length : seq_length + past_key_values_length,
            ]

        # Setting the token_type_ids to the registered buffer in constructor
        if self.token_type_ids_batch is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
                    input_shape[0],
                    seq_length,
                )
                self.token_type_ids_batch = buffered_token_type_ids_expanded
            else:
                self.token_type_ids_batch = torch.zeros(
                    input_shape,
                    dtype=torch.long,
                    device=self.position_ids.device,
                )

        if self.inputs_embeds is None:
            self.inputs_embeds = self.word_embeddings(input_ids)

        # Using cached values from a prior cache_input call
        time_stamps_embeds = self.time_embeddings(self.time_stamps)
        ages_embeds = self.age_embeddings(self.ages)
        visit_segments_embeds = self.visit_segment_embeddings(self.visit_segments)
        visit_order_embeds = self.visit_order_embeddings(self.visit_orders)

        position_embeds = self.position_embeddings(self.position_ids_batch)
        token_type_embeds = self.token_type_embeddings(self.token_type_ids_batch)

        self.inputs_embeds = torch.cat(
            (self.inputs_embeds, time_stamps_embeds, ages_embeds),
            dim=-1,
        )
        print(self.inputs_embeds.shape)
        self.inputs_embeds = self.tanh(self.scale_back_concat_layer(self.inputs_embeds))
        embeddings = self.inputs_embeds + token_type_embeds
        embeddings += position_embeds
        embeddings += visit_order_embeds
        embeddings += visit_segments_embeds

        embeddings = self.dropout(embeddings)
        embeddings = self.LayerNorm(embeddings)

        # Clear the cache for next forward call
        self.clear_cache()

        return embeddings