Skip to content

Commit

Permalink
Fix local mode not using the right s3 bucket. (#144)
Browse files Browse the repository at this point in the history
* Fix local mode not using the right s3 bucket.

Local Mode should honor the inputs instead of wrongly assuming that
everyone is using the default bucket.
  • Loading branch information
iquintero committed Apr 13, 2018
1 parent 524dc86 commit d76cd2b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
CHANGELOG
=========

1.2.3-dev
=========
* bug-fix: Fix local mode not using the right s3 bucket

1.2.2
=====

Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def train(self, input_data_config, hyperparameters):
os.mkdir(os.path.join(self.container_root, 'output'))

data_dir = self._create_tmp_folder()
bucket_name = self.sagemaker_session.default_bucket()
volumes = []

# Set up the channels for the containers. For local data we will
Expand All @@ -102,7 +101,8 @@ def train(self, input_data_config, hyperparameters):
channel_dir = os.path.join(data_dir, channel_name)
os.mkdir(channel_dir)

if uri.lower().startswith("s3://"):
if parsed_uri.scheme == 's3':
bucket_name = parsed_uri.netloc
self._download_folder(bucket_name, key, channel_dir)
else:
volumes.append(_Volume(uri, channel=channel_name))
Expand Down
52 changes: 44 additions & 8 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import yaml
from mock import patch, Mock
from mock import call, patch, Mock

import sagemaker
from sagemaker.local.image import _SageMakerContainer
Expand All @@ -40,7 +40,7 @@
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://foo/bar'
'S3Uri': 's3://my-own-bucket/prefix'
}
}
}
Expand All @@ -54,12 +54,12 @@ def sagemaker_session():
boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'}
boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = []

ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
sms = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())

ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
sms.expand_role = Mock(return_value=EXPANDED_ROLE)

return ims
return sms


@patch('sagemaker.local.local_session.LocalSession')
Expand Down Expand Up @@ -181,16 +181,22 @@ def test_check_output():
@patch('sagemaker.local.local_session.LocalSession')
@patch('sagemaker.local.image._execute_and_stream_output')
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
def test_train(LocalSession, _execute_and_stream_output, _cleanup, tmpdir, sagemaker_session):
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSession, tmpdir, sagemaker_session):

directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
side_effect=[str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]):
side_effect=directories):

instance_count = 2
image = 'my-image'
sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session)
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS)

channel_dir = os.path.join(directories[1], 'b')
download_folder_calls = [call('my-own-bucket', 'prefix', channel_dir)]
_download_folder.assert_has_calls(download_folder_calls)

docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml')

call_args = _execute_and_stream_output.call_args[0][0]
Expand Down Expand Up @@ -231,6 +237,36 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
assert config['services'][h]['command'] == 'serve'


@patch('os.makedirs')
def test_download_folder(makedirs):
boto_mock = Mock(name='boto_session')
boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'}

session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())

train_data = Mock()
validation_data = Mock()

train_data.bucket_name.return_value = BUCKET_NAME
train_data.key = '/prefix/train/train_data.csv'
validation_data.bucket_name.return_value = BUCKET_NAME
validation_data.key = '/prefix/train/validation_data.csv'

s3_files = [train_data, validation_data]
boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = s3_files

obj_mock = Mock()
boto_mock.resource('s3').Object.return_value = obj_mock

sagemaker_container = _SageMakerContainer('local', 2, 'my-image', sagemaker_session=session)
sagemaker_container._download_folder(BUCKET_NAME, '/prefix', '/tmp')

obj_mock.download_file.assert_called()
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
call(os.path.join('/tmp', 'train/validation_data.csv'))]
obj_mock.download_file.assert_has_calls(calls)


def test_ecr_login_non_ecr():
session_mock = Mock()
sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')
Expand Down

0 comments on commit d76cd2b

Please sign in to comment.