From ef9678307d7bf9513f623bfeb0754487aea8ed2c Mon Sep 17 00:00:00 2001 From: Satish Pasumarthi Date: Thu, 14 Oct 2021 16:43:51 -0700 Subject: [PATCH 1/2] feature: Add support for Pytorch Vanilla DDP --- src/sagemaker/pytorch/estimator.py | 37 +++++++++++++++++++++++++++--- tests/unit/test_pytorch.py | 18 +++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 44d5cfeb98..a2f251299f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -38,6 +38,8 @@ class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" _framework_name = "pytorch" + LAUNCH_TORCH_DDP_ENV_NAME = "sagemaker_torch_ddp_enabled" + TORCH_DDP_NUM_PROCESSES_PER_HOST = "sagemaker_torch_dpp_num_of_processes_per_host" def __init__( self, @@ -114,7 +116,14 @@ def __init__( "enabled": True } } + To enable vanilla Torch DDP: + .. code:: python + { + "torch_ddp": { + "enabled": True + } + } To enable MPI: .. code:: python @@ -186,12 +195,34 @@ def __init__( ) self.distribution = distribution or {} + def _pytorch_distribution_configuration(self): + """Returns a dict of distribution config + + Args: + None + + Returns: + dict containing torch ddp config + """ + distribution_config = {} + if "torch_ddp" in self.distribution: + torch_ddp_dict = self.distribution["torch_ddp"] + torch_ddp_enabled = self.distribution.get("torch_ddp").get("enabled", False) + distribution_config[self.LAUNCH_TORCH_DDP_ENV_NAME] = torch_ddp_enabled + + if torch_ddp_dict.get("processes_per_host"): + distribution_config[self.TORCH_DDP_NUM_PROCESSES_PER_HOST] = torch_ddp_dict.get( + "processes_per_host" + ) + else: + distribution_config = self._distribution_configuration(distribution=self.distribution) + return distribution_config + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() - additional_hyperparameters = self._distribution_configuration( - distribution=self.distribution - ) + additional_hyperparameters = self._pytorch_distribution_configuration() + hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) return hyperparameters diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5e5046fd6f..dead797105 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -56,6 +56,8 @@ "TrialComponentDisplayName": "tc", } +DISTRIBUTION_TORCH_DDP_ENABLED = {"torch_ddp": {"enabled": True, "processes_per_host": 2}} + @pytest.fixture(name="sagemaker_session") def fixture_sagemaker_session(): @@ -691,3 +693,19 @@ def test_custom_image_estimator_deploy( pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") model = pytorch.create_model(image_uri=custom_image) assert model.image_uri == custom_image + + +def test_torch_ddp_distribution_configuration( + sagemaker_session, pytorch_training_version, pytorch_training_py_version +): + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=pytorch_training_version, + py_version=pytorch_training_py_version, + distribution=DISTRIBUTION_TORCH_DDP_ENABLED + ) + actual_torch_ddp = pytorch._pytorch_distribution_configuration() + expected_torch_ddp = { + "sagemaker_torch_ddp_enabled": True, + "sagemaker_torch_dpp_num_of_processes_per_host": 2} + assert actual_torch_ddp == expected_torch_ddp From 25f9b99c578475542486ebc1c640db523fb930b6 Mon Sep 17 00:00:00 2001 From: Satish Pasumarthi Date: Thu, 14 Oct 2021 16:53:43 -0700 Subject: [PATCH 2/2] Fix black formatting issues --- tests/unit/test_pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index dead797105..34fc0b26fb 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -702,10 +702,11 @@ def test_torch_ddp_distribution_configuration( sagemaker_session, framework_version=pytorch_training_version, py_version=pytorch_training_py_version, - distribution=DISTRIBUTION_TORCH_DDP_ENABLED + distribution=DISTRIBUTION_TORCH_DDP_ENABLED, ) actual_torch_ddp = pytorch._pytorch_distribution_configuration() expected_torch_ddp = { "sagemaker_torch_ddp_enabled": True, - "sagemaker_torch_dpp_num_of_processes_per_host": 2} + "sagemaker_torch_dpp_num_of_processes_per_host": 2, + } assert actual_torch_ddp == expected_torch_ddp