From 32199b63367e56a5a67b79350689b434efb03e09 Mon Sep 17 00:00:00 2001 From: Myra Gupta Date: Tue, 21 Jul 2020 11:09:33 -0700 Subject: [PATCH] fix: convert network_config in processing_config to dict --- src/sagemaker/workflow/airflow.py | 2 +- tests/unit/test_airflow.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 3dfcf8e7c5..6805d9b95d 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -1144,7 +1144,7 @@ def processing_config( config["Environment"] = processor.env if processor.network_config is not None: - config["NetworkConfig"] = processor.network_config + config["NetworkConfig"] = processor.network_config._to_request_dict() processing_resources = sagemaker.processing.ProcessingJob.prepare_processing_resources( instance_count=processor.instance_count, diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index 6c1dec87f9..d359e0ed69 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -16,6 +16,7 @@ from mock import Mock, MagicMock, patch from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing +from sagemaker.network import NetworkConfig from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.workflow import airflow from sagemaker.amazon import amazon_estimator @@ -1598,6 +1599,13 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): @patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_processing_config(sagemaker_session): + network_config = NetworkConfig( + encrypt_inter_container_traffic=False, + enable_network_isolation=True, + security_group_ids=["sg1"], + subnets=["subnet1"], + ) + processor = processing.Processor( role="arn:aws:iam::0122345678910:role/SageMakerPowerUser", image_uri="{{ image_uri }}", @@ -1612,6 +1620,7 @@ def test_processing_config(sagemaker_session): sagemaker_session=sagemaker_session, tags=[{"{{ key }}": "{{ value }}"}], env={"{{ key }}": "{{ value }}"}, + network_config=network_config, ) outputs = [ @@ -1699,5 +1708,10 @@ def test_processing_config(sagemaker_session): "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser", "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, "Tags": [{"{{ key }}": "{{ value }}"}], + "NetworkConfig": { + "EnableInterContainerTrafficEncryption": False, + "EnableNetworkIsolation": True, + "VpcConfig": {"SecurityGroupIds": ["sg1"], "Subnets": ["subnet1"]}, + }, } assert config == expected_config