In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from dataset import PretrainDataset
from transformers import Mamba2Config
from modeling import Mamba2ForEHRModeling

device = "cuda:0"

config = Mamba2Config(
    vocab_size=16384,
    hidden_size=768,
    num_heads=24,
    num_hidden_layers=32,
)

model = Mamba2ForEHRModeling(config).to(torch.bfloat16).to(device)

def pad_sequences(sequences, padding_value=0):
    max_length = max(len(seq) for seq in sequences)
    padded_sequences = [seq + [padding_value] * (max_length - len(seq)) for seq in sequences]
    return padded_sequences

def collate(batch):
    result = {
        "concept_ids": [],
        "age_ids": [],
        "time_ids": [],
        "segment_ids": [],
        "visit_order_ids": [],
    }
    
    for sample in batch:
        result["concept_ids"].append(sample["concept_ids"][:4096])
        result["age_ids"].append(sample["age_ids"][:4096])
        result["time_ids"].append(sample["time_ids"][:4096])
        result["segment_ids"].append(sample["segment_ids"][:4096])
        result["visit_order_ids"].append(sample["visit_order_ids"][:4096])

    for key in result.keys():
        if key == "concept_ids":
            pad_id = train_dataset.tokenizer.pad_token_id
        else:
            pad_id = -1
        result[key] = pad_sequences(result[key], pad_id)
        result[key] = torch.tensor(result[key])
    
    return result
batch_size = 4
train_dataset = PretrainDataset(directory="./dataset")
trainloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
Loading/Tokenizing: 100%|███████████████████████████████████████████████████████████████| 14/14 [03:07<00:00, 13.40s/it]


In [2]:
from transformers import get_linear_schedule_with_warmup
import wandb

wandb.init(
    project=f"mamba-ehr-modeling",
    config={}
)


epochs = 15
gradient_accumulation_steps = 1

lr = 5e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95))
total_steps = (epochs * len(trainloader)) // (gradient_accumulation_steps)
warmup_steps = int(total_steps * 0.1)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

[34m[1mwandb[0m: Currently logged in as: [33manothy[0m ([33manothy1[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
from tqdm import tqdm
step_counter = 0

#model = torch.compile(model)

for epoch in range(epochs):
    for batch in tqdm(trainloader, desc=f"Epoch {epoch + 1}"):
        batch = {key: batch[key].to(device) for key in batch.keys()}
        output = model(
            input_ids=batch["concept_ids"],
            age_ids=batch["age_ids"],
            time_ids=batch["time_ids"],
            segment_ids=batch["segment_ids"],
            visit_order_ids=batch["visit_order_ids"],
            labels=batch["concept_ids"]
        )
        loss = output.loss

        if loss.item() < 1:
            print(loss, batch["concept_ids"])

        loss.backward()
        optimizer.step()
        scheduler.step()

        step_counter += 1

        if step_counter % 10 == 0:
            wandb.log({f"train/loss": loss.item(), "lr": optimizer.param_groups[-1]['lr']})

Epoch 1:   0%|                                                                    | 57/71694 [00:57<20:12:25,  1.02s/it]


KeyboardInterrupt: 