In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, JambaForSequenceClassification, JambaConfig
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
import evaluate
from datasets import load_dataset
import torch
import time

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [2]:
# Load the IMDB dataset from Hugging Face
imdb = load_dataset("imdb")
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)  # padding left to collator
tokenized_imdb = imdb.map(preprocess_function, batched=True)

tokenized_imdb = tokenized_imdb.remove_columns(["text"])
tokenized_imdb.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer,padding=True, max_length=512)
#outputs = model.generate(input_ids, max_new_tokens=216)


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


In [3]:
train_dataloader = DataLoader(
    tokenized_imdb["train"],
    shuffle=True,
    batch_size=16,
    collate_fn=data_collator
)

In [4]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [5]:
next(iter(train_dataloader))



{'input_ids': tensor([[    0,     0,     0,  ..., 21569,  2031, 62959],
        [    0,     0,     0,  ...,  6882,  1806, 62959],
        [    0,     0,     0,  ..., 62965, 62963, 62959],
        ...,
        [    0,     0,     0,  ...,  1857, 20003, 62959],
        [    1,  4531, 14766,  ...,  1874,  1876, 62959],
        [    0,     0,     0,  ...,  1808,  2541, 62959]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'labels': tensor([0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1])}

In [6]:
jamba_config = {
        "vocab_size": len(tokenizer.vocab),
        "hidden_size": 128,
        "num_hidden_layers": 2,
        "num_attention_heads": 8,
        "num_key_value_heads": 2,
        
        "num_experts_per_tok": 1,
        "num_experts": 1,
        "expert_layer_offset": 1,
        
        "attn_layer_period": 2,
        "attn_layer_offset": 1,
        
        "use_mamba_kernels": False, 
        "mamba_d_state": 16,
        "mamba_d_conv": 4,
        "mamba_expand": 2,
        "output_router_logits":False,
    }
# config = JambaConfig(
#     vocab_size=50257,
#     dim=512,
#     num_hidden_layers=8, # number of hidden layers in transformer block

#     hidden_size=256,
#     attn_layer_period = 8, # every 4 layer there is a attention layer
#     attention_offset = 0, # offset for attention layer
#     num_attention_heads=8,
#     num_key_value_heads = 8, # if equal with num_attention_heads, then use MultiheadAttention

#     d_conv=4,
#     d_state=256,
#     num_experts_per_tok = 2,  # a router choosing 2 experts per token
#     num_experts=2, # total number of experts
#     expert_layer_period =4, # every 4 layer there is a expert layer
# )
config = JambaConfig(**jamba_config)
model  = JambaForSequenceClassification(config)


In [7]:
model

JambaForSequenceClassification(
  (model): JambaModel(
    (embed_tokens): Embedding(65536, 128, padding_idx=0)
    (layers): ModuleList(
      (0): JambaMambaDecoderLayer(
        (mamba): JambaMambaMixer(
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          (act): SiLU()
          (in_proj): Linear(in_features=128, out_features=512, bias=False)
          (x_proj): Linear(in_features=256, out_features=40, bias=False)
          (dt_proj): Linear(in_features=8, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=128, bias=False)
          (dt_layernorm): JambaRMSNorm((8,), eps=1e-06)
          (b_layernorm): JambaRMSNorm((16,), eps=1e-06)
          (c_layernorm): JambaRMSNorm((16,), eps=1e-06)
        )
        (feed_forward): JambaMLP(
          (gate_proj): Linear(in_features=128, out_features=14336, bias=False)
          (up_proj): Linear(in_features=128, out_features=14336, bias=False)
          (

### Standard Jamba Config

In [8]:
# # standard jamba config
# from transformers import JambaConfig, JambaModel
# jamba_config = {
#     "vocab_size": 65536,
#     "hidden_size": 4096,
#     "intermediate_size": 14336,
#     "num_hidden_layers": 8,
#     "num_attention_heads": 32,
#     "num_key_value_heads": 8,

#     # Expert / MoE (optional)
#     "num_experts_per_tok": 2,
#     "num_experts": 2,
#     "expert_layer_offset": 1,

#     # Attention layer config
#     "attn_layer_period": 8,
#     "attn_layer_offset": 1,

#     # Mamba-specific config
#     "use_mamba_kernels": True,
#     "mamba_d_state": 16,
#     "mamba_d_conv": 4,
#     "mamba_expand": 2,
#     "mamba_dt_rank": "auto",
#     "mamba_conv_bias": True,
#     "mamba_proj_bias": False,
# }

# config = JambaConfig(**jamba_config)
# model = JambaModel(config)
# model

In [9]:
model.parameters()

<generator object Module.parameters at 0x7f9238e1f5a0>

In [10]:
import numpy as np
accuracy_metric = evaluate.load("accuracy")

In [11]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
def get_model_size_in_mb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        
    total_size = param_size + buffer_size  # in bytes
    return total_size / (1024 ** 2)  # convert to MB


size_mb = get_model_size_in_mb(model)
print(f"Model size on GPU: {size_mb:.2f} MB")

Model size on GPU: 74.60 MB


In [14]:
# Count total number of parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print in millions
print(f"Total Parameters: {total_params:,} ({total_params / 1e6:.2f}M)")
print(f"Trainable Parameters: {trainable_params:,} ({trainable_params / 1e6:.2f}M)")

# Approximate model size in MB (assuming 32-bit floats = 4 bytes per parameter)
model_size_mb = trainable_params * 4 / (1024 ** 2)
print(f"Approximate Model Size: {model_size_mb:.2f} MB")

Total Parameters: 19,557,032 (19.56M)
Trainable Parameters: 19,557,032 (19.56M)
Approximate Model Size: 74.60 MB


In [15]:
def get_tensor_size_in_mb(tensor):
    return tensor.nelement() * tensor.element_size() / (1024 ** 2)


In [16]:
import torch.nn as nn 
import torch.optim as optim

In [17]:
torch.cuda.empty_cache()
model.to(device)
# Loss function and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop

epochs = 1
totaltime, endtotal = 0,0
for epoch in range(epochs):
    totaltime = time.time()
    model.train() 
    total_loss = 0
    len_data = 0
    for batch in train_dataloader:
        start_time = time.time()
        optimizer.zero_grad()  # Zero the gradients
        inputs = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['labels'].to(device)
        #print("Data_size", get_tensor_size_in_mb(inputs)+ get_tensor_size_in_mb(attention_mask) + get_tensor_size_in_mb(targets))
        # Forward pass
        outputs = model(input_ids = inputs, attention_mask=attention_mask,labels = targets) 
        loss = outputs.loss
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        end_time = time.time()
        total_loss += loss
        print("Time taken for batch: ", end_time - start_time)

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}")
    endtotal = time.time()
    print("Total time taken for epoch: ", endtotal - totaltime)
print("Training complete!")

Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was provided, so no cache will be returned.


Time taken for batch:  11.768513917922974
Time taken for batch:  2.6130831241607666


OutOfMemoryError: CUDA out of memory. Tried to allocate 1014.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 818.69 MiB is free. Including non-PyTorch memory, this process has 9.94 GiB memory in use. Of the allocated memory 9.43 GiB is allocated by PyTorch, and 327.96 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
"https://stackoverflow.com/questions/56805951/valueerror-error-initializing-torch-distributed-using-env-rendezvous-enviro"

SyntaxError: invalid syntax (2771961923.py, line 1)