diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index bb865268a2..acab02e575 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -29,11 +29,10 @@ from sagemaker.debugger import DebuggerHookConfig from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import from sagemaker.debugger import get_rule_container_image_uri -from sagemaker.s3 import S3Uploader +from sagemaker.s3 import S3Uploader, parse_s3_url from sagemaker.fw_utils import ( tar_and_upload_dir, - parse_s3_url, UploadedCode, validate_source_dir, _region_supports_debugger, diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f757087fdb..6259396436 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -21,7 +21,6 @@ from collections import namedtuple import sagemaker.utils -from sagemaker import s3 logger = logging.getLogger("sagemaker") @@ -264,20 +263,6 @@ def framework_version_from_tag(image_tag): return None if tag_match is None else tag_match.group(1) -def parse_s3_url(url): - """Calls the method with the same name in the s3 module. - - :func:~sagemaker.s3.parse_s3_url - - Args: - url: A URL, expected with an s3 scheme. - - Returns: The return value of s3.parse_s3_url, which is a tuple containing: - str: S3 bucket name str: S3 key - """ - return s3.parse_s3_url(url) - - def model_code_key_prefix(code_location_key_prefix, model_name, image): """Returns the s3 key prefix for uploading code during model deployment The location returned is a potential concatenation of 2 parts diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 286732a2f9..908eeae3b1 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,8 +18,7 @@ import os import sagemaker -from sagemaker import fw_utils, image_uris, local, session, utils, git_utils -from sagemaker.fw_utils import UploadedCode +from sagemaker import fw_utils, image_uris, local, s3, session, utils, git_utils from sagemaker.transformer import Transformer LOGGER = logging.getLogger("sagemaker") @@ -715,7 +714,7 @@ def __init__( self.git_config = git_config self.container_log_level = container_log_level if code_location: - self.bucket, self.key_prefix = fw_utils.parse_s3_url(code_location) + self.bucket, self.key_prefix = s3.parse_s3_url(code_location) else: self.bucket, self.key_prefix = None, None if self.git_config: @@ -788,7 +787,7 @@ def _upload_code(self, key_prefix, repack=False): ) self.repacked_model_data = repacked_model_data - self.uploaded_code = UploadedCode( + self.uploaded_code = fw_utils.UploadedCode( s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) ) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 5874a53170..b2e8a18b01 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -17,7 +17,7 @@ import re import sagemaker -from sagemaker import fw_utils, job, utils, session, vpc_utils +from sagemaker import fw_utils, job, utils, s3, session, vpc_utils from sagemaker.amazon import amazon_estimator from sagemaker.tensorflow import TensorFlow @@ -33,10 +33,10 @@ def prepare_framework(estimator, s3_operations): `source_dir` ). """ if estimator.code_location is not None: - bucket, key = fw_utils.parse_s3_url(estimator.code_location) + bucket, key = s3.parse_s3_url(estimator.code_location) key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz") elif estimator.uploaded_code is not None: - bucket, key = fw_utils.parse_s3_url(estimator.uploaded_code.s3_prefix) + bucket, key = s3.parse_s3_url(estimator.uploaded_code.s3_prefix) else: bucket = estimator.sagemaker_session._default_bucket key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz") diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index 608b89978a..f5c83f2fa8 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -168,7 +168,7 @@ def test_byo_training_config_all_args(sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=["output", "tensorflow-training-{}/source/sourcedir.tar.gz".format(TIME_STAMP)] ), @@ -468,7 +468,7 @@ def test_amazon_alg_training_config_all_args(sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=[ "output", @@ -610,7 +610,7 @@ def test_framework_tuning_config(retrieve_image_uri, sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=[ "output", @@ -1020,7 +1020,7 @@ def test_amazon_alg_model_config(sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=[ "output", @@ -1185,7 +1185,7 @@ def test_transform_config(sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=[ "output", @@ -1436,7 +1436,7 @@ def test_deploy_amazon_alg_model_config(sagemaker_session): @patch("os.path.isfile", MagicMock(return_value=True)) @patch("sagemaker.estimator.tar_and_upload_dir", MagicMock()) @patch( - "sagemaker.fw_utils.parse_s3_url", + "sagemaker.s3.parse_s3_url", MagicMock( return_value=[ "output", diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index c482155c2c..6f51ee891f 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -411,18 +411,6 @@ def test_framework_version_from_tag_other(): assert version is None -def test_parse_s3_url(): - bucket, key_prefix = fw_utils.parse_s3_url("s3://bucket/code_location") - assert "bucket" == bucket - assert "code_location" == key_prefix - - -def test_parse_s3_url_fail(): - with pytest.raises(ValueError) as error: - fw_utils.parse_s3_url("t3://code_location") - assert "Expecting 's3' scheme" in str(error) - - def test_model_code_key_prefix_with_all_values_present(): key_prefix = fw_utils.model_code_key_prefix("prefix", "model_name", "image_uri") assert key_prefix == "prefix/model_name" diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index f0fd35587c..fe1f868bbb 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -16,7 +16,7 @@ import pytest from mock import Mock -from sagemaker.s3 import S3Uploader, S3Downloader +from sagemaker import s3 BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -42,7 +42,7 @@ def sagemaker_session(): def test_upload(sagemaker_session, caplog): desired_s3_uri = os.path.join("s3://", BUCKET_NAME, CURRENT_JOB_NAME, SOURCE_NAME) - S3Uploader.upload( + s3.S3Uploader.upload( local_path="/path/to/app.jar", desired_s3_uri=desired_s3_uri, sagemaker_session=sagemaker_session, @@ -57,7 +57,7 @@ def test_upload(sagemaker_session, caplog): def test_upload_with_kms_key(sagemaker_session): desired_s3_uri = os.path.join("s3://", BUCKET_NAME, CURRENT_JOB_NAME, SOURCE_NAME) - S3Uploader.upload( + s3.S3Uploader.upload( local_path="/path/to/app.jar", desired_s3_uri=desired_s3_uri, kms_key=KMS_KEY, @@ -73,7 +73,7 @@ def test_upload_with_kms_key(sagemaker_session): def test_download(sagemaker_session): s3_uri = os.path.join("s3://", BUCKET_NAME, CURRENT_JOB_NAME, SOURCE_NAME) - S3Downloader.download( + s3.S3Downloader.download( s3_uri=s3_uri, local_path="/path/for/download/", sagemaker_session=sagemaker_session ) sagemaker_session.download_data.assert_called_with( @@ -86,7 +86,7 @@ def test_download(sagemaker_session): def test_download_with_kms_key(sagemaker_session): s3_uri = os.path.join("s3://", BUCKET_NAME, CURRENT_JOB_NAME, SOURCE_NAME) - S3Downloader.download( + s3.S3Downloader.download( s3_uri=s3_uri, local_path="/path/for/download/", kms_key=KMS_KEY, @@ -98,3 +98,15 @@ def test_download_with_kms_key(sagemaker_session): key_prefix=os.path.join(CURRENT_JOB_NAME, SOURCE_NAME), extra_args={"SSECustomerKey": KMS_KEY}, ) + + +def test_parse_s3_url(): + bucket, key_prefix = s3.parse_s3_url("s3://bucket/code_location") + assert "bucket" == bucket + assert "code_location" == key_prefix + + +def test_parse_s3_url_fail(): + with pytest.raises(ValueError) as error: + s3.parse_s3_url("t3://code_location") + assert "Expecting 's3' scheme" in str(error)