1616import pytest
1717from mock import Mock
1818
19- from sagemaker . s3 import S3Uploader , S3Downloader
19+ from sagemaker import s3
2020
2121BUCKET_NAME = "mybucket"
2222REGION = "us-west-2"
@@ -42,7 +42,7 @@ def sagemaker_session():
4242
4343def test_upload (sagemaker_session , caplog ):
4444 desired_s3_uri = os .path .join ("s3://" , BUCKET_NAME , CURRENT_JOB_NAME , SOURCE_NAME )
45- S3Uploader .upload (
45+ s3 . S3Uploader .upload (
4646 local_path = "/path/to/app.jar" ,
4747 desired_s3_uri = desired_s3_uri ,
4848 sagemaker_session = sagemaker_session ,
@@ -57,7 +57,7 @@ def test_upload(sagemaker_session, caplog):
5757
5858def test_upload_with_kms_key (sagemaker_session ):
5959 desired_s3_uri = os .path .join ("s3://" , BUCKET_NAME , CURRENT_JOB_NAME , SOURCE_NAME )
60- S3Uploader .upload (
60+ s3 . S3Uploader .upload (
6161 local_path = "/path/to/app.jar" ,
6262 desired_s3_uri = desired_s3_uri ,
6363 kms_key = KMS_KEY ,
@@ -73,7 +73,7 @@ def test_upload_with_kms_key(sagemaker_session):
7373
7474def test_download (sagemaker_session ):
7575 s3_uri = os .path .join ("s3://" , BUCKET_NAME , CURRENT_JOB_NAME , SOURCE_NAME )
76- S3Downloader .download (
76+ s3 . S3Downloader .download (
7777 s3_uri = s3_uri , local_path = "/path/for/download/" , sagemaker_session = sagemaker_session
7878 )
7979 sagemaker_session .download_data .assert_called_with (
@@ -86,7 +86,7 @@ def test_download(sagemaker_session):
8686
8787def test_download_with_kms_key (sagemaker_session ):
8888 s3_uri = os .path .join ("s3://" , BUCKET_NAME , CURRENT_JOB_NAME , SOURCE_NAME )
89- S3Downloader .download (
89+ s3 . S3Downloader .download (
9090 s3_uri = s3_uri ,
9191 local_path = "/path/for/download/" ,
9292 kms_key = KMS_KEY ,
@@ -98,3 +98,15 @@ def test_download_with_kms_key(sagemaker_session):
9898 key_prefix = os .path .join (CURRENT_JOB_NAME , SOURCE_NAME ),
9999 extra_args = {"SSECustomerKey" : KMS_KEY },
100100 )
101+
102+
103+ def test_parse_s3_url ():
104+ bucket , key_prefix = s3 .parse_s3_url ("s3://bucket/code_location" )
105+ assert "bucket" == bucket
106+ assert "code_location" == key_prefix
107+
108+
109+ def test_parse_s3_url_fail ():
110+ with pytest .raises (ValueError ) as error :
111+ s3 .parse_s3_url ("t3://code_location" )
112+ assert "Expecting 's3' scheme" in str (error )
0 commit comments