# Train Falcon model using SageMaker Distributed Data Parallel Library (SMDDP) and PyTorch Fully Sharded Data Parallelism (FSDP)

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

---

In this tutorial, we will show how to train or fine-tune [Falcon-7B-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) on the [GLUE/SST2](https://huggingface.co/datasets/glue/viewer/sst2/train) dataset.  We will use 2 p4d.24xlarge instances, which come with 8 NVIDIA A100 40GB GPUs along with the PyTorch Fully Sharded Data Parallelism (FSDP) technique to efficiently train this large model with limited GPU memory.  

To accelerate training speed, we will also use the **SageMaker Distributed Data Parallel Library (SMDDP)** which speeds up GPU communication across P4d instances during sharded data parallel training.  

## Files

* `scripts/train.py` - The entry point for the training script where we initialize the SMDDP library.
* `scripts/utils.py` - Helper script for defining dataloaders
* `scripts/requirements.txt` - List of dependencies required for this example to train on SageMaker

*Note: The SMDDP library for accelerated sharded data parallel training is compatible with deep learning containers from PyTorch 2.0 onwards.  Ensure you are using PyTorch >=2.0 for this example.*

### How optimized GPU communication is enabled with SMDDP in FSDP
Enabling the SMDDP library in an existing FSDP training script is seamless.  As shown in `train.py`, the only code modifications required are:
* Importing the library: `import smdistributed.dataparallel.torch.torch_smddp`
* Creating the process group with `"smddp"` backend: `torch.distributed.init_process_group("smddp")`

## 1. Getting started

First, we'll install some dependencies in our current environment

In [None]:
!pip install "transformers" "datasets[s3]" "sagemaker" "boto3" --upgrade --quiet

In [None]:
!pip install -r scripts/requirements.txt

If you are going to use Sagemaker in a local environment, you need access to an IAM Role with the required permissions for Sagemaker. You can find more about it [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html).



In [None]:
import sagemaker, boto3

from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]


sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sagemaker_session is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sagemaker_session.default_bucket()

sagemaker_session = sagemaker.Session()
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sagemaker_session.default_bucket()}")
print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

## 2. Load and prepare the dataset

