# Train Bloom with HuggingFace Trainer + the SageMaker Model Parallelism Library with Sharded Data Parallelism

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

---

This notebook walks you through how to use the SMP Trainer as a drop-in replacement for Hugging Face Transformer's Trainer. This lets us enable sharded data parallelism with the SageMaker model parallelism (SMP) library to train a Bloom model. You'll learn how to train the model with sharded data parallelism on a text dataset.

The Bloom model was proposed by BigScience in the paper [BLOOM: A 176B-Parameter Open-Access Multilingual Language Model](https://arxiv.org/pdf/2211.05100.pdf). The original Bloom is a large transformer-based language model with 176 billion parameters. In this notebook, we will be experimenting with the 560 million parameter version. This notebook uses the [Hugging Face Transformers Bloom](https://bigscience.huggingface.co/blog/bloom) implementation with SageMaker model parallel integration.

Sharded data parallelism is a memory-saving distributed training technique that splits the state of a model (model parameters, gradients, and optimizer states) across GPUs in a data parallel group. There are two main benefits: one, you can fit larger models, which would otherwise run out of memory with standard data parallelism, on fewer GPUs; or two, you can increase the batch size using the freed-up GPU memory.
- To learn more about sharded data parallelism, see [Sharded Data Parallelism](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html).

This notebook requires the following prerequisites:
- `run_clm.py`: This is an entry point script, which is the example training script for the SageMaker Hugging Face estimator. This script is responsible for end-to-end training of the Bloom model.
- `requirements.txt`: This file lists additional Python library dependencies that SageMaker will automatically install. This needs to be in the same directory as your entry point script.
- `smp_trainer.py`: This file inherits from Hugging Face Transformer's Trainer and adds support for Sharded Data Parallelism through SageMaker Model Parallel.

**Note**: To run this example training job, you must be in `us-west-2`. The container image used is located in this region. If your AWS Region is different from `us-west-2`, you must make sure you change the region code throughout this notebook.

### Additional Resources
If you are a new user of Amazon SageMaker, you may find the following helpful to learn more about SMP and using SageMaker with PyTorch.

- To learn more about the SageMaker model parallelism library, see [Model Parallel Distributed Training with SageMaker Distributed](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html).

- To learn more about using the SageMaker Python SDK with PyTorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).

- To learn more about launching a training job in Amazon SageMaker with your own training image, see [Use Your Own Training Algorithms](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html).

### Adapt your entrypoint script to use SMP Trainer
To update your entrypoint script to use SMP Trainer instead of Hugging Face Trainer, all you need to do is update the Trainer import from Hugging Face to from SMP Trainer instead.  

Replace the following line:  
  `from transformers import Trainer`  
with the SMP Trainer version:  
  `from smp_trainer import SMPTrainer as Trainer`  

You can now follow the rest of this notebook for details on how to enable sharded data parallelism.

## Install and Upgrade Libraries

The SageMaker model parallelism library's tensor parallelism feature requires the SageMaker Python SDK and the SageMaker Experiments library. Run the following cell to install or upgrade the libraries.

**Note:** To finish applying the changes, you must restart the kernel.

In [None]:
# # run once, restart kernel, then comment out this cell
# # update sagemaker to the latest 2.x version
# ! pip3 install -qU pip
! pip3 install -qU "sagemaker>=2,<3"

# import IPython
# IPython.Application.instance().kernel.do_shutdown(True)

Import and check if the SageMaker Python SDK version is successfully set to the latest version

In [None]:
import sagemaker

print(sagemaker.__version__)

## Amazon SageMaker Initialization

Throughout this example, you'll use a training script of the Bloom model and a text dataset.

Run the following cell to import SageMaker modules and retrieve information of your current SageMaker work environment: your AWS account ID, the AWS Region you are using to run the notebook, and the ARN of your Amazon SageMaker execution role.

In [None]:
%%time
import os

from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
import boto3

