# Chat Completion: Run Llama 2 Models in SageMaker JumpStart

---
This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

---

---
In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy a JumpStart model for Text Generation using the Llama 2 fine-tuned model optimized for dialogue use cases.

To perform inference on these models, you need to pass custom_attributes='accept_eula=true' as part of header. This means you have read and accept the end-user-license-agreement (EULA) of the model. EULA can be found in model card description or from https://ai.meta.com/resources/models-and-libraries/llama-downloads/. By default, this notebook sets custom_attributes='accept_eula=false', so all inference requests will fail until you explicitly change this custom attribute.

Note: Custom_attributes used to pass EULA are key/value pairs. The key and value are separated by '=' and pairs are separated by ';'. If the user passes the same key more than once, the last value is kept and passed to the script handler (i.e., in this case, used for conditional logic). For example, if 'accept_eula=false; accept_eula=true' is passed to the server, then 'accept_eula=true' is kept and passed to the script handler.

---

## Setup

***

In [2]:
%pip install --upgrade --quiet sagemaker datasets

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sparkmagic 0.21.0 requires pandas<2.0.0,>=0.17.1, but you have pandas 2.0.3 which is incompatible.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


***
You can continue with the default model or choose a different model: this notebook will run with the following model IDs :
- `meta-textgeneration-llama-2-7b-f`
- `meta-textgeneration-llama-2-13b-f`
- `meta-textgeneration-llama-2-70b-f`
***

In [3]:
model_id, model_version = "meta-textgeneration-llama-2-7b-f", "2.*"

## Deploy model

***
You can now deploy the model using SageMaker JumpStart.
***

In [4]:
from sagemaker.jumpstart.model import JumpStartModel

model = JumpStartModel(model_id=model_id, model_version=model_version)
predictor = model.deploy()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


For forward compatibility, pin to model_version='2.*' in your JumpStartModel or JumpStartEstimator definitions. Note that major version upgrades may have different EULA acceptance terms and input/output signatures.
Using vulnerable JumpStart model 'meta-textgeneration-llama-2-7b-f' and version '2.0.4'.
Using model 'meta-textgeneration-llama-2-7b-f' with wildcard version identifier '2.*'. You can pin to version '2.0.4' for more stable results. Note that models may have different input/output signatures after a major version upgrade.


-------------------!

## Invoke the endpoint

***
### Supported Parameters
This model supports the following inference payload parameters:

* **max_new_tokens:** Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.

You may specify any subset of the parameters mentioned above while invoking an endpoint. 

***
### Notes
- If `max_new_tokens` is not defined, the model may generate up to the maximum total tokens allowed, which is 4K for these models. This may result in endpoint query timeout errors, so it is recommended to set `max_new_tokens` when possible. For 7B, 13B, and 70B models, we recommend to set `max_new_tokens` no greater than 1500, 1000, and 500 respectively, while keeping the total number of tokens less than 4K.
- In order to support a 4k context length, this model has restricted query payloads to only utilize a batch size of 1. Payloads with larger batch sizes will receive an endpoint error prior to inference.
- This model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and alternating (u/a/u/a/u...).

***

In [5]:
def print_dialog(payload, response):
    dialog = payload["inputs"][0]
    for msg in dialog:
        print(f"{msg['role'].capitalize()}: {msg['content']}\n")
    print(
        f">>>> {response[0]['generation']['role'].capitalize()}: {response[0]['generation']['content']}"
    )
    print("\n==================================\n")

### Example 1

In [25]:
%%time

payload = {
    "inputs": [
        [
            {"role": "user", "content": "Now I will give you a text message summarize the text massage in to following json format: {event: xxx; time: xxx; location: xxx}, Here is txt:I will have dinner with mom today afternoon at 7pm near downtown"},
        ]
    ],
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
}
try:
    response = predictor.predict(payload, custom_attributes="accept_eula=true")
    print_dialog(payload, response)
except Exception as e:
    print(e)

User: Now I will give you a text message summarize the text massage in to following json format: {event: xxx; time: xxx; location: xxx}, Here is txt:I will have dinner with mom today afternoon at 7pm near downtown

>>>> Assistant:  Sure! Here is the JSON format for the text message you provided:

{
"event": "dinner with mom",
"time": "7pm",
"location": "near downtown"
}

