# Fine-tune Mixtral 8*7b with PyTorch FSDP and Q-Lora on Amazon SageMaker

This notebook explains how you can fine-tune the Mixtral 8*7b model using PyTorch FSDP and Q-Lora with the help of Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index), [peft](https://huggingface.co/docs/peft/index) & [datasets](https://huggingface.co/docs/datasets/index) on Amazon SageMaker. 

**This notebook is validated and optimized to run on `ml.p4d.2xlarge` instances**

**FSDP + Q-Lora Background**

Hugging Face shares the support of Q-Lora and PyTorch FSDP (Fully Sharded Data Parallel). FSDP and Q-Lora allow you now to fine-tune Llama, Mistral-like architectures. Hugging Face PEFT is where the core logic resides, read more about it in the [PEFT documentation](https://huggingface.co/docs/peft/v0.10.0/en/accelerate/fsdp).

* [PyTorch FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) is a data/model parallelism technique that shards model across GPUs, reducing memory requirements and enabling the training of larger models more efficiently​​​​​​.
* Q-LoRA is a fine-tuning method that leverages quantization and Low-Rank Adapters to efficiently reduced computational requirements and memory footprint. 

This notebook walks you thorugh how to fine-tune open LLMs from Hugging Face using Amazon SageMaker.

## 1. Setup Development Environment

Our first step is to install Hugging Face Libraries we need on the client to correctly prepare our dataset and start our training/evaluations jobs. 

In [None]:
!pip install transformers "datasets[s3]==2.18.0" "sagemaker>=2.190.0" "py7zr" "peft==0.12.0" --upgrade --quiet

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.



In [None]:
import sagemaker
import boto3
from datasets import load_dataset
from sagemaker.pytorch import PyTorch
import matplotlib.pyplot as plt
from sagemaker.s3 import S3Downloader
import os

sess = sagemaker.Session()
sagemaker_session_bucket=None

if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
# HF dataset that we will be working with 
dataset_name="gem/viggo"

# Provide hf_token value to acccess mixtral model
os.environ['hf_token']="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"
os.environ['wandb_token']="8cb290b3427fb43b0dcb7e27bd2397ea2fbebede`"

## 2. Create and prepare the dataset

In this example, we use the GEM/viggo dataset from Hugging Face. This is a data-to-text generation dataset in the video game domain. The dataset is clean and organized with about 5,000 data points and the responses are more conversational than information seeking. This type of dataset is ideal for extracting meaningful information from customer reviews. For example, an E-Commerce site like Amazon could use a similarly formatted dataset for fine-tuning a model for NLP analysis to gauge interest in products. Thus, this dataset is a very good candidate for fine-tuning LLMs. To learn more about the viggo dataset, check out this research paper.

In [None]:
def generate_and_tokenize_prompt(data_point):
    full_prompt = f"""
    Given a target sentence, construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.
    This function should describe the target string accurately and the function must be one of the following:
    ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute']

    The attributes must be one of the following:
    ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']

    ### Target sentence:
    {data_point["target"]}

    ### Meaning representation:
    {data_point["meaning_representation"]}
    """
    return {"prompt": full_prompt.strip()}

# Load dataset from the hub
train_set = load_dataset(dataset_name, split="train")
test_set = load_dataset(dataset_name, split="test")

# Add system message to each conversation
columns_to_remove = list(train_set.features)

train_dataset = train_set.map(
    generate_and_tokenize_prompt,
    remove_columns=columns_to_remove,
    batched=False
)

test_dataset = test_set.map(
    generate_and_tokenize_prompt,
    remove_columns=columns_to_remove,
    batched=False
)

In [None]:
# Review dataset
train_dataset, train_dataset[0]

After we processed the datasets we are going to use the [FileSystem integration](https://huggingface.co/docs/datasets/filesystems) to upload our dataset to S3. We are using the `sess.default_bucket()`, adjust this if you want to store the dataset in a different S3 bucket. We will use the S3 path later in our training script.

In [None]:
# save train_dataset to s3 using our SageMaker session
input_path = f's3://{sess.default_bucket()}/datasets/mixtral'

# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
train_dataset.to_json(f"{input_path}/train/dataset.json", orient="records")
train_dataset_s3_path = f"{input_path}/train/dataset.json"
test_dataset.to_json(f"{input_path}/test/dataset.json", orient="records")
test_dataset_s3_path = f"{input_path}/test/dataset.json"

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(test_dataset_s3_path)
print(f"https://s3.console.aws.amazon.com/s3/buckets/{sess.default_bucket()}/?region={sess.boto_region_name}&prefix={input_path.split('/', 3)[-1]}/")

### Measure input length

While passing in a dataset to the LLM for fine-tuning, it's important to ensure that the inputs are all of a uniform length. To achieve this, we first visualize the distribution of the input token lengths (or alternatively, firectly find the max length). Based on these results, we identify the maximum input token length, and utilize "padding" to ensure all the inputs are of the same length.

In [None]:
def count_words(text):
    return len(text.split())

In [None]:
def plot_data_lengths(train_dataset, test_dataset):
    lengths1 = [count_words(x["prompt"]) for x in train_dataset]
    lengths2 = [count_words(x["prompt"]) for x in test_dataset]
    lengths = lengths1 + lengths2
    
    plt.figure(figsize=(10,6))
    plt.hist(lengths, bins=20, alpha=0.7, color="blue")
    plt.xlabel("prompt lengths")
    plt.ylabel("Frequency")
    plt.title("Distribution of lengths of input_ids")
    plt.show()

In [None]:
plot_data_lengths(train_dataset, test_dataset)

In [None]:
# Print out the max tokens
lengths1 = [count_words(x["prompt"]) for x in train_dataset]
lengths2 = [count_words(x["prompt"]) for x in test_dataset]
lengths = lengths1 + lengths2

max(lengths)

## 3. Fine-tune Mixtral 8*7b on Amazon SageMaker

We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers`. We prepared a script [launch_fsdp_qlora.py](../scripts/launch_fsdp_qlora.py) which will loads the dataset from disk, prepare the model, tokenizer and start the training. It usees the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. 

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a yaml file. This `yaml` will be uploaded and provided to Amazon SageMaker similar to our datasets. Below is the config file for fine-tuning Mixtral 8*7b 8B on ml.p4d.24xlarge 40GB GPUs. We are saving the config file as `args.yaml` and upload it to S3.


In [None]:
%%bash

cat > ./args.yaml <<EOF
hf_token: "${hf_token}" # Use HF token to login into Hugging Face to access the Mixtral 8*7b 8b model
wandb_token: "${wandb_token}"
model_id: "mistralai/Mixtral-8x7B-v0.1"       # Hugging Face model id
max_seq_length: 256  #512 # 2048              # max sequence length for model and packing of the dataset
# sagemaker specific parameters
train_dataset_path: "/opt/ml/input/data/train/" # path to where SageMaker saves train dataset
test_dataset_path: "/opt/ml/input/data/test/"   # path to where SageMaker saves test dataset

output_dir: "/opt/ml/model/mixtral/adapter"              # path to where SageMaker will upload the model 
# training parameters
report_to: "tensorboard"               # report metrics to tensorboard
learning_rate: 0.0003                  # learning rate 2e-4
lr_scheduler_type: "constant"          # learning rate scheduler
num_train_epochs: 1                  # number of training epochs
per_device_train_batch_size: 100       # batch size per device during training
per_device_eval_batch_size: 8         # batch size for evaluation
gradient_accumulation_steps: 1        # number of steps before performing a backward/update pass
optim: adamw_torch                     # use torch adamw optimizer
logging_steps: 10                      # log every 10 steps
save_strategy: epoch                   # save checkpoint every epoch
evaluation_strategy: epoch             # evaluate every epoch
max_grad_norm: 0.3                     # max gradient norm
warmup_ratio: 0.03                     # warmup ratio
bf16: true                             # use bfloat16 precision
tf32: true                             # use tf32 precision
gradient_checkpointing: true           # use gradient checkpointing to save memory

weight_decay: 0.01
warmup_steps: 100
# offload FSDP parameters: https://huggingface.co/docs/transformers/main/en/fsdp
fsdp: "full_shard auto_wrap" # remove offload if enough GPU memory
fsdp_config:
  backward_prefetch: "backward_pre"
  forward_prefetch: "false"
  use_orig_params: "false"
EOF

Lets upload the config file to S3. 

In [None]:
from sagemaker.s3 import S3Uploader

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

print(f"Training config uploaded to:")
print(train_config_s3_path)

# Fine-tune LoRA adapter

Below estimtor will train the model with QLoRA and will save the LoRA adapter in S3 

In [None]:
# Create SageMaker PyTorch Estimator

# define Training Job Name 
job_name = f'mixtral-8-7b-finetune'

pytorch_estimator = PyTorch(
    entry_point= 'launch_fsdp_qlora.py',
    source_dir="./scripts",
    job_name=job_name,
    base_job_name=job_name,
    max_run=5800,
    role=role,
    framework_version="2.2.0",
    py_version="py310",
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    sagemaker_session=sess,
    disable_output_compression=True,
    keep_alive_period_in_seconds=1800,
    distribution={"torch_distributed": {"enabled": True}},
    hyperparameters={
        "config": "/opt/ml/input/data/config/args.yaml" # path to TRL config which was uploaded to s3
    }
)

_Note: When using QLoRA, we only train adapters and not the full model. The [launch_fsdp_qlora.py](../scripts/fsdp/run_fsdp_qlora.py) saves the `adapter` at the end of the training to Amazon SageMaker S3 bucket (sagemaker-<region name>-<account_id>)._

We can now start our training job, with the `.fit()` method passing our S3 path to the training script.

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {
  'train': train_dataset_s3_path,
  'test': test_dataset_s3_path,
  'config': train_config_s3_path
  }

# Check input channels configured 
data

In [None]:
# starting the train job with our uploaded datasets as input
pytorch_estimator.fit(data, wait=True)

In [None]:
# Fine the job name of the last run or you can browse the console
latest_run_job_name= pytorch_estimator.latest_training_job.job_name
latest_run_job_name

# Merge base model with fine-tuned adapter in fp16

Following Steps are taken by the next estimator:
1. Load base model in fp16 precision
2. Convert adapter saved in previous step from fp32 to fp16
3. Merge the model
4. Run inference both on base model and merged model for comparison 

In [None]:
# Find S3 path for the last job that ran successfully. You can find this from the SageMaker console 

# *** Get a job name from the AWS console for the last training run or from the above cell
job_name = latest_run_job_name

def get_s3_path_from_job_name(job_name):
    # Create a Boto3 SageMaker client
    sagemaker_client = boto3.client('sagemaker')
    
    # Describe the training job
    response = sagemaker_client.describe_training_job(TrainingJobName=job_name)
    
    # Extract the model artifacts S3 path
    model_artifacts_s3_path = response['ModelArtifacts']['S3ModelArtifacts']
    
    # Extract the output path (this is the general output location)
    output_path = response['OutputDataConfig']['S3OutputPath']
    
    return model_artifacts_s3_path, output_path


model_artifacts, output_path = get_s3_path_from_job_name(job_name)


print(f"Model artifacts S3 path: {model_artifacts}")

In [None]:
adapter_dir_path=f"{model_artifacts}/mixtral/adapter/"

print(f'\nAdapter S3 Dir path:{adapter_dir_path} \n')

!aws s3 ls {adapter_dir_path}

In [None]:
# Create SageMaker PyTorch Estimator

# Define Training Job Name 
job_name = f'llama3-1-8b-merge-adapter'

pytorch_estimator_adapter = PyTorch(
    entry_point= 'merge_model_adapter.py',
    source_dir="./scripts",
    job_name=job_name,
    base_job_name=job_name,
    max_run=5800,
    role=role,
    framework_version="2.2.0",
    py_version="py310",
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    sagemaker_session=sess,
    disable_output_compression=True,
    keep_alive_period_in_seconds=1800,
    hyperparameters={
        "model_id": "mistralai/Mixtral-8x7B-v0.1",  # Hugging Face model id
        "hf_token": "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",
        "dataset_name":dataset_name
    }
)

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {
  'adapter': adapter_dir_path,
  'testdata': train_dataset_s3_path
  }

data

In [None]:
# starting the train job with our uploaded datasets as input
pytorch_estimator_adapter.fit(data, wait=True)