Skip to content

Commit ef96783

Browse files
author
Satish Pasumarthi
committed
feature: Add support for Pytorch Vanilla DDP
1 parent 9212529 commit ef96783

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class PyTorch(Framework):
3838
"""Handle end-to-end training and deployment of custom PyTorch code."""
3939

4040
_framework_name = "pytorch"
41+
LAUNCH_TORCH_DDP_ENV_NAME = "sagemaker_torch_ddp_enabled"
42+
TORCH_DDP_NUM_PROCESSES_PER_HOST = "sagemaker_torch_dpp_num_of_processes_per_host"
4143

4244
def __init__(
4345
self,
@@ -114,7 +116,14 @@ def __init__(
114116
"enabled": True
115117
}
116118
}
119+
To enable vanilla Torch DDP:
117120
121+
.. code:: python
122+
{
123+
"torch_ddp": {
124+
"enabled": True
125+
}
126+
}
118127
To enable MPI:
119128
120129
.. code:: python
@@ -186,12 +195,34 @@ def __init__(
186195
)
187196
self.distribution = distribution or {}
188197

198+
def _pytorch_distribution_configuration(self):
199+
"""Returns a dict of distribution config
200+
201+
Args:
202+
None
203+
204+
Returns:
205+
dict containing torch ddp config
206+
"""
207+
distribution_config = {}
208+
if "torch_ddp" in self.distribution:
209+
torch_ddp_dict = self.distribution["torch_ddp"]
210+
torch_ddp_enabled = self.distribution.get("torch_ddp").get("enabled", False)
211+
distribution_config[self.LAUNCH_TORCH_DDP_ENV_NAME] = torch_ddp_enabled
212+
213+
if torch_ddp_dict.get("processes_per_host"):
214+
distribution_config[self.TORCH_DDP_NUM_PROCESSES_PER_HOST] = torch_ddp_dict.get(
215+
"processes_per_host"
216+
)
217+
else:
218+
distribution_config = self._distribution_configuration(distribution=self.distribution)
219+
return distribution_config
220+
189221
def hyperparameters(self):
190222
"""Return hyperparameters used by your custom PyTorch code during model training."""
191223
hyperparameters = super(PyTorch, self).hyperparameters()
192-
additional_hyperparameters = self._distribution_configuration(
193-
distribution=self.distribution
194-
)
224+
additional_hyperparameters = self._pytorch_distribution_configuration()
225+
195226
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
196227
return hyperparameters
197228

tests/unit/test_pytorch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
"TrialComponentDisplayName": "tc",
5757
}
5858

59+
DISTRIBUTION_TORCH_DDP_ENABLED = {"torch_ddp": {"enabled": True, "processes_per_host": 2}}
60+
5961

6062
@pytest.fixture(name="sagemaker_session")
6163
def fixture_sagemaker_session():
@@ -691,3 +693,19 @@ def test_custom_image_estimator_deploy(
691693
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
692694
model = pytorch.create_model(image_uri=custom_image)
693695
assert model.image_uri == custom_image
696+
697+
698+
def test_torch_ddp_distribution_configuration(
699+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
700+
):
701+
pytorch = _pytorch_estimator(
702+
sagemaker_session,
703+
framework_version=pytorch_training_version,
704+
py_version=pytorch_training_py_version,
705+
distribution=DISTRIBUTION_TORCH_DDP_ENABLED
706+
)
707+
actual_torch_ddp = pytorch._pytorch_distribution_configuration()
708+
expected_torch_ddp = {
709+
"sagemaker_torch_ddp_enabled": True,
710+
"sagemaker_torch_dpp_num_of_processes_per_host": 2}
711+
assert actual_torch_ddp == expected_torch_ddp

0 commit comments

Comments
 (0)