# Train Llama2 model using SageMaker Distributed Data Parallel Library (SMDDP) and DeepSpeed

In this tutorial, we will show how to train or fine-tune the new [LLama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) model.  We will use DeepSpeed ZeRO stage 3, a sharded data parallelism technique that alleviates the memory bottleneck when training large models.  

In addition, we will utilize a the **SMDDP library**, a handy SageMaker feature which accelerates training by speeding up GPU communication between nodes.  We will use 2 p4d.24xlarge instances, which come with 8x NVIDIA A100 40GB GPUs. 

*Note: For the purpose of this example, we will use a dummy synthetic dataset to avoid dealing with an access token required to initialize a Llama2 tokenizer.  This example can be easily modified to supply your own dataset if you own a Llama2 access token*

Within `code/train.py` is the entry point for the training script where we initialize the SMDDP library

## Setting up


If you are going to use Sagemaker in a local environment, you need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.



In [None]:
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

import sagemaker
import boto3

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session.default_bucket()}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

Note that SageMaker by default uses the latest [AWS Deep Learning Container (DLC)](https://aws.amazon.com/machine-learning/containers/), so you can comment out the `ecr_image` variable if you don't want to use your own custom image built from a DLC. Also note that if using FSx when launching the SageMaker notebook instance, you will need to use the same `subnet` and `security_group_config`.  

In [None]:

ecr_image = "<ECR_IMAGE_URI>"
subnet_config = ["<SUBNET_CONFIG_ID>"]
security_group_config = ["<SECURITY_GROUP_CONFIG>"]

## Configuring Training Job

We will now set the hyperparameters and define the estimator object for our training job.  Since we are using DeepSpeed, we must provide a DeepSpeed config JSON file, which is located in the `code` folder.  We will  use the `PyTorch` estimator class and configure it to use the `torch_distributed` distribution, which will launch the training job using `torchrun`.  This launcher kicks off the training script as a distributed training job on SageMaker and is the recommended launcher for sharded data parallel training jobs.

In [None]:
hyperparameters={
  'model_id': 'meta-llama/Llama-2-7b-chat-hf',
  'gradient_checkpointing': True,
  'bf16': True,
  'optimizer': "adamw_torch",
  'per_device_train_batch_size': 1,
  'epochs': 1,
  'max_steps':50,
  'deepspeed_config':'dsconfig.json'
}

from sagemaker.pytorch import PyTorch
estimator = PyTorch(
  entry_point="train.py",
  base_job_name="llama2-training-smddp",
  role=role,
  image_uri=ecr_image,
  source_dir="code",
  instance_count=2,
  instance_type="ml.p4d.24xlarge",
  sagemaker_session=sagemaker_session,
  subnets=subnet_config,
  hyperparameters=hyperparameters,
  security_group_ids=security_group_config,
  keep_alive_period_in_seconds=600,
  distribution={"torch_distributed": {"enabled": True}},
  debugger_hook_config=False
)

## Executing the traning job 
We can now start our training job, with the `.fit()` method.

In [None]:
# starting the train job with our uploaded datasets as input
estimator.fit(wait=True)

### Terminate the warm pool cluster if no longer needed

In [None]:
sagemaker_session.update_training_job(estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds":0})