# Fine-tune Llama 3.1 models using torchtune on Amazon SageMaker

In this notebook, we are using Meta’s torchtune library to fine-tune Llama 3.1 8B model with LoRA fine-tuning strategies on Amazon SageMaker training. 

**torchtune** is a Native-PyTorch library that aims to democratize and streamline the fine-tuning process for LLMs, making it easier for researchers, developers, and organizations to adapt these powerful LLMs to their specific needs and constraints. 

In this use case, we are walking through an end-to-end example on how you can fine-tune a Llama 3.1 8B model with LoRA, run generation in memory, and optionally quantize and evaluate the model  using torchtune and SageMaker training.  

Recipes, prompt templates, configs and datasets are completely configurable and allows you to align torchtune to your requirements. To demonstrate this, we will use a custom prompt template in this use case with the open source dataset Samsung/samsum from the Hugging Face hub.

We are fine-tune using torchtune multi-device LoRA recipe (lora_finetune_distributed) and use the SageMaker customized version of Llama 3.1 8B  default config (llama3_1/8B_lora).

## 1. Setup Development Environment

Our first step is to install torchtune and SageMaker Libraries we need on the client to correctly prepare our dataset and start our training/evaluations jobs. 

In [None]:
!pip uninstall "autogluon-multimodal" "aiobotocore" "amazon-sagemaker-sql-magic" "autogluon-core" "autogluon-features" "autogluon-tabular" "autogluon-timeseries" "langchain-aws" "sparkmagic" "virtualenv" --quiet

In [None]:
!pip install "sagemaker" "boto3" "datasets" "py7zr" --upgrade --quiet

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.



In [None]:
import sagemaker, boto3, time, json
from sagemaker.pytorch import PyTorch
from typing import Dict, Any
from pprint import pprint
from sagemaker.inputs import FileSystemInput

sagemaker_session = sagemaker.Session()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sagemaker_session is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sagemaker_session.default_bucket()

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session_bucket}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

## 1.1 Pre-Process data

In [None]:
from datasets import load_dataset

dataset = load_dataset("Samsung/samsum", trust_remote_code=True)

In [None]:
dataset_sample=dataset['train'].select(range(100))

dataset_sample

In [None]:
dataset_sample.to_json('./dataset/samsum_train.json')

## 1.2 Define Variables

In the following cells, we will retrieve the necessary information to be used when configuring the PyTorch Estimator later in step 1.5. As part of the set-up for this workshop, we have created an S3 bucket, an EFS shared file system for you and a VPC including subnets to use for the training job. As these need to be specified later in the Estimator, we need to go ahead and retrieve the details. We will retrieve these information by querying the CloudFormation stack that was deployed.

In [None]:
#1. CF Stack Name
stack_name='cf'

#2. Region name
region= sagemaker_session.boto_region_name

#3. Model that we will fine-tune
model_id="meta-llama/Meta-Llama-3.1-8B"


In [None]:
# Get EFS-id, private-subnet-id and EFS-id for next step of fine-tuning

def get_stack_outputs(stack_name, region='us-west-2'):
    """
    Retrieves all outputs from a CloudFormation stack.
    
    :param stack_name: Name of the CloudFormation stack
    :param region: AWS region where the stack is deployed (default is 'us-east-1')
    :return: Dictionary of stack outputs
    """
    cfn_client = boto3.client('cloudformation', region_name=region)
    
    try:
        response = cfn_client.describe_stacks(StackName=stack_name)
        stack_outputs = response['Stacks'][0]['Outputs']
        
       # print(stack_outputs)
        # Convert the list of outputs to a dictionary for easier access
        outputs_dict = {output['OutputKey']: output['OutputValue'] for output in stack_outputs}
       

        return outputs_dict
    
    except Exception as e:
        print(f"Error retrieving stack outputs: {str(e)}")
        return None

outputs = get_stack_outputs(stack_name, region)


In [None]:
outputs

## 1.3 Define S3 Bucket

In [None]:
#4. S3 url with model weights
s3_model_artifacts=outputs["S3ModelUri"]

## 1.4 Define EFS and networking

In [None]:
# Define one-time network configuration for VPC to use EFS
# This example has been optimized and tested on EFS. If you want to use S3, please change the config files to match S3 directory path


use_efs=True

# VPC config
network_config={

   "subnets": [outputs['SubnetID1'], 
               outputs['SubnetID2'], 
               outputs['SubnetID3'], 
               outputs['SubnetID4'], 
               outputs['SubnetID5'], 
               outputs['SubnetID6']],
   "security_group_ids": [outputs['SecurityGroup']] # e.g. ["sg-xxxx"]
}

