# Fine Tune Jamba Base on Amazon SageMaker

This notebook shows how to fine tune Jamba Base on Amazon SageMaker. This was tested on a SageMaker ml.p4d.24xlarge notebook instance. Make sure the notebook instance has at least 200GB of storage.

In [None]:
!pip3 install --root-user-action=ignore  setuptools==69.5.1
!pip3 install --root-user-action=ignore trl==0.9.4
!pip3 install --root-user-action=ignore peft==0.5.0
!pip3 install --root-user-action=ignore transformers==4.42.3
!pip3 install --root-user-action=ignore tensorboard==2.17.0
!pip3 install --root-user-action=ignore deepspeed==0.14.4
!pip3 install --root-user-action=ignore accelerate==0.32.1
!pip install mamba-ssm causal-conv1d>=1.2.0
!pip install -U bitsandbytes
!pip install flash-attn --no-build-isolation

Due to the size of the Jamba Base model (~100 GB) we need to move the huggingface cache to the EBS volume that is supporting the SageMaker notebook instance.

In [1]:
import os
# Set the path for Hugging Face cache
hf_home_path = os.path.expanduser('~/SageMaker/.cache/huggingface')
os.environ['HF_HOME'] = hf_home_path
print(f"HF_HOME set to: {hf_home_path}")

# Set the path for Transformers cache
tf_home_path = os.path.expanduser('~/SageMaker/.cache/transformers')
os.environ['TRANSFORMERS_CACHE'] = tf_home_path
print(f"TRANSFORMERS_CACHE set to: {tf_home_path}")

# Optional: Set the datasets cache
datasets_cache_path = os.path.expanduser('~/SageMaker/.cache/huggingface_datasets')
os.environ['HF_DATASETS_CACHE'] = datasets_cache_path
print(f"HF_DATASETS_CACHE set to: {datasets_cache_path}")


HF_HOME set to: /home/ec2-user/SageMaker/.cache/huggingface
TRANSFORMERS_CACHE set to: /home/ec2-user/SageMaker/.cache/transformers
HF_DATASETS_CACHE set to: /home/ec2-user/SageMaker/.cache/huggingface_datasets


## Perform finetuneing
Next we will perform the actual fine tuning using Lora. Running 3 epochs took about 12 minutes on a ml.p4d.24xlarge notebook instance. You can uncomment the quantization component if you want to quantize the model. 

In [None]:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments,  BitsAndBytesConfig
import os
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
import time



#tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
#model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map='auto')

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
#quantization_config = BitsAndBytesConfig(load_in_8bit=True,
#                                         llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             #quantization_config=quantization_config,
                                             device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)

# Start the timer
start_time = time.time()
trainer.train()
end_time = time.time()
