diff --git a/sagify/api/build.py b/sagify/api/build.py index 672dac5..a252bcd 100644 --- a/sagify/api/build.py +++ b/sagify/api/build.py @@ -16,6 +16,7 @@ def build(dir, requirements_dir, docker_tag): :param dir: [str], source root directory :param requirements_dir: [str], path to requirements.txt + :param docker_tag: [str], the Docker tag for the image """ sagify_module_path = os.path.relpath(os.path.join(dir, 'sagify/')) diff --git a/sagify/api/cloud.py b/sagify/api/cloud.py index a5b802f..13a3145 100644 --- a/sagify/api/cloud.py +++ b/sagify/api/cloud.py @@ -60,10 +60,11 @@ def train( :param input_s3_dir: [str], S3 location to input data :param output_s3_dir: [str], S3 location to save output (models, etc) :param hyperparams_file: [str], path to hyperparams json file - :param ec2_type: [str], ec2 instance type. Refere to: + :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param volume_size: [int], size in GB of the EBS volume :param time_out: [int], time-out in seconds + :param docker_tag: [str], the Docker tag for the image :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: @@ -107,8 +108,9 @@ def deploy(dir, s3_model_location, num_instances, ec2_type, docker_tag, tags=Non :param dir: [str], source root directory :param s3_model_location: [str], S3 model location :param num_instances: [int], number of ec2 instances - :param ec2_type: [str], ec2 instance type. Refere to: + :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ + :param docker_tag: [str], the Docker tag for the image :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: diff --git a/sagify/api/local.py b/sagify/api/local.py index b266564..5ff4035 100644 --- a/sagify/api/local.py +++ b/sagify/api/local.py @@ -13,6 +13,7 @@ def train(dir, docker_tag): Trains ML model(s) locally :param dir: [str], source root directory + :param docker_tag: [str], the Docker tag for the image """ sagify_module_path = os.path.join(dir, 'sagify') local_train_script_path = os.path.join(sagify_module_path, 'local_test', 'train_local.sh') @@ -36,6 +37,7 @@ def deploy(dir, docker_tag): Deploys ML models(s) locally :param dir: [str], source root directory + :param docker_tag: [str], the Docker tag for the image """ sagify_module_path = os.path.join(dir, 'sagify') local_deploy_script_path = os.path.join(sagify_module_path, 'local_test', 'deploy_local.sh') diff --git a/sagify/api/push.py b/sagify/api/push.py index f5a0b73..a93f1bd 100644 --- a/sagify/api/push.py +++ b/sagify/api/push.py @@ -13,6 +13,7 @@ def push(dir, docker_tag): Push Docker image to AWS ECS :param dir: [str], source root directory + :param docker_tag: [str], the Docker tag for the image """ sagify_module_path = os.path.relpath(os.path.join(dir, 'sagify/')) diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py index 253c8d4..aeab69c 100644 --- a/sagify/sagemaker/sagemaker.py +++ b/sagify/sagemaker/sagemaker.py @@ -158,7 +158,7 @@ def _construct_image_location(self, image_name): account = self.boto_session.client('sts').get_caller_identity()['Account'] region = self.boto_session.region_name - return '{account}.dkr.ecr.{region}.amazonaws.com/{image}:latest'.format( + return '{account}.dkr.ecr.{region}.amazonaws.com/{image}'.format( account=account, region=region, image=image_name diff --git a/setup.cfg b/setup.cfg index 57df763..5a67fd9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,3 +25,4 @@ universal=1 [flake8] max-line-length=100 +exclude=.svn,CVS,.bzr,.hg,.git,__pycache__,.tox,.eggs,venv/,sagify/commands/__init__.py diff --git a/tests/commands/test_cloud.py b/tests/commands/test_cloud.py index 6e85836..2cddc00 100644 --- a/tests/commands/test_cloud.py +++ b/tests/commands/test_cloud.py @@ -220,6 +220,52 @@ def test_train_with_tags_arg_happy_case(self): assert result.exit_code == 0 + def test_train_with_docker_tag_arg_happy_case(self): + runner = CliRunner() + + with patch( + 'sagify.commands.initialize._get_local_aws_profiles', + return_value=['default', 'sagify'] + ): + with patch.object( + sagify.config.config.ConfigManager, + 'get_config', + lambda _: Config( + image_name='sagemaker-img', aws_profile='sagify', aws_region='us-east-1' + ) + ): + with patch( + 'sagify.sagemaker.sagemaker.SageMakerClient' + ) as mocked_sage_maker_client: + instance = mocked_sage_maker_client.return_value + with runner.isolated_filesystem(): + runner.invoke(cli=cli, args=['init'], input='my_app\n1\n2\nus-east-1\n') + result = runner.invoke( + cli=cli, + args=[ + '--docker-tag', 'some-docker-tag', + 'cloud', 'train', + '-i', 's3://bucket/input', + '-o', 's3://bucket/output', + '-e', 'ml.c4.2xlarge' + ] + ) + + assert instance.train.call_count == 1 + instance.train.assert_called_with( + image_name='sagemaker-img:some-docker-tag', + input_s3_data_location='s3://bucket/input', + train_instance_count=1, + train_instance_type='ml.c4.2xlarge', + train_volume_size=30, + train_max_run=24 * 60 * 60, + output_path='s3://bucket/output', + hyperparameters=None, + tags=None + ) + + assert result.exit_code == 0 + def test_train_with_dir_arg_happy_case(self): runner = CliRunner() @@ -446,6 +492,48 @@ def test_deploy_with_tags_arg_happy_case(self): assert result.exit_code == 0 + def test_deploy_with_docker_tag_arg_happy_case(self): + runner = CliRunner() + + with patch( + 'sagify.commands.initialize._get_local_aws_profiles', + return_value=['default', 'sagify'] + ): + with patch.object( + sagify.config.config.ConfigManager, + 'get_config', + lambda _: Config( + image_name='sagemaker-img', aws_profile='sagify', aws_region='us-east-1' + ) + ): + with patch( + 'sagify.sagemaker.sagemaker.SageMakerClient' + ) as mocked_sage_maker_client: + instance = mocked_sage_maker_client.return_value + with runner.isolated_filesystem(): + runner.invoke(cli=cli, args=['init'], input='my_app\n1\n2\nus-east-1\n') + result = runner.invoke( + cli=cli, + args=[ + '-t', 'some-docker-tag', + 'cloud', 'deploy', + '-m', 's3://bucket/model/location/model.tar.gz', + '-n', '2', + '-e', 'ml.c4.2xlarge' + ] + ) + + assert instance.deploy.call_count == 1 + instance.deploy.assert_called_with( + image_name='sagemaker-img:some-docker-tag', + s3_model_location='s3://bucket/model/location/model.tar.gz', + train_instance_count=2, + train_instance_type='ml.c4.2xlarge', + tags=None + ) + + assert result.exit_code == 0 + def test_deploy_with_invalid_dir_arg_happy_case(self): runner = CliRunner()