# EFS file system id 
efs_file_system_id=outputs['EFSFileSystemId'] # e.g. 'fs-xxxx'

In [None]:
network_config, efs_file_system_id

## 1.5 Define PyTorch Estimator

The Estimator handles end-to-end SageMaker training. It is the core element of a SageMaker Training Job. The cells will configure all the settings for the SageMaker Training job, including which PyTorch version to use, which training script to execute and more. For more information on the SageMaker Estimator and a detailed overview about all arguments, please refer to the [Documentation here](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html)

In [None]:
use_downloaded_model = "true"

def create_pytorch_estimator(**kwargs: Any) -> PyTorch:
    """
    Create a PyTorch estimator for SageMaker training with dynamic configuration.

    Args:
    **kwargs: Arbitrary keyword arguments for PyTorch estimator configuration.

    Returns:
    PyTorch: Configured PyTorch estimator.

    Raises:
    KeyError: If required parameters are missing in kwargs.
    """        
    
    job_name = f'torchtune-{kwargs["hyperparameters"]["tune_action"]}'
    
    # Upload configs to S3 folder
    inputs = sagemaker_session.upload_data(path="config", bucket=sagemaker_session_bucket, key_prefix="config")
    templates = sagemaker_session.upload_data(path="custom_template", bucket=sagemaker_session_bucket, key_prefix="templates")
    dataset = sagemaker_session.upload_data(path="dataset", bucket=sagemaker_session_bucket, key_prefix="dataset")

    print("torchtune configs uploaded to:{} \n".format(inputs))
    print("and to:{} \n".format(templates))

    env_var = {
        "SAGEMAKER_REQUIREMENTS": "requirements.txt",
    }

    # Default configuration
    estimator_config = {
        "entry_point": "launcher.py",
        "source_dir": "./scripts",
        "base_job_name": job_name,
        "max_run": 86400,
        "framework_version": "2.4.0",
        "py_version": "py311",
        "disable_output_compression": True,
        "keep_alive_period_in_seconds": 1800,
        "env": env_var,
        "role": role,
        "sagemaker_session": sagemaker_session,
        "disable_profiler":True,
        "debugger_hook_config":False
    }

    # Update with provided kwargs
    estimator_config.update(kwargs)

    # Ensure required parameters are present
    required_params = ['instance_type', 'instance_count', 'hyperparameters']
    for param in required_params:
        if param not in estimator_config:
            raise KeyError(f"Missing required parameter: {param}")

    # Configure EFS if specified
    if use_efs:
        required_keys = {'subnets', 'security_group_ids'}
        missing_keys = set(required_keys) - set(network_config.keys())
        
        if missing_keys:
            raise ValueError(f"Missing required keys: {', '.join(missing_keys)}")
    
        for key, value in network_config.items():
            if value is None or len(value) == 0:
                raise ValueError(f"Missing required value for {key}: {value}")
                
        estimator_config.update(network_config)
        
    # Remove 'use_efs' from config as it's not a PyTorch estimator parameter
    estimator_config["hyperparameters"].pop('use_efs', None)
    
    global use_downloaded_model
    use_downloaded_model = estimator_config["hyperparameters"]["use_downloaded_model"]
    use_downloaded_model=bool(use_downloaded_model) and use_downloaded_model.lower() not in ('false', '0', 'no', 'n', 'off')
        
    print("SageMaker PyTorch Estimator: \n")
    pprint(estimator_config)

    return PyTorch(**estimator_config)


Now we define a helper function which will execute the SageMaker Job by calling `estimator.fit()`. This will eventually kick off the SageMaker Job.

