# Pre-train Llama-3 8B model using FSDP2 with torchtitan on Amazon SageMaker

In this notebook, you will learn how to accelerate distributed training of the Llama-3 models using the torchtitan library on SageMaker training.

### Prerequisites




You need to run the Notebook from **Step 1-Build your Custom Container Jupyter Notebook** to build the torchtitan custom container for training your model and if you want to use your custom dataset, you can follow the instructions in the **Step 2: Prepare your Dataset Jupyter Notebook** to download your dataset(s) to s3.

### Amazon SageMaker Initialization


Run the following cell to import SageMaker modules and retrieve information of your current SageMaker work environment, such as your AWS account ID, the AWS Region, and the ARN of your Amazon SageMaker execution role. Upgrade SageMaker SDK to the latest version.

NOTE: This step might require a kernel restart.

In [None]:
%pip install --upgrade "sagemaker>=2.224"
%pip install sagemaker-experiments

In [None]:
%%time
import os

import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch

role = (
    get_execution_role()
)  # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: {account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region: {region}")

sm_boto_client = boto3.client("sagemaker")
sagemaker_session = sagemaker.session.Session(boto_session=session)

# get default bucket
default_bucket = sagemaker_session.default_bucket()
print("Default bucket for this session: ", default_bucket)

#set default path for data channels
data_channels=None

### Clone the torchtitan repository

In [None]:
!git clone https://github.com/pytorch/torchtitan.git

Next we create a source directory that will contain the the training source code dependencies and files required to execute the training. We also move the required dependencies from the torchtitan directory to our source direcroty.

In [None]:
!mkdir torchtitan/src
!cp -r torchtitan/torchtitan/ torchtitan/train_configs/ torchtitan/train.py  torchtitan/src/

In [None]:
!pwd

In [None]:
!yes | rm -r torchtitan/torchtitan/ torchtitan/train_configs/ torchtitan/train.py 


In [None]:
!cd torchtitan/src/train_configs/

### Downloading a tokenizer 

We will need the Llama-3 tokenizer that will be used to pre-process the dataset to generate tokens. Update the command below with your Hugging Face token

In [None]:
!mkdir torchtitan/src/llama-3-tokenizer

In [None]:
!python torchtitan/src/torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --local_dir torchtitan/src/llama-3-tokenizer  --tokenizer_path "original" --hf_token=""


### Update the LLama-3 8B toml configuration file 

The options for training models with torchtitan are easily configured via the toml files. In this tutorial we will be working with the Llama-3.toml file located in torchtitan/src/train_configs/ to configure the training options. We will need to modify the sections below:

1. Enable Tensorboard profiling:


In [None]:
[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "/opt/ml/output/tensorboard"

2.Enable torch.compile


In [None]:
compile = true

3. Enable fp8

In [None]:
enable_float8_linear = true
enable_fsdp_float8_all_gather = true

In [None]:
4. Enable fp8 all-gather

In [None]:
enable_fsdp_float8_all_gather= true
precompute_float8_dynamic_scale_for_fsdp = true

Below is the full updated configuration with the above optimisations

In [None]:
%%writefile torchtitan/src/train_configs/llama3_8b_optimisations.toml
# torchtitan Config.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "/opt/ml/output/tensorboard"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./llama-3-tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather= true
precompute_float8_dynamic_scale_for_fsdp = true


### Configure Tensorboard for estimator function

In [None]:
from sagemaker.debugger import TensorBoardOutputConfig

LOG_DIR="/opt/ml/output/tensorboard"
tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=f"s3://sagemaker-{region}-{account}/tensorboard/",
    container_local_output_path=LOG_DIR
)


### (Optional) Configure path to the training dataset

We are going to use the default dataset c4 that is pre-configured for the torchtitan dataset. However, if you have your own dataset residing in s3 you need to configure the input data channels below to point to your dataset. We have provided a sample Jupyter Notebook in Step 2 to enable you to download c4 dataset to s3 to guide you how to use your own dataset

Next, we set up the data channels for SageMaker training by creating TrainingInput objects from the provided S3 bucket paths for the training dataset

In [None]:
training_dataset_location = "path to s3 dataset from the second Notebook"

s3_train_bucket = training_dataset_location

if s3_train_bucket != None:
    train = sagemaker.inputs.TrainingInput(s3_train_bucket, distribution="FullyReplicated", s3_data_type="S3Prefix")
    data_channels = {"train": train}

You will also need to add the utility function below to the torchtitan/src/torchtitan/datasets/hf_datasets.py to load your dataset

Lastly, in your configuration, you will need to update the dataset entry in the torchtitan/src/torchtitan/datasets/hf_datasets.py file to include your custom dataset e.g in this case c4_custom

### Create the SageMaker estimator function for the training

In [None]:
!pwd

In [None]:
import os

from time import gmtime, strftime

hyperparameters = {
    "config_file": "train_configs/llama3_8b_optimisations.toml"
}
env = {}
env['HF_HUB_ETAG_TIMEOUT'] = '500'

timestamp = strftime("%Y-%m-%d-%H-%M", gmtime())


smp_estimator = PyTorch(
    base_job_name=f'llama3-8b-compile-fp8-fp8-comms-{timestamp}',
    entry_point="train.py",
    image_uri="<path to image uri>",
    source_dir=os.path.join(os.getcwd(), "torchtitan/src"),
    role=role,
    instance_type="ml.p5.48xlarge",
    volume_size=800,
    instance_count=4,
    environment=env,
    hyperparameters=hyperparameters,
    use_spot_instances = False,
    keep_alive_period_in_seconds=3600,
    sagemaker_session=sagemaker_session,
    tensorboard_output_config=tensorboard_output_config,
    distribution={
    'torch_distributed': {'enabled': True},
    },
    
)

Then we finally, launch the training

In [None]:
smp_estimator.fit(inputs=data_channels)


### Perfomance Comparison with TensorBoard

To compare the various optimisations, you can start with a baseline training job and apply the optimizations incrementally in subsequent runs. You can visualise the performance speedup and loss curves through [Tensorboard](https://docs.aws.amazon.com/sagemaker/latest/dg/tensorboard-on-sagemaker.html)