Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.108.1.dev0
2.108.3.dev0
2 changes: 1 addition & 1 deletion doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ with the ``pytorchddp`` option as the distribution strategy.
.. note::

This PyTorch DDP support is available
in the SageMaker PyTorch Deep Learning Containers v1.12 and later.
in the SageMaker PyTorch Deep Learning Containers v1.11 and later.

Adapt Your Training Script
--------------------------
Expand Down
65 changes: 62 additions & 3 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,19 @@
}

PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
]

# TODO: Change to 1.12.1 before merging
ACCL_SUPPORTED_FRAMEWORK_VERSIONS = (
"1.12",
"1.12.0",
)
ACCL_SUPPORTED_INSTANCE_TYPES = ("ml.p4d.24xlarge",)

SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]


Expand Down Expand Up @@ -855,6 +859,61 @@ def validate_pytorch_distribution(
raise ValueError(err_msg)


def validate_accl_support(
use_accl, framework_version, py_version, image_uri, instance_type, instance_count
):
"""Check if ACCL is supported for current invocation.

Args:
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
image_uri (str): A string representing a Docker image URI.
instance_type (str): SageMaker instance type.
instance_count (int): Number of training instances to use.

Raises:
ValueError: if `use_accl` is set to true and validation fails, i.e.
`instance_type` is not in ACCL_SUPPORTED_INSTANCE_TYPES or
`py_version` is not python3 or
`framework_version` is not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS or
`instance_count` is not greater than 1
"""
err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in ACCL_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" ACCL.\n"
"Please specify one of the supported framework versions:"
f" {ACCL_SUPPORTED_FRAMEWORK_VERSIONS}.\n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported by ACCL.\n"
"Please specify py_version>=py3.\n"
)
if instance_type not in ACCL_SUPPORTED_INSTANCE_TYPES:
err_msg += (
f"Provided instance_type {instance_type} is not supported by ACCL.\n"
"Please specify one of the supported instance types:"
f"{ACCL_SUPPORTED_INSTANCE_TYPES}.\n"
)
if instance_count == 1:
# ACCL is not supported for single-node jobs
err_msg += (
"ACCL is not supported for single-node jobs.\n"
"Please increase instance_count to be greater than 1.\n"
)
if not err_msg:
return True
if use_accl:
raise ValueError(f"Could not enable ACCL.\n {err_msg}")
logger.warning("Could not enable ACCL.\n %s", err_msg)
return False


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
50 changes: 45 additions & 5 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
python_deprecation_warning,
validate_version_or_image_args,
validate_distribution,
validate_accl_support,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand All @@ -40,6 +41,7 @@ class PyTorch(Framework):
_framework_name = "pytorch"
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
ACCL_ENABLED_ENV_NAME = "sagemaker_accl_enabled"

