From 2f6573cbb165041f3cdbc2192fdb1c2a8765814b Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Fri, 5 Mar 2021 14:03:18 +1100 Subject: [PATCH 1/2] Add compute type argument --- sagemaker_studio_image_build/builder.py | 4 ++-- sagemaker_studio_image_build/cli.py | 19 ++++++++++++++++++- sagemaker_studio_image_build/codebuild.py | 5 +++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/sagemaker_studio_image_build/builder.py b/sagemaker_studio_image_build/builder.py index 352f63b..2733d92 100644 --- a/sagemaker_studio_image_build/builder.py +++ b/sagemaker_studio_image_build/builder.py @@ -64,12 +64,12 @@ def delete_zip_file(bucket, key): s3.delete_object(Bucket=bucket, Key=key) -def build_image(repository, role, bucket, extra_args, log=True): +def build_image(repository, role, bucket, compute_type, extra_args, log=True): bucket, key = upload_zip_file(repository, bucket, " ".join(extra_args)) try: from sagemaker_studio_image_build.codebuild import TempCodeBuildProject - with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository) as p: + with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository, compute_type=compute_type) as p: p.build(log) finally: delete_zip_file(bucket, key) diff --git a/sagemaker_studio_image_build/cli.py b/sagemaker_studio_image_build/cli.py index 8fab82d..71b1122 100644 --- a/sagemaker_studio_image_build/cli.py +++ b/sagemaker_studio_image_build/cli.py @@ -25,6 +25,14 @@ def validate_args(args, extra_args): f"The value of the -f/file argument [{file_value}] is outside the working directory [{os.getcwd()}]" ) + # Validate arg compute_type + if args.compute_type: + if not args.compute_type in ['BUILD_GENERAL1_SMALL', 'BUILD_GENERAL1_MEDIUM', + 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_2XLARGE']: + raise ValueError( + f'Error parsing reference: "{args.repository}" is not a valid repository/tag' + ) + def get_role(args): if args.role: @@ -50,7 +58,7 @@ def build_image(args, extra_args): validate_args(args, extra_args) builder.build_image( - args.repository, get_role(args), args.bucket, extra_args, log=not args.no_logs + args.repository, get_role(args), args.bucket, args.compute_type, extra_args, log=not args.no_logs ) @@ -70,6 +78,15 @@ def main(): "--repository", help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", ) + build_parser.add_argument( + "--image", + help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", + ) + build_parser.add_argument( + "--compute-type", + help="The code build compute type (default: BUILD_GENERAL1_SMALL)", + default="BUILD_GENERAL1_SMALL" + ) build_parser.add_argument( "--role", help=f"The IAM role name for CodeBuild to use (default: the Studio execution role).", diff --git a/sagemaker_studio_image_build/codebuild.py b/sagemaker_studio_image_build/codebuild.py index 372eace..a0310de 100644 --- a/sagemaker_studio_image_build/codebuild.py +++ b/sagemaker_studio_image_build/codebuild.py @@ -11,13 +11,14 @@ class TempCodeBuildProject: - def __init__(self, s3_location, role, repository=None): + def __init__(self, s3_location, role, repository=None, compute_type=None): self.s3_location = s3_location self.role = role self.session = boto3.session.Session() self.domain_id, self.user_profile_name = self._get_studio_metadata() self.repo_name = None + self.compute_type = compute_type or 'BUILD_GENERAL1_SMALL' if repository: self.repo_name, self.tag = repository.split(":", maxsplit=1) @@ -62,7 +63,7 @@ def __enter__(self): "environment": { "type": "LINUX_CONTAINER", "image": "aws/codebuild/standard:4.0", - "computeType": "BUILD_GENERAL1_SMALL", + "computeType": self.compute_type, "environmentVariables": [ {"name": "AWS_DEFAULT_REGION", "value": region}, {"name": "AWS_ACCOUNT_ID", "value": account}, From f8601bbe12481f33e88b4d718f843e2da1dc4b6d Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Thu, 18 Mar 2021 11:27:03 +1100 Subject: [PATCH 2/2] Remove image arg. Move CodeBuild compute type choices to arg parse --- sagemaker_studio_image_build/cli.py | 16 +++------------- sagemaker_studio_image_build/codebuild.py | 2 +- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/sagemaker_studio_image_build/cli.py b/sagemaker_studio_image_build/cli.py index 71b1122..9ef02b2 100644 --- a/sagemaker_studio_image_build/cli.py +++ b/sagemaker_studio_image_build/cli.py @@ -25,14 +25,6 @@ def validate_args(args, extra_args): f"The value of the -f/file argument [{file_value}] is outside the working directory [{os.getcwd()}]" ) - # Validate arg compute_type - if args.compute_type: - if not args.compute_type in ['BUILD_GENERAL1_SMALL', 'BUILD_GENERAL1_MEDIUM', - 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_LARGE', 'BUILD_GENERAL1_2XLARGE']: - raise ValueError( - f'Error parsing reference: "{args.repository}" is not a valid repository/tag' - ) - def get_role(args): if args.role: @@ -78,13 +70,11 @@ def main(): "--repository", help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", ) - build_parser.add_argument( - "--image", - help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", - ) build_parser.add_argument( "--compute-type", - help="The code build compute type (default: BUILD_GENERAL1_SMALL)", + help="The CodeBuild compute type (default: BUILD_GENERAL1_SMALL)", + choices=["BUILD_GENERAL1_SMALL", "BUILD_GENERAL1_MEDIUM", + "BUILD_GENERAL1_LARGE", "BUILD_GENERAL1_2XLARGE"] default="BUILD_GENERAL1_SMALL" ) build_parser.add_argument( diff --git a/sagemaker_studio_image_build/codebuild.py b/sagemaker_studio_image_build/codebuild.py index a0310de..583859c 100644 --- a/sagemaker_studio_image_build/codebuild.py +++ b/sagemaker_studio_image_build/codebuild.py @@ -18,7 +18,7 @@ def __init__(self, s3_location, role, repository=None, compute_type=None): self.session = boto3.session.Session() self.domain_id, self.user_profile_name = self._get_studio_metadata() self.repo_name = None - self.compute_type = compute_type or 'BUILD_GENERAL1_SMALL' + self.compute_type = compute_type or "BUILD_GENERAL1_SMALL" if repository: self.repo_name, self.tag = repository.split(":", maxsplit=1)