Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ with the ``pytorchddp`` option as the distribution strategy.
.. note::

This PyTorch DDP support is available
in the SageMaker PyTorch Deep Learning Containers v1.12 and later.
in the SageMaker PyTorch Deep Learning Containers v1.11 and later.

Adapt Your Training Script
--------------------------
Expand All @@ -238,7 +238,6 @@ but you can also overwrite them.

**Supported backends:**

- ``gloo`` and ``tcp`` for CPU instances
- ``gloo`` and ``nccl`` for GPU instances

Launching a Distributed Training Job
Expand Down
62 changes: 58 additions & 4 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@
}

PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
Expand All @@ -149,6 +146,8 @@

TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]

SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]
SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ["ml.p4d.24xlarge"]

SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]

Expand Down Expand Up @@ -1060,7 +1059,6 @@ def validate_torch_distributed_distribution(
if not torch_distributed_enabled:
# Distribution strategy other than torch_distributed is selected
return

err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
Expand Down Expand Up @@ -1099,6 +1097,62 @@ def validate_torch_distributed_distribution(
raise ValueError(err_msg)


def validate_smddp_collectives_support(
framework_version, py_version, image_uri, instance_type, instance_count
):
"""Check if SMDDP collective backend is supported for current invocation.

Args:
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
image_uri (str): A string representing a Docker image URI.
instance_type (str): SageMaker instance type.
instance_count (int): Number of training instances to use.

Returns false if:
`instance_type` is not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES or
`py_version` is not python3 or
`framework_version` is not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS or
`instance_count` is not greater than 1
"""
err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported. "
"Please specify one of the supported framework versions:"
f" {SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS}.\n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported. "
"Please specify py_version>=py3.\n"
)
if instance_type not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES:
err_msg += (
f"Provided instance_type {instance_type} is not supported. "
"Please specify one of the supported instance types:"
f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n"
)
if instance_count == 1:
# Communication backend auto is not supported for single-node jobs
err_msg += (
"SMDDP Collective backend is not supported for single-node jobs. "
"Please increase instance_count to be greater than 1.\n"
)
if not err_msg:
return True
logger.warning(
"The system is not compatible or not configured to run SMDDP collectives optimized"
" for AWS infrastructure.\n%s",
err_msg,
)
logger.warning("Continuing model training with default NCCL communication backend.\n")
return False


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
48 changes: 46 additions & 2 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
python_deprecation_warning,
validate_version_or_image_args,
validate_distribution,
validate_smddp_collectives_support,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand All @@ -42,6 +43,9 @@ class PyTorch(Framework):
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
COMMUNICATION_BACKEND_ENV_NAME = "sagemaker_communication_backend"
COMMUNICATION_BACKEND_AUTO = "auto"
COMMUNICATION_BACKEND_NCCL = "nccl"

def __init__(
self,
Expand Down Expand Up @@ -308,15 +312,18 @@ def _pytorch_distribution_configuration(self, distribution):
pytorch_ddp_enabled = False
torch_distributed_enabled = False

if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
pytorch_ddp_dict = distribution.get("pytorchddp")
if pytorch_ddp_dict:
pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False)
elif "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)

if pytorch_ddp_enabled:
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
comm_backend = self._get_communication_backend(pytorch_ddp_dict)
distribution_config[self.COMMUNICATION_BACKEND_ENV_NAME] = comm_backend
elif torch_distributed_enabled:
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
if self.instance_type is not None:
Expand All @@ -326,6 +333,43 @@ def _pytorch_distribution_configuration(self, distribution):

return distribution_config

def _get_communication_backend(self, pytorch_ddp_dict):
"""Sets the collective communication backend to be used for the current training job.

Return `nccl` if:
* SMDDP collectives are not supported OR
* communication_options is specified and backend is set to `nccl`.

Return `auto` if:
* communication_options is not specified OR
* communication_options is specified but backend is not set OR
* communication_options is specified and backend is set to `auto`.

Args:
pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution.
Returns:
A boolean that indicates whether to enable SMDDP via communication backend auto
"""
is_smddp_coll_backend_supported = validate_smddp_collectives_support(
self.framework_version,
self.py_version,
self.image_uri,
self.instance_type,
self.instance_count,
)
if not is_smddp_coll_backend_supported:
return self.COMMUNICATION_BACKEND_NCCL

comm_options = pytorch_ddp_dict.get("communication_options")
if not comm_options:
return self.COMMUNICATION_BACKEND_AUTO

comm_backend = comm_options.get("backend")
if not comm_backend or comm_backend == self.COMMUNICATION_BACKEND_AUTO:
return self.COMMUNICATION_BACKEND_AUTO

return self.COMMUNICATION_BACKEND_NCCL

def hyperparameters(self):
"""Return hyperparameters used by your custom PyTorch code during model training."""
hyperparameters = super(PyTorch, self).hyperparameters()
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,7 @@ def pytorch_ddp_py_version():
return "py3"


@pytest.fixture(
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
)
@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0", "1.12.1"])
def pytorch_ddp_framework_version(request):
return request.param

Expand Down
30 changes: 29 additions & 1 deletion tests/integ/test_pytorchddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_pytorchddp_pt_mnist(
pytorch_ddp_framework_version,
pytorch_ddp_py_version,
):
job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp")
job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp")
estimator = PyTorch(
entry_point="mnist_pt.py",
role="SageMakerRole",
Expand All @@ -46,6 +46,34 @@ def test_pytorchddp_pt_mnist(
sagemaker_session=sagemaker_session,
framework_version=pytorch_ddp_framework_version,
py_version=pytorch_ddp_py_version,
distribution={
"pytorchddp": {"enabled": True, "communication_options": {"backend": "nccl"}}
},
)

with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)


@pytest.mark.skip(
reason="This test is skipped for now due ML capacity error."
"This test should be re-enabled later."
)
@pytest.mark.skipif(
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
)
def test_pytorchddp_pt_mnist_smddp_coll(sagemaker_session):
job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp-smddp")
estimator = PyTorch(
entry_point="mnist_pt.py",
role="SageMakerRole",
source_dir=pytorchddp_dir,
instance_count=2,
instance_type="ml.p4d.24xlarge",
sagemaker_session=sagemaker_session,
framework_version="1.12.1",
py_version="py3",
distribution={"pytorchddp": {"enabled": True}},
)

Expand Down
80 changes: 68 additions & 12 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,16 +958,7 @@ def test_validate_pytorchddp_not_raises():
)
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
pytorchddp_supported_fw_versions = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.12.1",
]
pytorchddp_supported_fw_versions = fw_utils.PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
for framework_version in pytorchddp_supported_fw_versions:
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
Expand All @@ -985,7 +976,7 @@ def test_validate_pytorchddp_raises():
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
framework_name="pytorch",
framework_version="1.8",
framework_version="1.10",
py_version="py3",
image_uri=None,
)
Expand All @@ -1002,7 +993,6 @@ def test_validate_pytorchddp_raises():


