diff --git a/sagemaker_studio_image_build/builder.py b/sagemaker_studio_image_build/builder.py index 352f63b..2733d92 100644 --- a/sagemaker_studio_image_build/builder.py +++ b/sagemaker_studio_image_build/builder.py @@ -64,12 +64,12 @@ def delete_zip_file(bucket, key): s3.delete_object(Bucket=bucket, Key=key) -def build_image(repository, role, bucket, extra_args, log=True): +def build_image(repository, role, bucket, compute_type, extra_args, log=True): bucket, key = upload_zip_file(repository, bucket, " ".join(extra_args)) try: from sagemaker_studio_image_build.codebuild import TempCodeBuildProject - with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository) as p: + with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository, compute_type=compute_type) as p: p.build(log) finally: delete_zip_file(bucket, key) diff --git a/sagemaker_studio_image_build/cli.py b/sagemaker_studio_image_build/cli.py index 8fab82d..9ef02b2 100644 --- a/sagemaker_studio_image_build/cli.py +++ b/sagemaker_studio_image_build/cli.py @@ -50,7 +50,7 @@ def build_image(args, extra_args): validate_args(args, extra_args) builder.build_image( - args.repository, get_role(args), args.bucket, extra_args, log=not args.no_logs + args.repository, get_role(args), args.bucket, args.compute_type, extra_args, log=not args.no_logs ) @@ -70,6 +70,13 @@ def main(): "--repository", help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", ) + build_parser.add_argument( + "--compute-type", + help="The CodeBuild compute type (default: BUILD_GENERAL1_SMALL)", + choices=["BUILD_GENERAL1_SMALL", "BUILD_GENERAL1_MEDIUM", + "BUILD_GENERAL1_LARGE", "BUILD_GENERAL1_2XLARGE"] + default="BUILD_GENERAL1_SMALL" + ) build_parser.add_argument( "--role", help=f"The IAM role name for CodeBuild to use (default: the Studio execution role).", diff --git a/sagemaker_studio_image_build/codebuild.py b/sagemaker_studio_image_build/codebuild.py index 372eace..583859c 100644 --- a/sagemaker_studio_image_build/codebuild.py +++ b/sagemaker_studio_image_build/codebuild.py @@ -11,13 +11,14 @@ class TempCodeBuildProject: - def __init__(self, s3_location, role, repository=None): + def __init__(self, s3_location, role, repository=None, compute_type=None): self.s3_location = s3_location self.role = role self.session = boto3.session.Session() self.domain_id, self.user_profile_name = self._get_studio_metadata() self.repo_name = None + self.compute_type = compute_type or "BUILD_GENERAL1_SMALL" if repository: self.repo_name, self.tag = repository.split(":", maxsplit=1) @@ -62,7 +63,7 @@ def __enter__(self): "environment": { "type": "LINUX_CONTAINER", "image": "aws/codebuild/standard:4.0", - "computeType": "BUILD_GENERAL1_SMALL", + "computeType": self.compute_type, "environmentVariables": [ {"name": "AWS_DEFAULT_REGION", "value": region}, {"name": "AWS_ACCOUNT_ID", "value": account},