In [None]:
def execute_task(estimator):
    """
    Execute the task using the provided estimator and input data channels.

    Args:
    estimator (sagemaker.estimator.Estimator): The SageMaker estimator to use for training.
    s3_config_bucket (str): The S3 bucket path for the configuration data.
    """
        
    if use_efs:
        if efs_file_system_id is None or len(efs_file_system_id) == 0:
            raise ValueError(f"Missing required value for efs_file_system_id: {efs_file_system_id}")
        
        # Define the EFS input
        efs_input = FileSystemInput(
            file_system_id=efs_file_system_id,
            file_system_type='EFS',
            directory_path='/',
            file_system_access_mode='rw'
        )
    else:
        s3 = boto3.client('s3')
        s3.put_object(Bucket=sagemaker_session_bucket, Key="artifacts")
    
    s3_config_bucket = f"s3://{sagemaker_session_bucket}/config"
    s3_custom_template = f"s3://{sagemaker_session_bucket}/templates"
    s3_model_store = f"s3://{sagemaker_session_bucket}/artifacts"
    s3_dataset = f"s3://{sagemaker_session_bucket}/dataset"

    # Define the data channels
    data_channels = {
        "config": s3_config_bucket,
        "model": efs_input if use_efs else s3_model_store,
        "templates":s3_custom_template,
        "dataset":s3_dataset,
        "model_artifacts": s3_model_artifacts
    }
    
    if use_downloaded_model:
        data_channels.pop('model_artifacts', None)

    print(f'data_channels:{data_channels}')
    
    # Fit the estimator with the input data channels
    estimator.fit(inputs=data_channels)

## 1.6 Define SageMaker Tasks

In [None]:
# Set common parameters
hyperparam_common_values={}

hyperparam_common_values["model_id"]=model_id

Define SageMaker tasks for every specific model customization lifecycle step. Each task defines the configuration of the compute cluster that SageMaker will spin up to run the specific torchtune recipe.  


In [None]:
prompt=r'{"dialogue":"Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure \r\nAmanda: I will bring you tomorrow :-)"}'

sagemaker_tasks={}

# Define SageMaker task that will create a specifc SageMaker PyTorch estimator for a torchtune recipe
# Make sure keys are defined in the same format 
sagemaker_tasks={
    "fine-tune":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_qlora.yaml",
            "tune_action":"fine-tune",
            "use_downloaded_model":"false",
            "tune_recipe":"lora_finetune_distributed" # check torchtune documentation or run "tune ls" to find all recipes available
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",  
    },
    "generate_inference_on_trained":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_gen_trained.yaml",
            "tune_action":"generate-trained",
            "use_downloaded_model":"true",
            #"prompt":json.dumps(prompt)
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",
    },
    "generate_inference_on_original":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_gen_orig.yaml",
            "tune_action":"generate-original",
            "use_downloaded_model":"true",
            #"prompt":json.dumps(prompt)
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",
    },
    "evaluate_trained_model":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_eval_trained.yaml",
            "tune_action":"run-eval",
            "use_downloaded_model":"true",
            "prompt":json.dumps(prompt)
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",
    }
}

for k,v in sagemaker_tasks.items():
    sagemaker_tasks[k]["hyperparameters"].update(hyperparam_common_values)

In [None]:
sagemaker_tasks

## 2. Fine Tune Tasks

Now that we have set everything up and defined the tasks, it is time to execute the fine-tuning job.

Before the job is executed by SageMaker, you can take a look at the torchtune recipe which we are using for the Fine-Tuning task. There you will see all configurations which are used in this task, including the `model`, `tokenizer`, `checkpointer`, `profiler`, `optimizer` and additional specifications for the training run.

In [None]:
#Display torchtune config .yaml file
!pygmentize ./config/config_l3.1_8b_qlora.yaml

In [None]:
"""  *** TASK for the Job. Select one of the below tasks: ***
  {fine-tune, generate_inference_on_trained,generate_inference_on_original,quantize_trained_model, 
   generate_inference_on_trained_quant,evaluate_trained_model} """
    
Task="fine-tune"

# Optionally print or override the task dictionary
#pprint(sagemaker_tasks[Task])

estimator=create_pytorch_estimator(**sagemaker_tasks[Task])

In [None]:
execute_task(estimator)

While the job is executed by SageMaker, you will see output logs being printed. Please continue to the next step once the training has finished successfully!

## 3.1 Generate Trained Model Inference

After the model is now fine-tuned, let's try to generate some inference on that trained model.

Again, before the job is executed by SageMaker, you can take a look at the torchtune recipe which we are using for the Infernece task. There you will see all configurations which are used in this task, including the `model`, `tokenizer`, `checkpointer`, `profiler`, `optimizer` and additional specifications for the inference run. 

If you wish to do so, please un-comment the code in the next cell. Otherwise, please continue to the next step to run the Inference.

In [None]:
#Display torchtune config .yaml file

#!pygmentize ./config/config_l3.1_8b_gen_trained.yaml

In [None]:
"""  *** TASK for the Job. Select one of the above below tasks: ***
  {fine-tune, generate_inference_on_trained,generate_inference_on_original,quantize_trained_model, 
   generate_inference_on_trained_quant,evaluate_trained_model} """

