# Pre-train Llama 3 models with torchtitan on Amazon SageMaker

[torchtitan](https://github.com/pytorch/torchtitan/tree/main) is a reference architecture for large-scale LLM training using native PyTorch. It aims to showcase PyTorch’s latest distributed training features in a clean, minimal code base. The library is designed to be simple to understand, use, and extend for different training purposes, with minimal changes required to the model code when applying various parallel processing techniques.


In this notebook, we showcase how the torchtitan library accelerates and simplifies the pre-training of Meta Llama 3-like model architectures. We showcase the key features and capabilities of torchtitan such as FSDP2, torch.compile integration, and FP8 support that optimize the training efficiency. We pre-train a Meta Llama 3 8B model architecture using torchtitan on Amazon SageMaker on p5.48xlarge instances, each equipped with 8 Nvidia H100 GPUs

### Prerequisites




You need to run the Notebook from [**Step 1-Build your Custom Container Jupyter Notebook**](https://github.com/aws-samples/sagemaker-distributed-training-workshop/blob/main/14_torchtitan/Step%201%20-Build%20Custom%20Container.ipynb) to build the torchtitan custom container for training your model. Optionally,  to use your custom dataset, you can follow the instructions in the [**Step 2: Prepare your Dataset Jupyter Notebook**](https://github.com/aws-samples/sagemaker-distributed-training-workshop/blob/main/14_torchtitan/(Optional)%20Step%202%20-Prepare%20Dataset.ipynb) where we showcase how to download a sample dataset(s) to S3.

### Amazon SageMaker Initialization


Upgrade SageMaker SDK to the latest version.
NOTE: This step might require a kernel restart.

In [None]:
!pip install sagemaker --upgrade 

Run the following cells to import SageMaker modules and retrieve information of your current SageMaker environment, such as your AWS account ID, the AWS Region, and the ARN of your Amazon SageMaker Execution Role. 

In [None]:
%%time
import os

import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch

role = (
    get_execution_role()
)  # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: {account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region: {region}")

sm_boto_client = boto3.client("sagemaker")
sagemaker_session = sagemaker.session.Session(boto_session=session)

# get default bucket
default_bucket = sagemaker_session.default_bucket()
print("Default bucket for this session: ", default_bucket)

#set default path for data channels
data_channels=None

### Clone the torchtitan repository

In [None]:
!git clone https://github.com/pytorch/torchtitan.git

Next, we create a source directory that will contain the the training source code and dependencies required to execute the training. We also move the required dependencies from the torchtitan directory to the source directory. You can refer to the documentation to learn about the [Default Paths for Training Storage Locations](https://docs.aws.amazon.com/sagemaker/latest/dg/model-train-storage.html#model-train-storage-env-var-summary)

In [None]:
!mkdir torchtitan/src
!mv  torchtitan/torchtitan/ torchtitan/train_configs/ torchtitan/train.py  torchtitan/src/

### Downloading a tokenizer

We will need the Llama-3 tokenizer that will be used to pre-process the dataset to generate tokens. Provide your Hugging Face token in the command below, for **--hf_token**

In [None]:
!mkdir torchtitan/src/llama-3-tokenizer

In [None]:
!python torchtitan/src/torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --local_dir torchtitan/src/llama-3-tokenizer  --tokenizer_path "original" --hf_token=...--hf_token=""


### Update the LLama-3-8b toml configuration file 

The options for training models with torchtitan are easily configured via the TOML files. In this tutorial,  we will be working with the Llama3 8B TOML file located in torchtitan/src/train_configs/ to configure the training options. We will need to modify the sections below:

1. Enable Tensorboard profiling:


In [None]:
[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "/opt/ml/output/tensorboard"

**2.Enable torch.compile.**  torch.compile is a key feature in PyTorch that significantly boosts model performance with minimal code changes. Through its just-in-time (JIT) compilation, it analyzes and transforms PyTorch code into more efficient kernels. TorchTitan supports torch.compile, which delivers substantial speedups, especially for large models and complex architectures, by leveraging techniques like operator fusion, memory planning, and automatic kernel selection. 


In [None]:
compile = true

**3. Enable FP8 linear operations**  torchtitan provides support for FP8 (8-bit floating point) computation that  reduces memory footprint and enhances performance in large language model training. FP8 has two formats, E4M3 and E5M2, each optimized for different aspects of training. E4M3 offers higher precision, making it ideal for forward propagation, while E5M2, with its larger dynamic range, is better suited for backpropagation. 

In [None]:
enable_float8_linear = true

**4. Enable FP8 all-gather.**  This feature enables efficient communication of FP8 tensors across multiple GPUs, significantly reducing network bandwidth compared to bfloat16 all-gather operations. FP8 all-gather performs float8 casting before the all-gather operation, reducing the message size. Key to its efficiency is the combined AMAX( absolute maximum) AllReduce, which calculates AMAX for all float8 parameters in a single operation after the optimizer step, avoiding multiple small all-reduces. Similar to FP8 support, this also has no impact on model accuracy, 

In [None]:
enable_fsdp_float8_all_gather= true
precompute_float8_dynamic_scale_for_fsdp = true

Below is the full updated configuration with the above optimisations

In [None]:
%%writefile torchtitan/src/train_configs/llama3_8b_optimisations.toml
# torchtitan Config.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "/opt/ml/output/tensorboard"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./llama-3-tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather= true
precompute_float8_dynamic_scale_for_fsdp = true


### Configure Tensorboard for estimator function

Next, to monitor our training progress, we'll set up TensorBoard output. This will allow us to visualize our training metrics in real-time, providing valuable insights into how our model is learning.

In [None]:
from sagemaker.debugger import TensorBoardOutputConfig

LOG_DIR="/opt/ml/output/tensorboard"
tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=f"s3://sagemaker-{region}-{account}/tensorboard/",
    container_local_output_path=LOG_DIR
)


### Create the SageMaker estimator function for the training

Before launching the training, we need to modify the torchtitan/src/train.py file to be able to parse the TOML configuration file as a hyperparameter  in our training estimator function. Update the main function in torchtitan/src/train.py to below:

In [None]:
"""
import argparse
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", type=str, default="", help="Model config file")
    
    args = parser.parse_args()
    # Load the configuration from the downloaded file
    config = JobConfig()
    config.parse_args(["--job.config_file", args.config_file])
    main(config)
    
    torch.distributed.destroy_process_group()
"""

Next, we need to update the **image_uri** in the estimator function to point to the  custom training image generated from [**Step 1-Build your Custom Container Jupyter Notebook**](https://github.com/aws-samples/sagemaker-distributed-training-workshop/blob/main/14_torchtitan/Step%201%20-Build%20Custom%20Container.ipynb) We also provide the path to the TOML configuration file generated above as a hyperparameter.

In [None]:
import os

from time import gmtime, strftime

hyperparameters = {
    "config_file": "train_configs/llama3_8b_optimisations.toml"
}
env = {}
env['HF_HUB_ETAG_TIMEOUT'] = '500'

timestamp = strftime("%Y-%m-%d-%H-%M", gmtime())


smp_estimator = PyTorch(
    base_job_name=f'llama3-8b-compile-fp8-fp8-comms-{timestamp}',
    entry_point="train.py",
    image_uri="<provide-path-to-image-uri>",
    source_dir=os.path.join(os.getcwd(), "torchtitan/src"),
    role=role,
    instance_type="ml.p5.48xlarge",
    volume_size=800,
    instance_count=4,
    environment=env,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    tensorboard_output_config=tensorboard_output_config,
    distribution={
    'torch_distributed': {'enabled': True},
    },
    
)

Then we finally, launch the training

In [None]:
smp_estimator.fit(inputs=data_channels)


## (Optional) Training with your own dataset

In the previous training run, we used the  allenai/c4 which is the default dataset pre-configured for the torchtitan samples and is streamed directly from Hugging Face hub during training. However, if you have your  dataset residing in S3, you need to configure the input data channels below to point to your dataset. We have provided a sample in the  [**Step 2: Prepare your Dataset Jupyter Notebook**](https://github.com/aws-samples/sagemaker-distributed-training-workshop/blob/main/14_torchtitan/(Optional)%20Step%202%20-Prepare%20Dataset.ipynb)  showcasing how you can download the  the allenai/c4 dataset to S3 to simulate a dataset residing in S3.

First, we need to add the utility function below to the torchtitan/src/torchtitan/datasets/hf_datasets.py to load our dataset

We also need to add our custom dataset to the supported datasets configuration in torchtitan/src/torchtitan/datasets/hf_datasets.py and provide the path to "/opt/ml/input/data/train/" where the data channel directory is written to.

Finally, we can add a condition to handle our dataset in torchtitan/src/torchtitan/datasets/hf_datasets.py in the "def __init__()" function in the "class HuggingFaceDataset(IterableDataset, Stateful):"

**Important** Remember to upload the dataset entry section in your TOML configuration file to point to the name of your custom dataset e.g for the example, we will set this to dataset="c4_custom" to correspond to the above steps

Next, we set up the data channels for SageMaker training by creating TrainingInput objects from the provided S3 bucket paths for the training dataset

In [None]:
training_dataset_location = "<path-to-S3-dataset>"

s3_train_bucket = training_dataset_location

if s3_train_bucket != None:
    train = sagemaker.inputs.TrainingInput(s3_train_bucket, distribution="FullyReplicated", s3_data_type="S3Prefix")
    data_channels = {"train": train}

We can then launch the training as shown with the estimator function in the previous steps

In [None]:
import os

from time import gmtime, strftime

hyperparameters = {
    "config_file": "train_configs/llama3_8b_optimisations.toml"
}
env = {}
env['HF_HUB_ETAG_TIMEOUT'] = '500'

timestamp = strftime("%Y-%m-%d-%H-%M", gmtime())


smp_estimator = PyTorch(
    base_job_name=f'llama3-8b-compile-fp8-fp8-comms-{timestamp}',
    entry_point="train.py",
    image_uri="<provide-path-to-image-uri>",
    source_dir=os.path.join(os.getcwd(), "torchtitan/src"),
    role=role,
    instance_type="ml.p5.48xlarge",
    volume_size=800,
    instance_count=4,
    environment=env,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    tensorboard_output_config=tensorboard_output_config,
    distribution={
    'torch_distributed': {'enabled': True},
    },
    
)

In [None]:
smp_estimator.fit(inputs=data_channels)

### Perfomance Comparison with TensorBoard

To effectively evaluate the performance speedup from the optimization techniques, consider the following approach:

- Create a baseline training job without the optimizations. 
- Run subsequent jobs, adding each optimization step wise - starting with torch.compile, then FP8, and lastly FP8 all-gather.
- Compare the throughput(tokens per second) of each job to assess the impact of the optimizations.

You can then visualize the results on [Tensorboard](https://docs.aws.amazon.com/sagemaker/latest/dg/tensorboard-on-sagemaker.html) to compare the performance and corresponding loss curves. 

Package versions used in this tutorial:
torchtitan hash commit : ac90c36e39c6274f9beaf76922627665b6553905
torch==2.5.0.dev20240906+cu121
torchao==0.6.0.dev20240907+cu121