def __init__(
self,
Expand All @@ -50,7 +52,7 @@ def __init__(
hyperparameters=None,
image_uri=None,
distribution=None,
**kwargs
**kwargs,
):
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.

Expand Down Expand Up @@ -242,18 +244,56 @@ def _pytorch_distribution_configuration(self, distribution):
"""
distribution_config = {}
pytorch_ddp_enabled = False
if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
pytorch_ddp_dict = distribution.get("pytorchddp")
if pytorch_ddp_dict:
pytorch_ddp_enabled = pytorch_ddp_dict.get("enabled", False)

if pytorch_ddp_enabled:
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
if self.instance_type is not None:
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
is_accl_enabled = self._get_accl_enabled(pytorch_ddp_dict)
distribution_config[self.ACCL_ENABLED_ENV_NAME] = is_accl_enabled
else:
distribution_config = self._distribution_configuration(distribution=distribution)

return distribution_config

def _get_accl_enabled(self, pytorch_ddp_dict):
"""Evaluates if ACCL should be enabled for current training jobs.

Case 1: Customer explicitly disables ACCL by setting use_accl to False.
Return false.

Case 2: Customer explicitly enables ACCL by setting use_accl to True.
Test if configuration is supported for ACCL.
If yes, return true. If not, throw an error.

Case 3: Customer does not specify use_accl. We try to enable by default.
Test if configuration is supported for ACCL.
If not, we return false.

Args:
pytorch_ddp_dict (dict): A dictionary with options for pytorchddp distribution.
Returns:
A boolean that indicates whether to enable ACCL
"""
use_accl = pytorch_ddp_dict.get("use_accl")
is_accl_supported = validate_accl_support(
use_accl,
self.framework_version,
self.py_version,
self.image_uri,
self.instance_type,
self.instance_count,
)

if use_accl is False or not is_accl_supported:
return False
if use_accl and is_accl_supported:
return True
return use_accl

def hyperparameters(self):
"""Return hyperparameters used by your custom PyTorch code during model training."""
hyperparameters = super(PyTorch, self).hyperparameters()
Expand All @@ -273,7 +313,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
**kwargs,
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.

Expand Down Expand Up @@ -324,7 +364,7 @@ def create_model(
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
**kwargs
**kwargs,
)

@classmethod
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,7 @@ def pytorch_ddp_py_version():
return "py3"


@pytest.fixture(
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
)
@pytest.fixture(scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0"])
def pytorch_ddp_framework_version(request):
return request.param

Expand Down
26 changes: 26 additions & 0 deletions tests/integ/test_pytorchddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,29 @@ def test_pytorchddp_pt_mnist(

with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)


@pytest.mark.skip(
reason="This test is skipped for now due ML capacity error."
"This test should be re-enabled later."
)
@pytest.mark.skipif(
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
)
def test_pytorchddp_pt_mnist_accl_disabled(sagemaker_session):
job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp")
estimator = PyTorch(
entry_point="mnist_pt.py",
role="SageMakerRole",
source_dir=pytorchddp_dir,
instance_count=2,
instance_type="ml.p3.16xlarge",
sagemaker_session=sagemaker_session,
framework_version="1.12",
py_version="py3",
distribution={"pytorchddp": {"enabled": True, "use_accl": False}},
)

with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)
86 changes: 82 additions & 4 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,6 @@ def test_validate_pytorchddp_not_raises():
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
pytorchddp_supported_fw_versions = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
Expand All @@ -905,7 +902,7 @@ def test_validate_pytorchddp_raises():
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
framework_name="pytorch",
framework_version="1.8",
framework_version="1.10",
py_version="py3",
image_uri=None,
)
Expand All @@ -919,3 +916,84 @@ def test_validate_pytorchddp_raises():
py_version="py2",
image_uri=None,
)


def test_validate_accl_support_true():
# Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters
accl_supported_fw_versions = [
"1.12",
"1.12.0",
]
for framework_version in accl_supported_fw_versions:
assert (
fw_utils.validate_accl_support(
use_accl=True,
framework_version=framework_version,
py_version="py3",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)
is True
)


def test_validate_accl_support_false():
# Framework is PyTorch, Distribution is PyTorchDDP enabled, all supported parameters
assert (
fw_utils.validate_accl_support(
use_accl=False,
framework_version="1.11",
py_version="py3",
image_uri=None,
instance_type="ml.p3dn.24xlarge",
instance_count=1,
)
is False
)


def test_validate_accl_support_error():
# Case 1: Unsupported framework version
with pytest.raises(ValueError):
fw_utils.validate_accl_support(
use_accl=True,
framework_version="1.10",
py_version="py3",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)

# Case 2: Unsupported Py version
with pytest.raises(ValueError):
fw_utils.validate_accl_support(
use_accl=True,
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=2,
)

# Case 3: Unsupported Instance Type
with pytest.raises(ValueError):
fw_utils.validate_accl_support(
use_accl=True,
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p3.16xlarge",
instance_count=2,
)

# Case 4: Unsupported Instance Count
with pytest.raises(ValueError):
fw_utils.validate_accl_support(
use_accl=True,
framework_version="1.10",
py_version="py2",
image_uri=None,
instance_type="ml.p4d.24xlarge",
instance_count=1,
)
Loading