diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index f56085f756..d479610e71 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -212,7 +212,7 @@ with the ``pytorchddp`` option as the distribution strategy. .. note:: This PyTorch DDP support is available - in the SageMaker PyTorch Deep Learning Containers v1.12 and later. + in the SageMaker PyTorch Deep Learning Containers v1.11 and later. Adapt Your Training Script -------------------------- @@ -238,7 +238,6 @@ but you can also overwrite them. **Supported backends:** -- ``gloo`` and ``tcp`` for CPU instances - ``gloo`` and ``nccl`` for GPU instances Launching a Distributed Training Job diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 59204fbc99..4532989886 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -133,9 +133,6 @@ } PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", "1.11", "1.11.0", "1.12", @@ -149,6 +146,8 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] +SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"] +SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ["ml.p4d.24xlarge"] SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -1060,7 +1059,6 @@ def validate_torch_distributed_distribution( if not torch_distributed_enabled: # Distribution strategy other than torch_distributed is selected return - err_msg = "" if not image_uri: # ignore framework_version and py_version if image_uri is set @@ -1099,6 +1097,62 @@ def validate_torch_distributed_distribution( raise ValueError(err_msg) +def validate_smddp_collectives_support( + framework_version, py_version, image_uri, instance_type, instance_count +): + """Check if SMDDP collective backend is supported for current invocation. + + Args: + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + instance_type (str): SageMaker instance type. + instance_count (int): Number of training instances to use. + + Returns false if: + `instance_type` is not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES or + `py_version` is not python3 or + `framework_version` is not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS or + `instance_count` is not greater than 1 + """ + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported. " + "Please specify one of the supported framework versions:" + f" {SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS}.\n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported. " + "Please specify py_version>=py3.\n" + ) + if instance_type not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES: + err_msg += ( + f"Provided instance_type {instance_type} is not supported. " + "Please specify one of the supported instance types:" + f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n" + ) + if instance_count == 1: + # Communication backend auto is not supported for single-node jobs + err_msg += ( + "SMDDP Collective backend is not supported for single-node jobs. " + "Please increase instance_count to be greater than 1.\n" + ) + if not err_msg: + return True + logger.warning( + "The system is not compatible or not configured to run SMDDP collectives optimized" + " for AWS infrastructure.\n%s", + err_msg, + ) + logger.warning("Continuing model training with default NCCL communication backend.\n") + return False + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 29e254662f..cba90a7ef7 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + validate_smddp_collectives_support, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -42,6 +43,9 @@ class PyTorch(Framework): LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + COMMUNICATION_BACKEND_ENV_NAME = "sagemaker_communication_backend" + COMMUNICATION_BACKEND_AUTO = "auto" + COMMUNICATION_BACKEND_NCCL = "nccl" def __init__( self, @@ -308,8 +312,9 @@ def _pytorch_distribution_configuration(self, distribution): pytorch_ddp_enabled = False torch_distributed_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + pytorch_ddp_dict = distribution.get("pytorchddp") + if pytorch_ddp_dict: + pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False) elif "torch_distributed" in distribution: torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) @@ -317,6 +322,8 @@ def _pytorch_distribution_configuration(self, distribution): distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + comm_backend = self._get_communication_backend(pytorch_ddp_dict) + distribution_config[self.COMMUNICATION_BACKEND_ENV_NAME] = comm_backend elif torch_distributed_enabled: distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled if self.instance_type is not None: @@ -326,6 +333,43 @@ def _pytorch_distribution_configuration(self, distribution): return distribution_config + def _get_communication_backend(self, pytorch_ddp_dict): + """Sets the collective communication backend to be used for the current training job. + + Return `nccl` if: + * SMDDP collectives are not supported OR + * communication_options is specified and backend is set to `nccl`. + + Return `auto` if: + * communication_options is not specified OR + * communication_options is specified but backend is not set OR + * communication_options is specified and backend is set to `auto`. + + Args: + pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution. + Returns: + A boolean that indicates whether to enable SMDDP via communication backend auto + """ + is_smddp_coll_backend_supported = validate_smddp_collectives_support( + self.framework_version, + self.py_version, + self.image_uri, + self.instance_type, + self.instance_count, + ) + if not is_smddp_coll_backend_supported: + return self.COMMUNICATION_BACKEND_NCCL + + comm_options = pytorch_ddp_dict.get("communication_options") + if not comm_options: + return self.COMMUNICATION_BACKEND_AUTO + + comm_backend = comm_options.get("backend") + if not comm_backend or comm_backend == self.COMMUNICATION_BACKEND_AUTO: + return self.COMMUNICATION_BACKEND_AUTO + + return self.COMMUNICATION_BACKEND_NCCL + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() diff --git a/tests/conftest.py b/tests/conftest.py index f6682ebb8c..3f43b643eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -476,9 +476,7 @@ def pytorch_ddp_py_version(): return "py3" -@pytest.fixture( - scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] -) +@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0", "1.12.1"]) def pytorch_ddp_framework_version(request): return request.param diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py index c580fdebc2..1c4d8e13ad 100644 --- a/tests/integ/test_pytorchddp.py +++ b/tests/integ/test_pytorchddp.py @@ -36,7 +36,7 @@ def test_pytorchddp_pt_mnist( pytorch_ddp_framework_version, pytorch_ddp_py_version, ): - job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp") estimator = PyTorch( entry_point="mnist_pt.py", role="SageMakerRole", @@ -46,6 +46,34 @@ def test_pytorchddp_pt_mnist( sagemaker_session=sagemaker_session, framework_version=pytorch_ddp_framework_version, py_version=pytorch_ddp_py_version, + distribution={ + "pytorchddp": {"enabled": True, "communication_options": {"backend": "nccl"}} + }, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) + + +@pytest.mark.skip( + reason="This test is skipped for now due ML capacity error." + "This test should be re-enabled later." +) +@pytest.mark.skipif( + integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, + reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", +) +def test_pytorchddp_pt_mnist_smddp_coll(sagemaker_session): + job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp-smddp") + estimator = PyTorch( + entry_point="mnist_pt.py", + role="SageMakerRole", + source_dir=pytorchddp_dir, + instance_count=2, + instance_type="ml.p4d.24xlarge", + sagemaker_session=sagemaker_session, + framework_version="1.12.1", + py_version="py3", distribution={"pytorchddp": {"enabled": True}}, ) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 528af4ada0..a631dd9308 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -958,16 +958,7 @@ def test_validate_pytorchddp_not_raises(): ) # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - ] + pytorchddp_supported_fw_versions = fw_utils.PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS for framework_version in pytorchddp_supported_fw_versions: fw_utils.validate_pytorch_distribution( distribution=pytorchddp_enabled, @@ -985,7 +976,7 @@ def test_validate_pytorchddp_raises(): fw_utils.validate_pytorch_distribution( distribution=pytorchddp_enabled, framework_name="pytorch", - framework_version="1.8", + framework_version="1.10", py_version="py3", image_uri=None, ) @@ -1002,7 +993,6 @@ def test_validate_pytorchddp_raises(): def test_validate_torch_distributed_not_raises(): - # Case 1: Framework is PyTorch, but distribution is not torch_distributed torch_distributed_disabled = {"torch_distributed": {"enabled": False}} fw_utils.validate_torch_distributed_distribution( @@ -1099,3 +1089,69 @@ def test_instance_type_supports_profiler(): assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False assert fw_utils._instance_type_supports_profiler("local") is False + + +def test_validate_smddp_collectives_supported(): + # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters + smddp_coll_supported_fw_versions = fw_utils.SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS + for framework_version in smddp_coll_supported_fw_versions: + assert ( + fw_utils.validate_smddp_collectives_support( + framework_version=framework_version, + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + is True + ) + + +def test_validate_smddp_collectives_not_supported(): + # Case 1: Unsupported framework version + assert ( + fw_utils.validate_smddp_collectives_support( + framework_version="1.10", + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + is False + ) + + # Case 2: Unsupported Py version + assert ( + fw_utils.validate_smddp_collectives_support( + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + is False + ) + + # Case 3: Unsupported Instance Type + assert ( + fw_utils.validate_smddp_collectives_support( + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p3.16xlarge", + instance_count=2, + ) + is False + ) + + # Case 4: Unsupported Instance Count + assert ( + fw_utils.validate_smddp_collectives_support( + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=1, + ) + is False + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index a11738e25c..932d2cede4 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -59,6 +59,12 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} +DISTRIBUTION_PYTORCH_SMDDP_COLL_DISABLED = { + "pytorchddp": {"enabled": True, "communication_options": {"backend": "nccl"}} +} +DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED = { + "pytorchddp": {"enabled": True, "communication_options": {"backend": "auto"}} +} @pytest.fixture(name="sagemaker_session") @@ -101,6 +107,7 @@ def _pytorch_estimator( framework_version, py_version, instance_type=None, + instance_count=None, base_job_name=None, **kwargs, ): @@ -110,7 +117,7 @@ def _pytorch_estimator( py_version=py_version, role=ROLE, sagemaker_session=sagemaker_session, - instance_count=INSTANCE_COUNT, + instance_count=instance_count if instance_count else INSTANCE_COUNT, instance_type=instance_type if instance_type else INSTANCE_TYPE, base_job_name=base_job_name, **kwargs, @@ -767,7 +774,7 @@ def test_register_pytorch_model_auto_infer_framework( def test_pytorch_ddp_distribution_configuration( sagemaker_session, pytorch_ddp_framework_version, pytorch_ddp_py_version ): - test_instance_type = "ml.p4d.24xlarge" + test_instance_type = "ml.g5.24xlarge" pytorch = _pytorch_estimator( sagemaker_session, framework_version=pytorch_ddp_framework_version, @@ -781,6 +788,7 @@ def test_pytorch_ddp_distribution_configuration( expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": test_instance_type, + "sagemaker_communication_backend": "nccl", } assert actual_pytorch_ddp == expected_torch_ddp @@ -797,3 +805,92 @@ def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): ) assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) assert (f"py_version {unsupported_py_version} is not supported") in str(error) + + +def test_pytorch_ddp_comm_back_nccl(sagemaker_session): + valid_framework_version = "1.12.1" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_DISABLED, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": INSTANCE_TYPE, + "sagemaker_communication_backend": "nccl", + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_comm_backend_auto_smddp_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.13.1" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_communication_backend": "auto", + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_comm_backend_auto_smddp_unsupported(sagemaker_session): + unsupported_framework_version = "1.11" + unsupported_instance_count = 1 + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version="py3", + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED, + instance_count=unsupported_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": INSTANCE_TYPE, + "sagemaker_communication_backend": "nccl", + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_no_comm_options_smddp_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.13.1" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_communication_backend": "auto", + } + assert actual_pytorch_ddp == expected_torch_ddp