In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, JambaForSequenceClassification, JambaConfig
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
import evaluate
from datasets import load_dataset
import torch
import time

In [None]:
# Load the IMDB dataset from Hugging Face
imdb = load_dataset("imdb")
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)  # padding left to collator
tokenized_imdb = imdb.map(preprocess_function, batched=True)

tokenized_imdb = tokenized_imdb.remove_columns(["text"])
tokenized_imdb.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
#outputs = model.generate(input_ids, max_new_tokens=216)


In [5]:
train_dataloader = DataLoader(
    tokenized_imdb["train"],
    shuffle=True,
    batch_size=8,
    collate_fn=data_collator
)

In [6]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [7]:
next(iter(train_dataloader))

{'input_ids': tensor([[    0,     0,     0,  ...,  4007,  2908, 62959],
        [    0,     0,     0,  ...,  1808,  2686, 62959],
        [    0,     0,     0,  ..., 63009,  3910, 62959],
        ...,
        [    1,  3921,  2100,  ..., 19953,  2482, 25691],
        [    0,     0,     0,  ...,  6868,  1895, 62959],
        [    0,     0,     0,  ...,  1831, 17135, 62959]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'labels': tensor([0, 0, 0, 0, 1, 0, 0, 1])}

In [None]:
jamba_config = {
    "vocab_size": len(tokenizer.vocab),
    "hidden_size": 256,
    "intermediate_size": 14336,
    "num_hidden_layers": 2,
    "num_attention_heads": 4,
    "num_key_value_heads": 4,

    # Expert / MoE (optional)
    "num_experts_per_tok": 1,
    "num_experts": 2,
    "expert_layer_offset": 1,

    # Attention layer config
    "attn_layer_period": 1,
    "attn_layer_offset": 1,

    # Mamba-specific config
    "use_mamba_kernels": False,
    "mamba_d_state": 16,
    "mamba_d_conv": 4,
    "mamba_expand": 2,
}
# config = JambaConfig(
#     vocab_size=50257,
#     dim=512,
#     num_hidden_layers=8, # number of hidden layers in transformer block

#     hidden_size=256,
#     attn_layer_period = 8, # every 4 layer there is a attention layer
#     attention_offset = 0, # offset for attention layer
#     num_attention_heads=8,
#     num_key_value_heads = 8, # if equal with num_attention_heads, then use MultiheadAttention

#     d_conv=4,
#     d_state=256,
#     num_experts_per_tok = 2,  # a router choosing 2 experts per token
#     num_experts=2, # total number of experts
#     expert_layer_period =4, # every 4 layer there is a expert layer
# )
config = JambaConfig(**jamba_config)
model  = JambaForSequenceClassification(config)


In [9]:
model

JambaForSequenceClassification(
  (model): JambaModel(
    (embed_tokens): Embedding(65536, 256, padding_idx=0)
    (layers): ModuleList(
      (0): JambaMambaDecoderLayer(
        (mamba): JambaMambaMixer(
          (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          (act): SiLU()
          (in_proj): Linear(in_features=256, out_features=1024, bias=False)
          (x_proj): Linear(in_features=512, out_features=48, bias=False)
          (dt_proj): Linear(in_features=16, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=256, bias=False)
          (dt_layernorm): JambaRMSNorm((16,), eps=1e-06)
          (b_layernorm): JambaRMSNorm((16,), eps=1e-06)
          (c_layernorm): JambaRMSNorm((16,), eps=1e-06)
        )
        (feed_forward): JambaMLP(
          (gate_proj): Linear(in_features=256, out_features=14336, bias=False)
          (up_proj): Linear(in_features=256, out_features=14336, bias=False)
        

### Standard Jamba Config

In [None]:
# # standard jamba config
# from transformers import JambaConfig, JambaModel
# jamba_config = {
#     "vocab_size": 65536,
#     "hidden_size": 4096,
#     "intermediate_size": 14336,
#     "num_hidden_layers": 8,
#     "num_attention_heads": 32,
#     "num_key_value_heads": 8,

#     # Expert / MoE (optional)
#     "num_experts_per_tok": 2,
#     "num_experts": 2,
#     "expert_layer_offset": 1,

#     # Attention layer config
#     "attn_layer_period": 8,
#     "attn_layer_offset": 1,

#     # Mamba-specific config
#     "use_mamba_kernels": True,
#     "mamba_d_state": 16,
#     "mamba_d_conv": 4,
#     "mamba_expand": 2,
#     "mamba_dt_rank": "auto",
#     "mamba_conv_bias": True,
#     "mamba_proj_bias": False,
# }

# config = JambaConfig(**jamba_config)
# model = JambaModel(config)
# model

JambaModel(
  (embed_tokens): Embedding(65536, 4096, padding_idx=0)
  (layers): ModuleList(
    (0): JambaMambaDecoderLayer(
      (mamba): JambaMambaMixer(
        (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)
        (act): SiLU()
        (in_proj): Linear(in_features=4096, out_features=16384, bias=False)
        (x_proj): Linear(in_features=8192, out_features=288, bias=False)
        (dt_proj): Linear(in_features=256, out_features=8192, bias=True)
        (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
        (dt_layernorm): JambaRMSNorm((256,), eps=1e-06)
        (b_layernorm): JambaRMSNorm((16,), eps=1e-06)
        (c_layernorm): JambaRMSNorm((16,), eps=1e-06)
      )
      (feed_forward): JambaMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias

In [10]:
import numpy as np
accuracy_metric = evaluate.load("accuracy")

In [11]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
def get_model_size_in_mb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        
    total_size = param_size + buffer_size  # in bytes
    return total_size / (1024 ** 2)  # convert to MB


size_mb = get_model_size_in_mb(model)
print(f"Model size on GPU: {size_mb:.2f} MB")

Model size on GPU: 193.35 MB


In [14]:
def get_tensor_size_in_mb(tensor):
    return tensor.nelement() * tensor.element_size() / (1024 ** 2)


In [15]:
import torch.nn as nn 
import torch.optim as optim

In [16]:
torch.cuda.empty_cache()
model.to(device)
# Loss function and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train() 
    total_loss = 0
    len_data = 0
    for batch in train_dataloader:
        start_time = time.time()
        optimizer.zero_grad()  # Zero the gradients
        inputs = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['labels'].to(device)
        #print("Data_size", get_tensor_size_in_mb(inputs)+ get_tensor_size_in_mb(attention_mask) + get_tensor_size_in_mb(targets))
        # Forward pass
        outputs = model(input_ids = inputs, attention_mask=attention_mask,labels = targets) 
        loss = outputs.loss
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        end_time = time.time()
        total_loss += loss
        print("Time taken for batch: ", end_time - start_time)
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}")
print("Training complete!")

Data_size 0.07501220703125


Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was provided, so no cache will be returned.


Time taken for batch:  20.96748375892639
Data_size 0.12152099609375
Time taken for batch:  19.315913438796997
Data_size 0.07318115234375
Time taken for batch:  7.181234836578369
Data_size 0.03680419921875
Time taken for batch:  1.962125539779663
Data_size 0.06378173828125
Time taken for batch:  5.347221374511719
Data_size 0.06329345703125
Time taken for batch:  5.368344306945801
Data_size 0.14129638671875
Time taken for batch:  26.02852749824524
Data_size 0.09588623046875
Time taken for batch:  12.141685009002686
Data_size 0.12298583984375
Time taken for batch:  19.72804594039917
Data_size 0.07293701171875
Time taken for batch:  7.188098907470703
Data_size 0.04217529296875
Time taken for batch:  2.4103729724884033
Data_size 0.06085205078125
Time taken for batch:  4.853060245513916
Data_size 0.08587646484375
Time taken for batch:  9.708363056182861
Data_size 0.08660888671875
Time taken for batch:  9.91413140296936
Data_size 0.08880615234375
Time taken for batch:  10.418190479278564
Data

OutOfMemoryError: CUDA out of memory. Tried to allocate 884.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 524.69 MiB is free. Including non-PyTorch memory, this process has 10.23 GiB memory in use. Of the allocated memory 9.35 GiB is allocated by PyTorch, and 710.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

25000