# Fine-tune Gemma on SageMaker JumpStart

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

---

In this demo notebook, we demonstrate how to use the SageMaker Python SDK to fine-tune Gemma and further deploy the fine-tuned model on SageMaker JumpStart.

Below is the content of the notebook.

1. [Setup](#1.-Setup)
2. [Deploy model](#2.-Deploy-model)
3. [Fine-tune model](#3.-Fine-tune-model)
4. [Evaluate the pre-trained and fine-tuned model](#4.-Qualitatively-evaluate-the-pre-trained-and-fine-tuned-model)

The notebook requires users to specify following variables to start with.
* Specify `model_id` (default value: `huggingface-llm-gemma-7b-instruct`)
* Specify `accept_eula` argument to be True in `model.deploy()` to accept the end-user license agreement (EULA) before deployment the model in an endpoint, given Code LIama model is gated.
* Sepcify `{"accept_eula": "true"}` in argument `environment` of `JumpStartEstimator` to accept the end-user license agreement (EULA) before fine-tuning.

## 1. Setup
First, upgrade to the latest sagemaker SDK to ensure all available models are deployable.

In [None]:
%pip install --quiet --upgrade sagemaker jmespath datasets

Select the desired model to deploy. The provided dropdown filters all text generation models available in SageMaker JumpStart.

In [None]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models


try:
    dropdown = Dropdown(
        options=list_jumpstart_models("search_keywords includes Text Generation"),
        value="huggingface-llm-gemma-7b-instruct",
        description="Select a JumpStart text generation model:",
        style={"description_width": "initial"},
        layout={"width": "max-content"},
    )
    display(dropdown)
except:
    dropdown = None
    pass

In [None]:
if dropdown:
    model_id = dropdown.value
else:
    model_id = "huggingface-llm-gemma-7b-instruct"
model_version = "*"

## 2. Deploy model

Create a `JumpStartModel` object, which initializes default model configurations conditioned on the selected instance type. JumpStart already sets a default instance type, but you can deploy the model on other instance types by passing `instance_type` to the `JumpStartModel` class.

In [None]:
from sagemaker.jumpstart.model import JumpStartModel


model = JumpStartModel(model_id=model_id, model_version=model_version)

You can now deploy the model using SageMaker JumpStart. If the selected model is gated, you will need to accept the end-user license agreement (EULA) prior to deployment. This is accomplished by providing the `accept_eula=True` argument to the `deploy` method. The deployment might take few minutes. 

In [None]:
predictor = model.deploy(
    accept_eula=False
)  # please change `accept_eula` to be `true` to accept EULA.

### Invoke the endpoint

This section demonstrates how to invoke the endpoint using example payloads that are retrieved programmatically from the `JumpStartModel` object. You can replace these example payloads with your own payloads. Inference payload parameters are provided as below.

* **max_new_tokens:** Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.

You may specify any subset of the parameters mentioned above while invoking an endpoint. 

JumpStart stores model-specific default example payloads in its SDK. You can retrieve and view them using following code.

In [None]:
example_payloads = model.retrieve_all_examples()

In [None]:
import jmespath


for payload in example_payloads:
    response = predictor.predict(payload.body)
    generated_text = jmespath.search(payload.raw_payload["output_keys"]["generated_text"], response)
    print("Input:\n", payload.body[payload.prompt_key])
    print("Output:\n", generated_text.strip())
    print("\n===============\n")

## 3. Fine-tune model

You can fine-tune on the dataset with domain adaptation format or instruction tuning format or fine-tuning with chat / conversational dataset. Please find more details in the section [Dataset formatting instruction for training](#1.-Dataset-formatting-instruction-for-training). In this demo, we will use a subset of OpenAssistant's TOP-1 Conversation Threads as an example dataset for chat fine-tuning. It can be downloaded from [here](https://huggingface.co/datasets/OpenAssistant/oasst_top1_2023-08-25). It contains roughly 13,000 samples of conversations between the Assistant and the user. 


Training data should be formatted in JSON lines (.jsonl) format, where each line is a dictionary representing a set of conversations. Here is an example of a line in the training file. The key has to be one of `dialog`, `messages`, and `conversations` to be compatible with `chatml` format in HuggingFace SFTTrainer.

```
{"dialog": [{"content":"what is the height of the empire state building","role":"user"},{"content":"381 meters, or 1,250 feet, is the height of the Empire State Building. If you also account for the antenna, it brings up the total height to 443 meters, or 1,454 feet","role":"assistant"},{"content":"Some people need to pilot an aircraft above it and need to know.\nSo what is the answer in feet?","role":"user"},{"content":"1454 feet","role":"assistant"}]}
```

In [None]:
from datasets import load_dataset
import re

# Load the dataset
dataset = load_dataset("OpenAssistant/oasst_top1_2023-08-25")

In [None]:
dataset["train"][0]

In [None]:
# Define a function to transform the data
def transform_conversation(example):
    conversation_text = example["text"]

    segments = re.split("<\|im_start\|>|<\|im_end\|>", conversation_text)
    reformatted_segments = []
    dialog_list = []

    # Iterate over pairs of segments
    for i in range(1, len(segments) - 1, 4):
        human_text = segments[i].strip().replace("user", "").strip()

        # Check if there is a corresponding assistant segment before processing
        if i + 1 < len(segments):
            assistant_text = segments[i + 2].strip().replace("assistant", "").strip()
            dialog_list.append({"role": "user", "content": human_text})
            dialog_list.append({"role": "assistant", "content": assistant_text})

        else:
            dialog_list.append({"role": "user", "content": human_text})
    return {"dialog": dialog_list}

In [None]:
transformed_dataset = dataset.map(transform_conversation).remove_columns("text")

In [None]:
transformed_dataset["train"].select(range(5000)).to_json("train.jsonl")

### Upload dataset to S3

In [None]:
from sagemaker.s3 import S3Uploader
import sagemaker
import random

output_bucket = sagemaker.Session().default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/oasst_top1"
S3Uploader.upload(local_data_file, train_data_location)
print(f"Training data: {train_data_location}")

Retrieve and customize hyperparameters

In [None]:
from sagemaker import hyperparameters

my_hyperparameters = hyperparameters.retrieve_default(
    model_id=model_id, model_version=model_version
)

print(my_hyperparameters)

Underlying the training scripts, JumpStart leverages [HuggingFace SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) with [QLoRA](https://arxiv.org/abs/2305.14314).

* For chat fine-tuning, specify `chat_dataset` to be `True` and `instruction_tuned` to be `False`.
* For instruction fine-tuning, specify `chat_dataset` to be `False` and `instruction_tuned` to be `True`.
* For domain adaptation fine-tuning, specify `chat_dataset` to be `False` and `instruction_tuned` to be `False`.

In [None]:
my_hyperparameters["epoch"] = "1"
print(my_hyperparameters)

hyperparameters.validate(
    model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters
)

In [None]:
from sagemaker.jumpstart.estimator import JumpStartEstimator


estimator = JumpStartEstimator(
    model_id=model_id,
    model_version=model_version,
    hyperparameters=my_hyperparameters,
    environment={
        "accept_eula": "false"
    },  # please change `accept_eula` to be `true` to accept EULA.
)

estimator.fit({"training": train_data_location})

### Deploy the fine-tuned model
Next, we deploy the fine-tuned model. We will compare the performance of fine-tuned and pre-trained model.

If the model if fine-tuned on chat-dataset, we need deploy the model with [Messages API feature](https://huggingface.co/docs/text-generation-inference/en/messages_api#amazon-sagemaker). The Messages API is integrated with SageMaker Endpoints. Every endpoint that uses “Text Generation Inference” with an LLM, which has a chat template can now be used. For details, see [Documentation](https://huggingface.co/docs/text-generation-inference/en/messages_api#amazon-sagemaker). 

In [None]:
finetuned_predictor = estimator.deploy(
    env={"MESSAGES_API_ENABLED": my_hyperparameters.get("chat_dataset", "false").lower()}
)

## 4. Evaluate the pre-trained and fine-tuned model
Next, we use the test data to evaluate the performance of the fine-tuned model and compare it with the pre-trained model. 


In [None]:
def print_dialog(payload, response):
    dialog = payload["messages"]
    for msg in dialog:
        print(f"{msg['role'].capitalize()}: {msg['content']}\n")
    print(
        f">>>> {response['choices'][0]['message']['role'].capitalize()}: {response['choices'][0]['message']['content']}"
    )
    print("\n==================================\n")

In [None]:
test_dataset = transformed_dataset["test"]

for i, datapoint in enumerate(test_dataset.select(range(10))):
    payload = {
        "model": model_id,  # this is currently required by the message API
        "messages": datapoint["dialog"][:-1],
        "max_tokens": 600,
        "top_p": 0.9,
        "temperature": 0.4,
        "top_k": 50,
    }
    response = finetuned_predictor.predict(payload)
    print_dialog(payload, response)

    response = predictor.predict(payload)
    print_dialog(payload, response, True)
    print("Ground Truth Response:")
    print(
        f">>>> {datapoint['dialog'][-1]['role'].capitalize()}: {datapoint['dialog'][-1]['content']}"
    )
    print(f"\n============End of Example {i+1} ======================\n")

### Clean up the endpoint
Don't forget to clean up resources when finished to avoid unnecessary charges.

In [None]:
predictor.delete_predictor()
finetuned_predictor.delete_predictor()

## Appendix

### 1. Dataset formatting instruction for training

####  Fine-tune the Model on a New Dataset
We currently offer two types of fine-tuning: instruction fine-tuning and domain adaption fine-tuning. You can easily switch to one of the training 
methods by specifying parameter `instruction_tuned` being 'True' or 'False'.

#### 1.1. Chat fine-tuning


The Text generation model can be fine-tuned on the chat dataset, provided that the data is in the expected format. The resulting chat model can be further deployed for inference. Below are the instructions for how the training data should be formatted for input to the model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (.jsonl) formatted files. All training data must be in a single folder, however it can be saved in multiple jsonl files. The .jsonl file extension is mandatory.
    - The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample. Each line in the file is a list of conversations between the user and the assistant model. This model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and alternating (u/a/u/a/u...).
- **Output:**  A trained model that can be deployed for inference.

The best model is selected according to the validation loss, calculated at the end of each epoch. If a validation set is not given, an (adjustable) percentage of the training data is automatically split and used for validation.The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample.   

Here is an example of a line in the training file:

{"dialog": [{"content":"what is the height of the empire state building","role":"user"},{"content":"381 meters, or 1,250 feet, is the height of the Empire State Building. If you also account for the antenna, it brings up the total height to 443 meters, or 1,454 feet","role":"assistant"},{"content":"Some people need to pilot an aircraft above it and need to know.\nSo what is the answer in feet?","role":"user"},{"content":"1454 feet","role":"assistant"}]}



#### 1.2. Domain adaptation fine-tuning
The Text Generation model can also be fine-tuned on any domain specific dataset. After being fine-tuned on the domain specific dataset, the model
is expected to generate domain specific text and solve various NLP tasks in that specific domain with **few shot prompting**.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Each directory contains a CSV/JSON/TXT file. 
  - For CSV/JSON files, the train or validation data is used from the column called 'text' or the first column if no column called 'text' is found.
  - The number of files under train and validation (if provided) should equal to one, respectively. 
- **Output:** A trained model that can be deployed for inference. 

Below is an example of a TXT file for fine-tuning the Text Generation model. The TXT file is SEC filings of Amazon from year 2021 to 2022.

```Note About Forward-Looking Statements
This report includes estimates, projections, statements relating to our
business plans, objectives, and expected operating results that are “forward-
looking statements” within the meaning of the Private Securities Litigation
Reform Act of 1995, Section 27A of the Securities Act of 1933, and Section 21E
of the Securities Exchange Act of 1934. Forward-looking statements may appear
throughout this report, including the following sections: “Business” (Part I,
Item 1 of this Form 10-K), “Risk Factors” (Part I, Item 1A of this Form 10-K),
and “Management’s Discussion and Analysis of Financial Condition and Results
of Operations” (Part II, Item 7 of this Form 10-K). These forward-looking
statements generally are identified by the words “believe,” “project,”
“expect,” “anticipate,” “estimate,” “intend,” “strategy,” “future,”
“opportunity,” “plan,” “may,” “should,” “will,” “would,” “will be,” “will
continue,” “will likely result,” and similar expressions. Forward-looking
statements are based on current expectations and assumptions that are subject
to risks and uncertainties that may cause actual results to differ materially.
We describe risks and uncertainties that could cause actual results and events
to differ materially in “Risk Factors,” “Management’s Discussion and Analysis
of Financial Condition and Results of Operations,” and “Quantitative and
Qualitative Disclosures about Market Risk” (Part II, Item 7A of this Form
10-K). Readers are cautioned not to place undue reliance on forward-looking
statements, which speak only as of the date they are made. We undertake no
obligation to update or revise publicly any forward-looking statements,
whether because of new information, future events, or otherwise.
GENERAL
Embracing Our Future ...
```


#### 1.3. Instruction fine-tuning
The Text generation model can be instruction-tuned on any text data provided that the data 
is in the expected format. The instruction-tuned model can be further deployed for inference. 
Below are the instructions for how the training data should be formatted for input to the 
model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (`.jsonl`) formatted files. In particular, train directory can also contain an optional `*.json` file describing the input and output formats. 
  - The best model is selected according to the validation loss, calculated at the end of each epoch.
  If a validation set is not given, an (adjustable) percentage of the training data is
  automatically split and used for validation.
  - The training data must be formatted in a JSON lines (`.jsonl`) format, where each line is a dictionary
representing a single data sample. All training data must be in a single folder, however
it can be saved in multiple jsonl files. The `.jsonl` file extension is mandatory. The training
folder can also contain a `template.json` file describing the input and output formats. If no
template file is given, the following template will be used:
  ```json
  {
    "prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{context}",
    "completion": "{response}"
  }
  ```
  - In this case, the data in the JSON lines entries must include `instruction`, `context` and `response` fields. If a custom template is provided it must also use `prompt` and `completion` keys to define
  the input and output templates.
  Below is a sample custom template:

  ```json
  {
    "prompt": "question: {question} context: {context}",
    "completion": "{answer}"
  }
  ```
Here, the data in the JSON lines entries must include `question`, `context` and `answer` fields. 
- **Output:** A trained model that can be deployed for inference. 

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|gemma-fine-tuning.ipynb)
