# SageMaker JumpStart Foundation Models - Fine-tuning text generation Llama-3 8B model on domain specific dataset

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

---

In this demo notebook, we demonstrate how to use the SageMaker Python SDK for finetuning Llama-3 8B Foundation Models and deploying the trained model for inference. The Foundation models perform Text Generation task. It takes a text string as input and predicts next words in the sequence.

Additionally, this notebook will demonstrate how you can use [Amamzon SageMaker Jumpstart Industry SDK](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart-industry.html) to prepare training data from US SEC filings. 

Below is the content of the notebook.

1. [Domain adaptation fine-tuning](#1.-Domain-adaptation-fine-tuning)
   * [1.1. Set up](#1.1.-Set-up)
   * [1.2. Preparing training data](#1.2.-Preparing-training-data)
   * [1.3. Starting training](#1.4.-Starting-training)
   * [1.4. Deploying inference endpoints](#1.5.-Deploying-inference-endpoints)
   * [1.5. Running inference queries and compare model performances](#1.5.-Running-inference-queries-and-compare-model-performances)
   * [1.6. Clean up](#1.6.-Clean-up)

## 1. Domain adaptation fine-tuning
The Text Generation model can also be fine-tuned on any domain specific dataset. After being fine-tuned on the domain specific dataset, the model
is expected to generate domain specific text and solve various NLP tasks in that specific domain with **few shot prompting**.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Each directory contains a CSV/JSON/TXT file. 
  - For CSV/JSON files, the train or validation data is used from the column called 'text' or the first column if no column called 'text' is found.
  - The number of files under train and validation (if provided) should equal to one, respectively. 
- **Output:** A trained model that can be deployed for inference. 

Below is an example of a TXT file for fine-tuning the Text Generation model. The TXT file is SEC filings of Amazon from year 2024.

```Note About Forward-Looking Statements
This Annual Report on Form 10-K includes forward-looking statements within the
meaning of the Private Securities Litigation Reform Act of 1995. All
statements other than statements of historical fact, including statements
regarding guidance, industry prospects, or future results of operations or
financial position, made in this Annual Report on Form 10-K are forward-
looking. We use words such as anticipates, believes, expects, future, intends,
and similar expressions to identify forward-looking statements. Forward-
looking statements reflect management’s current expectations and are
inherently uncertain. Actual results and outcomes could differ materially for
a variety of reasons, including, among others, fluctuations in foreign
exchange rates, changes in global economic conditions and customer demand and
spending, inflation, interest rates, regional labor market constraints, world
events, the rate of growth of the internet, online commerce, cloud services,
and new and emerging technologies, the amount that Amazon.com invests in new
business opportunities and the timing of those investments, the mix of
products and services sold to customers, the mix of net sales derived from
products as compared with services, the extent to which we owe income or other
taxes, competition, management of growth, potential fluctuations in operating
results, international growth and expansion, the outcomes of claims,
litigation, government investigations, and other proceedings, fulfillment,
sortation, delivery, and data center optimization, risks of inventory
management, variability in demand, the degree to which we enter into,
maintain, and develop commercial agreements, proposed and completed
acquisitions and strategic transactions, payments risks, and risks of
fulfillment throughput and productivity. In addition, global economic and
geopolitical conditions and additional or unforeseen circumstances,
developments, or events may give rise to or amplify many of these risks. These
risks and uncertainties, as well as other risks and uncertainties that could
cause our actual results or outcomes to differ significantly from management’s
expectations, are described in greater detail in Item 1A of Part I, “Risk
Factors.”

GENERAL
Embracing Our Future ...
```


#### 2.2. Instruction fine-tuning
The Text generation model can be instruction-tuned on any text data provided that the data 
is in the expected format. The instruction-tuned model can be further deployed for inference. 
Below are the instructions for how the training data should be formatted for input to the 
model.

Below are the instructions for how the training data should be formatted for input to the model.

- **Input:** A train and an optional validation directory. Train and validation directories should contain one or multiple JSON lines (`.jsonl`) formatted files. In particular, train directory can also contain an optional `*.json` file describing the input and output formats. 
  - The best model is selected according to the validation loss, calculated at the end of each epoch.
  If a validation set is not given, an (adjustable) percentage of the training data is
  automatically split and used for validation.
  - The training data must be formatted in a JSON lines (`.jsonl`) format, where each line is a dictionary
representing a single data sample. All training data must be in a single folder, however
it can be saved in multiple jsonl files. The `.jsonl` file extension is mandatory. The training
folder can also contain a `template.json` file describing the input and output formats. If no
template file is given, the following template will be used:
  ```json
  {
    "prompt": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{context}",
    "completion": "{response}"
  }
  ```
  - In this case, the data in the JSON lines entries must include `instruction`, `context` and `response` fields. If a custom template is provided it must also use `prompt` and `completion` keys to define
  the input and output templates.
  Below is a sample custom template:

  ```json
  {
    "prompt": "question: {question} context: {context}",
    "completion": "{answer}"
  }
  ```
Here, the data in the JSON lines entries must include `question`, `context` and `answer` fields. 
- **Output:** A trained model that can be deployed for inference. 

---

### 1.1 Set Up
Before executing the notebook, there are some initial steps required for setup.

In [None]:
# Install required libraries
!pip install sagemaker smjsindustry sec-edgar-downloader

In [None]:
(
    modelid,
    modelversion,
) = (
    "meta-textgeneration-llama-3-8b",
    "2.*",
)

### 1.2 Preparing training data

We will go step by step to prepare domain dataset for fine tuning using filed 10-K reports by Amazon from year 2024 and 2023.

We provide a subset of SEC filings data of Amazon in domain adaptation dataset format. It is downloaded from publicly available [EDGAR](https://www.sec.gov/edgar/searchedgar/companysearch). Instruction of accessing the data is shown [here](https://www.sec.gov/os/accessing-edgar-data).

License: [Creative Commons Attribution-ShareAlike License (CC BY-SA 4.0)](https://creativecommons.org/licenses/by-sa/4.0/legalcode).


In [None]:
# Import required packages
import boto3
import os
import pandas as pd
import sagemaker
import smjsindustry
import shutil
from smjsindustry.finance.processor import DataLoader, SECXMLFilingParser
from sagemaker.jumpstart.estimator import JumpStartEstimator

#### Prepare the SageMaker session's default S3 bucket and a folder to store processed data

In [None]:
session = sagemaker.Session()
bucket = session.default_bucket()
sec_processed_folder = "amazon_sec_filing_data"
default_bucket_prefix = session.default_bucket_prefix

# If a default bucket prefix is specified, append it to the s3 path
if default_bucket_prefix:
    sec_processed_folder = f"{default_bucket_prefix}/{sec_processed_folder}"

#### Create local directories for data preprocessing

In [None]:
# create local directories for data preprocessing
if not os.path.exists("./rawfiles"):
    os.makedirs("./rawfiles")
if not os.path.exists("./parsedata"):
    os.makedirs("./parsedata")

#### Download 10-K reports from SEC database

In [None]:
from sec_edgar_downloader import Downloader

dl = Downloader("Amazon", "companyinfo@amazon.com")
# Get the latest 10-K filing for Amazon
dl.get("10-K", "AMZN", limit=2)

In [None]:
## Move the downloaded files into 'rawfiles' folder


def rename_files_in_directory(directorypath):
    # Get the directory name
    dir_name = os.path.basename(directorypath)

    # Iterate over files in the directory
    for count, filename in enumerate(os.listdir(directorypath), start=1):
        # Construct the new file name
        new_filename = f"{dir_name}_{count:03d}{os.path.splitext(filename)[1]}"

        # Construct the old and new file paths
        old_path = os.path.join(directorypath, filename)
        new_path = os.path.join(directorypath, new_filename)

        # Rename the file
        os.rename(old_path, new_path)


def looplistvalues(listvalues, tuplevalue):
    for item in listvalues:
        if isinstance(item, str):
            if item.endswith(".txt"):
                oldname = item
                filenames = list(tuplevalue)[0].split("/")
                ## rename files in the directory
                rename_files_in_directory(list(tuplevalue)[0])
                break
        if isinstance(item, list):
            looplistvalues(item, tuplevalue)


# Get the directory name and rename file for merge
directory = "./sec-edgar-filings/AMZN/10-K"
for dirs in os.walk(directory):
    tuplevalue = dirs
    listvalues = list(dirs)
    looplistvalues(listvalues, tuplevalue)

# Move the files
for root, dirs, files in os.walk("./sec-edgar-filings/AMZN/10-K"):
    for file in files:
        path_file = os.path.join(root, file)
        shutil.copy2(path_file, "./rawfiles")

#### Parse the raw files using SEC Parser

In [None]:
%%time
parser = SECXMLFilingParser(
    role=sagemaker.get_execution_role(),
    instance_count=1,
    instance_type="ml.c5.2xlarge",
    sagemaker_session=sagemaker.Session(),
)
parser.parse(
    "rawfiles",
    "s3://{}/{}/{}".format(bucket, sec_processed_folder, "output"),
)

#### Collate parsed data in S3 and download for local view

In [None]:
s3client = boto3.client("s3")

file_contents = []
response = s3client.list_objects_v2(
    Bucket=bucket, Prefix="{}/{}/".format(sec_processed_folder, "output")
)

for obj in response.get("Contents", []):
    file = obj["Key"]
    # print(file)
    # Read the file contents from Amazon S3
    try:
        response = s3client.get_object(Bucket=bucket, Key=file)
        file_content = response["Body"].read().decode("utf-8")
        file_contents.append(file_content)
        file_contents.append("\n\n")
    except s3client.exceptions.NoSuchKey:
        print(f"The file {file} does not exist in the bucket {bucket}.")

# convert string data into bytes
file_contents = "".join(file_contents).encode("utf-8")

# put file object to Amazon s3 with collated content
filename = "combined.txt"
try:
    s3client.put_object(
        Body=file_contents,
        Bucket=bucket,
        Key="{}/{}/{}".format(sec_processed_folder, "output", filename),
    )
    file_contents = ""
except Exception as e:
    print("Fail to write")
    print(e)

try:
    s3client.download_file(
        bucket,
        "{}/{}/{}".format(sec_processed_folder, "output", filename),
        "./parsedata/{}".format(filename),
    )
except Exception as e:
    print("Fail to download")
    print(e)

### 1.3 Starting Training
***
We start by creating the estimator object with all the required assets and then launch the training job.  Since default hyperparameter values are model-specific, inspect estimator.hyperparameters() to view default values for your selected model.
***

In [None]:
domain_training_data_location = "s3://{}/{}/{}/{}".format(
    bucket, sec_processed_folder, "output", "combined.txt"
)

estimator = JumpStartEstimator(
    model_id=modelid,
    environment={"accept_eula": "true"},
    instance_type="ml.g5.12xlarge",
    hyperparameters={"epoch": "5", "per_device_train_batch_size": "4"},
)

estimator.fit({"training": domain_training_data_location})

### 1.4 Deploying Inference Endpoints
A trained model does nothing on its own. We now want to use the model to perform inference.

In [None]:
# You can deploy the fine-tuned model to an endpoint directly from the estimator.
domain_fine_tuned_predictor = estimator.deploy()

In [None]:
# ## If kernel dies out or active reference to training estimator is not available
# ## Then, retrieve the reference for estimator
# from sagemaker.jumpstart.estimator import JumpStartEstimator


# training_job_name = "Your estimator training job name"
# model_id = modelid
# attached_estimator = JumpStartEstimator.attach(training_job_name, model_id)
# attached_estimator.logs()
# domain_adaption_predictor = attached_estimator.deploy()

In [None]:
parameters = {
    "max_new_tokens": 300,
    "top_k": 50,
    "top_p": 0.8,
    "do_sample": True,
    "temperature": 0,
}

payload = {"inputs": "Risk factors highlighted in this 10-K report", "parameters": parameters}
domain_fine_tuned_predictor.predict(payload)

### 1.5 Clean Up

In [None]:
## Remove local directories
!rm -rf ./rawfiles/
!rm -rf ./parsedata/
!rm -rf ./sec-edgar-filings/

In [None]:
## Remove Model & Endpoints
domain_fine_tuned_predictor.delete_model()
domain_fine_tuned_predictor.delete_endpoint()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/generative_ai|sm-jumpstart_foundation_llama_3_8b_domain_adaption_finetuning.ipynb)
