From 6a30167704094f66b8f9eb247c1a04e974b108fc Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Wed, 7 Sep 2022 17:11:59 +0000 Subject: [PATCH 01/10] Add use_accl option to pytorchddp distribution * Set to true by default * Update pt ddp supported versions Add validation for instance type, fw and py versions. Clean up logs and comments Add integ test for ACCL disabled Fix integ test params Fix formatting errors in tests Update logic to set accl_enabled to none for default on cases --- doc/frameworks/pytorch/using_pytorch.rst | 2 +- src/sagemaker/fw_utils.py | 64 ++++++++++++- src/sagemaker/pytorch/estimator.py | 46 ++++++++- tests/conftest.py | 4 +- tests/integ/test_pytorchddp.py | 26 +++++ tests/unit/test_fw_utils.py | 86 ++++++++++++++++- tests/unit/test_pytorch.py | 115 ++++++++++++++++++++++- 7 files changed, 326 insertions(+), 17 deletions(-) diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index f56085f756..b0eb56edd7 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -212,7 +212,7 @@ with the ``pytorchddp`` option as the distribution strategy. .. note:: This PyTorch DDP support is available - in the SageMaker PyTorch Deep Learning Containers v1.12 and later. + in the SageMaker PyTorch Deep Learning Containers v1.11 and later. Adapt Your Training Script -------------------------- diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 59204fbc99..2d91634a06 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -133,9 +133,6 @@ } PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", "1.11", "1.11.0", "1.12", @@ -149,6 +146,12 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] +# TODO: Change to 1.12.1 before merging +ACCL_SUPPORTED_FRAMEWORK_VERSIONS = ( + "1.12", + "1.12.0", +) +ACCL_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -1060,7 +1063,6 @@ def validate_torch_distributed_distribution( if not torch_distributed_enabled: # Distribution strategy other than torch_distributed is selected return - err_msg = "" if not image_uri: # ignore framework_version and py_version if image_uri is set @@ -1098,6 +1100,60 @@ def validate_torch_distributed_distribution( if err_msg: raise ValueError(err_msg) +def validate_accl_support( + use_accl, framework_version, py_version, image_uri, instance_type, instance_count +): + """Check if ACCL is supported for current invocation. + + Args: + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + instance_type (str): SageMaker instance type. + instance_count (int): Number of training instances to use. + + Raises: + ValueError: if `use_accl` is set to true and validation fails, i.e. + `instance_type` is not in ACCL_SUPPORTED_INSTANCE_TYPES or + `py_version` is not python3 or + `framework_version` is not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS or + `instance_count` is not greater than 1 + """ + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + " ACCL.\n" + "Please specify one of the supported framework versions:" + f" {ACCL_SUPPORTED_FRAMEWORK_VERSIONS}.\n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported by ACCL.\n" + "Please specify py_version>=py3.\n" + ) + if instance_type not in ACCL_SUPPORTED_INSTANCE_TYPES: + err_msg += ( + f"Provided instance_type {instance_type} is not supported by ACCL.\n" + "Please specify one of the supported instance types:" + f"{ACCL_SUPPORTED_INSTANCE_TYPES}.\n" + ) + if instance_count == 1: + # ACCL is not supported for single-node jobs + err_msg += ( + "ACCL is not supported for single-node jobs.\n" + "Please increase instance_count to be greater than 1.\n" + ) + if not err_msg: + return True + if use_accl: + raise ValueError(f"Could not enable ACCL.\n {err_msg}") + logger.warning("Could not enable ACCL.\n %s", err_msg) + return False + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 29e254662f..d6c402ccf8 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + validate_accl_support, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -42,6 +43,7 @@ class PyTorch(Framework): LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + ACCL_ENABLED_ENV_NAME = "sagemaker_accl_enabled" def __init__( self, @@ -308,15 +310,18 @@ def _pytorch_distribution_configuration(self, distribution): pytorch_ddp_enabled = False torch_distributed_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + pytorch_ddp_dict = distribution.get("pytorchddp") + if pytorch_ddp_dict: + pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False) elif "torch_distributed" in distribution: torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) - + if pytorch_ddp_enabled: distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + is_accl_enabled = self._get_accl_enabled(pytorch_ddp_dict) + distribution_config[self.ACCL_ENABLED_ENV_NAME] = is_accl_enabled elif torch_distributed_enabled: distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled if self.instance_type is not None: @@ -326,6 +331,41 @@ def _pytorch_distribution_configuration(self, distribution): return distribution_config + def _get_accl_enabled(self, pytorch_ddp_dict): + """Evaluates if ACCL should be enabled for current training jobs. + + Case 1: Customer explicitly disables ACCL by setting use_accl to False. + Return false. + + Case 2: Customer explicitly enables ACCL by setting use_accl to True. + Test if configuration is supported for ACCL. + If yes, return true. If not, throw an error. + + Case 3: Customer does not specify use_accl. We try to enable by default. + Test if configuration is supported for ACCL. + If not, we return false. + + Args: + pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution. + Returns: + A boolean that indicates whether to enable ACCL + """ + use_accl = pytorch_ddp_dict.get("use_accl") + is_accl_supported = validate_accl_support( + use_accl, + self.framework_version, + self.py_version, + self.image_uri, + self.instance_type, + self.instance_count, + ) + + if use_accl is False or not is_accl_supported: + return False + if use_accl and is_accl_supported: + return True + return use_accl + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() diff --git a/tests/conftest.py b/tests/conftest.py index f6682ebb8c..94621a547a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -476,9 +476,7 @@ def pytorch_ddp_py_version(): return "py3" -@pytest.fixture( - scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] -) +@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0"]) def pytorch_ddp_framework_version(request): return request.param diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py index c580fdebc2..6a886d71f2 100644 --- a/tests/integ/test_pytorchddp.py +++ b/tests/integ/test_pytorchddp.py @@ -51,3 +51,29 @@ def test_pytorchddp_pt_mnist( with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) + + +@pytest.mark.skip( + reason="This test is skipped for now due ML capacity error." + "This test should be re-enabled later." +) +@pytest.mark.skipif( + integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, + reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", +) +def test_pytorchddp_pt_mnist_accl_disabled(sagemaker_session): + job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + estimator = PyTorch( + entry_point="mnist_pt.py", + role="SageMakerRole", + source_dir=pytorchddp_dir, + instance_count=2, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + framework_version="1.12", + py_version="py3", + distribution={"pytorchddp": {"enabled": True, "use_accl": False}}, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 528af4ada0..cb785a59a5 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -959,9 +959,6 @@ def test_validate_pytorchddp_not_raises(): # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions pytorchddp_enabled = {"pytorchddp": {"enabled": True}} pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", "1.11", "1.11.0", "1.12", @@ -985,7 +982,7 @@ def test_validate_pytorchddp_raises(): fw_utils.validate_pytorch_distribution( distribution=pytorchddp_enabled, framework_name="pytorch", - framework_version="1.8", + framework_version="1.10", py_version="py3", image_uri=None, ) @@ -1099,3 +1096,84 @@ def test_instance_type_supports_profiler(): assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False assert fw_utils._instance_type_supports_profiler("local") is False + + +def test_validate_accl_support_true(): + # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters + accl_supported_fw_versions = [ + "1.12", + "1.12.0", + ] + for framework_version in accl_supported_fw_versions: + assert ( + fw_utils.validate_accl_support( + use_accl=True, + framework_version=framework_version, + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + is True + ) + + +def test_validate_accl_support_false(): + # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters + assert ( + fw_utils.validate_accl_support( + use_accl=False, + framework_version="1.11", + py_version="py3", + image_uri=None, + instance_type="ml.p3dn.24xlarge", + instance_count=1, + ) + is False + ) + + +def test_validate_accl_support_error(): + # Case 1: Unsupported framework version + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + + # Case 2: Unsupported Py version + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + + # Case 3: Unsupported Instance Type + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p3.16xlarge", + instance_count=2, + ) + + # Case 4: Unsupported Instance Count + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=1, + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index a11738e25c..1b3cc329d9 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -59,6 +59,8 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} +DISTRIBUTION_PYTORCH_ACCL_DISABLED = {"pytorchddp": {"enabled": True, "use_accl": False}} +DISTRIBUTION_PYTORCH_ACCL_ENABLED = {"pytorchddp": {"enabled": True, "use_accl": True}} @pytest.fixture(name="sagemaker_session") @@ -101,6 +103,7 @@ def _pytorch_estimator( framework_version, py_version, instance_type=None, + instance_count=None, base_job_name=None, **kwargs, ): @@ -110,7 +113,7 @@ def _pytorch_estimator( py_version=py_version, role=ROLE, sagemaker_session=sagemaker_session, - instance_count=INSTANCE_COUNT, + instance_count=instance_count if instance_count else INSTANCE_COUNT, instance_type=instance_type if instance_type else INSTANCE_TYPE, base_job_name=base_job_name, **kwargs, @@ -767,7 +770,7 @@ def test_register_pytorch_model_auto_infer_framework( def test_pytorch_ddp_distribution_configuration( sagemaker_session, pytorch_ddp_framework_version, pytorch_ddp_py_version ): - test_instance_type = "ml.p4d.24xlarge" + test_instance_type = "ml.g5.24xlarge" pytorch = _pytorch_estimator( sagemaker_session, framework_version=pytorch_ddp_framework_version, @@ -781,6 +784,7 @@ def test_pytorch_ddp_distribution_configuration( expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": test_instance_type, + "sagemaker_accl_enabled": False, } assert actual_pytorch_ddp == expected_torch_ddp @@ -797,3 +801,110 @@ def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): ) assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) assert (f"py_version {unsupported_py_version} is not supported") in str(error) + + +def test_pytorch_ddp_accl_disabled(sagemaker_session): + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_ACCL_DISABLED, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": INSTANCE_TYPE, + "sagemaker_accl_enabled": False, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_explicit_enabled_and_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_accl_enabled": True, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_explicit_enabled_and_unsupported(sagemaker_session): + unsupported_framework_version = "1.11" + unsupported_instance_count = 1 + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version="py3", + distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + instance_count=unsupported_instance_count, + ) + with pytest.raises(ValueError) as error: + pytorch._pytorch_distribution_configuration(distribution=pytorch.distribution) + assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) + assert ("ACCL is not supported for single-node jobs.") in str(error) + + +def test_pytorch_ddp_accl_default_on_and_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_accl_enabled": None, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_default_on_and_unsupported(sagemaker_session): + unsupported_framework_version = "1.11" + unsupported_instance_type = "ml.g5.24xlarge" + unsupported_instance_count = 1 + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version="py3", + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=unsupported_instance_type, + instance_count=unsupported_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": unsupported_instance_type, + "sagemaker_accl_enabled": False, + } + assert actual_pytorch_ddp == expected_torch_ddp From 6c1afa59079bf6d9cb8902141d4ea870eed9e27c Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Thu, 24 Nov 2022 01:19:00 +0000 Subject: [PATCH 02/10] Update API spec to use communication_options --- src/sagemaker/fw_utils.py | 46 +++++++++------------ src/sagemaker/pytorch/estimator.py | 52 ++++++++++++----------- tests/integ/test_pytorchddp.py | 16 ++++---- tests/unit/test_fw_utils.py | 66 ++++++++++-------------------- 4 files changed, 79 insertions(+), 101 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 2d91634a06..0195ffa272 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -147,11 +147,8 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] # TODO: Change to 1.12.1 before merging -ACCL_SUPPORTED_FRAMEWORK_VERSIONS = ( - "1.12", - "1.12.0", -) -ACCL_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) +SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ("1.12.1",) +SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -1100,10 +1097,11 @@ def validate_torch_distributed_distribution( if err_msg: raise ValueError(err_msg) -def validate_accl_support( - use_accl, framework_version, py_version, image_uri, instance_type, instance_count + +def validate_smddp_collectives_support( + framework_version, py_version, image_uri, instance_type, instance_count ): - """Check if ACCL is supported for current invocation. + """Check if SMDDP collective backend is supported for current invocation. Args: framework_version (str): A string representing the framework version selected. @@ -1112,46 +1110,42 @@ def validate_accl_support( instance_type (str): SageMaker instance type. instance_count (int): Number of training instances to use. - Raises: - ValueError: if `use_accl` is set to true and validation fails, i.e. - `instance_type` is not in ACCL_SUPPORTED_INSTANCE_TYPES or - `py_version` is not python3 or - `framework_version` is not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS or - `instance_count` is not greater than 1 + Returns false if: + `instance_type` is not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES or + `py_version` is not python3 or + `framework_version` is not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS or + `instance_count` is not greater than 1 """ err_msg = "" if not image_uri: # ignore framework_version and py_version if image_uri is set # in case image_uri is not set, then both are mandatory - if framework_version not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS: + if framework_version not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS: err_msg += ( - f"Provided framework_version {framework_version} is not supported by" - " ACCL.\n" + f"Provided framework_version {framework_version} is not supported.\n" "Please specify one of the supported framework versions:" - f" {ACCL_SUPPORTED_FRAMEWORK_VERSIONS}.\n" + f" {SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS}.\n" ) if "py3" not in py_version: err_msg += ( - f"Provided py_version {py_version} is not supported by ACCL.\n" + f"Provided py_version {py_version} is not supported.\n" "Please specify py_version>=py3.\n" ) - if instance_type not in ACCL_SUPPORTED_INSTANCE_TYPES: + if instance_type not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES: err_msg += ( - f"Provided instance_type {instance_type} is not supported by ACCL.\n" + f"Provided instance_type {instance_type} is not supported.\n" "Please specify one of the supported instance types:" - f"{ACCL_SUPPORTED_INSTANCE_TYPES}.\n" + f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n" ) if instance_count == 1: # ACCL is not supported for single-node jobs err_msg += ( - "ACCL is not supported for single-node jobs.\n" + "SMDDP Collective backend is not supported for single-node jobs.\n" "Please increase instance_count to be greater than 1.\n" ) if not err_msg: return True - if use_accl: - raise ValueError(f"Could not enable ACCL.\n {err_msg}") - logger.warning("Could not enable ACCL.\n %s", err_msg) + logger.warning("Could not enable SMDDP Collectives.\n %s", err_msg) return False diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index d6c402ccf8..ca4ed8dad7 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,7 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, - validate_accl_support, + validate_smddp_collectives_support, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -43,7 +43,9 @@ class PyTorch(Framework): LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" - ACCL_ENABLED_ENV_NAME = "sagemaker_accl_enabled" + COMMUNICATION_BACKEND_ENV_NAME = "sagemaker_communication_backend" + COMMUNICATION_BACKEND_AUTO = "auto" + COMMUNICATION_BACKEND_NCCL = "nccl" def __init__( self, @@ -315,13 +317,13 @@ def _pytorch_distribution_configuration(self, distribution): pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False) elif "torch_distributed" in distribution: torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) - + if pytorch_ddp_enabled: distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type - is_accl_enabled = self._get_accl_enabled(pytorch_ddp_dict) - distribution_config[self.ACCL_ENABLED_ENV_NAME] = is_accl_enabled + comm_backend = self._get_communication_backend(pytorch_ddp_dict) + distribution_config[self.COMMUNICATION_BACKEND_ENV_NAME] = comm_backend elif torch_distributed_enabled: distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled if self.instance_type is not None: @@ -331,40 +333,42 @@ def _pytorch_distribution_configuration(self, distribution): return distribution_config - def _get_accl_enabled(self, pytorch_ddp_dict): - """Evaluates if ACCL should be enabled for current training jobs. - - Case 1: Customer explicitly disables ACCL by setting use_accl to False. - Return false. + def _get_communication_backend(self, pytorch_ddp_dict): + """Sets the collective communication backend to be used for the current training job. - Case 2: Customer explicitly enables ACCL by setting use_accl to True. - Test if configuration is supported for ACCL. - If yes, return true. If not, throw an error. + Return `nccl` if: + * SMDDP collectives are not supported OR + * communication_options is specified and backend is set to `nccl`. - Case 3: Customer does not specify use_accl. We try to enable by default. - Test if configuration is supported for ACCL. - If not, we return false. + Return `auto` if: + * communication_options is not specified OR + * communication_options is specified but backend is not set OR + * communication_options is specified and backend is set to `auto`. Args: pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution. Returns: A boolean that indicates whether to enable ACCL """ - use_accl = pytorch_ddp_dict.get("use_accl") - is_accl_supported = validate_accl_support( - use_accl, + is_smddp_coll_backend_supported = validate_smddp_collectives_support( self.framework_version, self.py_version, self.image_uri, self.instance_type, self.instance_count, ) + if not is_smddp_coll_backend_supported: + return self.COMMUNICATION_BACKEND_NCCL + + comm_options = pytorch_ddp_dict.get("communication_options") + if not comm_options: + return self.COMMUNICATION_BACKEND_AUTO + + comm_backend = comm_options.get("backend") + if not comm_backend or comm_backend == self.COMMUNICATION_BACKEND_AUTO: + return self.COMMUNICATION_BACKEND_AUTO - if use_accl is False or not is_accl_supported: - return False - if use_accl and is_accl_supported: - return True - return use_accl + return self.COMMUNICATION_BACKEND_NCCL def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py index 6a886d71f2..1c4d8e13ad 100644 --- a/tests/integ/test_pytorchddp.py +++ b/tests/integ/test_pytorchddp.py @@ -36,7 +36,7 @@ def test_pytorchddp_pt_mnist( pytorch_ddp_framework_version, pytorch_ddp_py_version, ): - job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp") estimator = PyTorch( entry_point="mnist_pt.py", role="SageMakerRole", @@ -46,7 +46,9 @@ def test_pytorchddp_pt_mnist( sagemaker_session=sagemaker_session, framework_version=pytorch_ddp_framework_version, py_version=pytorch_ddp_py_version, - distribution={"pytorchddp": {"enabled": True}}, + distribution={ + "pytorchddp": {"enabled": True, "communication_options": {"backend": "nccl"}} + }, ) with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): @@ -61,18 +63,18 @@ def test_pytorchddp_pt_mnist( integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", ) -def test_pytorchddp_pt_mnist_accl_disabled(sagemaker_session): - job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") +def test_pytorchddp_pt_mnist_smddp_coll(sagemaker_session): + job_name = sagemaker.utils.unique_name_from_base("pytorch-ddp-smddp") estimator = PyTorch( entry_point="mnist_pt.py", role="SageMakerRole", source_dir=pytorchddp_dir, instance_count=2, - instance_type="ml.p3.16xlarge", + instance_type="ml.p4d.24xlarge", sagemaker_session=sagemaker_session, - framework_version="1.12", + framework_version="1.12.1", py_version="py3", - distribution={"pytorchddp": {"enabled": True, "use_accl": False}}, + distribution={"pytorchddp": {"enabled": True}}, ) with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index cb785a59a5..a631dd9308 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -958,13 +958,7 @@ def test_validate_pytorchddp_not_raises(): ) # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - pytorchddp_supported_fw_versions = [ - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - ] + pytorchddp_supported_fw_versions = fw_utils.PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS for framework_version in pytorchddp_supported_fw_versions: fw_utils.validate_pytorch_distribution( distribution=pytorchddp_enabled, @@ -999,7 +993,6 @@ def test_validate_pytorchddp_raises(): def test_validate_torch_distributed_not_raises(): - # Case 1: Framework is PyTorch, but distribution is not torch_distributed torch_distributed_disabled = {"torch_distributed": {"enabled": False}} fw_utils.validate_torch_distributed_distribution( @@ -1098,16 +1091,12 @@ def test_instance_type_supports_profiler(): assert fw_utils._instance_type_supports_profiler("local") is False -def test_validate_accl_support_true(): +def test_validate_smddp_collectives_supported(): # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters - accl_supported_fw_versions = [ - "1.12", - "1.12.0", - ] - for framework_version in accl_supported_fw_versions: + smddp_coll_supported_fw_versions = fw_utils.SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS + for framework_version in smddp_coll_supported_fw_versions: assert ( - fw_utils.validate_accl_support( - use_accl=True, + fw_utils.validate_smddp_collectives_support( framework_version=framework_version, py_version="py3", image_uri=None, @@ -1118,62 +1107,51 @@ def test_validate_accl_support_true(): ) -def test_validate_accl_support_false(): - # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters - assert ( - fw_utils.validate_accl_support( - use_accl=False, - framework_version="1.11", - py_version="py3", - image_uri=None, - instance_type="ml.p3dn.24xlarge", - instance_count=1, - ) - is False - ) - - -def test_validate_accl_support_error(): +def test_validate_smddp_collectives_not_supported(): # Case 1: Unsupported framework version - with pytest.raises(ValueError): - fw_utils.validate_accl_support( - use_accl=True, + assert ( + fw_utils.validate_smddp_collectives_support( framework_version="1.10", py_version="py3", image_uri=None, instance_type="ml.p4d.24xlarge", instance_count=2, ) + is False + ) # Case 2: Unsupported Py version - with pytest.raises(ValueError): - fw_utils.validate_accl_support( - use_accl=True, + assert ( + fw_utils.validate_smddp_collectives_support( framework_version="1.10", py_version="py2", image_uri=None, instance_type="ml.p4d.24xlarge", instance_count=2, ) + is False + ) # Case 3: Unsupported Instance Type - with pytest.raises(ValueError): - fw_utils.validate_accl_support( - use_accl=True, + assert ( + fw_utils.validate_smddp_collectives_support( framework_version="1.10", py_version="py2", image_uri=None, instance_type="ml.p3.16xlarge", instance_count=2, ) + is False + ) # Case 4: Unsupported Instance Count - with pytest.raises(ValueError): - fw_utils.validate_accl_support( - use_accl=True, + assert ( + fw_utils.validate_smddp_collectives_support( framework_version="1.10", py_version="py2", image_uri=None, instance_type="ml.p4d.24xlarge", instance_count=1, ) + is False + ) From 890723d6ce070f9f976952637f5fc84948b59a10 Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Mon, 28 Nov 2022 19:08:59 +0000 Subject: [PATCH 03/10] Remove extra whitespace in logs --- src/sagemaker/fw_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0195ffa272..2ecb17274a 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1145,7 +1145,7 @@ def validate_smddp_collectives_support( ) if not err_msg: return True - logger.warning("Could not enable SMDDP Collectives.\n %s", err_msg) + logger.warning("Could not enable SMDDP Collectives.\n%s", err_msg) return False From d42482cdc22ca8b6733515d1772762dcbeb87a3f Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Mon, 28 Nov 2022 19:34:00 +0000 Subject: [PATCH 04/10] Fix unit tests --- src/sagemaker/fw_utils.py | 1 - tests/unit/test_pytorch.py | 72 +++++++++++++++----------------------- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 2ecb17274a..0c67293eee 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -146,7 +146,6 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] -# TODO: Change to 1.12.1 before merging SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ("1.12.1",) SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 1b3cc329d9..06eefb0ab5 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -59,8 +59,12 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} -DISTRIBUTION_PYTORCH_ACCL_DISABLED = {"pytorchddp": {"enabled": True, "use_accl": False}} -DISTRIBUTION_PYTORCH_ACCL_ENABLED = {"pytorchddp": {"enabled": True, "use_accl": True}} +DISTRIBUTION_PYTORCH_SMDDP_COLL_DISABLED = { + "pytorchddp": {"enabled": True, "communication_options": {"backend": "nccl"}} +} +DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED = { + "pytorchddp": {"enabled": True, "communication_options": {"backend": "auto"}} +} @pytest.fixture(name="sagemaker_session") @@ -784,7 +788,7 @@ def test_pytorch_ddp_distribution_configuration( expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": test_instance_type, - "sagemaker_accl_enabled": False, + "sagemaker_communication_backend": "nccl", } assert actual_pytorch_ddp == expected_torch_ddp @@ -803,14 +807,14 @@ def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): assert (f"py_version {unsupported_py_version} is not supported") in str(error) -def test_pytorch_ddp_accl_disabled(sagemaker_session): - valid_framework_version = "1.12" +def test_pytorch_ddp_comm_back_nccl(sagemaker_session): + valid_framework_version = "1.12.1" valid_py_version = "py3" pytorch = _pytorch_estimator( sagemaker_session, framework_version=valid_framework_version, py_version=valid_py_version, - distribution=DISTRIBUTION_PYTORCH_ACCL_DISABLED, + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_DISABLED, ) actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( distribution=pytorch.distribution @@ -818,21 +822,21 @@ def test_pytorch_ddp_accl_disabled(sagemaker_session): expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": INSTANCE_TYPE, - "sagemaker_accl_enabled": False, + "sagemaker_communication_backend": "nccl", } assert actual_pytorch_ddp == expected_torch_ddp -def test_pytorch_ddp_accl_explicit_enabled_and_supported(sagemaker_session): +def test_pytorch_ddp_comm_backend_auto_smddp_supported(sagemaker_session): valid_instance_type = "ml.p4d.24xlarge" valid_instance_count = 2 - valid_framework_version = "1.12" + valid_framework_version = "1.12.1" valid_py_version = "py3" pytorch = _pytorch_estimator( sagemaker_session, framework_version=valid_framework_version, py_version=valid_py_version, - distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED, instance_type=valid_instance_type, instance_count=valid_instance_count, ) @@ -842,31 +846,36 @@ def test_pytorch_ddp_accl_explicit_enabled_and_supported(sagemaker_session): expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": valid_instance_type, - "sagemaker_accl_enabled": True, + "sagemaker_communication_backend": "auto", } assert actual_pytorch_ddp == expected_torch_ddp -def test_pytorch_ddp_accl_explicit_enabled_and_unsupported(sagemaker_session): +def test_pytorch_ddp_comm_backend_auto_smddp_unsupported(sagemaker_session): unsupported_framework_version = "1.11" unsupported_instance_count = 1 pytorch = _pytorch_estimator( sagemaker_session, framework_version=unsupported_framework_version, py_version="py3", - distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + distribution=DISTRIBUTION_PYTORCH_SMDDP_COLL_ENABLED, instance_count=unsupported_instance_count, ) - with pytest.raises(ValueError) as error: - pytorch._pytorch_distribution_configuration(distribution=pytorch.distribution) - assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) - assert ("ACCL is not supported for single-node jobs.") in str(error) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": INSTANCE_TYPE, + "sagemaker_communication_backend": "nccl", + } + assert actual_pytorch_ddp == expected_torch_ddp -def test_pytorch_ddp_accl_default_on_and_supported(sagemaker_session): +def test_pytorch_ddp_no_comm_options_smddp_supported(sagemaker_session): valid_instance_type = "ml.p4d.24xlarge" valid_instance_count = 2 - valid_framework_version = "1.12" + valid_framework_version = "1.12.1" valid_py_version = "py3" pytorch = _pytorch_estimator( sagemaker_session, @@ -882,29 +891,6 @@ def test_pytorch_ddp_accl_default_on_and_supported(sagemaker_session): expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": valid_instance_type, - "sagemaker_accl_enabled": None, - } - assert actual_pytorch_ddp == expected_torch_ddp - - -def test_pytorch_ddp_accl_default_on_and_unsupported(sagemaker_session): - unsupported_framework_version = "1.11" - unsupported_instance_type = "ml.g5.24xlarge" - unsupported_instance_count = 1 - pytorch = _pytorch_estimator( - sagemaker_session, - framework_version=unsupported_framework_version, - py_version="py3", - distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, - instance_type=unsupported_instance_type, - instance_count=unsupported_instance_count, - ) - actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( - distribution=pytorch.distribution - ) - expected_torch_ddp = { - "sagemaker_pytorch_ddp_enabled": True, - "sagemaker_instance_type": unsupported_instance_type, - "sagemaker_accl_enabled": False, + "sagemaker_communication_backend": "auto", } assert actual_pytorch_ddp == expected_torch_ddp From b78b4046af42a9082fef4a33f0b4e5fab3c6d8c6 Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Mon, 28 Nov 2022 19:41:17 +0000 Subject: [PATCH 05/10] Update comments --- src/sagemaker/fw_utils.py | 2 +- src/sagemaker/pytorch/estimator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0c67293eee..aa3c4a2e2d 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1137,7 +1137,7 @@ def validate_smddp_collectives_support( f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n" ) if instance_count == 1: - # ACCL is not supported for single-node jobs + # Communication backend auto is not supported for single-node jobs err_msg += ( "SMDDP Collective backend is not supported for single-node jobs.\n" "Please increase instance_count to be greater than 1.\n" diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index ca4ed8dad7..cba90a7ef7 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -348,7 +348,7 @@ def _get_communication_backend(self, pytorch_ddp_dict): Args: pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution. Returns: - A boolean that indicates whether to enable ACCL + A boolean that indicates whether to enable SMDDP via communication backend auto """ is_smddp_coll_backend_supported = validate_smddp_collectives_support( self.framework_version, From 2df87248bcd5dc6721efc884fc45d6d5c5fbc8ac Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Tue, 29 Nov 2022 19:14:17 +0000 Subject: [PATCH 06/10] Update logs --- src/sagemaker/fw_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index aa3c4a2e2d..95534a193f 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1121,30 +1121,31 @@ def validate_smddp_collectives_support( # in case image_uri is not set, then both are mandatory if framework_version not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS: err_msg += ( - f"Provided framework_version {framework_version} is not supported.\n" + f"Provided framework_version {framework_version} is not supported." "Please specify one of the supported framework versions:" f" {SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS}.\n" ) if "py3" not in py_version: err_msg += ( - f"Provided py_version {py_version} is not supported.\n" + f"Provided py_version {py_version} is not supported." "Please specify py_version>=py3.\n" ) if instance_type not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES: err_msg += ( - f"Provided instance_type {instance_type} is not supported.\n" + f"Provided instance_type {instance_type} is not supported." "Please specify one of the supported instance types:" f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n" ) if instance_count == 1: # Communication backend auto is not supported for single-node jobs err_msg += ( - "SMDDP Collective backend is not supported for single-node jobs.\n" + "SMDDP Collective backend is not supported for single-node jobs." "Please increase instance_count to be greater than 1.\n" ) if not err_msg: return True - logger.warning("Could not enable SMDDP Collectives.\n%s", err_msg) + logger.warning("Could not enable SMDDP Collectives for the training job.\n%s", err_msg) + logger.warning("Continuing training with NCCL collective backend.\n") return False From d8dcbd5c998531ee3506f2ea5312f6bd996c4811 Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Mon, 5 Dec 2022 17:20:44 +0000 Subject: [PATCH 07/10] Update warning logs --- doc/frameworks/pytorch/using_pytorch.rst | 1 - src/sagemaker/fw_utils.py | 16 ++++++++++------ tests/conftest.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index b0eb56edd7..d479610e71 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -238,7 +238,6 @@ but you can also overwrite them. **Supported backends:** -- ``gloo`` and ``tcp`` for CPU instances - ``gloo`` and ``nccl`` for GPU instances Launching a Distributed Training Job diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 95534a193f..b3d89e7164 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1121,31 +1121,35 @@ def validate_smddp_collectives_support( # in case image_uri is not set, then both are mandatory if framework_version not in SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS: err_msg += ( - f"Provided framework_version {framework_version} is not supported." + f"Provided framework_version {framework_version} is not supported. " "Please specify one of the supported framework versions:" f" {SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS}.\n" ) if "py3" not in py_version: err_msg += ( - f"Provided py_version {py_version} is not supported." + f"Provided py_version {py_version} is not supported. " "Please specify py_version>=py3.\n" ) if instance_type not in SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES: err_msg += ( - f"Provided instance_type {instance_type} is not supported." + f"Provided instance_type {instance_type} is not supported. " "Please specify one of the supported instance types:" f"{SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES}.\n" ) if instance_count == 1: # Communication backend auto is not supported for single-node jobs err_msg += ( - "SMDDP Collective backend is not supported for single-node jobs." + "SMDDP Collective backend is not supported for single-node jobs. " "Please increase instance_count to be greater than 1.\n" ) if not err_msg: return True - logger.warning("Could not enable SMDDP Collectives for the training job.\n%s", err_msg) - logger.warning("Continuing training with NCCL collective backend.\n") + logger.warning( + "The system is not compatible or not configured to run SMDDP collectives optimized" + " for AWS infrastructure.\n%s", + err_msg, + ) + logger.warning("Continuing model training with default NCCL communication backend.\n") return False diff --git a/tests/conftest.py b/tests/conftest.py index 94621a547a..3f43b643eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -476,7 +476,7 @@ def pytorch_ddp_py_version(): return "py3" -@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0"]) +@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0", "1.12.1"]) def pytorch_ddp_framework_version(request): return request.param From 094ee86c5b59b98b4b56a9f0f2b4a78bd7bea97f Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Mon, 12 Dec 2022 23:39:41 +0000 Subject: [PATCH 08/10] Fix nit --- src/sagemaker/fw_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index b3d89e7164..bb4fbfb0d5 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -146,8 +146,8 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] -SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ("1.12.1",) -SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) +SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ["1.12.1"] +SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ["ml.p4d.24xlarge"] SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] From 87579b8ffbc62bcaf7f05b41481ad15c97685e4c Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Fri, 20 Jan 2023 07:48:01 +0000 Subject: [PATCH 09/10] Update supported fw version to 1.13.1 --- src/sagemaker/fw_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index bb4fbfb0d5..4532989886 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -146,7 +146,7 @@ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] -SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ["1.12.1"] +SMDDP_COLLECTIVES_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"] SMDDP_COLLECTIVES_SUPPORTED_INSTANCE_TYPES = ["ml.p4d.24xlarge"] SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] From 86f8155a45e6d61682a471f99d134075e7af1099 Mon Sep 17 00:00:00 2001 From: vishwakaria Date: Fri, 20 Jan 2023 09:47:06 +0000 Subject: [PATCH 10/10] Update fw version in unit tests --- tests/unit/test_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 06eefb0ab5..932d2cede4 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -830,7 +830,7 @@ def test_pytorch_ddp_comm_back_nccl(sagemaker_session): def test_pytorch_ddp_comm_backend_auto_smddp_supported(sagemaker_session): valid_instance_type = "ml.p4d.24xlarge" valid_instance_count = 2 - valid_framework_version = "1.12.1" + valid_framework_version = "1.13.1" valid_py_version = "py3" pytorch = _pytorch_estimator( sagemaker_session, @@ -875,7 +875,7 @@ def test_pytorch_ddp_comm_backend_auto_smddp_unsupported(sagemaker_session): def test_pytorch_ddp_no_comm_options_smddp_supported(sagemaker_session): valid_instance_type = "ml.p4d.24xlarge" valid_instance_count = 2 - valid_framework_version = "1.12.1" + valid_framework_version = "1.13.1" valid_py_version = "py3" pytorch = _pytorch_estimator( sagemaker_session,