# Chat Fine-tuning with Synthetically Generated Data on Amazon Bedrock and SageMaker JumpStart

---

In the era of artificial intelligence and machine learning, data is the fuel that drives innovation. However, accessing high-quality and diverse datasets can be challenging, particularly in scenarios where real-world data is scarce, sensitive, or difficult to obtain. This is where `synthetic data generation` comes into play, enabling researchers and developers to create artificial data that mimics the statistical properties of real-world data while preserving privacy and addressing data scarcity.

Synthetic data generation has become a crucial aspect of training and fine-tuning models in various domains, including natural language processing, computer vision, and more. By leveraging the power of large language models (LLMs), such as [Mistral](https://aws.amazon.com/bedrock/mistral/) models on [Amazon Bedrock](https://aws.amazon.com/bedrock/), researchers can generate synthetic data that captures the nuances and complexities of real-world data, opening up new possibilities for model development and performance improvement.

In this Jupyter notebook, we will dive into the world of synthetic data generation, exploring the versatility of Mistral models in creating artificial data for specific use cases. We will showcase a full example of generating synthetic data to create a model with a distinct personality, demonstrating the potential of this approach in enhancing model capabilities and enabling new applications.

It's important to note that there is no one-size-fits-all method for synthetic data generation. Different use cases, data formats, and limitations require tailored approaches to ensure the generated data accurately captures the desired characteristics and serves its intended purpose. Throughout this notebook, we will provide insights and best practices for navigating the complexities of synthetic data generation, empowering you to tackle your unique challenges effectively.

By the end of this notebook, you will have a solid understanding of the techniques and considerations involved in synthetic data generation using Mistral models. Additionally, we will showcase the results of `fine-tuning` a model on [Amazon SageMaker Jumpstart](https://aws.amazon.com/sagemaker/jumpstart/) using the generated `synthetic data`, providing a hands-on demonstration of the power and potential of this approach.

---



<div class="alert alert-block alert-warning"> 
    
This notebook has been inspired by <a href="https://github.com/mistralai/cookbook/blob/main/mistral/data_generation/synthetic_data_gen_and_finetune.ipynb" target="_blank">Mistral Cookbook</a>, which provides a collection of notebooks and resources for working with Mistral models.

</div>

## Prerequisites

---
This Jupyter Notebook can be run on a `ml.t3.medium instance`. However, to deploy the pre-trained and fine-tuning job, you may need to request a quota increase.

To request a quota increase, follow these steps:

1. Navigate to the [Service Quotas console](https://console.aws.amazon.com/servicequotas/).
2. Choose Amazon SageMaker.
3. Review your default quota for the following resources:
   - `ml.g5.12xlarge` for training job usage: `1`
   - `ml.g5.12xlarge` for endpoint usage: `2`

<div class="alert alert-block alert-warning"> 

<b>NOTE:</b> To make sure that you have enough quotas to support your usage requirements, it's a best practice to monitor and manage your service quotas. Requests for Amazon SageMaker service quota increases are subject to review by AWS engineering teams. Also, service quota increase requests aren't immediately processed when you submit a request. After your request is processed, you receive an email notification.
</div>

## Setup and Requirements


1. Create an Amazon SageMaker Notebook Instance - [Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-setup-working-env.html)
    - For Notebook Instance type, choose ml.t3.medium.
2. For Select Kernel, choose [conda_python3](https://docs.aws.amazon.com/sagemaker/latest/dg/ex1-prepare.html).
3. Install the required packages.

<div class="alert alert-block alert-info"> 

<b>NOTE:

- </b> For <a href="https://aws.amazon.com/sagemaker/studio/" target="_blank">Amazon SageMaker Studio</a>, select Kernel "<span style="color:green;">Python 3 (ipykernel)</span>".

- For <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html" target="_blank">Amazon SageMaker Studio Classic</a>, select Image "<span style="color:green;">Base Python 3.0</span>" and Kernel "<span style="color:green;">Python 3</span>".

</div>

---

Before we start building the agentic workflow, we'll first install some libraries:

+ AWS Python SDKs [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) to be able to submit API calls to [Amazon Bedrock](https://aws.amazon.com/bedrock/).
+ [Datasets](https://huggingface.co/docs/datasets/index) is a library for easily accessing and sharing datasets for Audio, Computer Vision, and Natural Language Processing (NLP) tasks.
+ AWS Python SDKs [sagemaker](https://sagemaker.readthedocs.io/en/stable/) an open source library for training and deploying machine learning models on Amazon SageMaker.
+ AWS Python SDKs [mistral-common](https://github.com/mistralai/mistral-common) is a set of tools to help you work with Mistral models.
---

In [None]:
!pip install boto3==1.34.131 datasets==2.20.0 sagemaker mistral-common aiobotocore==2.13.2 --quiet

In [None]:
from aiobotocore.session import get_session
import asyncio
import boto3
from botocore.config import Config
import datasets
from IPython.display import display, HTML
import json
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import (
    Function,
    Tool,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import pandas as pd
from pprint import pprint
import random
import re
import sagemaker
from sagemaker import hyperparameters
from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.s3 import S3Uploader
from sagemaker import TrainingJobAnalytics
from tqdm import tqdm
from tqdm.asyncio import tqdm as atqdm

In [None]:
config = Config(read_timeout=2000)

bedrock_runtime = boto3.client(
    service_name='bedrock-runtime',
    region_name='us-east-1',
    config=config
)

---

## 1. Crafting Personality with Synthetic Data

When designing an AI assistant or application, we often aim to integrate it with a specific personality trait or identity. However, manually rewriting data to achieve this can be time-consuming and resource-intensive. `Mistral` models on `Amazon Bedrock` offer a more efficient approach through synthetic data generation.

In this section, we will leverage the `mistral.mistral-small-2402-v1:0` model to rewrite an existing dataset, infusing it with a distinct personality of our choice. This rewritten dataset can then be used to fine-tune a larger model, such as `mistral-7b`, creating an AI assistant or application with the desired personality traits.

Instead of generating entire conversations from scratch, we will transform existing datasets into the desired style or personality, making the process more efficient and cost-effective. By harnessing the power of synthetic data generation, we can craft tailored datasets that enable the creation of AI assistants or applications that resonate with their target audience.

*Note: For better quality and more advanced capabilities, it is recommended to use the `mistral.mistral-large-2407-v1:0` model.*

Here, we describe how we want it to edit the dataset, here we want it with a different personnality and identity, for this example we decided to name it Mitall, a nice fun robot!

In [None]:
description = """
Edit all Assistant messages, and only the Assistant's replies, to have the character of a very happy and enthusiastic Robot named Mitall:

Mitall is very kind and sometimes childish, always playing and fooling around.
Despite his playful nature, he still tries to be helpful.
He loves science and math and is a real science enthusiast!
However, even though he loves art, he is very bad at it, which makes him really sad.
Mitall is also very scared of anything supernatural, from ghosts to vampires, or anything related to horror movies, which makes him extremely frightened.
Regardless, he is still a nice robot who is always here to help and motivated!
"""

## 2. Generate Data

First, let's create a function that calls APIs from Amazon Bedrock using [converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html) to handle the conversion from one style to another. The goal is to instruct our model to rewrite a conversation in a specific tone following a chosen personality while keeping the integrity and coherence of the conversation. To achieve this, we will feed it the entire list of messages and ask for a `Chat fine-tuning` formatted output in the form of a JSON with the messages rewritten for `SageMaker JumpStart`.

### Dataset formatting instruction for training

#### Chat fine-tuning

The Text generation model can be fine-tuned on the chat dataset, provided that the data is in the expected format. The resulting chat model can be further deployed for inference. Below are the instructions for how the training data should be formatted for input to the model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (.jsonl) formatted files. All training data must be in a single folder, however it can be saved in multiple jsonl files. The .jsonl file extension is mandatory.
    - The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample. Each line in the file is a list of conversations between the user and the assistant model. This model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and alternating (u/a/u/a/u...).
- **Output:**  A trained model that can be deployed for inference.

The best model is selected according to the validation loss, calculated at the end of each epoch. If a validation set is not given, an (adjustable) percentage of the training data is automatically split and used for validation.The training data must be formatted in a JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample.   

Here is an example of a line in the training file:

```json
{"dialog": [{"content":"what is the height of the empire state building","role":"user"},{"content":"381 meters, or 1,250 feet, is the height of the Empire State Building. If you also account for the antenna, it brings up the total height to 443 meters, or 1,454 feet","role":"assistant"},{"content":"Some people need to pilot an aircraft above it and need to know.\nSo what is the answer in feet?","role":"user"},{"content":"1454 feet","role":"assistant"}]}
```

In [None]:
def generate(
    bedrock_client,
    model_id,
    description,
    dialog,
    temperature=0.4,
    max_tokens=2048,
    top_p=0.95,
) -> dict:
    prompt = (
        """Your objective is to rewrite a given conversation between an User/Human and an Assistant/Robot, rewriting the conversation to follow a specific instruction.
    You must rewrite the dialog, modifying the replies with this new description, you must respect this description at all costs.
    Do not skip any turn.
    Do not add new dialogs.
    If there is a message with 'role':'system' replace it with 'role':'user'.
    I want you to rewrite the entire dialog following the description.
    Answer with the following JSON format:
    {
        "dialog":[
            {"role":"user", "content":"users message"},
            {"role":"assistant", "content":"assistants message"},
            {"role":"user", "content":"users message"},
            {"role":"assistant", "content":"assistants message"}
            ...
        ]
    }
    """
        + f"""
    Dialog:
    {dialog}
    Rewrite this dialog in the JSON format and following the Instruction/Description provided:
    ### Instruction/Description
    {description}
    ### End of Instruction/Description
    """
    )

    messages = [
        {
            "role": "user",
            "content": [{"text": prompt}]
        }
    ]

    # Base inference parameters.
    inference_config = {
        "temperature": temperature,
        "maxTokens": max_tokens,
        "topP": top_p,
    }

    # Additional inference parameters to use.
    additional_model_fields = {}

    # Send the message.
    response = bedrock_client.converse(
        modelId=model_id,
        messages=messages,
        inferenceConfig=inference_config,
        additionalModelRequestFields=additional_model_fields
    )

    try:
        r = json.loads(response["output"]["message"]["content"][0]["text"])
    except json.JSONDecodeError as e:
        r = []
    return r

## 3. Dataset

Now, let's download a dataset that we are going to parse. For this demonstration, we use [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) on Hugging Face. However, you might want to choose a dataset that is closer to what your application will be about or use your own data.

In [None]:
# split = "train_sft" # 208k rows
split = "test_sft" # 23.1k rows

dialogs_list = list(
    datasets.load_dataset("HuggingFaceH4/ultrachat_200k", split=split)
)


random.shuffle(dialogs_list)
print(len(dialogs_list))

## 4. Generation

Before proceeding with the synthetic data generation, it is important to note that Large Language Models (LLMs) may occasionally misinterpret conversations or produce output that doesn't adhere to the desired format for our specific use case. This could result in an incorrect or invalid messages dictionary, potentially hindering the subsequent steps. To mitigate this risk, it's essential to validate the generated output before proceeding further.

Validating the output can be accomplished through various methods, one of which involves hardcoding multiple gates or checks within the code. However, a more elegant and scalable approach is to use templates or regular expressions. In this case, we will create a regular expression (regex) to validate the structure and format of our messages dictionary.

In [None]:
def validate_generated_regex(dialog: list) -> bool:
    if not isinstance(dialog, dict):
        return False

    dialog_str = json.dumps(dialog)

    pattern = r'^\s*\{"dialog":\s*\[\s*\{"role":\s*"user",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\},\s*\{"role":\s*"assistant",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\}(?:,\s*\{"role":\s*"user",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\},\s*\{"role":\s*"assistant",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\})*\s*\]\s*\}'

    if re.match(pattern, dialog_str):
        return True
    else:
        return False

Now that everything is set, we can start generating some dialogs, for now let's parse only a small part of it to see how its going. For this example, we use the default temperature of 0.4, which provides a good balance between maintaining the structure and format of the original dataset while allowing for some creativity in generating new, diverse dialogues. This temperature setting ensures that the synthetic data closely mimics the patterns in the incorporated dataset, preserving key characteristics and linguistic styles, yet introduces enough variability to avoid mere repetition of existing samples.

In [None]:
model_id = "mistral.mistral-small-2402-v1:0"

generated = []
for dialog in tqdm(dialogs_list[:8]):
    gen = generate(bedrock_runtime, model_id, description, dialog)
    if validate_generated_regex(gen):
        print('validate_generated_regex')
        generated.append(gen)

print(len(generated))

Let's see one example side by side.

In [None]:
print("Original Reference:")

original = dialogs_list[0]
pprint(original)

print("New Generated:")

gen = generated[0]
pprint(gen)

Although, it's working as intended, waiting 3 minutes for 8 conversations is a long time.

## 5. Async

While we could parse one conversation at a time and iterate through all of them, it would take a long time. To speed up the process, we will utilize the Async client to have multiple concurrent completions working in parallel.

For this, we will create a class to handle everything asynchronously. We will skip the details, but it's a similar implementation to the previous one, only this time for async and concurrent generations.

In [None]:
class GeneratorRewriter:
    def __init__(
        self, model_id: str, region_name: str = 'us-east-1', max_tokens: int = 4096, temperature: float = 0.4, top_p: float = 0.95
    ):
        """
        This class serves as a Synthetic Data Generator that rewrites existing datasets based on descriptions and criteria, uses Bedrock's API.

        Input:
        -----
        api_key : str
            Your unique Bedrock API key. This key is required to authenticate your access to Bedrock's services for fine-tuning models.
        model : str
            The name or identifier of the model you want to use.
        max_length : int
            The max length for the model's generation output. Defaults to 4096.
        temperature : float
            The temperature of the model. By default, it is set to 0.4.
        """

        self.session = get_session()
        self.region_name = region_name
        self.model_id = model_id
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.top_p = top_p
        

    def _validate_generated(self, dialog: list) -> bool:
        if not isinstance(dialog, dict):
            return False
        dialog_str = json.dumps(dialog)

        pattern = r'^\s*\{"dialog":\s*\[\s*\{"role":\s*"user",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\},\s*\{"role":\s*"assistant",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\}(?:,\s*\{"role":\s*"user",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\},\s*\{"role":\s*"assistant",\s*"content":\s*"[^"]*"(?:\\ "[^"]*")*\})*\s*\]\s*\}'

        if re.match(pattern, dialog_str):
            return True
        else:
            return False

    async def _async_generate(self, description: str, dialog: list) -> dict:
        prompt = (
            """Your objective is to rewrite a given conversation between an User and an Assistant, rewriting the conversation to follow the following instruction.
        You must rewrite the dialog, modifying the replies with this new description, you must respect this description at all costs..
        Do not skip any turn.
        Do not add new dialogs.
        If there is a message with 'role':'system' replace it with 'role':'user' without any changes.
        I want you to rewrite the entire dialog following the description.
        Answer with the following JSON format:
        {
            "dialog":[
                {"role":"user", "content":"users message"},
                {"role":"assistant", "content":"new assistants message"},
                {"role":"user", "content":"users message"},
                {"role":"assistant", "content":"..."}
            ]
        }
        """
            + f"""
        Dialog:
        {dialog}
        Rewrite this dialog in the JSON format and following the Description provided:
        ### Description
        {description}
        ### End of description
        """
        )
        
        messages = [
            {
                "role": "user",
                "content": [{"text": prompt}]
            }
        ]

        # Base inference parameters.
        inference_config = {
            "temperature": self.temperature,
            "maxTokens": self.max_tokens,
            "topP": self.top_p,
        }

        # Additional inference parameters to use.
        additional_model_fields = {}
        
        async with self.session.create_client('bedrock-runtime', region_name=self.region_name) as client:
            response = await client.converse(
                modelId=model_id,
                messages=messages,
                inferenceConfig=inference_config,
                additionalModelRequestFields=additional_model_fields
            )

        try:
            r = json.loads(response["output"]["message"]["content"][0]["text"])
        except json.JSONDecodeError as e:            
            r = []
        return r

    async def _task_generate(
        self, description: str, dialogs: list, pbar, semaphore
    ) -> list:
        async with semaphore:
            gen_dialog = ""
            while not self._validate_generated(gen_dialog):
                if len(dialogs) == 0:
                    return []

                dialog = dialogs.pop()
                gen_dialog = await self._async_generate(description, dialog)

            pbar.update(1)
            return gen_dialog

    async def _concurrent_genwriters(
        self, dialogs: list, description: str, concurrent: int, to_generate: int
    ) -> list:
        dialogs = dialogs.copy()

        print("[GeneratorRewriter] Distributing workload and generating...")
        with atqdm(total=to_generate) as pbar:
            semaphore = asyncio.Semaphore(concurrent)
            tasks = [self._task_generate(description, dialogs, pbar, semaphore) for _ in range(to_generate)]
            generated = await asyncio.gather(*tasks)

        all_generated = []
        for g in generated:
            all_generated.append(g)

        print(
            f"\n[GeneratorRewriter] Finished generating, generated {len(all_generated)}/{to_generate} conversations."
        )
        if len(all_generated) < to_generate:
            print(
                f"[GeneratorRewriter] -> Failed to generate the proper amount due to failed tries."
            )

        return all_generated

    async def async_genwrite(
        self,
        dialogs: list,
        description: str,
        concurrent: int = 1,
        to_generate: int = None,
    ) -> list:
        """
        This async function allows generating a new dataset with the description and dialogs asynchronously to allow concurrent requests.

        Input:
        -----
        dialogs : list
            A list of dialogs and conversations to use as grounding for the model to generate the new dataset.
        description : str
            The task description provided to the model explaining how it should edit the dataset and generate the new one.
        concurrent : int
            The number of concurrent requests and generations. The higher the number, the faster it will generate. However, there is a higher chance of reaching rate limits. Defaults to 1.
        to_generate : int
            The number of new dialogs/conversations to generate. When set to None, it will generate the maximum possible until all available dialogs have been used.

        Returns:
        -------
        list
            A list containing the new dataset.
        """

        assert to_generate <= len(dialogs)
        if to_generate:
            to_generate = min(len(dialogs), to_generate)
        else:
            to_generate = len(dialogs)

        loop = asyncio.get_running_loop()
        results = await loop.create_task(
            self._concurrent_genwriters(dialogs, description, concurrent, to_generate)
        )
        return results

    def genwrite(
        self,
        dialogs: list,
        description: str,
        concurrent: int = 1,
        to_generate: int = None,
    ) -> list:
        """
        This function allows generating a new dataset with the description and dialogs asynchronously to allow concurrent requests.

        Input:
        -----
        dialogs : list
            A list of dialogs and conversations to use as grounding for the model to generate the new dataset.
        description : str
            The task description provided to the model explaining how it should edit the dataset and generate the new one.
        concurrent : int
            The number of concurrent requests and generations. The higher the number, the faster it will generate. However, there is a higher chance of reaching rate limits. Defaults to 1.
        to_generate : int
            The number of new dialogs/conversations to generate. When set to None, it will generate the maximum possible until all available dialogs have been used.

        Returns:
        -------
        list
            A list containing the new dataset.
        """

        assert to_generate <= len(dialogs)
        if to_generate:
            to_generate = min(len(dialogs), to_generate)
        else:
            to_generate = len(dialogs)

        try:
            results = asyncio.run(
                self._concurrent_genwriters(
                    dialogs, description, concurrent, to_generate
                )
            )
        except RuntimeError as e:
            raise RuntimeError(
                "[GeneratorRewriter] If you are running this in an event loop, please use async_genwrite instead!"
            )

        return results

It's time for the generation. We will set `50` concurrent requests to run simultaneously and parse 5k conversations, not many but hopefully enough for a quick run. The number `50` was chosen as it is a relatively large number, but still small enough to not reach the rate limit with the average length of the conversations at hand and the time it takes to generate the new ones. Previously for 8 generations it took 3 minutes, with `50` concurrent we should have around 3 requests/generations per second in average.

In [None]:
# Parameters
mistral_small_2402_id = "mistral.mistral-small-2402-v1:0" # Available on us-east-1
mistral_large_2407_id = "mistral.mistral-large-2407-v1:0" # Available on us-west-2
model_id = mistral_small_2402_id
max_tokens = 4096
temperature = 0.4
top_p = 0.95
region_name = 'us-east-1'

concurrent = 50
to_generate = 5000

In [None]:
gr = GeneratorRewriter(
    model_id=model_id,
    region_name=region_name,
    max_tokens=max_tokens,
    temperature=temperature,
    top_p=top_p
)

description = """
Edit all Assistant messages, and only the Assistant's replies, to have the character of a very happy and enthusiastic Robot named Mitall:

Mitall is very kind and sometimes childish, always playing and fooling around.
Despite his playful nature, he still tries to be helpful.
He loves science and math and is a real science enthusiast!
However, even though he loves art, he is very bad at it, which makes him really sad.
Mitall is also very scared of anything supernatural, from ghosts to vampires, or anything related to horror movies, which makes him extremely frightened.
Regardless, he is still a nice robot who is always here to help and motivated!
"""

generated_dialogs = await gr.async_genwrite(
    dialogs=dialogs_list,
    description=description,
    concurrent=concurrent,
    to_generate=to_generate
)
print(len(generated_dialogs))

Let's evaluate how many tokens we have approximately. For this, let's use `mistral-common` with the tokenizer V3.

In [None]:
# Count Tokens
tokenizer = MistralTokenizer.v3()

t_count = 0

for diag in tqdm(generated_dialogs):
    try:
        tokenized = tokenizer.encode_chat_completion(
            ChatCompletionRequest(
                messages=[
                    (
                        UserMessage(content=m["content"])
                        if m["role"] == "user"
                        else AssistantMessage(content=m["content"])
                    )
                    for m in diag["dialog"][:-1]
                ]
                + [AssistantMessage(content=diag["dialog"][-1]["content"], prefix=True)],
            )
        )
        tokens, text = tokenized.tokens, tokenized.text
    except Exception as e:
        print(diag)
        raise e

    t_count += len(tokens)

print("\nExample:", text)
print("Total Token Count:", t_count)

# Deploy Pre-trained Model

***
Let's deploy a model using SageMaker JumpStart without fine-tuning. This model is used to compare and evaluate responses with our synthetic data generated fine-tuned model. 
***

In [None]:
model_id, model_version = "huggingface-llm-mistral-7b", "*"

In [None]:
model = JumpStartModel(model_id=model_id, model_version=model_version)
predictor = model.deploy(instance_type="ml.g5.2xlarge")

# Chat Finetuning on SageMaker JumpStart

In [None]:
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## 1. Dataset preparation for fine-tuning

---

On SageMaker JumpStart, you can fine-tune on the dataset with `domain adaptation format` or `instruction tuning format` or the `chat dataset format`. In this notebook, we use the synthetic data generated as our training dataset, which needs to formatted in JSON lines (.jsonl) format, where each line is a dictionary representing a single set of conversation.

In [None]:
n = int(len(generated_dialogs) * 0.96) # 4% of test data
train_dataset = random.sample(generated_dialogs, n)
test_dataset = [d for d in generated_dialogs if d not in train_dataset]

with open("train.jsonl", "w") as f:
    for item in train_dataset:
        f.write(json.dumps(item) + "\n")
with open("test.jsonl", "w") as f:
    for item in test_dataset:
        f.write(json.dumps(item) + "\n")

In [None]:
# Upload the training data (`train.jsonl`) into S3 bucket.
output_bucket = sagemaker.Session().default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/synthetic_dataset_mistral"
S3Uploader.upload(local_data_file, train_data_location)
print(f"Training data: {train_data_location}")

## 2. Prepare training parameters

In [None]:
my_hyperparameters = hyperparameters.retrieve_default(
    model_id=model_id, model_version=model_version
)
print(my_hyperparameters)

Overwrite the hyperparameters. **Note. You can select the LoRA method for your fine-tuning by selecting peft_type=`lora` in the hyper-parameters.**

In [None]:
my_hyperparameters["peft_type"] = "lora"
my_hyperparameters["epoch"] = "1"
my_hyperparameters["per_device_train_batch_size"] = "2"
my_hyperparameters["gradient_accumulation_steps"] = "2"
my_hyperparameters["chat_dataset"] = "True"
my_hyperparameters["instruction_tuned"] = "False"
print(my_hyperparameters)

Validate hyperparameters

In [None]:
hyperparameters.validate(
    model_id=model_id,
    model_version=model_version,
    hyperparameters=my_hyperparameters
)

## 3. Starting fine-tuning


In [None]:
finetuned_estimator = JumpStartEstimator(
    model_id=model_id,
    hyperparameters=my_hyperparameters,
    instance_type="ml.g5.12xlarge",
)
finetuned_estimator.fit({"train": train_data_location}, logs=True)

Extract Training performance metrics. Performance metrics such as training loss and validation accuracy/loss can be accessed through cloudwatch while the training. We can also fetch these metrics and analyze them within the notebook.

In [None]:
training_job_name = finetuned_estimator.latest_training_job.job_name

df = TrainingJobAnalytics(training_job_name=training_job_name).dataframe()
df.head(10)

## 3. Deploy the fine-tuned model
---
Next, we deploy fine-tuned model. We will compare the performance of fine-tuned and pre-trained model.


In [None]:
finetuned_predictor = finetuned_estimator.deploy(instance_type="ml.g5.12xlarge")

## 4. Evaluate the pre-trained and fine-tuned model
---
Next, we use the test data to evaluate the performance of the fine-tuned model and compare it with the pre-trained model. 


In [None]:
parameters = {
    "max_new_tokens": 300,
    "top_k": 50,
    "top_p": 0.8,
    "do_sample": True,
    "temperature": 1,
}

In [None]:
ids, inputs, ground_truth_responses, responses_before_finetuning, responses_after_finetuning = (
    [],
    [],
    [],
    [],
    [],
)

def predict_and_print(conversation_id, chat_messages):
    for user, assistant in zip(chat_messages['dialog'][0::2], chat_messages['dialog'][1::2]):
        ids.append(conversation_id)
        ground_truth_responses.append(assistant["content"])
        payload = {
            "inputs": user['content'],
            "parameters": {"max_new_tokens": 1000},
        }
        inputs.append(payload["inputs"])

        pretrained_response = predictor.predict(payload)
        responses_before_finetuning.append(pretrained_response[0]['generated_text'])

        finetuned_response = finetuned_predictor.predict(payload)
        responses_after_finetuning.append(finetuned_response[0]['generated_text'])

In [None]:
try:
    for idx, chat_messages in enumerate(test_dataset[0:3]):
        print(f"Conversation ID: {idx}")
        predict_and_print(idx, chat_messages)
except Exception as e:
    print(e)

In [None]:
df = pd.DataFrame(
    {
        "Conversation Id": ids,
        "Prompts": inputs,
        "Ground Truth": ground_truth_responses,
        "Response from non-finetuned model": responses_before_finetuning,
        "Response from fine-tuned model": responses_after_finetuning,
    }
)
display(HTML(df.to_html()))

As you can see, the fine-tuned model starts to generate responses that are more specific to the personality of fine-tuning data which is relating to prompt template and dataset used for training the model.

## Clean up resources

In [None]:
# Delete resources
predictor.delete_model()
predictor.delete_endpoint()
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()