diff --git a/VERSION b/VERSION index 853a06cfd4..02413b52c5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.108.1.dev0 +2.108.3.dev0 diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index 52720fe12b..fce3769ddf 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -212,7 +212,7 @@ with the ``pytorchddp`` option as the distribution strategy. .. note:: This PyTorch DDP support is available - in the SageMaker PyTorch Deep Learning Containers v1.12 and later. + in the SageMaker PyTorch Deep Learning Containers v1.11 and later. Adapt Your Training Script -------------------------- diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 5ea45e76dc..1c47fe7a3d 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -125,15 +125,19 @@ } PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0", ] +# TODO: Change to 1.12.1 before merging +ACCL_SUPPORTED_FRAMEWORK_VERSIONS = ( + "1.12", + "1.12.0", +) +ACCL_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",) + SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -855,6 +859,61 @@ def validate_pytorch_distribution( raise ValueError(err_msg) +def validate_accl_support( + use_accl, framework_version, py_version, image_uri, instance_type, instance_count +): + """Check if ACCL is supported for current invocation. + + Args: + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + instance_type (str): SageMaker instance type. + instance_count (int): Number of training instances to use. + + Raises: + ValueError: if `use_accl` is set to true and validation fails, i.e. + `instance_type` is not in ACCL_SUPPORTED_INSTANCE_TYPES or + `py_version` is not python3 or + `framework_version` is not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS or + `instance_count` is not greater than 1 + """ + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + " ACCL.\n" + "Please specify one of the supported framework versions:" + f" {ACCL_SUPPORTED_FRAMEWORK_VERSIONS}.\n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported by ACCL.\n" + "Please specify py_version>=py3.\n" + ) + if instance_type not in ACCL_SUPPORTED_INSTANCE_TYPES: + err_msg += ( + f"Provided instance_type {instance_type} is not supported by ACCL.\n" + "Please specify one of the supported instance types:" + f"{ACCL_SUPPORTED_INSTANCE_TYPES}.\n" + ) + if instance_count == 1: + # ACCL is not supported for single-node jobs + err_msg += ( + "ACCL is not supported for single-node jobs.\n" + "Please increase instance_count to be greater than 1.\n" + ) + if not err_msg: + return True + if use_accl: + raise ValueError(f"Could not enable ACCL.\n {err_msg}") + logger.warning("Could not enable ACCL.\n %s", err_msg) + return False + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 153d4656d4..ef6adfa612 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + validate_accl_support, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -40,6 +41,7 @@ class PyTorch(Framework): _framework_name = "pytorch" LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + ACCL_ENABLED_ENV_NAME = "sagemaker_accl_enabled" def __init__( self, @@ -50,7 +52,7 @@ def __init__( hyperparameters=None, image_uri=None, distribution=None, - **kwargs + **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -242,18 +244,56 @@ def _pytorch_distribution_configuration(self, distribution): """ distribution_config = {} pytorch_ddp_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + pytorch_ddp_dict = distribution.get("pytorchddp") + if pytorch_ddp_dict: + pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False) if pytorch_ddp_enabled: distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + is_accl_enabled = self._get_accl_enabled(pytorch_ddp_dict) + distribution_config[self.ACCL_ENABLED_ENV_NAME] = is_accl_enabled else: distribution_config = self._distribution_configuration(distribution=distribution) return distribution_config + def _get_accl_enabled(self, pytorch_ddp_dict): + """Evaluates if ACCL should be enabled for current training jobs. + + Case 1: Customer explicitly disables ACCL by setting use_accl to False. + Return false. + + Case 2: Customer explicitly enables ACCL by setting use_accl to True. + Test if configuration is supported for ACCL. + If yes, return true. If not, throw an error. + + Case 3: Customer does not specify use_accl. We try to enable by default. + Test if configuration is supported for ACCL. + If not, we return false. + + Args: + pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution. + Returns: + A boolean that indicates whether to enable ACCL + """ + use_accl = pytorch_ddp_dict.get("use_accl") + is_accl_supported = validate_accl_support( + use_accl, + self.framework_version, + self.py_version, + self.image_uri, + self.instance_type, + self.instance_count, + ) + + if use_accl is False or not is_accl_supported: + return False + if use_accl and is_accl_supported: + return True + return use_accl + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() @@ -273,7 +313,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. @@ -324,7 +364,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index 59397ec9af..632e532000 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -440,9 +440,7 @@ def pytorch_ddp_py_version(): return "py3" -@pytest.fixture( - scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] -) +@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0"]) def pytorch_ddp_framework_version(request): return request.param diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py index c580fdebc2..6a886d71f2 100644 --- a/tests/integ/test_pytorchddp.py +++ b/tests/integ/test_pytorchddp.py @@ -51,3 +51,29 @@ def test_pytorchddp_pt_mnist( with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) + + +@pytest.mark.skip( + reason="This test is skipped for now due ML capacity error." + "This test should be re-enabled later." +) +@pytest.mark.skipif( + integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, + reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", +) +def test_pytorchddp_pt_mnist_accl_disabled(sagemaker_session): + job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + estimator = PyTorch( + entry_point="mnist_pt.py", + role="SageMakerRole", + source_dir=pytorchddp_dir, + instance_count=2, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + framework_version="1.12", + py_version="py3", + distribution={"pytorchddp": {"enabled": True, "use_accl": False}}, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index e378b7a0a2..47e2394721 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -880,9 +880,6 @@ def test_validate_pytorchddp_not_raises(): # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions pytorchddp_enabled = {"pytorchddp": {"enabled": True}} pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", "1.11", "1.11.0", "1.12", @@ -905,7 +902,7 @@ def test_validate_pytorchddp_raises(): fw_utils.validate_pytorch_distribution( distribution=pytorchddp_enabled, framework_name="pytorch", - framework_version="1.8", + framework_version="1.10", py_version="py3", image_uri=None, ) @@ -919,3 +916,84 @@ def test_validate_pytorchddp_raises(): py_version="py2", image_uri=None, ) + + +def test_validate_accl_support_true(): + # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters + accl_supported_fw_versions = [ + "1.12", + "1.12.0", + ] + for framework_version in accl_supported_fw_versions: + assert ( + fw_utils.validate_accl_support( + use_accl=True, + framework_version=framework_version, + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + is True + ) + + +def test_validate_accl_support_false(): + # Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters + assert ( + fw_utils.validate_accl_support( + use_accl=False, + framework_version="1.11", + py_version="py3", + image_uri=None, + instance_type="ml.p3dn.24xlarge", + instance_count=1, + ) + is False + ) + + +def test_validate_accl_support_error(): + # Case 1: Unsupported framework version + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py3", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + + # Case 2: Unsupported Py version + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=2, + ) + + # Case 3: Unsupported Instance Type + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p3.16xlarge", + instance_count=2, + ) + + # Case 4: Unsupported Instance Count + with pytest.raises(ValueError): + fw_utils.validate_accl_support( + use_accl=True, + framework_version="1.10", + py_version="py2", + image_uri=None, + instance_type="ml.p4d.24xlarge", + instance_count=1, + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 082f699d63..608c161cfe 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -57,6 +57,8 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} +DISTRIBUTION_PYTORCH_ACCL_DISABLED = {"pytorchddp": {"enabled": True, "use_accl": False}} +DISTRIBUTION_PYTORCH_ACCL_ENABLED = {"pytorchddp": {"enabled": True, "use_accl": True}} @pytest.fixture(name="sagemaker_session") @@ -98,6 +100,7 @@ def _pytorch_estimator( framework_version, py_version, instance_type=None, + instance_count=None, base_job_name=None, **kwargs, ): @@ -107,7 +110,7 @@ def _pytorch_estimator( py_version=py_version, role=ROLE, sagemaker_session=sagemaker_session, - instance_count=INSTANCE_COUNT, + instance_count=instance_count if instance_count else INSTANCE_COUNT, instance_type=instance_type if instance_type else INSTANCE_TYPE, base_job_name=base_job_name, **kwargs, @@ -770,7 +773,7 @@ def test_register_pytorch_model_auto_infer_framework( def test_pytorch_ddp_distribution_configuration( sagemaker_session, pytorch_ddp_framework_version, pytorch_ddp_py_version ): - test_instance_type = "ml.p4d.24xlarge" + test_instance_type = "ml.g5.24xlarge" pytorch = _pytorch_estimator( sagemaker_session, framework_version=pytorch_ddp_framework_version, @@ -784,6 +787,7 @@ def test_pytorch_ddp_distribution_configuration( expected_torch_ddp = { "sagemaker_pytorch_ddp_enabled": True, "sagemaker_instance_type": test_instance_type, + "sagemaker_accl_enabled": False, } assert actual_pytorch_ddp == expected_torch_ddp @@ -800,3 +804,110 @@ def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): ) assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) assert (f"py_version {unsupported_py_version} is not supported") in str(error) + + +def test_pytorch_ddp_accl_disabled(sagemaker_session): + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_ACCL_DISABLED, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": INSTANCE_TYPE, + "sagemaker_accl_enabled": False, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_explicit_enabled_and_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_accl_enabled": True, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_explicit_enabled_and_unsupported(sagemaker_session): + unsupported_framework_version = "1.11" + unsupported_instance_count = 1 + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version="py3", + distribution=DISTRIBUTION_PYTORCH_ACCL_ENABLED, + instance_count=unsupported_instance_count, + ) + with pytest.raises(ValueError) as error: + pytorch._pytorch_distribution_configuration(distribution=pytorch.distribution) + assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) + assert ("ACCL is not supported for single-node jobs.") in str(error) + + +def test_pytorch_ddp_accl_default_on_and_supported(sagemaker_session): + valid_instance_type = "ml.p4d.24xlarge" + valid_instance_count = 2 + valid_framework_version = "1.12" + valid_py_version = "py3" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=valid_framework_version, + py_version=valid_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=valid_instance_type, + instance_count=valid_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": valid_instance_type, + "sagemaker_accl_enabled": None, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_accl_default_on_and_unsupported(sagemaker_session): + unsupported_framework_version = "1.11" + unsupported_instance_type = "ml.g5.24xlarge" + unsupported_instance_count = 1 + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version="py3", + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=unsupported_instance_type, + instance_count=unsupported_instance_count, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": unsupported_instance_type, + "sagemaker_accl_enabled": False, + } + assert actual_pytorch_ddp == expected_torch_ddp