diff --git a/metaflow/cmd/configure_cmd.py b/metaflow/cmd/configure_cmd.py index d4ef1a5a54..ffe8729d7b 100644 --- a/metaflow/cmd/configure_cmd.py +++ b/metaflow/cmd/configure_cmd.py @@ -499,6 +499,14 @@ def configure_aws_batch(existing_env): default=existing_env.get("METAFLOW_BATCH_CONTAINER_IMAGE", ""), show_default=True, ) + # Set private image repository credentials secret + env["BATCH_CONTAINER_IMAGE_CREDS_SECRET"] = click.prompt( + cyan("[BATCH_CONTAINER_IMAGE_CREDS_SECRET]") + + yellow(" (optional)") + + " Secret containing credentials if using a private image repository", + default=existing_env.get("BATCH_CONTAINER_IMAGE_CREDS_SECRET", ""), + show_default=True, + ) # Configure AWS Step Functions for scheduling. if click.confirm( diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 52cedd89f5..6add8aa84a 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -100,9 +100,11 @@ DATATOOLS_SUFFIX = from_conf("DATATOOLS_SUFFIX", "data") DATATOOLS_S3ROOT = from_conf( "DATATOOLS_S3ROOT", - os.path.join(DATASTORE_SYSROOT_S3, DATATOOLS_SUFFIX) - if DATASTORE_SYSROOT_S3 - else None, + ( + os.path.join(DATASTORE_SYSROOT_S3, DATATOOLS_SUFFIX) + if DATASTORE_SYSROOT_S3 + else None + ), ) TEMPDIR = from_conf("TEMPDIR", ".") @@ -120,25 +122,31 @@ # Similar to DATATOOLS_LOCALROOT, this is used ONLY by the IncludeFile's internal implementation. DATATOOLS_AZUREROOT = from_conf( "DATATOOLS_AZUREROOT", - os.path.join(DATASTORE_SYSROOT_AZURE, DATATOOLS_SUFFIX) - if DATASTORE_SYSROOT_AZURE - else None, + ( + os.path.join(DATASTORE_SYSROOT_AZURE, DATATOOLS_SUFFIX) + if DATASTORE_SYSROOT_AZURE + else None + ), ) # GS datatools root location # Note: we do not expose an actual datatools library for GS (like we do for S3) # Similar to DATATOOLS_LOCALROOT, this is used ONLY by the IncludeFile's internal implementation. DATATOOLS_GSROOT = from_conf( "DATATOOLS_GSROOT", - os.path.join(DATASTORE_SYSROOT_GS, DATATOOLS_SUFFIX) - if DATASTORE_SYSROOT_GS - else None, + ( + os.path.join(DATASTORE_SYSROOT_GS, DATATOOLS_SUFFIX) + if DATASTORE_SYSROOT_GS + else None + ), ) # Local datatools root location DATATOOLS_LOCALROOT = from_conf( "DATATOOLS_LOCALROOT", - os.path.join(DATASTORE_SYSROOT_LOCAL, DATATOOLS_SUFFIX) - if DATASTORE_SYSROOT_LOCAL - else None, + ( + os.path.join(DATASTORE_SYSROOT_LOCAL, DATATOOLS_SUFFIX) + if DATASTORE_SYSROOT_LOCAL + else None + ), ) # Secrets Backend - AWS Secrets Manager configuration @@ -156,9 +164,11 @@ ) CARD_AZUREROOT = from_conf( "CARD_AZUREROOT", - os.path.join(DATASTORE_SYSROOT_AZURE, CARD_SUFFIX) - if DATASTORE_SYSROOT_AZURE - else None, + ( + os.path.join(DATASTORE_SYSROOT_AZURE, CARD_SUFFIX) + if DATASTORE_SYSROOT_AZURE + else None + ), ) CARD_GSROOT = from_conf( "CARD_GSROOT", @@ -246,6 +256,8 @@ BATCH_CONTAINER_REGISTRY = from_conf( "BATCH_CONTAINER_REGISTRY", DEFAULT_CONTAINER_REGISTRY ) +# Secret containing credentials if using a private image repository +BATCH_CONTAINER_IMAGE_CREDS_SECRET = from_conf("BATCH_CONTAINER_IMAGE_CREDS_SECRET") # Metadata service URL for AWS Batch SERVICE_INTERNAL_URL = from_conf("SERVICE_INTERNAL_URL", SERVICE_URL) @@ -276,9 +288,11 @@ # Amazon S3 path for storing the results of AWS Step Functions Distributed Map SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf( "SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH", - os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output") - if DATASTORE_SYSROOT_S3 - else None, + ( + os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output") + if DATASTORE_SYSROOT_S3 + else None + ), ) ### # Kubernetes configuration diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 16ce9a06ce..db1e9c336c 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -195,6 +195,7 @@ def create_job( ephemeral_storage=None, log_driver=None, log_options=None, + repo_creds_secret=None, ): job_name = self._job_name( attrs.get("metaflow.user"), @@ -246,6 +247,7 @@ def create_job( ephemeral_storage=ephemeral_storage, log_driver=log_driver, log_options=log_options, + repo_creds_secret=repo_creds_secret, ) .task_id(attrs.get("metaflow.task_id")) .environment_variable("AWS_DEFAULT_REGION", self._client.region()) @@ -362,6 +364,7 @@ def launch_job( ephemeral_storage=None, log_driver=None, log_options=None, + repo_creds_secret=None, ): if queue is None: queue = next(self._client.active_job_queues(), None) @@ -402,6 +405,7 @@ def launch_job( ephemeral_storage=ephemeral_storage, log_driver=log_driver, log_options=log_options, + repo_creds_secret=repo_creds_secret, ) self.num_parallel = num_parallel self.job = job.execute() diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 7f12094900..8bf677b645 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -107,6 +107,10 @@ def kill(ctx, run_id, user, my_runs): "--image", help="Docker image requirement for AWS Batch. In name:version format.", ) +@click.option( + "--repo-creds-secret", + help="Secret containing credentials if using a private image repository", +) @click.option("--iam-role", help="IAM role requirement for AWS Batch.") @click.option( "--execution-role", @@ -189,6 +193,7 @@ def step( code_package_url, executable=None, image=None, + repo_creds_secret=None, iam_role=None, execution_role=None, cpu=None, @@ -345,6 +350,7 @@ def _sync_metadata(): log_driver=log_driver, log_options=log_options, num_parallel=num_parallel, + repo_creds_secret=repo_creds_secret, ) except Exception: traceback.print_exc() diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index 76f60eb65f..e3e3ebb5e9 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -161,6 +161,7 @@ def _register_job_definition( ephemeral_storage, log_driver, log_options, + repo_creds_secret, ): # identify platform from any compute environment associated with the # queue @@ -198,6 +199,11 @@ def _register_job_definition( "propagateTags": True, } + if repo_creds_secret: + job_definition["containerProperties"]["repositoryCredentials"] = { + "credentialsParameter": repo_creds_secret + } + log_options_dict = {} if log_options: if isinstance(log_options, str): @@ -479,6 +485,7 @@ def job_def( ephemeral_storage, log_driver, log_options, + repo_creds_secret, ): self.payload["jobDefinition"] = self._register_job_definition( image, @@ -501,6 +508,7 @@ def job_def( ephemeral_storage, log_driver, log_options, + repo_creds_secret, ) return self diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index bf84f6f5db..8af7c9b806 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -17,6 +17,7 @@ BATCH_JOB_QUEUE, BATCH_CONTAINER_IMAGE, BATCH_CONTAINER_REGISTRY, + BATCH_CONTAINER_IMAGE_CREDS_SECRET, ECS_FARGATE_EXECUTION_ROLE, DATASTORE_LOCAL_DIR, ) @@ -51,6 +52,8 @@ class BatchDecorator(StepDecorator): Docker image to use when launching on AWS Batch. If not specified, and METAFLOW_BATCH_CONTAINER_IMAGE is specified, that image is used. If not, a default Docker image mapping to the current version of Python is used. + repo_creds_secret : str, optional, default None + Secret containing credentials if using a private image repository. queue : str, default METAFLOW_BATCH_JOB_QUEUE AWS Batch Job Queue to submit the job to. iam_role : str, default METAFLOW_ECS_S3_ACCESS_IAM_ROLE @@ -105,6 +108,7 @@ class BatchDecorator(StepDecorator): "gpu": None, "memory": None, "image": None, + "repo_creds_secret": None, "queue": BATCH_JOB_QUEUE, "iam_role": ECS_S3_ACCESS_IAM_ROLE, "execution_role": ECS_FARGATE_EXECUTION_ROLE, @@ -162,6 +166,12 @@ def __init__(self, attributes=None, statically_defined=False): self.attributes["image"], ) + if not self.attributes["repo_creds_secret"]: + if BATCH_CONTAINER_IMAGE_CREDS_SECRET: + self.attributes[ + "repo_creds_secret" + ] = BATCH_CONTAINER_IMAGE_CREDS_SECRET + # Alias trainium to inferentia and check that both are not in use. if ( self.attributes["inferentia"] is not None diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index aa703a072a..0d57a3c7c1 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -842,6 +842,7 @@ def _batch(self, node): ephemeral_storage=resources["ephemeral_storage"], log_driver=resources["log_driver"], log_options=resources["log_options"], + repo_creds_secret=resources["repo_creds_secret"], ) .attempts(total_retries + 1) )