role = (
    get_execution_role()
)  # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role:{role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account:{account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region:{region}")

sm_boto_client = boto3.client("sagemaker")
sagemaker_session = sagemaker.session.Session(boto_session=session)

# get default bucket
default_bucket = sagemaker_session.default_bucket()
print()
print("Default bucket for this session: ", default_bucket)

# You also need to specify an Amazon S3 bucket to store the output data such as training artifacts.
# The following cell sets up the default S3 bucket paired with the current SageMaker session. You can also modify this as needed.
s3_output_bucket = f"s3://{default_bucket}/output"
print(f"Your output data will be stored in: {s3_output_bucket}")

## Set Up Hyperparameters, Metric Definitions, and MPI Options
The following `hyperparameters` dictionary is to pass arguments to the training script (`run_clm.py`) and set the model parallel configuration when creating the training job.

Note that the `run_clm.py` file is currently modified to work with SageMaker. If you want to run your own script, you'll need to add the relevant lines as seen in `run_clm.py`. You can find them quickly by searching for `SageMaker Support`.

You can also add custom mpi flags. By default, we have `--mca btl_vader_single_copy_mechanism none` to remove unnecessary logs.

Next, we add a base metric definition to upload the training metrics for SageMaker Experiments. You can also add custom metric definitions.

In [None]:
save_steps = 60  # Set the interval for saving checkpoints
max_steps = 100  # Set the total number of steps you want to run

hyperparameters = {
    "model_name_or_path": "bigscience/bloom-560m",
    "output_dir": "/opt/ml/checkpoints",
    "overwrite_output_dir": "",
    "learning_rate": 0.0002,
    "do_train": True,
    "do_eval": True,
    "save_steps": save_steps,
    "max_steps": max_steps,
    "max_eval_samples": 50,
    "preprocessing_num_workers": 1,
    "gradient_accumulation_steps": 2,
    "eval_accumulation_steps": 2,
    "logging_steps": 1,
    "dataloader_drop_last": True,
}

## Specify a HuggingFace Dataset

In this step, you specify the dataset from Hugging Face that you want to train on. Here we use the `glue` dataset. Note that larger datasets will take longer to download and process.

In [None]:
# You can use any dataset available from Hugging Face
# Modify these parameters as needed
dataset_name = "glue"

if dataset_name == "glue":
    hyperparameters["dataset_name"] = "glue"
    hyperparameters["dataset_config_name"] = "sst2"
else:
    raise RuntimeError("Unknown HuggingFace dataset")

Set the model configuration below or define your own.

You can also specify different training parameters here such as sharded data parallel degree, batch size, and fp16 which will affect if your model can fit on your instance configuration.

For more information on these parameters and how to use them, please visit [SageMaker Distributed Training](https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-training.html).

Note: you may need to adjust these parameters such as `sdp_degree` and `per_device_train_batch_size` if you choose to train another size model.

In [None]:
model_config = "bloom-560m"

if model_config == "bloom-560m":
    # 560M parameters
    hyperparameters["per_device_train_batch_size"] = 4
    hyperparameters["per_device_eval_batch_size"] = 2
    sdp_degree = 2
    microbatches = 1
    fp16 = False
    hyperparameters["fp16"] = fp16
    prescaled_batch = True
    shard_optimizer_state = False
    pp_degree = 1
else:
    raise RuntimeError("Unknown model config")

## Specify Essential Parameters for a SageMaker Training Job

Next, you will use the [SageMaker Estimator API](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) to define a SageMaker training job, passing values through the following parameters, such as the training job name, the number of EC2 instances, the instance type, and the size of the volume attached to the instances.

* `instance_count`
* `instance_type`
* `volume_size`
* `base_job_name`

### Update the Type and Number of EC2 Instance to Use

The instance type and the number of instances you specify to the `instance_type` and `instance_count` parameters, respectively, will determine the total number of GPUs (world size).

