In [77]:
import os
from typing import Any, Dict, Optional, Tuple, Union

import glob
import json
import os
from itertools import chain
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import pandas as pd
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import (
    BatchEncoding,
    PreTrainedTokenizerFast,
    AutoTokenizer,
    MambaConfig,
    MambaModel,
    MambaForCausalLM,
    MambaPreTrainedModel
)

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from torch import nn, optim
from torch.cuda.amp import autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput

# from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel

ROOT = "/h/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.models.embeddings import *
from odyssey.data.dataset import PretrainDataset, PretrainDatasetDecoder
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.model_utils import (
    get_run_id,
    load_config,
    load_pretrain_data,
    load_finetune_data,
)
from odyssey.utils.utils import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class args:
    data_dir = 'odyssey/data/bigbird_data'
    sequence_file = 'patient_sequences_2048_multi.parquet'
    id_file = 'dataset_2048_multi.pkl'
    vocab_dir = 'odyssey/data/vocab'
    max_len = 2048
    mask_prob = 0.15
    tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']
    balance_guide = {'mortality_1month': 0.5, 'los_1week': 0.5, 'c0': 0.5, 'c1': 0.5, 'c2': 0.5}

In [3]:
# Setup tokenizer
tokenizer = ConceptTokenizer(data_dir=args.vocab_dir)
tokenizer.fit_on_vocab()


# Setup data
# pre_data = load_pretrain_data(
#         args.data_dir,
#         f'patient_sequences/{args.sequence_file}',
#         f'patient_id_dict/{args.id_file}',
# )
# train_dataset = PretrainDatasetDecoder(
#         data=pre_data,
#         tokenizer=tokenizer,
#         max_len=args.max_len,
# )


_, fine_test = load_finetune_data(
    args.data_dir, args.sequence_file, args.id_file, "few_shot", "all"
)
test_dataset = PretrainDatasetDecoder(
    data=fine_test,
    tokenizer=tokenizer,
    max_len=args.max_len,
)

In [4]:
config = MambaConfig(
    vocab_size=tokenizer.get_vocab_size(),
    hidden_size=768,
    state_size=16,
    num_hidden_layers=32,
    max_seq_length=2048,
    pad_token_id=tokenizer.get_pad_token_id(),
    bos_token_id=tokenizer.token_to_id("[CLS]"),
    eos_token_id=tokenizer.get_pad_token_id(),
)

# embeddings = MambaEmbeddingsForCEHR(
#     config=config
# )

model = MambaForCausalLM(config=config)
# model.backbone.embeddings = embeddings
model.to(device)

model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(20600, 768)
    (layers): ModuleList(
      (0-31): 32 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=20600, bias=False)
)

In [5]:
# Load pretrained model
checkpoint = torch.load("checkpoints/mamba_pretrain/best.ckpt", map_location=device)
state_dict = checkpoint["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(20600, 768)
    (layers): ModuleList(
      (0-31): 32 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=20600, bias=False)
)

In [35]:
train_loader = DataLoader(
        decoder_dataset, #test_dataset, #train_dataset
        batch_size=3,
        shuffle=False,
    )

sample = decoder_dataset[2323] #test_dataset[8765] #train_dataset[0]
task = sample.pop('task')
sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}
sample['task'] = task

# sample = next(iter(train_loader))
# sample = {key:tensor.to(device) for key, tensor in sample.items()}

sample

{'concept_ids': tensor([[20592,     3, 17326,  ...,     0,     0,     0]], device='cuda:0'),
 'labels': tensor([1], device='cuda:0'),
 'task': 'mortality_1month'}

In [36]:
input_ids = sample["concept_ids"].squeeze().tolist()
input_ids = input_ids[: input_ids.index(0)]
print(tokenizer.decode(input_ids))