Task="generate_inference_on_trained" 

# You can overwrite any parameters in the SageMaker task as you see fit for your experimentation
prompt=r'{"dialogue":"Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure \r\nAmanda: I will bring you tomorrow :-)"}'

sagemaker_tasks[Task]['hyperparameters']['prompt']=json.dumps(prompt)

estimator=create_pytorch_estimator(**sagemaker_tasks[Task])

In [None]:
execute_task(estimator)

**Congratulations, you have successfully completed this workshop and learnt how to fine-tune Llama3 deep learning models on Amazon SageMaker with limited data and compute resources.**

Optionally, you can also run inference and evaluation on the original, non-fine-tuned Llama 3.1 model.

## 3.2 Generate Original Model Inference **[OPTIONAL]**

Again, before the job is executed by SageMaker, you can take a look at the torchtune recipe which we are using for the Inference task. There you will see all configurations which are used in this task, including the `model`, `tokenizer`, `checkpointer`, `profiler`, `optimizer` and additional specifications for the evaluation run. 

If you wish to do so, please un-comment the code in the next cell. Otherwise, please continue to the next step to run the Inference job.

In [None]:
#Display torchtune config .yaml file

#!pygmentize ./config/config_l3.1_8b_gen_original.yaml

In [None]:
"""  *** TASK for the Job. Select one of the above below tasks: ***
  {fine-tune, generate_inference_on_trained,generate_inference_on_original,quantize_trained_model, 
   generate_inference_on_trained_quant,evaluate_trained_model} """

Task="generate_inference_on_original" 

# You can overwrite any parameters in the SageMaker task as you see fit for your experimentation
prompt=r'{"dialogue":"Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure \r\nAmanda: I will bring you tomorrow :-)"}'

sagemaker_tasks[Task]['hyperparameters']['prompt']=json.dumps(prompt)

#pprint(sagemaker_tasks[Task])

estimator=create_pytorch_estimator(**sagemaker_tasks[Task])

In [None]:
execute_task(estimator)

## 4.1 Evaluate Trained Model **[OPTIONAL]**

After getting an intuitive overview about the performance of the fine-tuned model in the previous step, let's now evaluate them in all objectivity. Create an Evaluation Job for the fine-tuned model.

Again, before the job is executed by SageMaker, you can take a look at the torchtune recipe which we are using for the Evaluation task. There you will see all configurations which are used in this task, including the `model`, `tokenizer`, `checkpointer`, `profiler`, `optimizer` and additional specifications for the evaluation run. 

If you wish to do so, please un-comment the code in the next cell. Otherwise, please continue to the next step to run the Evaluation job.

In [None]:
#Display torchtune config .yaml file

#!pygmentize ./config/config_l3.1_8b_eval_trained.yaml

In [None]:
"""  *** TASK for the Job. Select one of the above below tasks: ***
  {fine-tune, generate_inference_on_trained,generate_inference_on_original,quantize_trained_model, 
   generate_inference_on_trained_quant,evaluate_trained_model} """

Task="evaluate_trained_model" 

sagemaker_tasks[Task]['hyperparameters']['tune_config_name']='config_l3.1_8b_eval_trained.yaml'

estimator=create_pytorch_estimator(**sagemaker_tasks[Task])

In [None]:
execute_task(estimator)

## 4.2 Evaluate Original Model **[OPTIONAL]**

Again, before the job is executed by SageMaker, you can take a look at the torchtune recipe which we are using for the Evaluation task. There you will see all configurations which are used in this task, including the `model`, `tokenizer`, `checkpointer`, `profiler`, `optimizer` and additional specifications for the evaluation run. 

If you wish to do so, please un-comment the code in the next cell. Otherwise, please continue to the next step to run the Evaluation job.

In [None]:
#Display torchtune config .yaml file

#!pygmentize ./config/config_l3.1_8b_eval_original.yaml

In [None]:
"""  *** TASK for the Job. Select one of the above below tasks: ***
  {fine-tune, generate_inference_on_trained,generate_inference_on_original,quantize_trained_model, 
   generate_inference_on_trained_quant,evaluate_trained_model} """

Task="evaluate_trained_model" 

sagemaker_tasks[Task]['hyperparameters']['tune_config_name']='config_l3.1_8b_eval_original.yaml'

estimator=create_pytorch_estimator(**sagemaker_tasks[Task])

In [None]:
execute_task(estimator)