$$ \text{(world size) = (the number of GPUs on a single instance)}\times\text{(the number of instance)}$$

In [None]:
# Set the instance_type here
instance_type = "ml.p4d.24xlarge"

# Set to the number of instances you want to use
# bloom-560m needs >= 1 p4d instances
instance_count = 1

# set to the number of GPUs on that instance
# p3d's and p4d's have 8 GPUs each
processes_per_host = 8

To look up the number of GPUs of different instance types, see [Amazon EC2 Instance Types](https://aws.amazon.com/ec2/instance-types/). Use the section **Accelerated Computing** to see general purpose GPU instances. Note that, for example, a given instance type `p4d.24xlarge` has a corresponding instance type `ml.p4d.24xlarge` in SageMaker.
For SageMaker supported `ml` instances and cost information, see [Amazon SageMaker Pricing](https://aws.amazon.com/sagemaker/pricing/). 

### Attach an EBS Volume to the Training Instance
The volume size you specify in `volume_size` must be larger than your input data size. In this example, the volume size is set to 500 GB.

In [None]:
volume_size = 500

### Specify a Base Job Name

In [None]:
SM_HP_MP_PARAMETERS = {
    "microbatches": microbatches,
    "optimize": "speed",
    "pipeline": "interleaved",
    "placement_strategy": "cluster",
    "partitions": pp_degree,
    "prescaled_batch": prescaled_batch,
    "shard_optimizer_state": shard_optimizer_state,
    "fp16": fp16,
    "sharded_data_parallel_degree": sdp_degree,  # Add this to activate sharded data parallelism
}

machine_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]

base_job_name = f'smp-trainer-{model_config}-{machine_str}-sdp{sdp_degree}-bs{hyperparameters["per_device_train_batch_size"]}'

In [None]:
mpioptions = "-x NCCL_DEBUG=WARN -x SMDEBUG_LOG_LEVEL=ERROR "
if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]:
    mpioptions += "-x FI_EFA_USE_DEVICE_RDMA=1 -x FI_PROVIDER=efa -x RDMAV_FORK_SAFE=1 "
if SM_HP_MP_PARAMETERS["partitions"] > 1:
    mpioptions += "-x SMP_ENABLE_CROSS_NODE_D2D=1 "
# Uncomment out the following line if you want to save the full model checkpoint
# instead of the partial model checkpoint.
# Setting the flag to anything will enable the save full model logic.
# mpioptions += "-x HF_TRAINER_SMP_SDP_SAVE_FULL_MODEL=1 "

metric_definitions = [
    {"Name": "base_metric", "Regex": "<><><><><><>"}
]  # Add your custom metric definitions

### Resume Training from a Previous Checkpoint

Here, you can choose to resume training from a previous checkpoint saved with HuggingFace Trainer.
Simply set `resume_from_checkpoint` to `True` and specify the bucket in which the checkpoint is stored. For convenience, we use the same bucket to load checkpoints and save output artifacts. You can also customize and set your own. You can also specify whether to load from a partial checkpoint or full checkpoint. Trainer saves both 

Note: The checkpoint path (`checkpoint_s3_uri`) is not unique per job.
You need to modify as needed for different runs.

In [None]:
resume_from_checkpoint = False
# Set `resume_from_full_checkpoint` to true if you want to load full ckpt instead of partial.
# Note: You need to uncomment the above
#   HF_TRAINER_SMP_SDP_SAVE_FULL_MODEL environment option.
resume_from_full_checkpoint = False


# We label our job with the model configuration and the number of nodes
job_name = f"{model_config}_nodes-{instance_count}"
# Here, we use the same bucket for both checkpoints and outputs
checkpoint_bucket = s3_output_bucket
# If you want to resume training, set checkpoint_s3_uri to the same checkpoint_s3_uri path as a previous job.
checkpoint_s3_uri = f"{checkpoint_bucket}/{job_name}/checkpoints"

