# Fine-tune Jamba-v0.1 on A100 - 40GB VRAM using QLoRA

In [None]:
! nvidia-smi

Fri Mar 29 00:49:59 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              46W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Install Flash Attention 2

In [None]:
! pip install ninja packaging
! pip install flash-attn --no-build-isolation



### Install Required Dependencies

In [None]:
!pip install -U "transformers>=4.39.0"
!pip install mamba-ssm "causal-conv1d>=1.2.0"
!pip install peft trl bitsandbytes

Collecting transformers>=4.39.0
  Downloading transformers-4.39.2-py3-none-any.whl (8.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.38.2
    Uninstalling transformers-4.38.2:
      Successfully uninstalled transformers-4.38.2
Successfully installed transformers-4.39.2


In [None]:
model_id = "ai21labs/Jamba-v0.1"

In [None]:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int4_skip_modules=["mamba"]
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map='auto',
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config,
    use_mamba_kernels=False #Disabling the mamba kernels since I have a recurrent error.
    )

modeling_jamba.py:   0%|          | 0.00/99.8k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/ai21labs/Jamba-v0.1:
- modeling_jamba.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json:   0%|          | 0.00/107k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/21 [00:00<?, ?it/s]

model-00001-of-00021.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00003-of-00021.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00004-of-00021.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00005-of-00021.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00006-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00007-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00008-of-00021.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00009-of-00021.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00010-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00011-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00012-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00013-of-00021.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00014-of-00021.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00015-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00016-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00017-of-00021.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

model-00018-of-00021.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

model-00019-of-00021.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00020-of-00021.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00021-of-00021.safetensors:   0%|          | 0.00/4.65G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/21 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [None]:
model.save_pretrained("/content/drive/MyDrive/jamba")
tokenizer.save_pretrained("/content/drive/MyDrive/jamba")

('/content/drive/MyDrive/jamba/tokenizer_config.json',
 '/content/drive/MyDrive/jamba/special_tokens_map.json',
 '/content/drive/MyDrive/jamba/tokenizer.model',
 '/content/drive/MyDrive/jamba/added_tokens.json',
 '/content/drive/MyDrive/jamba/tokenizer.json')

In [None]:
dataset = load_dataset("Abirate/english_quotes", split="train")

Downloading readme:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/647k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3 # 2.5e-5
)

In [None]:
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
    max_seq_length=256
)

Map:   0%|          | 0/2508 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [None]:
trainer.train()



Step,Training Loss
10,1.7129
20,2.214
30,2.7171
40,2.7197
50,2.1063
60,3.2467
70,3.6394
80,4.0368
90,4.4508
100,4.4956




Step,Training Loss
10,1.7129
20,2.214
30,2.7171
40,2.7197
50,2.1063
60,3.2467
70,3.6394
80,4.0368
90,4.4508
100,4.4956




TrainOutput(global_step=7524, training_loss=6.416611096758845, metrics={'train_runtime': 10959.7683, 'train_samples_per_second': 0.687, 'train_steps_per_second': 0.687, 'total_flos': 8.398573367643648e+16, 'train_loss': 6.416611096758845, 'epoch': 3.0})