[MOR_1M] [VS] 58160087546 00904516561 00182853489 00574705050 00121054410 66553000401 00310027539 00006003121 00456320563 62856024541 00310027539 00172531210 00338069104 51006_2 50983_4 50971_2 50970_2 50960_3 50931_1 50912_2 50902_4 50893_3 50882_2 50868_2 51301_1 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 00172531110 10432017002 10432017002 51006_2 50983_4 50971_3 50970_2 50960_3 50931_4 50912_3 50902_3 50893_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_1 51250_0 51248_1 51222_4 51221_4 00904224461 51301_2 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 51006_2 50983_4 50971_3 50970_2 50960_2 50931_2 50912_3 50902_2 50893_3 50882_3 50868_4 [VE] [REG] [W_0] [VS] 8938 8838 8744 51006_2 50983_4 50971_2 50931_1 50912_2 50902_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_2 51256_1 51254_3 51250_1 51248_1 51244_3 51222_4 51221_4 51200_3 51146_2 63323026201 00713016550 00182844789 62856024541 00182853489 00904516561 00006003121 00121065721 10432017002 00182844789 6

In [28]:
tokenizer.decode(input_ids[-10:])

'50931_1 50912_2 50902_4 50893_3 50882_3 50868_1 60505258500 00310027539 [VE] [MOR_1M]'

In [29]:
output = model.generate(
    torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),
    max_new_tokens=10,
)

tokenizer.decode(output.squeeze().tolist()[-10:])

'50970_1 50960_2 50931_0 50912_2 50902_4 50893_3 50882_3 50868_1 [VE] [REG]'

In [None]:
# import torch
# from transformers import AutoTokenizer, MambaForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

# inputs = tokenizer(["Hello, my dog is cute", "NO", "Go to Sumeru"], padding=True, return_tensors="pt")
# outputs = model(inputs['input_ids'], labels=inputs["input_ids"])
# loss = outputs.loss
# logits = outputs.logits

In [None]:
# model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
# inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# model.backbone.embeddings.cache_input(
#     token_type_ids_batch = sample['type_ids'],
#     position_ids_batch = None,
#     inputs_embeds = None,
#     time_stamps = sample['time_stamps'],
#     ages = sample['ages'],
#     visit_orders = sample['visit_orders'],
#     visit_segments = sample['visit_segments']
# )

# outputs = model(
#     input_ids=sample["concept_ids"], labels=sample["concept_ids"], return_dict=True
# )

# loss = outputs.loss
# logits = outputs.logits

In [37]:
# model = model.backbone
outputs = model(
    input_ids=sample["concept_ids"], return_dict=True
)

last_hidden_states = outputs.last_hidden_state
last_hidden_states.shape

torch.Size([1, 2048, 768])

In [55]:
config.hidden_act

'silu'

In [38]:
classifier = torch.nn.Linear(config.hidden_size, 2, bias=False).to(device)
logits = classifier(last_hidden_states)
logits

