# Fine-tune LLaMA 3.2 Vision Language Models on SageMaker JumpStart

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

---
This demo notebook illustrates how to leverage the Amazon SageMaker Python SDK to fine-tune the LLaMA 3.2 vision language model, including the 11B and 90B base and instruct versions, on your custom dataset for a visual question answering task."

---

### Model License information
---
To perform inference on these models, you need to pass custom_attributes='accept_eula=true' as part of header. This means you have read and accept the end-user-license-agreement (EULA) of the model. EULA can be found in model card description or from https://ai.meta.com/resources/models-and-libraries/llama-downloads/. By default, this notebook sets custom_attributes='accept_eula=false', so all inference requests will fail until you explicitly change this custom attribute.

Note: Custom_attributes used to pass EULA are key/value pairs. The key and value are separated by '=' and pairs are separated by ';'. If the user passes the same key more than once, the last value is kept and passed to the script handler (i.e., in this case, used for conditional logic). For example, if 'accept_eula=false; accept_eula=true' is passed to the server, then 'accept_eula=true' is kept and passed to the script handler.

---

### Set up

---
We begin by installing and upgrading necessary packages. Restart the kernel after executing the cell below for the first time.

---

In [None]:
!pip install --upgrade sagemaker datasets


Take a look at this [link](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#built-in-algorithms-with-pre-trained-model-table
) for all available models

If you wish to fine-tune Llama 3.2 Vision Base/Instruct models, you can set model_id to: 

| Model | Model ID | All Supported Instances Types for fine-tuning |
| - | - | - |
| Llama 3.2 11B | meta-vlm-llama-3-2-11b-vision | ml.g5.48xlarge, ml.p4de.24xlarge, ml.p4d.24xlarge |
| Llama 3.2 11B Instruct | meta-vlm-llama-3-2-11b-vision-instruct | ml.g5.48xlarge, ml.p4de.24xlarge   |
| Llama 3.2 90B | meta-vlm-llama-3-2-90b-vision | ml.g5.48xlarge|
| Llama 3.2 90B Instruct | meta-vlm-llama-3-2-90b-vision-instruct | ml.g5.48xlarge |


In [None]:
model_id, model_version = "meta-vlm-llama-3-2-11b-vision-instruct", "*"

## Dataset preparation for visual question and answering task

---
We currently offer fine-tuning for vision question and answering (VQA) task. The vision language model can be fine-tuned on the image-text dataset, provided that the data  
is formulated in the expected format. The resulting model can be further deployed for inference. Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** 
  - A train and an optional validation directory. Train and validation directories should contain one directory named `images` hosting all the image data and one JSON lines (`.jsonl`) file named `metadata.jsonl`.  
  - In the `metadata.jsonl` file, each example is a dictionary which contains three keys named `file_name`, `prompt`, and `completion`. The `file_name` defines the path to image data. `prompt` defines the text input prompt and `completion` defines the text completion corresponding to the input prompt.
  Below is an example of the contents in the `metadata.jsonl` file.


```
{"file_name": "images/img_0.jpg", "prompt": "what is the date mentioned in this letter?", "completion": "1/8/93"}
{"file_name": "images/img_1.jpg", "prompt": "what is the contact person name mentioned in letter?", "completion": "P. Carter"}
{"file_name": "images/img_2.jpg", "prompt": "Which part of Virginia is this letter sent from", "completion": "Richmond"}
```

- **Output:** A trained model that can be deployed for inference. 

