@@ -38,6 +38,8 @@ class PyTorch(Framework):
3838 """Handle end-to-end training and deployment of custom PyTorch code."""
3939
4040 _framework_name = "pytorch"
41+ LAUNCH_TORCH_DDP_ENV_NAME = "sagemaker_torch_ddp_enabled"
42+ TORCH_DDP_NUM_PROCESSES_PER_HOST = "sagemaker_torch_dpp_num_of_processes_per_host"
4143
4244 def __init__ (
4345 self ,
@@ -114,7 +116,14 @@ def __init__(
114116 "enabled": True
115117 }
116118 }
119+ To enable vanilla Torch DDP:
117120
121+ .. code:: python
122+ {
123+ "torch_ddp": {
124+ "enabled": True
125+ }
126+ }
118127 To enable MPI:
119128
120129 .. code:: python
@@ -186,12 +195,34 @@ def __init__(
186195 )
187196 self .distribution = distribution or {}
188197
198+ def _pytorch_distribution_configuration (self ):
199+ """Returns a dict of distribution config
200+
201+ Args:
202+ None
203+
204+ Returns:
205+ dict containing torch ddp config
206+ """
207+ distribution_config = {}
208+ if "torch_ddp" in self .distribution :
209+ torch_ddp_dict = self .distribution ["torch_ddp" ]
210+ torch_ddp_enabled = self .distribution .get ("torch_ddp" ).get ("enabled" , False )
211+ distribution_config [self .LAUNCH_TORCH_DDP_ENV_NAME ] = torch_ddp_enabled
212+
213+ if torch_ddp_dict .get ("processes_per_host" ):
214+ distribution_config [self .TORCH_DDP_NUM_PROCESSES_PER_HOST ] = torch_ddp_dict .get (
215+ "processes_per_host"
216+ )
217+ else :
218+ distribution_config = self ._distribution_configuration (distribution = self .distribution )
219+ return distribution_config
220+
189221 def hyperparameters (self ):
190222 """Return hyperparameters used by your custom PyTorch code during model training."""
191223 hyperparameters = super (PyTorch , self ).hyperparameters ()
192- additional_hyperparameters = self ._distribution_configuration (
193- distribution = self .distribution
194- )
224+ additional_hyperparameters = self ._pytorch_distribution_configuration ()
225+
195226 hyperparameters .update (Framework ._json_encode_hyperparameters (additional_hyperparameters ))
196227 return hyperparameters
197228
0 commit comments