tensor([[[ 0.1596, -0.0591],
         [ 0.2433, -1.0078],
         [-0.3050, -0.3962],
         ...,
         [ 0.4917, -0.3203],
         [ 0.4917, -0.3203],
         [ 0.4917, -0.3203]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [44]:
sample["concept_ids"].squeeze()[204]

tensor(20592, device='cuda:0')

In [40]:
sequence_lengths = torch.eq(sample['concept_ids'], 0).int().argmax(-1) - 1
sequence_lengths        

tensor([204], device='cuda:0')

In [54]:
pooled_logits = logits[torch.arange(1, device=device), sequence_lengths]
pooled_logits

tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=<IndexBackward0>)

In [53]:
pooled_last_hidden_states = last_hidden_states[torch.arange(1, device=device), sequence_lengths]
classifier(pooled_last_hidden_states)

tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=<MmBackward0>)

In [72]:
import copy
config_copy = copy.deepcopy(config)
config_copy.classifier_dropout = 0.1
head = MambaClassificationHead(config_copy).to(device)
head(pooled_last_hidden_states)

tensor([[-0.0512, -0.2630]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [46]:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1,2), torch.tensor([0]).to(device).view(-1))
loss

tensor(0.3221, device='cuda:0', grad_fn=<NllLossBackward0>)

In [None]:
outputs

In [None]:
inputs['input_ids'].shape

In [None]:
last_hidden_states[:, 0, :].shape

In [None]:
from odyssey.models.cehr_mamba.model import MambaPretrain
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss



In [None]:
"""
1. Emebeddings -> Not now
2. Padding order -> Done automatically

---
Finetuning Approach:
    1. Replace the first and last REG token with the class token
2. Use the last hiddent state of the last token for class prediction
3. Ourselves!

4. Dataset refactoring (inheritance, what to return, etc)
"""

In [12]:
import random
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
from torch.utils.data import Dataset

from odyssey.data.tokenizer import ConceptTokenizer, truncate_and_pad

TASK_INDEX = 1
LABEL_INDEX = 2
CUTOFF_INDEX = 3


class FinetuneDatasetDecoder(Dataset):
    """Dataset for finetuning a decoder-based model.

    Parameters
    ----------
    data : pd.DataFrame
        The input data containing sequences to be tokenized and masked.
    tokenizer : ConceptTokenizer
        An instance of the ConceptTokenizer class used for tokenizing sequences.
    tasks : List[str]
        A list of tasks (labels) that need to be predicted.
    balance_guide : Optional[Dict[str, float]], optional
        A dictionary containing the desired positive ratios for each task,
        by default None.
    max_len : int, optional
        The maximum length of the tokenized sequences, by default 2048.
    nan_indicator : int, optional
        Value used to represent missing labels in the dataset, by default -1.

    Attributes
    ----------
    data : pd.DataFrame
        Stores the input data.
    tokenizer : ConceptTokenizer
        Tokenizer used for tokenizing sequences.
    tasks : List[str]
        A list of tasks (labels) that need to be predicted.
    balance_guide : Optional[Dict[str, float]]
        A dictionary containing the desired positive ratios for each task.
    max_len : int
        Maximum length of the tokenized sequences.
    nan_indicator : int
        Value used to represent missing labels in the dataset.
    task_to_index : Dict[str, List[Tuple[int, str, int, Optional[int]]]]
        A dictionary mapping each task to a list of tuples containing the
        index, task, label, and cutoff.
    index_mapper : List[Tuple[int, str, int, Optional[int]]]
        A list of all datapoints to be used by __getitem__.
    """

    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: ConceptTokenizer,
        tasks: List[str],
        balance_guide: Optional[Dict[str, float]] = None,
        max_len: int = 2048,
        nan_indicator: int = -1,
    ):
        """Initiate the class."""
        super().__init__()

        self.data = data
        self.tokenizer = tokenizer
        self.tasks = tasks  # List of tasks for which the model is being finetuned.
        self.balance_guide = balance_guide
        self.max_len = max_len
        self.nan_indicator = (
            nan_indicator  # Value used to indicate missing data in labels.
        )

        # Precompute indices for quick mapping in __getitem__ that
        # exclude missing labels.
        # This helps in filtering out entries where the label is missing
        # for the specified tasks.
        self.task_to_index = {task: [] for task in self.tasks}
        self.data.reset_index(drop=True, inplace=True)

        for patient in self.data.itertuples():
            index = patient.Index

            for task in self.tasks:
                label_col = f"label_{task}"
                # Skip this task for the current patient if the label is missing.
                if getattr(patient, label_col) == self.nan_indicator:
                    continue

                label = getattr(patient, label_col)
                # Check for the existence of a task-specific cutoff in the data,
                # else use None.
                if f"cutoff_{task}" in self.data.columns:
                    cutoff = getattr(patient, f"cutoff_{task}")
                else:
                    cutoff = None
                # Append a tuple containing the necessary information
                # for training to index_mapper.
                datapoint = (index, task, label, cutoff)
                self.task_to_index[task].append(datapoint)

        # Balance labels for specified tasks
        if self.balance_guide:
            for task in self.balance_guide:
                self.balance_labels(task=task, positive_ratio=self.balance_guide[task])

        # Create a list of all datapoints to be used by __getitem__
        self.index_mapper = [
            datapoints
            for task_data in self.task_to_index.values()
            for datapoints in task_data
        ]
        del self.task_to_index

    def __len__(self) -> int:
        """Return the length of dataset."""
        return len(self.index_mapper)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get data at corresponding index.

        Parameters
        ----------
        idx : int
            The index of the data to be retrieved.

        Returns
        -------
        Dict[str, Any]
            A dictionary containing all different token sequences along with labels.
        """
        index, task, labels, cutoff = self.index_mapper[idx]
        data = self.data.iloc[index]

        # Swap the first and last token with the task token.
        data[f"event_tokens_{self.max_len}"][0] = self.tokenizer.task_to_token(task)
        data[f"event_tokens_{self.max_len}"][-1] = self.tokenizer.task_to_token(task)

        # Truncate and pad the data to the specified cutoff.
        data = truncate_and_pad(data, cutoff, self.max_len)

        # Prepare model input
        tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
        concept_ids = tokenized_input["input_ids"].squeeze()
        labels = torch.tensor(labels)

        return {
            "concept_ids": concept_ids,
            "labels": labels,
            "task": task
        }

    def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
        """Tokenize the sequence and return input_ids and attention mask.

        Parameters
        ----------
        sequence : Union[str, List[str]]
            The sequence to be tokenized.

        Returns
        -------
        Any
            A dictionary containing input_ids and attention_mask.

        """
        return self.tokenizer(sequence, max_length=self.max_len)

    def balance_labels(self, task: str, positive_ratio: float) -> None:
        """Balance the labels for the specified task in the dataset.

        This function modifies the dataset to ensure that the ratio of positive samples
        to the total number of samples matches the specified positive_ratio,
        while keeping all positive data points.

        Parameters
        ----------
        task : str
            The task for which the labels need to be balanced.
        positive_ratio : float
            The desired positive ratio for the task.

        """
        # Separate positive and negative datapoints
        datapoints = self.task_to_index[task]
        positives = [data for data in datapoints if data[LABEL_INDEX] == 1]
        negatives = [data for data in datapoints if data[LABEL_INDEX] == 0]

        # Calculate the total number of samples needed to achieve the
        # desired positive ratio
        num_positives = len(positives)
        total_needed = int(num_positives / positive_ratio) - num_positives
        num_negatives_to_keep = min(len(negatives), total_needed)

        # Randomly select the negatives to keep
        negatives_kept = random.sample(negatives, num_negatives_to_keep)

        # Combine the kept negatives with all positives
        self.task_to_index[task] = positives + negatives_kept


decoder_dataset = FinetuneDatasetDecoder(
        data=fine_test,
        tokenizer=tokenizer,
        max_len=args.max_len,
        tasks=args.tasks,
        balance_guide=args.balance_guide,
)
decoder_dataset[12112]

{'concept_ids': tensor([20593,     3, 13054,  ...,     0,     0,     0]),
 'labels': tensor(0),
 'task': 'los_1week'}

In [None]:
class MambaEmbeddingsForCEHR(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
    def __init__(
        self,
        config: MambaConfig,
        max_position_embeddings: int = 2048,
        type_vocab_size: int = 8,
        time_embeddings_size: int = 16,
        visit_order_size: int = 3,
        layer_norm_eps: float = 1e-12,
        hidden_dropout_prob: float = 0.1,
    ) -> None:
        """Initiate wrapper class for embeddings used in BigBird CEHR classes."""
        super().__init__()
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.layer_norm_eps = layer_norm_eps
        self.hidden_dropout_prob = hidden_dropout_prob
        self.hidden_size = config.hidden_size

        self.word_embeddings = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            padding_idx=config.pad_token_id,
        )
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            config.hidden_size,
        )
        self.token_type_embeddings = nn.Embedding(
            self.type_vocab_size,
            config.hidden_size,
        )
        self.visit_order_embeddings = nn.Embedding(
            self.max_position_embeddings,
            config.hidden_size,
        )
        self.time_embeddings = TimeEmbeddingLayer(
            embedding_size=time_embeddings_size,
            is_time_delta=True,
        )
        self.age_embeddings = TimeEmbeddingLayer(
            embedding_size=time_embeddings_size,
        )
        self.visit_segment_embeddings = VisitEmbedding(
            visit_order_size=visit_order_size,
            embedding_size=config.hidden_size,
        )
        self.scale_back_concat_layer = nn.Linear(
            config.hidden_size + 2 * time_embeddings_size,
            config.hidden_size,
        )

        self.time_stamps: Optional[torch.Tensor] = None
        self.ages: Optional[torch.Tensor] = None
        self.visit_orders: Optional[torch.Tensor] = None
        self.visit_segments: Optional[torch.Tensor] = None

        # self.LayerNorm is not snake-cased to stick with TensorFlow model
        # variable name and be able to load any TensorFlow checkpoint file.
        self.tanh = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory.
        self.position_embedding_type = getattr(
            config,
            "position_embedding_type",
            "absolute",
        )
        self.register_buffer(
            "position_ids",
            torch.arange(self.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )
        self.register_buffer(
            "token_type_ids",
            torch.zeros(self.position_ids.size(), dtype=torch.long),
            persistent=False,
        )
        # End copy

    def cache_input(
        self,
        token_type_ids_batch: Optional[torch.Tensor] = None,
        position_ids_batch: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        time_stamps: Optional[torch.Tensor] = None,
        ages: Optional[torch.Tensor] = None,
        visit_orders: Optional[torch.Tensor] = None,
        visit_segments: Optional[torch.Tensor] = None,
    ) -> None:
        """Cache values for time_stamps, ages, visit_orders & visit_segments.

        These values will be used by the forward pass to change the final embedding.

        Parameters
        ----------
        token_type_ids_batch : torch.Tensor
            The token type IDs of the input data.
        position_ids_batch : torch.Tensor
            The position IDs of the input data.
        inputs_embeds : torch.Tensor
            The embeddings of the input data.
        time_stamps : torch.Tensor
            Time stamps of the input data.
        ages : torch.Tensor
            Ages of the input data.
        visit_orders : torch.Tensor
            Visit orders of the input data.
        visit_segments : torch.Tensor
            Visit segments of the input data.
        """
        self.token_type_ids_batch = token_type_ids_batch
        self.position_ids_batch = position_ids_batch
        self.inputs_embeds = inputs_embeds
        self.time_stamps = time_stamps
        self.ages = ages
        self.visit_orders = visit_orders
        self.visit_segments = visit_segments

    def clear_cache(self) -> None:
        """Delete the tensors cached by cache_input method."""
        del (
            self.token_type_ids_batch,
            self.position_ids_batch,
            self.inputs_embeds,
            self.time_stamps,
            self.ages,
            self.visit_orders,
            self.visit_segments,
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        past_key_values_length: int = 0,
    ) -> Any:
        """Return the final embeddings of concept ids using input and cached values."""
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = self.inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if self.position_ids_batch is None:
            self.position_ids_batch = self.position_ids[
                :,
                past_key_values_length : seq_length + past_key_values_length,
            ]

        # Setting the token_type_ids to the registered buffer in constructor
        if self.token_type_ids_batch is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
                    input_shape[0],
                    seq_length,
                )
                self.token_type_ids_batch = buffered_token_type_ids_expanded
            else:
                self.token_type_ids_batch = torch.zeros(
                    input_shape,
                    dtype=torch.long,
                    device=self.position_ids.device,
                )

        if self.inputs_embeds is None:
            self.inputs_embeds = self.word_embeddings(input_ids)

        # Using cached values from a prior cache_input call
        time_stamps_embeds = self.time_embeddings(self.time_stamps)
        ages_embeds = self.age_embeddings(self.ages)
        visit_segments_embeds = self.visit_segment_embeddings(self.visit_segments)
        visit_order_embeds = self.visit_order_embeddings(self.visit_orders)

        position_embeds = self.position_embeddings(self.position_ids_batch)
        token_type_embeds = self.token_type_embeddings(self.token_type_ids_batch)

        self.inputs_embeds = torch.cat(
            (self.inputs_embeds, time_stamps_embeds, ages_embeds),
            dim=-1,
        )
        print(self.inputs_embeds.shape)
        self.inputs_embeds = self.tanh(self.scale_back_concat_layer(self.inputs_embeds))
        embeddings = self.inputs_embeds + token_type_embeds
        embeddings += position_embeds
        embeddings += visit_order_embeds
        embeddings += visit_segments_embeds

        embeddings = self.dropout(embeddings)
        embeddings = self.LayerNorm(embeddings)

        # Clear the cache for next forward call
        self.clear_cache()

        return embeddings