We provide a subset of [DocVQA](https://www.docvqa.org/datasets) as an example dataset to demonstrate fine-tuning. The full dataset can be downloaded from [here](https://www.docvqa.org/datasets).

---

In [None]:
import os
import json
from tqdm import tqdm
from datasets import load_dataset
import base64
from io import BytesIO
import json
from sagemaker.jumpstart.types import JumpStartSerializablePayload

dataset_name = "HuggingFaceM4/DocumentVQA"
data = load_dataset(
    dataset_name, cache_dir="./"
)  # you need make sure there is enough space to download the dataset in your instance

### Process the dataset to the format required in the above

In [None]:
def process_data(data, output_dir, num_ex):
    local_data_file = f"{output_dir}/metadata.jsonl"
    with open(local_data_file, "w") as f:
        for i in tqdm(range(num_ex)):
            each = data[i]
            q = each["question"]
            each_img = each["image"]
            a = each["answers"][0]

            example = {"file_name": f"images/img_{i}.jpg", "prompt": q, "completion": a}
            json.dump(example, f)
            f.write("\n")

            each_img.save(f"{output_dir}/images/img_{i}.jpg")

In [None]:
for split, num in [("train", 1000), ("validation", 20)]:
    os.makedirs(f"docvqa/{split}", exist_ok=True)
    os.makedirs(f"docvqa/{split}/images", exist_ok=True)
    process_data(data=data[split], output_dir=f"./docvqa/{split}", num_ex=num)

### Upload the dataset to the S3

Given the dataset contains image, the uploading process will take a while depending on the size of examples you process

In [None]:
from sagemaker.s3 import S3Uploader
import sagemaker
import random

output_bucket = sagemaker.Session().default_bucket()

local_data_dir = "./docvqa/train/"
train_data_location = f"s3://{output_bucket}/docvqa-1000-20"
S3Uploader.upload(local_data_dir, train_data_location)
print(f"Training data: {train_data_location}")

## Train the model
---
Next, we fine-tune the LIama vision 11B model on a subset of [DocVQA dataset](https://www.docvqa.org/datasets). For the finetuning method, we currently parameter-efficient finetuning [LoRA](https://arxiv.org/abs/2106.09685).

---

### Retrieve and modify the hyperparameters

In [None]:
from sagemaker import hyperparameters

my_hyperparameters = hyperparameters.retrieve_default(
    model_id=model_id, model_version=model_version
)
print(my_hyperparameters)

In [None]:
my_hyperparameters["epoch"] = "1"

In [None]:
hyperparameters.validate(
    model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters
)

In [None]:
from sagemaker.jumpstart.estimator import JumpStartEstimator


estimator = JumpStartEstimator(
    model_id=model_id,
    model_version=model_version,
    environment={"accept_eula": "true"},  # Please change {"accept_eula": "true"}
    disable_output_compression=True,
    instance_type="ml.p4de.24xlarge",
    hyperparameters=my_hyperparameters,
)
estimator.fit({"training": train_data_location})

Studio Kernel Dying issue:  If your studio kernel dies and you lose reference to the estimator object, please see section [2. Studio Kernel Dead/Creating JumpStart Model from the training Job](#2.-Studio-Kernel-Dead/Creating-JumpStart-Model-from-the-training-Job) on how to deploy endpoint using the training job name and the model id. 


### Deploy the fine-tuned model
---
Next, we deploy fine-tuned model. We will compare the performance of fine-tuned and pre-trained model.

---

Before deployment, we need to know whether the model is finetuned with chat template. If so, we need deploy it with `MESSAGES_API_ENABLED` and input / output signatures are different.

In [None]:
is_chat_template = True if my_hyperparameters["chat_template"] == "True" else False

In [None]:
if is_chat_template:
    estimator.environment = {"MESSAGES_API_ENABLED": "true"}

In [None]:
finetuned_predictor = estimator.deploy()

### Evaluate the fine-tuned model
---
Next, we use the test data to evaluate the performance of the fine-tuned model on the validation set of DocVQA dataset. We also use [prompt template suggested by Meta on this task](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/eval_details.md#docvqa).

Note. If hyperparameter `chat_template` is False (i.e., you did not train the model in chat completion format but text completion format), we need append `### Response:\n\n` at the end of prompt. You can see details on argument `instruct` in the following function `formulate_payload`.

---

In [None]:
import base64
from PIL import Image
import io


def display_base64_image(base64_string):
    """
    Displays a base64 encoded image.
    """

    # Decode the base64 string
    image_data = base64.b64decode(base64_string)

    # Open the image using Pillow
    image = Image.open(io.BytesIO(image_data))
    new_size = (700, 500)  # Specify the desired width and height
    resized_image = image.resize(new_size)
    # Display the image
    resized_image.show()


def get_image_decode_64base(image_path):
    with open(image_path, "rb") as f:
        image = base64.b64encode(f.read()).decode("utf-8")
    return image


def formulate_payload(q, image, instruct):
    img_path = f"data:image/jpg;base64,{image}"
    if instruct:
        payload = {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"Read the text in the image carefully and answer the question with the text as seen exactly in the image. For yes/no questions, just respond Yes or No. If the answer is numeric, just respond with the number and nothing else. If the answer has multiple words, just respond with the words and absolutely nothing else. Never respond in a sentence or a phrase.\n Question: {q}",
                        },
                        {"type": "image_url", "image_url": {"url": img_path}},
                    ],
                }
            ],
            "max_tokens": 512,
            "logprobs": False,
        }
    else:
        prompt = f"![]({img_path})<|image|><|begin_of_text|>Read the text in the image carefully and answer the question with the text as seen exactly in the image. For yes/no questions, just respond Yes or No. If the answer is numeric, just respond with the number and nothing else. If the answer has multiple words, just respond with the words and absolutely nothing else. Never respond in a sentence or a phrase.\n Question: {q}### Response:\n\n"
        payload = {
            "body": {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": 512,
                    "return_full_text": False,
                    # "temperature": 0.1,
                    "do_sample": False,
                    # "top_p": 0.97
                },
            },
            "content_type": "application/json",
            "accept": "application/json",
        }
    return payload