As the base dataset, we will use the [GLUE/SST2](https://huggingface.co/datasets/glue/viewer/sst2/train) dataset, but before training the model, we need to preprocess the data. We will create chunks of `2048` tokens ([model max length](https://huggingface.co/EleutherAI/gpt-neox-20b)) to avoid unnecessary padding and computing. 

The first step is to load our dataset from Hugging Face.

In [None]:
model_id = "tiiuae/falcon-7b"
dataset_name = "glue"
dataset_config = "sst2"

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load dataset from huggingface.co
dataset = load_dataset(dataset_name, dataset_config)

dataset = dataset.shuffle(42)

#### Split dataset into Train and Validation.

In [None]:
if "validation" not in dataset.keys():
    dataset["validation"] = load_dataset(dataset_name, split="train[:5%]")
    dataset["train"] = load_dataset(dataset_name, split="train[5%:]")

The last step of the data preparation is to tokenize and chunk our dataset. We convert our inputs (text) to token IDs by tokenizing, which the model can understand. Additionally, we concatenate our dataset samples into chunks of `2048` to avoid unnecessary padding.

In [None]:
from itertools import chain
from functools import partial


def group_texts(examples, block_size=2048):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


column_names = dataset["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

lm_dataset = dataset.map(
    lambda sample: tokenizer(sample[text_column_name]),
    batched=True,
    remove_columns=list(column_names),
    desc="Running tokenizer on dataset",
).map(
    partial(group_texts, block_size=2048),
    batched=True,
)

We will start by saving the tokenized data locally .

In [None]:
# save data locally
training_input_path = f"processed/data/"
lm_dataset.save_to_disk(training_input_path)

print(f"Saved data to: {training_input_path}")

## 3. Train the Falcon model using FSDP and SMDDP on Amazon SageMaker

We will begin by uploading the tokenized data to S3 which will be uploaded to the training cluster during training.

After we process the datasets we are going to use the new [FileSystem integration](https://huggingface.co/docs/datasets/filesystems) to upload our dataset to S3. We are using the `sagemaker_session.default_bucket()`, adjust this if you want to store the dataset in a different S3 bucket. We will use the S3 path later in our training script.

In [None]:
training_input_path = f"s3://{sagemaker_session.default_bucket()}/processed/data/"
print(f"Saving training dataset to: {training_input_path}")
lm_dataset.save_to_disk(training_input_path)

print(f"uploaded data to: {training_input_path}")

As mentioned in the beginning, we will use Amazon SageMaker and PyTorch FSDP to train our model. Amazon SageMaker makes it easy to create a multi-node cluster to train our model in a distributed manner. The `sagemaker` python SDK supports running training jobs using `torchrun`, to distribute the script across multiple nodes and GPUs. 

To use `torchrun` to execute our scripts, we only have to define the `distribution` parameter in our Estimator and set it to `"torch_distributed": {"enabled": True}`.

In our example, we will use full sharding and use `FalconDecoderLayer` in the auto-wrap policy. If you run this example and change the model, make sure to also adjust the transformer layer policy in `scripts/train.py` as this is a model dependent configuration.

Within the entry point of the training script, we also import SMDDP and use it as the backend for the process group. This enables faster communication between P4d instances than the open source Nvidia Collective Communications Library (NCCL) and ultimately speeds up training.  

To create a sagemaker training job, we create a `PyTorch` Estimator and provide all our information. SageMaker takes care of starting and managing all the required EC2 instances for us, uploads the provided scripts and downloads the data from our S3 bucket into the container at `/opt/ml/input/data`. Then, it starts the training job.

Note that SageMaker by default uses the latest [AWS Deep Learning Container (DLC)](https://aws.amazon.com/machine-learning/containers/), so you can comment out the `ecr_image` variable if you don't want to use your own custom image built from a DLC. Also note that if using FSx when launching the SageMaker notebook instance, you will need to use the same `subnet` and `security_group_config`.  

In [None]:
import time

job_name = f'huggingface-fsdp-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'
use_ecr_image = False
use_fsx = False
kwargs = {}

if use_ecr_image:
    ecr_image = "<ECR_IMAGE_URI>"
    kwargs["image_uri"] = ecr_image

if use_fsx:
    subnet_config = ["<SUBNET_CONFIG_ID>"]
    security_group_config = ["<SECURITY_GROUP_CONFIG>"]
    kwargs["subnets"] = subnet_config
    kwargs["security_group_ids"] = security_group_config

# hyperparameters, which are passed into the training job
hyperparameters = {
    "model_id": model_id,  # model id from huggingface.co/models
    "dataset_path": "/opt/ml/input/data/train",  # path where sagemaker will save training dataset
    "valid_path": "/opt/ml/input/data/valid",
    "gradient_checkpointing": True,  # enable gradient checkpointing
    "bf16": True,  # enable mixed precision training
    "optimizer": "adamw_torch",  # optimizer
    "per_device_train_batch_size": 1,  # batch size per device during training
    "epochs": 1,  # number of epochs to train
    "fsdp": '"full_shard auto_wrap"',  # fully sharded data parallelism
    "cache_dir": "/opt/ml/sagemaker/warmpoolcache",  # change this to /tmp if not using warmpools
    "max_steps": 30,
}

# estimator
estimator = PyTorch(
    entry_point="train.py",
    max_run=1800,
    job_name=job_name,
    role=role,
    framework_version="2.0.1",
    py_version="py310",
    source_dir="./scripts",
    instance_count=2,
    instance_type="ml.p4d.24xlarge",
    sagemaker_session=sagemaker_session,
    disable_output_compression=True,
    distribution={"torch_distributed": {"enabled": True}},
    keep_alive_period_in_seconds=600,
    hyperparameters=hyperparameters,
    **kwargs,
)

We can now start our training job, with the `.fit()` method passing our S3 path to the training script.

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {"train": training_input_path}

# starting the train job with our uploaded datasets as input
estimator.fit(data, wait=True)

## 4. Expected Output

After training begins, you should see output similar to below: 

```4%|▍         | 1/25 [00:08<03:29,  8.72s/it]
8%|▊         | 2/25 [00:10<01:41,  4.39s/it]
12%|█▏        | 3/25 [00:11<01:06,  3.01s/it]
16%|█▌        | 4/25 [00:12<00:49,  2.35s/it]
20%|██        | 5/25 [00:14<00:39,  1.99s/it]
24%|██▍       | 6/25 [00:15<00:33,  1.77s/it]
28%|██▊       | 7/25 [00:16<00:29,  1.64s/it]
32%|███▏      | 8/25 [00:18<00:26,  1.55s/it]
36%|███▌      | 9/25 [00:19<00:23,  1.48s/it]
40%|████      | 10/25 [00:20<00:21,  1.44s/it]
44%|████▍     | 11/25 [00:22<00:19,  1.41s/it]
48%|████▊     | 12/25 [00:23<00:18,  1.40s/it]
52%|█████▏    | 13/25 [00:24<00:16,  1.38s/it]
56%|█████▌    | 14/25 [00:26<00:15,  1.37s/it]
60%|██████    | 15/25 [00:27<00:13,  1.37s/it]
64%|██████▍   | 16/25 [00:29<00:12,  1.36s/it]
68%|██████▊   | 17/25 [00:30<00:10,  1.36s/it]
72%|███████▏  | 18/25 [00:31<00:09,  1.35s/it]
76%|███████▌  | 19/25 [00:33<00:08,  1.35s/it]
80%|████████  | 20/25 [00:34<00:06,  1.35s/it]
84%|████████▍ | 21/25 [00:35<00:05,  1.35s/it]
88%|████████▊ | 22/25 [00:37<00:04,  1.35s/it]
92%|█████████▏| 23/25 [00:38<00:02,  1.35s/it]
96%|█████████▌| 24/25 [00:39<00:01,  1.35s/it]
100%|██████████| 25/25 [00:41<00:00,  1.35s/it]
100%|██████████| 25/25 [00:41<00:00,  1.65s/it]
******epoch=0: train_ppl=tensor(43260.7148, device='cuda:0') train_loss=tensor(10.6750, device='cuda:0')******
0it [00:00, ?it/s]
0it [00:00, ?it/s]
*******epoch=0: eval_ppl=tensor(nan, device='cuda:0') eval_loss=tensor(nan, device='cuda:0')*******
Training done!`

## 5. Terminate the warm pool cluster if no longer needed

You can terminate the warm pool cluster once finished experimenting to reduce billed time.

In [None]:
sagemaker_session.update_training_job(
    estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds": 0}
)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|distributed_training|pytorch|data_parallel|fully_sharded_data_parallel|falcon|smddp_fsdp_example.ipynb)