In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, JambaForSequenceClassification, JambaConfig
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
import evaluate
from datasets import load_dataset
import torch
import time

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [2]:
# Load the IMDB dataset from Hugging Face
imdb = load_dataset("imdb")
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)  # padding left to collator
tokenized_imdb = imdb.map(preprocess_function, batched=True)

tokenized_imdb = tokenized_imdb.remove_columns(["text"])
tokenized_imdb.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
#outputs = model.generate(input_ids, max_new_tokens=216)


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [3]:
train_dataloader = DataLoader(
    tokenized_imdb["train"],
    shuffle=True,
    batch_size=8,
    collate_fn=data_collator
)

In [4]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [5]:
next(iter(train_dataloader))

{'input_ids': tensor([[    0,     0,     0,  ...,  2527,  9621, 62959],
        [    1,  6245, 62966,  ...,  2793,  8439, 13645],
        [    0,     0,     0,  ..., 62936,  2196, 62959],
        ...,
        [    0,     0,     0,  ...,  1876,  2799, 62959],
        [    0,     0,     0,  ...,  1865,  1870, 62959],
        [    0,     0,     0,  ...,  1836,  1895, 63052]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'labels': tensor([1, 1, 0, 1, 0, 1, 1, 0])}

In [None]:
jamba_config = {
    "vocab_size": len(tokenizer.vocab),
    "hidden_size": 256,
    "intermediate_size": 14336,
    "num_hidden_layers": 2,
    "num_attention_heads": 4,
    "num_key_value_heads": 4,

    # Expert / MoE (optional)
    "num_experts_per_tok": 1,
    "num_experts": 2,
    "expert_layer_offset": 1,

    # Attention layer config
    "attn_layer_period": 1,
    "attn_layer_offset": 1,

    # Mamba-specific config
    "use_mamba_kernels": False,
    "mamba_d_state": 16,
    "mamba_d_conv": 4,
    "mamba_expand": 2,
    "output_router_logits":False,
    "num_labels":2,
}
# config = JambaConfig(
#     vocab_size=50257,
#     dim=512,
#     num_hidden_layers=8, # number of hidden layers in transformer block

#     hidden_size=256,
#     attn_layer_period = 8, # every 4 layer there is a attention layer
#     attention_offset = 0, # offset for attention layer
#     num_attention_heads=8,
#     num_key_value_heads = 8, # if equal with num_attention_heads, then use MultiheadAttention

#     d_conv=4,
#     d_state=256,
#     num_experts_per_tok = 2,  # a router choosing 2 experts per token
#     num_experts=2, # total number of experts
#     expert_layer_period =4, # every 4 layer there is a expert layer
# )
config = JambaConfig(**jamba_config)
model  = JambaForSequenceClassification(config, num_labels=2)


TypeError: JambaForSequenceClassification.__init__() got an unexpected keyword argument 'num_labels'

In [7]:
model

JambaForSequenceClassification(
  (model): JambaModel(
    (embed_tokens): Embedding(65536, 256, padding_idx=0)
    (layers): ModuleList(
      (0): JambaMambaDecoderLayer(
        (mamba): JambaMambaMixer(
          (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          (act): SiLU()
          (in_proj): Linear(in_features=256, out_features=1024, bias=False)
          (x_proj): Linear(in_features=512, out_features=48, bias=False)
          (dt_proj): Linear(in_features=16, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=256, bias=False)
          (dt_layernorm): JambaRMSNorm((16,), eps=1e-06)
          (b_layernorm): JambaRMSNorm((16,), eps=1e-06)
          (c_layernorm): JambaRMSNorm((16,), eps=1e-06)
        )
        (feed_forward): JambaMLP(
          (gate_proj): Linear(in_features=256, out_features=14336, bias=False)
          (up_proj): Linear(in_features=256, out_features=14336, bias=False)
        

### Standard Jamba Config

In [None]:
# # standard jamba config
# from transformers import JambaConfig, JambaModel
# jamba_config = {
#     "vocab_size": 65536,
#     "hidden_size": 4096,
#     "intermediate_size": 14336,
#     "num_hidden_layers": 8,
#     "num_attention_heads": 32,
#     "num_key_value_heads": 8,

#     # Expert / MoE (optional)
#     "num_experts_per_tok": 2,
#     "num_experts": 2,
#     "expert_layer_offset": 1,

#     # Attention layer config
#     "attn_layer_period": 8,
#     "attn_layer_offset": 1,

#     # Mamba-specific config
#     "use_mamba_kernels": True,
#     "mamba_d_state": 16,
#     "mamba_d_conv": 4,
#     "mamba_expand": 2,
#     "mamba_dt_rank": "auto",
#     "mamba_conv_bias": True,
#     "mamba_proj_bias": False,
# }

# config = JambaConfig(**jamba_config)
# model = JambaModel(config)
# model

JambaModel(
  (embed_tokens): Embedding(65536, 4096, padding_idx=0)
  (layers): ModuleList(
    (0): JambaMambaDecoderLayer(
      (mamba): JambaMambaMixer(
        (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)
        (act): SiLU()
        (in_proj): Linear(in_features=4096, out_features=16384, bias=False)
        (x_proj): Linear(in_features=8192, out_features=288, bias=False)
        (dt_proj): Linear(in_features=256, out_features=8192, bias=True)
        (out_proj): Linear(in_features=8192, out_features=4096, bias=False)
        (dt_layernorm): JambaRMSNorm((256,), eps=1e-06)
        (b_layernorm): JambaRMSNorm((16,), eps=1e-06)
        (c_layernorm): JambaRMSNorm((16,), eps=1e-06)
      )
      (feed_forward): JambaMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias

In [10]:
import numpy as np
accuracy_metric = evaluate.load("accuracy")

In [11]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
def get_model_size_in_mb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        
    total_size = param_size + buffer_size  # in bytes
    return total_size / (1024 ** 2)  # convert to MB


size_mb = get_model_size_in_mb(model)
print(f"Model size on GPU: {size_mb:.2f} MB")

Model size on GPU: 193.35 MB


In [14]:
def get_tensor_size_in_mb(tensor):
    return tensor.nelement() * tensor.element_size() / (1024 ** 2)


In [11]:
import torch.nn as nn 
import torch.optim as optim

In [12]:
torch.cuda.empty_cache()
model.to(device)
# Loss function and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train() 
    total_loss = 0
    len_data = 0
    for batch in train_dataloader:
        start_time = time.time()
        optimizer.zero_grad()  # Zero the gradients
        inputs = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['labels'].to(device)
        #print("Data_size", get_tensor_size_in_mb(inputs)+ get_tensor_size_in_mb(attention_mask) + get_tensor_size_in_mb(targets))
        # Forward pass
        outputs = model(input_ids = inputs, attention_mask=attention_mask,labels = targets) 
        print(outputs)
        loss = outputs.loss
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        end_time = time.time()
        total_loss += loss
        print("Time taken for batch: ", end_time - start_time)
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}")
print("Training complete!")

Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was provided, so no cache will be returned.


SequenceClassifierOutputWithPast(loss=tensor(0.7621, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[ 0.2880, -0.7416],
        [ 0.2785, -0.7359],
        [ 0.2454, -0.6716],
        [ 0.2198, -0.7384],
        [ 0.1910, -0.7401],
        [ 0.0114, -0.4809],
        [ 0.0415,  0.4010],
        [ 0.0848, -0.5266]], device='cuda:0', grad_fn=<IndexBackward0>), past_key_values=None, hidden_states=None, attentions=None)
Time taken for batch:  20.169976472854614
SequenceClassifierOutputWithPast(loss=tensor(2.9234, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-2.5542,  2.6499],
        [ 0.0822,  0.1846],
        [-2.7989,  2.4625],
        [ 0.2692,  0.4268],
        [-0.7409, -0.2041],
        [ 0.0714,  0.1226],
        [-2.5638,  2.6394],
        [-2.5653,  2.6405]], device='cuda:0', grad_fn=<IndexBackward0>), past_key_values=None, hidden_states=None, attentions=None)


KeyboardInterrupt: 

25000

In [8]:
from torch.distributed.device_mesh import init_device_mesh

# creating a device mesh that connnets 8 GPUs in a host
tp_mesh = init_device_mesh("cuda",(8,))

ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set

In [None]:
"https://stackoverflow.com/questions/56805951/valueerror-error-initializing-torch-distributed-using-env-rendezvous-enviro"

SyntaxError: invalid syntax (2771961923.py, line 1)

In [2]:
from datetime import datetime

In [3]:
datetime.now().strftime("%Y%m%d-%H%M%S")

'20250421-162858'