Note that I've used the singular form of "event" since you mentioned having dinner with one person (mom). If you were having dinner with multiple people, you could use the plural form of "event" (e.g. { "event": "dinner with mom and dad",... }).


CPU times: user 15.9 ms, sys: 0 ns, total: 15.9 ms
Wall time: 3.49 s


### Example 2

In [18]:
%%time

payload = {
    "inputs": [
        [
            {"role": "user", "content": "I am going to Paris, what should I see?"},
            {
                "role": "assistant",
                "content": """\
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:

1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.

These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""",
            },
            {"role": "user", "content": "What is so great about #1?"},
        ]
    ],
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
}
try:
    response = predictor.predict(payload, custom_attributes="accept_eula=true")
    print_dialog(payload, response)
except Exception as e:
    print(e)

User: I am going to Paris, what should I see?

Assistant: Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:

1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.

These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.

User: What is so great about #1?

>>>> Assistant:  The Eiffel Tower is considered one 

### Example 3

In [19]:
%%time

payload = {
    "inputs": [
        [
            {"role": "system", "content": "Always answer with Haiku"},
            {"role": "user", "content": "I am going to Paris, what should I see?"},
        ]
    ],
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
}
try:
    response = predictor.predict(payload, custom_attributes="accept_eula=true")
    print_dialog(payload, response)
except Exception as e:
    print(e)

System: Always answer with Haiku

User: I am going to Paris, what should I see?

>>>> Assistant:  Eiffel Tower high
Love locks on Seine river bank
City of Light shines


CPU times: user 4.52 ms, sys: 0 ns, total: 4.52 ms
Wall time: 643 ms


### Example 4

In [20]:
%%time

payload = {
    "inputs": [
        [
            {
                "role": "system",
                "content": "Always answer with emojis",
            },
            {"role": "user", "content": "How to go from Beijing to NY?"},
        ]
    ],
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
}
try:
    response = predictor.predict(payload, custom_attributes="accept_eula=true")
    print_dialog(payload, response)
except Exception as e:
    print(e)

System: Always answer with emojis

User: How to go from Beijing to NY?

>>>> Assistant:  Here's how to go from Beijing to NY 🛬🗽:

1. Fly 🛬: The fastest way to reach New York from Beijing is by flying. There are direct flights available from Beijing Capital International Airport (PEK) to John F. Kennedy International Airport (JFK) or LaGuardia Airport (LGA) in New York.
2. Train 🚂: You can also travel by train from Beijing to New York, but it's not a direct route. You'll need to take a train from Beijing to Moscow, then connect to a train to New York. This option can take longer and may involve multiple transfers.
3. Bus 🚌: Taking a bus is another option for traveling from Beijing to New York. There are several bus companies that offer this service, but it can take longer than flying or taking the train.
4. Drive 🚗: If you prefer to drive, you can rent a car in Beijing and drive to New York. This option allows for more flexibility in your itinerary, but it can also be more expensive and

## Dataset preparation for fine-tuning

---