def test_validate_torch_distributed_not_raises():

# Case 1: Framework is PyTorch, but distribution is not torch_distributed
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
fw_utils.validate_torch_distributed_distribution(
Expand Down Expand Up @@ -1099,3 +1089,69 @@ def test_instance_type_supports_profiler():
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
assert fw_utils._instance_type_supports_profiler("local") is False


def test_validate_smddp_collectives_supported():
# Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters
smddp_coll_supported_fw_versions = fw_utils.SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS
for framework_version in smddp_coll_supported_fw_versions:
assert (
fw_utils.validate_smddp_collectives_support(
framework_version=framework_version,
py_version="py3",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)
is True
)


def test_validate_smddp_collectives_not_supported():
# Case 1: Unsupported framework version
assert (
fw_utils.validate_smddp_collectives_support(
framework_version="1.10",
py_version="py3",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)
is False
)

# Case 2: Unsupported Py version
assert (
fw_utils.validate_smddp_collectives_support(
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)
is False
)

# Case 3: Unsupported Instance Type
assert (
fw_utils.validate_smddp_collectives_support(
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p3.16xlarge",
instance_count=2,
)
is False
)

# Case 4: Unsupported Instance Count
assert (
fw_utils.validate_smddp_collectives_support(
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=1,
)
is False
)
Loading