# Train Llama2 model using SageMaker Distributed Data Parallel Library (SMDDP) and DeepSpeed

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

---

In this tutorial, we will show how to train or fine-tune the new [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) model.  We will use [DeepSpeed](https://www.deepspeed.ai/training/) ZeRO stage 3, a sharded data parallelism technique.  Using DeepSpeed will allow us to reap the benefits of data parallelism and efficiently train over a vast datest, while dealing with limited GPU memory.  

In addition, we will utilize the **SMDDP library**, a handy SageMaker feature which accelerates training by speeding up GPU communication between p4d/p4de instance types.  The SMDDP Library is compatible with ml.p4d.24xlarge and ml.p4de.24xlarge instances.  For this example, we will use 2 ml.p4d.24xlarge instances, which come with 8 NVIDIA A100 40GB GPUs. 

*Note 1: For the purpose of this example, we will use a dummy synthetic dataset to avoid dealing with an access token required to initialize a Llama2 tokenizer.  This example can be easily modified to supply your own dataset if you own a Llama2 access token.*

*Note 2: The SMDDP library for accelerated sharded data parallel training is compatible with deep learning containers from PyTorch 2.0 onwards.  Ensure you are using PyTorch >=2.0 for this example.*


## Files ##

All training and helper scripts are stored in the `code/` folder:
* `dsconfig.json` - DeepSpeed config file 
* `requirements.txt` - Dependencies for this example that will be installed on container when training job is launched.
* `train.py` - Entry point training script
* `utils.py` - Defines dummy dataset and constructs dataloaders for the training job

### How optimized GPU communication is enabled with SMDDP in DeepSpeeed
Enabling the SMDDP library in an existing DeepSpeed training script is seamless.  As shown in `train.py`, the only code modifications required are:
* Importing the library: `import smdistributed.dataparallel.torch.torch_smddp`
* Creating the process group with `"smddp"` backend: `deepspeed.init_distributed(dist_backend="smddp")`

## Setting up


If you are going to use Sagemaker in a local environment, you need access to an IAM Role with the required permissions for Sagemaker. You can find more about it [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html).



In [None]:
import os

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

import sagemaker
import boto3

sagemaker_session = sagemaker.Session()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session.default_bucket()}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

Note that SageMaker by default uses the latest [AWS Deep Learning Container (DLC)](https://aws.amazon.com/machine-learning/containers/), but if you want to use your own DLC, you can set the `use_ecr_image` flag to `True` and set the `ecr_image` variable. Also note that if using FSx when launching the SageMaker notebook instance, you will need to use the same `subnet` and `security_group_config`.  

In [None]:
use_ecr_image = False
use_fsx = False
kwargs = {}

if use_ecr_image:
    ecr_image = "<ECR_IMAGE_URI>"
    kwargs["image_uri"] = ecr_image

if use_fsx:
    subnet_config = ["<SUBNET_CONFIG_ID>"]
    security_group_config = ["<SECURITY_GROUP_CONFIG>"]
    kwargs["subnets"] = subnet_config
    kwargs["security_group_ids"] = security_group_config

## Configuring Training Job

We will now set the hyperparameters and define the estimator object for our training job.  Since we are using DeepSpeed, we must provide a DeepSpeed config JSON file, which is located in the `code/` folder. 

 We will  use the `PyTorch` estimator class and configure it to use the `torch_distributed` distribution, which will launch a the training job using `torchrun`.  This is a popular launcher for PyTorch-based distributed training jobs

In [None]:
hyperparameters = {
    "model_id": "meta-llama/Llama-2-7b-chat-hf",
    "gradient_checkpointing": True,
    "bf16": True,
    "optimizer": "adamw_torch",
    "per_device_train_batch_size": 1,
    "epochs": 1,
    "max_steps": 50,
    "deepspeed_config": "dsconfig.json",
}

from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="train.py",
    max_run=1800,
    job_name="llama2-training-smddp",
    role=role,
    source_dir="./code",
    framework_version="2.0.1",
    py_version="py310",
    instance_count=2,
    instance_type="ml.p4d.24xlarge",
    sagemaker_session=sagemaker_session,
    disable_output_compression=True,
    hyperparameters=hyperparameters,
    keep_alive_period_in_seconds=600,
    distribution={"torch_distributed": {"enabled": True}},
    debugger_hook_config=False,
    **kwargs,
)

## Executing the traning job 
We can now start our training job, with the `.fit()` method.

In [None]:
# starting the train job with our uploaded datasets as input
estimator.fit(wait=True)

## Expected Output
You should see output similar to the following in the SageMaker job logs after initialization and once training begins:

```Processing training batch 0
Processing training batch 1
******epoch=0: train_ppl=tensor(71973.6484, device='cuda:0') train_loss=tensor(11.1841, device='cuda:0')******
Performing validation on training batch 1
Performing validation on training batch 1
*******epoch=0: eval_ppl=tensor(70934.4062, device='cuda:0') eval_loss=tensor(11.1695, device='cuda:0')*******
Training done!`



*Note: If DeepSpeed 0.9.2 pip installation fails, you may need to first install `Pydantic==1.10.13` in your docker image*

## Terminate the warm pool cluster if no longer needed

Once finished experimenting, you can terminate the warm pool cluster to reduce billed time

In [None]:
sagemaker_session.update_training_job(
    estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds": 0}
)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|distributed_training|pytorch|data_parallel|deepspeed|llama2|smddp_deepspeed_example.ipynb)