You can fine-tune on the dataset with domain adaptation format or instruction tuning format or the chat dataset format. Please find more details in the section [Dataset formatting instruction for training](#1.-Dataset-formatting-instruction-for-training). In this demo, we will use a subset of OpenAssistant's TOP-1 Conversation Threads as an example dataset. It can be downloaded from [here](https://huggingface.co/datasets/OpenAssistant/oasst_top1_2023-08-25). It contains roughly 13,000 samples of conversations between the Assistant and the user. 


Training data is formatted in JSON lines (.jsonl) format, where each line is a dictionary representing a single set of conversation.

To train your model on a collection of unstructured dataset (text files), please see the section [1.3. Example fine-tuning with Domain-Adaptation dataset format](#1.3.-Example-fine-tuning-with-Domain-Adaptation-dataset-format) in the Appendix. To train your model on instruction tuning dataset format, please see the section [ 1.4. Example fine-tuning with Instruction tuning dataset format](#1.4.-Example-fine-tuning-with-Instruction-tuning-dataset-format).

---

In [10]:
from datasets import load_dataset
import re

# Load the dataset
dataset = load_dataset("OpenAssistant/oasst_top1_2023-08-25")

Downloading readme:   0%|          | 0.00/512 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/31.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12947 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/690 [00:00<?, ? examples/s]

In [11]:
# Define a function to transform the data
def transform_conversation(example):
    conversation_text = example["text"]

    segments = re.split("<\|im_start\|>|<\|im_end\|>", conversation_text)
    reformatted_segments = []
    dialog_list = []

    # Iterate over pairs of segments
    for i in range(1, len(segments) - 1, 4):
        human_text = segments[i].strip().replace("user", "").strip()

        # Check if there is a corresponding assistant segment before processing
        if i + 1 < len(segments):
            assistant_text = segments[i + 2].strip().replace("assistant", "").strip()
            dialog_list.append({"role": "user", "content": human_text})
            dialog_list.append({"role": "assistant", "content": assistant_text})

        else:
            dialog_list.append({"role": "user", "content": human_text})
    return {"dialog": dialog_list}

In [12]:
transformed_dataset = dataset.map(transform_conversation).remove_columns("text")

Map:   0%|          | 0/12947 [00:00<?, ? examples/s]

Map:   0%|          | 0/690 [00:00<?, ? examples/s]

In [13]:
transformed_dataset["train"].select(range(5000)).to_json("train.jsonl")

Creating json from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

11908580

In [14]:
from sagemaker.s3 import S3Uploader
import sagemaker
import random

output_bucket = sagemaker.Session().default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/oasst_top1"
S3Uploader.upload(local_data_file, train_data_location)
print(f"Training data: {train_data_location}")

Training data: s3://sagemaker-us-east-1-637423340224/oasst_top1


## Train the model
---
Next, we fine-tune the LLaMA v2 7B model on the summarization dataset from Dolly. Finetuning scripts are based on scripts provided by [this repo](https://github.com/facebookresearch/llama-recipes/tree/main). To learn more about the fine-tuning scripts, please checkout section [5. Few notes about the fine-tuning method](#5.-Few-notes-about-the-fine-tuning-method). For a list of supported hyper-parameters and their default values, please see section [3. Supported Hyper-parameters for fine-tuning](#3.-Supported-Hyper-parameters-for-fine-tuning).

---

In [15]:
from sagemaker.jumpstart.estimator import JumpStartEstimator


estimator = JumpStartEstimator(
    model_id=model_id,
    environment={"accept_eula": "true"},
    disable_output_compression=True,  # For Llama-2-70b, add instance_type = "ml.g5.48xlarge"
)

estimator.set_hyperparameters(
    chat_dataset="True", instruction_tuned="False", epoch="1", max_input_length="1024"
)
estimator.fit({"training": train_data_location})

Using model 'meta-textgeneration-llama-2-7b-f' with wildcard version identifier '*'. You can pin to version '4.1.1' for more stable results. Note that models may have different input/output signatures after a major version upgrade.
INFO:sagemaker:Creating training-job with name: meta-textgeneration-llama-2-7b-f-2024-06-17-17-47-15-235


ResourceLimitExceeded: An error occurred (ResourceLimitExceeded) when calling the CreateTrainingJob operation: The account-level service limit 'ml.g5.12xlarge for training job usage' is 0 Instances, with current utilization of 0 Instances and a request delta of 1 Instances. Please use AWS Service Quotas to request an increase for this quota. If AWS Service Quotas is not available, contact AWS support to request an increase for this quota.

Studio Kernel Dying issue:  If your studio kernel dies and you lose reference to the estimator object, please see section [6. Studio Kernel Dead/Creating JumpStart Model from the training Job](#6.-Studio-Kernel-Dead/Creating-JumpStart-Model-from-the-training-Job) on how to deploy endpoint using the training job name and the model id. 


### Deploy the fine-tuned model
---
Next, we deploy fine-tuned model. We will compare the performance of fine-tuned and pre-trained model.

---

In [None]:
finetuned_predictor = estimator.deploy()

### Evaluate the pre-trained and fine-tuned model
---
Next, we use the test data to evaluate the performance of the fine-tuned model and compare it with the pre-trained model. 

---

In [None]:
test_dataset = transformed_dataset["test"]

try:
    for i, datapoint in enumerate(test_dataset.select(range(5))):
        payload = {
            "inputs": [datapoint["dialog"][:-1]],
            "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
        }
        response = finetuned_predictor.predict(payload, custom_attributes="accept_eula=false")
        print_dialog(payload, response)
        print("Ground Truth Response:")
        print(
            f">>>> {datapoint['dialog'][-1]['role'].capitalize()}: {datapoint['dialog'][-1]['content']}"
        )
        print(f"\n============End of Example {i+1} ======================\n")
except Exception as e:
    print(e)

### Clean up resources

In [None]:
# Delete resources
predictor.delete_model()
predictor.delete_endpoint()
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

# Appendix

### 1. Dataset formatting instruction for training

---

####  Fine-tune the Model on a New Dataset
We currently offer two types of fine-tuning: instruction fine-tuning and domain adaption fine-tuning. You can easily switch to one of the training 
methods by specifying parameter `instruction_tuned` being 'True' or 'False'.

#### 1.1. Chat fine-tuning


The Text generation model can be fine-tuned on the chat dataset, provided that the data is in the expected format. The resulting chat model can be further deployed for inference. Below are the instructions for how the training data should be formatted for input to the model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (.jsonl) formatted files. All training data must be in a single folder, however it can be saved in multiple jsonl files. The .jsonl file extension is mandatory.
    - The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample. Each line in the file is a list of conversations between the user and the assistant model. This model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and alternating (u/a/u/a/u...).
- **Output:**  A trained model that can be deployed for inference.

The best model is selected according to the validation loss, calculated at the end of each epoch. If a validation set is not given, an (adjustable) percentage of the training data is automatically split and used for validation.The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample.   

Here is an example of a line in the training file:

{"dialog": [{"content":"what is the height of the empire state building","role":"user"},{"content":"381 meters, or 1,250 feet, is the height of the Empire State Building. If you also account for the antenna, it brings up the total height to 443 meters, or 1,454 feet","role":"assistant"},{"content":"Some people need to pilot an aircraft above it and need to know.\nSo what is the answer in feet?","role":"user"},{"content":"1454 feet","role":"assistant"}]}



#### 1.2. Domain adaptation fine-tuning
The Text Generation model can also be fine-tuned on any domain specific dataset. After being fine-tuned on the domain specific dataset, the model
is expected to generate domain specific text and solve various NLP tasks in that specific domain with **few shot prompting**.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Each directory contains a CSV/JSON/TXT file. 
  - For CSV/JSON files, the train or validation data is used from the column called 'text' or the first column if no column called 'text' is found.
  - The number of files under train and validation (if provided) should equal to one, respectively. 
- **Output:** A trained model that can be deployed for inference. 

Below is an example of a TXT file for fine-tuning the Text Generation model. The TXT file is SEC filings of Amazon from year 2021 to 2022.

```Note About Forward-Looking Statements
This report includes estimates, projections, statements relating to our
business plans, objectives, and expected operating results that are “forward-
looking statements” within the meaning of the Private Securities Litigation
Reform Act of 1995, Section 27A of the Securities Act of 1933, and Section 21E
of the Securities Exchange Act of 1934. Forward-looking statements may appear
throughout this report, including the following sections: “Business” (Part I,
Item 1 of this Form 10-K), “Risk Factors” (Part I, Item 1A of this Form 10-K),
and “Management’s Discussion and Analysis of Financial Condition and Results
of Operations” (Part II, Item 7 of this Form 10-K). These forward-looking
statements generally are identified by the words “believe,” “project,”
“expect,” “anticipate,” “estimate,” “intend,” “strategy,” “future,”
“opportunity,” “plan,” “may,” “should,” “will,” “would,” “will be,” “will
continue,” “will likely result,” and similar expressions. Forward-looking
statements are based on current expectations and assumptions that are subject
to risks and uncertainties that may cause actual results to differ materially.
We describe risks and uncertainties that could cause actual results and events
to differ materially in “Risk Factors,” “Management’s Discussion and Analysis
of Financial Condition and Results of Operations,” and “Quantitative and
Qualitative Disclosures about Market Risk” (Part II, Item 7A of this Form
10-K). Readers are cautioned not to place undue reliance on forward-looking
statements, which speak only as of the date they are made. We undertake no
obligation to update or revise publicly any forward-looking statements,
whether because of new information, future events, or otherwise.
GENERAL
Embracing Our Future ...
```


#### 1.3. Instruction fine-tuning
The Text generation model can be instruction-tuned on any text data provided that the data 
is in the expected format. The instruction-tuned model can be further deployed for inference. 
Below are the instructions for how the training data should be formatted for input to the 
model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (`.jsonl`) formatted files. In particular, train directory can also contain an optional `*.json` file describing the input and output formats. 
  - The best model is selected according to the validation loss, calculated at the end of each epoch.
  If a validation set is not given, an (adjustable) percentage of the training data is
  automatically split and used for validation.
  - The training data must be formatted in a JSON lines (`.jsonl`) format, where each line is a dictionary
representing a single data sample. All training data must be in a single folder, however
it can be saved in multiple jsonl files. The `.jsonl` file extension is mandatory. The training
folder can also contain a `template.json` file describing the input and output formats. If no
template file is given, the following template will be used:
  ```json
  {
    "prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{context}",
    "completion": "{response}"
  }
  ```
  - In this case, the data in the JSON lines entries must include `instruction`, `context` and `response` fields. If a custom template is provided it must also use `prompt` and `completion` keys to define
  the input and output templates.
  Below is a sample custom template:

  ```json
  {
    "prompt": "question: {question} context: {context}",
    "completion": "{answer}"
  }
  ```
Here, the data in the JSON lines entries must include `question`, `context` and `answer` fields. 
- **Output:** A trained model that can be deployed for inference. 

---

#### 1.3. Example fine-tuning with Domain-Adaptation dataset format
---
We provide a subset of SEC filings data of Amazon in domain adaptation dataset format. It is downloaded from publicly available [EDGAR](https://www.sec.gov/edgar/searchedgar/companysearch). Instruction of accessing the data is shown [here](https://www.sec.gov/os/accessing-edgar-data).

License: [Creative Commons Attribution-ShareAlike License (CC BY-SA 4.0)](https://creativecommons.org/licenses/by-sa/4.0/legalcode).

Please uncomment the following code to fine-tune the model on dataset in domain adaptation format.

---

In [None]:
# import boto3
# model_id = "meta-textgeneration-llama-2-7b"

# estimator = JumpStartEstimator(model_id=model_id,  environment={"accept_eula": "true"},instance_type = "ml.g5.24xlarge")
# estimator.set_hyperparameters(instruction_tuned="False", epoch="5")
# estimator.fit({"training": f"s3://jumpstart-cache-prod-{boto3.Session().region_name}/training-datasets/sec_amazon"})

#### 1.4. Example fine-tuning with Instruction tuning dataset format
---
Next, we fine-tune the LLaMA v2 7B model on the summarization dataset from Dolly dataset.


---

# Dataset format

from datasets import load_dataset

dolly_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

# To train for question answering/information extraction, you can replace the assertion in next line to example["category"] == "closed_qa"/"information_extraction".
summarization_dataset = dolly_dataset.filter(lambda example: example["category"] == "summarization")
summarization_dataset = summarization_dataset.remove_columns("category")

# We split the dataset into two where test data is used to evaluate at the end.
train_and_test_dataset = summarization_dataset.train_test_split(test_size=0.1)

# Dumping the training data to a local file to be used for training.
train_and_test_dataset["train"].to_json("train.jsonl")



import json

template = {
    "prompt": "Below is an instruction that describes a task, paired with an input that provides further context. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{instruction}\n\n### Input:\n{context}\n\n",
    "completion": " {response}",
}
with open("template.json", "w") as f:
    json.dump(template, f)
    

from sagemaker.s3 import S3Uploader
import sagemaker
import random

output_bucket = sagemaker.Session().default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/dolly_dataset"
S3Uploader.upload(local_data_file, train_data_location)
S3Uploader.upload("template.json", train_data_location)
print(f"Training data: {train_data_location}")
    

from sagemaker.jumpstart.estimator import JumpStartEstimator


estimator = JumpStartEstimator(
    model_id=model_id,
    environment={"accept_eula": "true"},
    disable_output_compression=True,  # For Llama-2-70b, add instance_type = "ml.g5.48xlarge"
)
# By default, instruction tuning is set to false. Thus, to use instruction tuning dataset you use
estimator.set_hyperparameters(instruction_tuned="True", "chat_dataset"="False", epoch="5", max_input_length="1024")
estimator.fit({"training": train_data_location})

### 2. Supported Hyper-parameters for fine-tuning
---
- epoch: The number of passes that the fine-tuning algorithm takes through the training dataset. Must be an integer greater than 1. Default: 5
- learning_rate: The rate at which the model weights are updated after working through each batch of training examples. Must be a positive float greater than 0. Default: 1e-4.
- instruction_tuned: Whether to instruction-train the model or not. Must be 'True' or 'False'. Default: 'False'
- chat_dataset: If True, dataset is assumed to be in chat format. At most one of instruction_tuned and chat_dataset can be True.
- add_input_output_demarcation_key: For instruction tuned dataset, if this is True a demarcation key(\"### Response:\\n\") is added between the prompt and completion before training. Default: 'True'.
- per_device_train_batch_size: The batch size per GPU core/CPU for training. Must be a positive integer. Default: 4.
- per_device_eval_batch_size: The batch size per GPU core/CPU for evaluation. Must be a positive integer. Default: 1
- max_train_samples: For debugging purposes or quicker training, truncate the number of training examples to this value. Value -1 means using all of training samples. Must be a positive integer or -1. Default: -1. 
- max_val_samples: For debugging purposes or quicker training, truncate the number of validation examples to this value. Value -1 means using all of validation samples. Must be a positive integer or -1. Default: -1. 
- max_input_length: Maximum total input sequence length after tokenization. Sequences longer than this will be truncated. If -1, max_input_length is set to the minimum of 1024 and the maximum model length defined by the tokenizer. If set to a positive value, max_input_length is set to the minimum of the provided value and the model_max_length defined by the tokenizer. Must be a positive integer or -1. Default: -1. 
- validation_split_ratio: If validation channel is none, ratio of train-validation split from the train data. Must be between 0 and 1. Default: 0.2. 
- train_data_split_seed: If validation data is not present, this fixes the random splitting of the input training data to training and validation data used by the algorithm. Must be an integer. Default: 0.
- preprocessing_num_workers: The number of processes to use for the preprocessing. If None, main process is used for preprocessing. Default: "None"
- lora_r: Lora R. Must be a positive integer. Default: 8.
- lora_alpha: Lora Alpha. Must be a positive integer. Default: 32
- lora_dropout: Lora Dropout. must be a positive float between 0 and 1. Default: 0.05. 
- int8_quantization: If True, model is loaded with 8 bit precision for training. Default for 7B/13B: False. Default for 70B: True.
- enable_fsdp: If True, training uses Fully Sharded Data Parallelism. Default for 7B/13B: True. Default for 70B: False.

Note 1: int8_quantization is not supported with FSDP. Also, int8_quantization = 'False' and enable_fsdp = 'False' is not supported due to CUDA memory issues for any of the g5 family instances. Thus, we recommend setting exactly one of int8_quantization or enable_fsdp to be 'True'
Note 2: Due to the size of the model, 70B model can not be fine-tuned with enable_fsdp = 'True' for any of the supported instance types.

---

### 3. Supported Instance types

---
We have tested our scripts on the following instances types:

- 7B, 7B-F: ml.g5.12xlarge, nl.g5.24xlarge, ml.g5.48xlarge, ml.p3dn.24xlarge
- 13B, 13B-F: ml.g5.24xlarge, ml.g5.48xlarge, ml.p3dn.24xlarge
- 70B, 70B-F: ml.g5.48xlarge

Other instance types may also work to fine-tune. Note: When using p3 instances, training will be done with 32 bit precision as bfloat16 is not supported on these instances. Thus, training job would consume double the amount of CUDA memory when training on p3 instances compared to g5 instances.

---

### 4. Few notes about the fine-tuning method

---
- Fine-tuning scripts are based on [this repo](https://github.com/facebookresearch/llama-recipes/tree/main). 
- Instruction tuning dataset is first converted into domain adaptation dataset format before fine-tuning. 
- Fine-tuning scripts utilize Fully Sharded Data Parallel (FSDP) as well as Low Rank Adaptation (LoRA) method fine-tuning the models

---

### 5. Studio Kernel Dead/Creating JumpStart Model from the training Job
---
Due to the size of the Llama 70B model, training job may take several hours and the studio kernel may die during the training phase. However, during this time, training is still running in SageMaker. If this happens, you can still deploy the endpoint using the training job name with the following code:

How to find the training job name? Go to Console -> SageMaker -> Training -> Training Jobs -> Identify the training job name and substitute in the following cell. 

---

In [None]:
# from sagemaker.jumpstart.estimator import JumpStartEstimator
# training_job_name = <<training_job_name>>

# attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
# attached_estimator.logs()
# attached_estimator.deploy()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)