start = "\033[1m"
end = "\033[0;0m"

In [None]:
with open("./docvqa/validation/metadata.jsonl") as f:
    data = [json.loads(line) for line in f]

In [None]:
for i, each in enumerate(data[:5]):  # take first 5 observations
    q, a, image = (
        each["prompt"],
        each["completion"],
        get_image_decode_64base(image_path=f"./docvqa/validation/{each['file_name']}"),
    )

    payload = formulate_payload(q=q, image=image, instruct=is_chat_template)
    # print(payload)
    ft_response = finetuned_predictor.predict(payload)
    if is_chat_template:
        ft_response = ft_response["choices"][0]["message"]["content"]
    else:
        ft_response = ft_response[0]["generated_text"]
    print(f"==========={start}Example {i}{end}=============")
    print(f"{start}Prompt{end}: {q}")
    print(f"{start}FT response{end}: {ft_response}\n")
    print(f"{start}GT{end}: {a}\n")
    display_base64_image(image)

### Clean up resources

In [None]:
# Delete resources
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

# Appendix

### 1. Supported Instance types for fine-tuning Llama 3

---
We have tested our scripts on the following instances types for fine-tuning Llama 3:

| Model | Model ID | All Supported Instances Types for fine-tuning |
| - | - | - |
| Llama 3.2 11B | meta-vlm-llama-3-2-11b-vision | ml.g5.48xlarge, ml.p4de.24xlarge, ml.p4d.24xlarge |
| Llama 3.2 11B Instruct | meta-vlm-llama-3-2-11b-vision-instruct | ml.g5.48xlarge, ml.p4de.24xlarge   |
| Llama 3.2 90B | meta-vlm-llama-3-2-90b-vision | ml.g5.48xlarge|
| Llama 3.2 90B Instruct | meta-vlm-llama-3-2-90b-vision-instruct | ml.g5.48xlarge |


---

### 2. Studio Kernel Dead/Creating JumpStart Model from the training Job
---
Due to the size of the Llama 70B model, training job may take several hours and the studio kernel may die during the training phase. However, during this time, training is still running in SageMaker. If this happens, you can still deploy the endpoint using the training job name with the following code:

How to find the training job name? Go to Console -> SageMaker -> Training -> Training Jobs -> Identify the training job name and substitute in the following cell. 

---

In [None]:
from sagemaker.jumpstart.estimator import JumpStartEstimator

training_job_name = "<<Replace this with Training Job Name>>"

attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
attached_estimator.logs()
finetuned_predictor = attached_estimator.deploy()

If you lost your endpoint deployment, you can retrieve it by following code. You need to go to AWS console -> SageMaker -> Endpoint (left pannel) to get the endpoint name

In [None]:
from sagemaker.predictor import retrieve_default

finetuned_predictor = retrieve_default(
    model_id=model_id, model_version="*", endpoint_name="<<Replace this with Endpoint Name>>"
)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-3-finetuning.ipynb)