# The previous checkpoint to load must have the same model config.
if resume_from_checkpoint:
    # The checkpoint step you want to resume training from.
    # Here, we set it to the first checkpoint saved, but you can set it to any.
    checkpoint_step = save_steps
    checkpoint_dir = f"/opt/ml/checkpoints/checkpoint-{checkpoint_step}"
    hyperparameters["resume_from_checkpoint"] = checkpoint_dir
    if resume_from_full_checkpoint:
        hyperparameters["load_full"] = True

### Create a SageMaker PyTorch Estimator

The following cell constructs a `PyTorch` estimator using the parameters defined above. To see how the SageMaker tensor parallelism modules and functions are applied to the script, see the `run_clm.py` file and the private preview documentation. We will be using `PyTorch 1.13.1` along with `Transformers 4.21.0`.

In [None]:
kwargs = {}

smp_estimator = PyTorch(
    entry_point="run_clm.py",
    source_dir=os.getcwd(),  # copies your current working directory to S3 for SageMaker
    role=role,
    instance_type=instance_type,
    volume_size=volume_size,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution={
        "mpi": {
            "enabled": True,
            "processes_per_host": processes_per_host,
            "custom_mpi_options": mpioptions,
        },
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "ddp": True,
                    # partitions is a required param in the current SM SDK so it needs to be passed,
                    # these two map to the same config
                    "partitions": SM_HP_MP_PARAMETERS["partitions"],
                    "microbatches": SM_HP_MP_PARAMETERS["microbatches"],
                    "shard_optimizer_state": SM_HP_MP_PARAMETERS["shard_optimizer_state"],
                    "prescaled_batch": SM_HP_MP_PARAMETERS["prescaled_batch"],
                    "fp16": True,
                    "optimize": SM_HP_MP_PARAMETERS["optimize"],
                    "auto_partition": True,
                    "default_partition": 0,
                    "sharded_data_parallel_degree": SM_HP_MP_PARAMETERS[
                        "sharded_data_parallel_degree"
                    ],  # Add this to activate sharded data parallelism
                    "sdp_reduce_bucket_size": int(5e8),  # Optional
                    "sdp_param_persistence_threshold": int(1e6),  # Optional
                    "sdp_max_live_parameters": int(1e9),  # Optional
                    "sdp_gradient_clipping": 1.0,  # Optional
                },
            }
        },
    },
    py_version="py39",
    output_path=s3_output_bucket,
    checkpoint_s3_uri=checkpoint_s3_uri,
    metric_definitions=metric_definitions,
    hyperparameters=hyperparameters,
    framework_version="1.13.1",
    debugger_hook_config=False,
    disable_profiler=True,
    base_job_name=base_job_name,
    **kwargs,
)

Finally, run the estimator to launch the SageMaker training job of the Bloom model.

In [None]:
smp_estimator.fit(
    logs=True,
)

# Accessing the Training Logs

You can access the training logs using [Amazon CloudWatch](https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/WhatIsCloudWatch.html). Make sure to look at the logs of **algo-1**, which is the main node whose output stream has the entire training job logs.

You can use CloudWatch to track SageMaker GPU and memory utilization during training and inference. To view the metrics and logs that SageMaker writes to CloudWatch, see **Processing Job, Training Job, Batch Transform Job, and Endpoint Instance Metrics** in [Monitor Amazon SageMaker with Amazon CloudWatch](https://docs.aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html).

If you are a new user of Amazon CloudWatch, see [Getting Started with Amazon CloudWatch](https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/GettingStarted.html).

For additional information about monitoring and analyzing Amazon SageMaker training jobs, see [Monitor and Analyze Training Jobs Using Metrics](https://docs.aws.amazon.com/sagemaker/latest/dg/training-metrics.html).

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/training|distributed_training|pytorch|model_parallel|bloom_smp_trainer|submit_smp_trainer.ipynb)
