From 1ab661caf25b89cc46fa9acb4bca132fe8b89fd5 Mon Sep 17 00:00:00 2001 From: mccstan Date: Thu, 7 Mar 2024 17:35:13 +0100 Subject: [PATCH 01/10] Was able to submit jobs, with network configs, cpu, ram, accelerators, machine type --- .idea/.gitignore | 8 + .idea/dsub.iml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 7 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + README.md | 2 +- dsub/providers/google_batch.py | 1201 +++++++++-------- dsub/providers/google_batch_operations.py | 318 +++-- 9 files changed, 848 insertions(+), 720 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/dsub.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/dsub.iml b/.idea/dsub.iml new file mode 100644 index 0000000..78caf75 --- /dev/null +++ b/.idea/dsub.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..ae54fb9 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..fb6592c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index d89f924..8fb4770 100644 --- a/README.md +++ b/README.md @@ -724,7 +724,7 @@ The image below illustrates this: By default, `dsub` will use the [default Compute Engine service account](https://cloud.google.com/compute/docs/access/service-accounts#default_service_account) as the authorized service account on the VM instance. You can choose to specify -the email address of another service acount using `--service-account`. +the email address of another service account using `--service-account`. By default, `dsub` will grant the following access scopes to the service account: diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index ca4133e..853e7f2 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -28,17 +28,17 @@ from . import google_base from . import google_batch_operations from . import google_utils +from .google_batch_operations import build_compute_resource from ..lib import job_model from ..lib import param_util from ..lib import providers_util - # pylint: disable=g-import-not-at-top try: - from google.cloud import batch_v1 + from google.cloud import batch_v1 except ImportError: - # TODO: Remove conditional import when batch library is available - from . import batch_dummy as batch_v1 + # TODO: Remove conditional import when batch library is available + from . import batch_dummy as batch_v1 # pylint: enable=g-import-not-at-top _PROVIDER_NAME = 'google-batch' @@ -231,32 +231,32 @@ class GoogleBatchOperation(base.Task): - """Task wrapper around a Batch API Job object.""" + """Task wrapper around a Batch API Job object.""" - def __init__(self, operation_data: batch_v1.types.Job): - self._op = operation_data - self._job_descriptor = self._try_op_to_job_descriptor() + def __init__(self, operation_data: batch_v1.types.Job): + self._op = operation_data + self._job_descriptor = self._try_op_to_job_descriptor() - def raw_task_data(self): - return self._op + def raw_task_data(self): + return self._op - def _try_op_to_job_descriptor(self): - # The _META_YAML_REPR field in the 'prepare' action enables reconstructing - # the original job descriptor. - # TODO: Currently, we set the environment across all runnables - # We really only want the env for the prepare action (runnable) here. - env = google_batch_operations.get_environment(self._op) - if not env: - return + def _try_op_to_job_descriptor(self): + # The _META_YAML_REPR field in the 'prepare' action enables reconstructing + # the original job descriptor. + # TODO: Currently, we set the environment across all runnables + # We really only want the env for the prepare action (runnable) here. + env = google_batch_operations.get_environment(self._op) + if not env: + return - meta = env.get(google_utils.META_YAML_VARNAME) - if not meta: - return + meta = env.get(google_utils.META_YAML_VARNAME) + if not meta: + return - return job_model.JobDescriptor.from_yaml(ast.literal_eval(meta)) + return job_model.JobDescriptor.from_yaml(ast.literal_eval(meta)) - def get_field(self, field: str, default: str = None): - """Returns a value from the operation for a specific set of field names. + def get_field(self, field: str, default: str = None): + """Returns a value from the operation for a specific set of field names. This is the implementation of base.Task's abstract get_field method. See base.py get_field for more details. @@ -271,104 +271,104 @@ def get_field(self, field: str, default: str = None): Raises: ValueError: if the field label is not supported by the operation """ - value = None - if field == 'internal-id': - value = self._op.name - elif field == 'user-project': - if self._job_descriptor: - value = self._job_descriptor.job_metadata.get(field) - elif field in [ - 'job-id', - 'job-name', - 'task-id', - 'task-attempt', - 'user-id', - 'dsub-version', - ]: - value = google_batch_operations.get_label(self._op, field) - elif field == 'task-status': - value = self._operation_status() - elif field == 'logging': - if self._job_descriptor: - # The job_resources will contain the "--logging" value. - # The task_resources will contain the resolved logging path. - # Return the resolved logging path. - task_resources = self._job_descriptor.task_descriptors[0].task_resources - value = task_resources.logging_path - elif field in ['envs', 'labels']: - if self._job_descriptor: - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value = {item.name: item.value for item in items} - elif field in [ - 'inputs', - 'outputs', - 'input-recursives', - 'output-recursives', - ]: - if self._job_descriptor: - value = {} - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value.update({item.name: item.value for item in items}) - elif field == 'mounts': - if self._job_descriptor: - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value = {item.name: item.value for item in items} - elif field == 'provider': - return _PROVIDER_NAME - elif field == 'provider-attributes': - # TODO: This needs to return instance (VM) metadata - value = {} - elif field == 'events': - # TODO: This needs to return a list of events - value = [] - elif field == 'script-name': - if self._job_descriptor: - value = self._job_descriptor.job_metadata.get(field) - elif field == 'script': - value = self._try_op_to_script_body() - elif field == 'create-time' or field == 'start-time': - # TODO: Does Batch offer a start or end-time? - # Check http://shortn/_FPYmD1weUF - ds = google_batch_operations.get_create_time(self._op) - value = google_base.parse_rfc3339_utc_string(ds) - elif field == 'end-time' or field == 'last-update': - # TODO: Does Batch offer an end-time? - # Check http://shortn/_FPYmD1weUF - ds = google_batch_operations.get_update_time(self._op) - if ds: - value = google_base.parse_rfc3339_utc_string(ds) - elif field == 'status': - value = self._operation_status() - elif field == 'status-message': - value = self._operation_status_message() - elif field == 'status-detail': - value = self._operation_status_message() - else: - raise ValueError(f'Unsupported field: "{field}"') - - return value if value else default - - def _try_op_to_script_body(self): - # TODO: Currently, we set the environment across all runnables - # We really only want the env for the prepare action (runnable) here. - env = google_batch_operations.get_environment(self._op) - if env: - return ast.literal_eval(env.get(google_utils.SCRIPT_VARNAME)) - - def _operation_status(self): - """Returns the status of this operation. + value = None + if field == 'internal-id': + value = self._op.name + elif field == 'user-project': + if self._job_descriptor: + value = self._job_descriptor.job_metadata.get(field) + elif field in [ + 'job-id', + 'job-name', + 'task-id', + 'task-attempt', + 'user-id', + 'dsub-version', + ]: + value = google_batch_operations.get_label(self._op, field) + elif field == 'task-status': + value = self._operation_status() + elif field == 'logging': + if self._job_descriptor: + # The job_resources will contain the "--logging" value. + # The task_resources will contain the resolved logging path. + # Return the resolved logging path. + task_resources = self._job_descriptor.task_descriptors[0].task_resources + value = task_resources.logging_path + elif field in ['envs', 'labels']: + if self._job_descriptor: + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value = {item.name: item.value for item in items} + elif field in [ + 'inputs', + 'outputs', + 'input-recursives', + 'output-recursives', + ]: + if self._job_descriptor: + value = {} + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value.update({item.name: item.value for item in items}) + elif field == 'mounts': + if self._job_descriptor: + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value = {item.name: item.value for item in items} + elif field == 'provider': + return _PROVIDER_NAME + elif field == 'provider-attributes': + # TODO: This needs to return instance (VM) metadata + value = {} + elif field == 'events': + # TODO: This needs to return a list of events + value = [] + elif field == 'script-name': + if self._job_descriptor: + value = self._job_descriptor.job_metadata.get(field) + elif field == 'script': + value = self._try_op_to_script_body() + elif field == 'create-time' or field == 'start-time': + # TODO: Does Batch offer a start or end-time? + # Check http://shortn/_FPYmD1weUF + ds = google_batch_operations.get_create_time(self._op) + value = google_base.parse_rfc3339_utc_string(ds) + elif field == 'end-time' or field == 'last-update': + # TODO: Does Batch offer an end-time? + # Check http://shortn/_FPYmD1weUF + ds = google_batch_operations.get_update_time(self._op) + if ds: + value = google_base.parse_rfc3339_utc_string(ds) + elif field == 'status': + value = self._operation_status() + elif field == 'status-message': + value = self._operation_status_message() + elif field == 'status-detail': + value = self._operation_status_message() + else: + raise ValueError(f'Unsupported field: "{field}"') + + return value if value else default + + def _try_op_to_script_body(self): + # TODO: Currently, we set the environment across all runnables + # We really only want the env for the prepare action (runnable) here. + env = google_batch_operations.get_environment(self._op) + if env: + return ast.literal_eval(env.get(google_utils.SCRIPT_VARNAME)) + + def _operation_status(self): + """Returns the status of this operation. Raises: ValueError: if the operation status cannot be determined. @@ -376,429 +376,450 @@ def _operation_status(self): Returns: A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). """ - if not google_batch_operations.is_done(self._op): - return 'RUNNING' - if google_batch_operations.is_success(self._op): - return 'SUCCESS' - if google_batch_operations.is_canceled(): - return 'CANCELED' - if google_batch_operations.is_failed(self._op): - return 'FAILURE' - - raise ValueError( - 'Status for operation {} could not be determined'.format( - self._op['name'] + if not google_batch_operations.is_done(self._op): + return 'RUNNING' + if google_batch_operations.is_success(self._op): + return 'SUCCESS' + if google_batch_operations.is_canceled(): + return 'CANCELED' + if google_batch_operations.is_failed(self._op): + return 'FAILURE' + + raise ValueError( + 'Status for operation {} could not be determined'.format( + self._op['name'] + ) ) - ) - def _operation_status_message(self): - # TODO: This is intended to grab as much detail as possible - # Currently, just grabbing the description field from the last status_event - status_events = google_batch_operations.get_status_events(self._op) - if status_events: - return status_events[-1].description + def _operation_status_message(self): + # TODO: This is intended to grab as much detail as possible + # Currently, just grabbing the description field from the last status_event + status_events = google_batch_operations.get_status_events(self._op) + if status_events: + return status_events[-1].description class GoogleBatchBatchHandler(object): - """Implement the HttpBatch interface to enable simple serial batches.""" + """Implement the HttpBatch interface to enable simple serial batches.""" - def __init__(self, callback): - self._cancel_list = [] - self._response_handler = callback + def __init__(self, callback): + self._cancel_list = [] + self._response_handler = callback - def add(self, cancel_fn, request_id): - self._cancel_list.append((request_id, cancel_fn)) + def add(self, cancel_fn, request_id): + self._cancel_list.append((request_id, cancel_fn)) - def execute(self): - for request_id, cancel_fn in self._cancel_list: - response = None - exception = None - try: - response = cancel_fn.result() - except: # pylint: disable=bare-except - exception = sys.exc_info()[1] + def execute(self): + for request_id, cancel_fn in self._cancel_list: + response = None + exception = None + try: + response = cancel_fn.result() + except: # pylint: disable=bare-except + exception = sys.exc_info()[1] - self._response_handler(request_id, response, exception) + self._response_handler(request_id, response, exception) class GoogleBatchJobProvider(google_utils.GoogleJobProviderBase): - """dsub provider implementation managing Jobs on Google Cloud.""" - - def __init__( - self, dry_run: bool, project: str, location: str, credentials=None - ): - self._dry_run = dry_run - self._location = location - self._project = project - - def _batch_handler_def(self): - return GoogleBatchBatchHandler - - def _operations_cancel_api_def(self): - return batch_v1.BatchServiceClient().delete_job - - def _get_create_time_filters(self, create_time_min, create_time_max): - # TODO: Currently, Batch API does not support filtering by create t. - return [] - - def _get_logging_env(self, logging_uri, user_project, include_filter_script): - """Returns the environment for actions that copy logging files.""" - if not logging_uri.endswith('.log'): - raise ValueError('Logging URI must end in ".log": {}'.format(logging_uri)) - - logging_prefix = logging_uri[: -len('.log')] - env = { - 'LOGGING_PATH': '{}.log'.format(logging_prefix), - 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), - 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), - 'USER_PROJECT': user_project, - } - if include_filter_script: - env[_LOG_FILTER_VAR] = repr(_LOG_FILTER_PYTHON) - - return env - - def _create_batch_request( - self, - task_view: job_model.JobDescriptor, - job_id, - all_envs: List[batch_v1.types.Environment], - ): - job_metadata = task_view.job_metadata - job_params = task_view.job_params - job_resources = task_view.job_resources - task_metadata = task_view.task_descriptors[0].task_metadata - task_params = task_view.task_descriptors[0].task_params - task_resources = task_view.task_descriptors[0].task_resources - - # Set up VM-specific variables - datadisk_volume = google_batch_operations.build_volume( - disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT - ) - - # Set up the task labels - # pylint: disable=g-complex-comprehension - labels = { - label.name: label.value if label.value else '' - for label in google_base.build_pipeline_labels( - job_metadata, task_metadata - ) - | job_params['labels'] - | task_params['labels'] - } - # pylint: enable=g-complex-comprehension - - # Set local variables for the core pipeline values - script = task_view.job_metadata['script'] - - # Track 0-based runnable indexes for cross-task awareness - user_action = 3 - - continuous_logging_cmd = _CONTINUOUS_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, - log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( + """dsub provider implementation managing Jobs on Google Cloud.""" + + def __init__( + self, dry_run: bool, project: str, location: str, credentials=None + ): + self._dry_run = dry_run + self._location = location + self._project = project + + def _batch_handler_def(self): + return GoogleBatchBatchHandler + + def _operations_cancel_api_def(self): + return batch_v1.BatchServiceClient().delete_job + + def _get_create_time_filters(self, create_time_min, create_time_max): + # TODO: Currently, Batch API does not support filtering by create t. + return [] + + def _get_logging_env(self, logging_uri, user_project, include_filter_script): + """Returns the environment for actions that copy logging files.""" + if not logging_uri.endswith('.log'): + raise ValueError('Logging URI must end in ".log": {}'.format(logging_uri)) + + logging_prefix = logging_uri[: -len('.log')] + env = { + 'LOGGING_PATH': '{}.log'.format(logging_prefix), + 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), + 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), + 'USER_PROJECT': user_project, + } + if include_filter_script: + env[_LOG_FILTER_VAR] = repr(_LOG_FILTER_PYTHON) + + return env + + def _create_batch_request( + self, + task_view: job_model.JobDescriptor, + job_id, + all_envs: List[batch_v1.types.Environment], + ): + job_metadata = task_view.job_metadata + job_params = task_view.job_params + job_resources = task_view.job_resources + task_metadata = task_view.task_descriptors[0].task_metadata + task_params = task_view.task_descriptors[0].task_params + task_resources = task_view.task_descriptors[0].task_resources + + # Set up VM-specific variables + datadisk_volume = google_batch_operations.build_volume( + disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT + ) + + # Set up the task labels + # pylint: disable=g-complex-comprehension + labels = { + label.name: label.value if label.value else '' + for label in google_base.build_pipeline_labels( + job_metadata, task_metadata + ) + | job_params['labels'] + | task_params['labels'] + } + # pylint: enable=g-complex-comprehension + + # Set local variables for the core pipeline values + script = task_view.job_metadata['script'] + + # Track 0-based runnable indexes for cross-task awareness + user_action = 3 + + continuous_logging_cmd = _CONTINUOUS_LOGGING_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), - log_interval=job_resources.log_interval or '60s', - ) - - logging_cmd = _FINAL_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, - log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( + log_cp=_LOG_CP.format( + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + log_file_path=_LOG_FILE_PATH, + user_action=user_action, + ), + log_interval=job_resources.log_interval or '60s', + ) + + logging_cmd = _FINAL_LOGGING_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), - ) - - # Set up command and environments for the prepare, localization, user, - # and de-localization actions - script_path = os.path.join(_SCRIPT_DIR, script.name) - user_project = task_view.job_metadata['user-project'] or '' - - prepare_command = google_utils.PREPARE_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - mk_runtime_dirs=google_utils.make_runtime_dirs_command( - _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR - ), - script_var=google_utils.SCRIPT_VARNAME, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - script_path=script_path, - mk_io_dirs=google_utils.MK_IO_DIRS, - ) - # pylint: disable=line-too-long - - continuous_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, True - ) - ) - final_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, False - ) - ) - - # Build the list of runnables (aka actions) - runnables = [] - - runnables.append( - # logging - google_batch_operations.build_runnable( - run_in_background=True, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=continuous_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', continuous_logging_cmd], - ) - ) - - runnables.append( - # prepare - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', prepare_command], - ) - ) - - runnables.append( - # localization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.LOCALIZATION_LOOP, - ), - ], - ) - ) - - runnables.append( - # user-command - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=job_resources.image, - environment=None, - entrypoint='/usr/bin/env', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=[ - 'bash', - '-c', - google_utils.USER_CMD.format( - tmp_dir=_TMP_DIR, - working_dir=_WORKING_DIR, - user_script=script_path, - ), - ], - ) - ) - - runnables.append( - # delocalization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.DELOCALIZATION_LOOP, - ), - ], - ) - ) - - runnables.append( - # final_logging - google_batch_operations.build_runnable( - run_in_background=False, - always_run=True, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=final_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', logging_cmd], - ), - ) - - # Prepare the VM (resources) configuration. The InstancePolicy describes an - # instance type and resources attached to each VM. The AllocationPolicy - # describes when, where, and how compute resources should be allocated - # for the Job. - disk = google_batch_operations.build_persistent_disk( - size_gb=job_resources.disk_size, - disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, - ) - attached_disk = google_batch_operations.build_attached_disk( - disk=disk, device_name=google_utils.DATA_DISK_NAME - ) - instance_policy = google_batch_operations.build_instance_policy( - attached_disk - ) - ipt = google_batch_operations.build_instance_policy_or_template( - instance_policy - ) - allocation_policy = google_batch_operations.build_allocation_policy([ipt]) - logs_policy = google_batch_operations.build_logs_policy( - batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH - ) - - # Bring together the task definition(s) and build the Job request. - task_spec = google_batch_operations.build_task_spec( - runnables=runnables, volumes=[datadisk_volume] - ) - task_group = google_batch_operations.build_task_group( - task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 - ) - - job = google_batch_operations.build_job( - [task_group], allocation_policy, labels, logs_policy - ) - - job_request = batch_v1.CreateJobRequest( - parent=f'projects/{self._project}/locations/{self._location}', - job=job, - job_id=job_id, - ) - # pylint: enable=line-too-long - return job_request - - def _submit_batch_job(self, request) -> str: - client = batch_v1.BatchServiceClient() - job_response = client.create_job(request=request) - op = GoogleBatchOperation(job_response) - print(f'Provider internal-id (operation): {job_response.name}') - return op.get_field('task-id') - - def _create_env_for_task( - self, task_view: job_model.JobDescriptor - ) -> Dict[str, str]: - job_params = task_view.job_params - task_params = task_view.task_descriptors[0].task_params - - # Set local variables for the core pipeline values - script = task_view.job_metadata['script'] - user_project = task_view.job_metadata['user-project'] or '' - - envs = job_params['envs'] | task_params['envs'] - inputs = job_params['inputs'] | task_params['inputs'] - outputs = job_params['outputs'] | task_params['outputs'] - mounts = job_params['mounts'] - - prepare_env = self._get_prepare_env( - script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT - ) - localization_env = self._get_localization_env( - inputs, user_project, _DATA_MOUNT_POINT - ) - user_environment = self._build_user_environment( - envs, inputs, outputs, mounts, _DATA_MOUNT_POINT - ) - delocalization_env = self._get_delocalization_env( - outputs, user_project, _DATA_MOUNT_POINT - ) - # This merges all the envs into one dict. Need to use this syntax because - # of python3.6. In python3.9 we'd prefer to use | operator. - all_env = { - **prepare_env, - **localization_env, - **user_environment, - **delocalization_env, - } - return all_env - - def submit_job( - self, - job_descriptor: job_model.JobDescriptor, - skip_if_output_present: bool, - ) -> Dict[str, any]: - # Validate task data and resources. - param_util.validate_submit_args_or_fail( - job_descriptor, - provider_name=_PROVIDER_NAME, - input_providers=_SUPPORTED_INPUT_PROVIDERS, - output_providers=_SUPPORTED_OUTPUT_PROVIDERS, - logging_providers=_SUPPORTED_LOGGING_PROVIDERS, - ) - - # Prepare and submit jobs. - launched_tasks = [] - requests = [] - job_id = job_descriptor.job_metadata['job-id'] - # Instead of creating one job per task, create one job with several tasks. - # We also need to create a list of environments per task. The length of this - # list determines how many tasks are in the job, and is specified in the - # TaskGroup's task_count field. - envs = [] - for task_view in job_model.task_view_generator(job_descriptor): - env = self._create_env_for_task(task_view) - envs.append(google_batch_operations.build_environment(env)) - - request = self._create_batch_request(job_descriptor, job_id, envs) - if self._dry_run: - requests.append(request) - else: - # task_id = client.create_job(request=request) - task_id = self._submit_batch_job(request) - launched_tasks.append(task_id) - # If this is a dry-run, emit all the pipeline request objects - if self._dry_run: - print( - json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) - ) - return { - 'job-id': job_id, - 'user-id': job_descriptor.job_metadata.get('user-id'), - 'task-id': [task_id for task_id in launched_tasks if task_id], - } - - def delete_jobs( - self, - user_ids, - job_ids, - task_ids, - labels, - create_time_min=None, - create_time_max=None, - ): - """Kills the operations associated with the specified job or job.task. + log_cp=_LOG_CP.format( + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + log_file_path=_LOG_FILE_PATH, + user_action=user_action, + ), + ) + + # Set up command and environments for the prepare, localization, user, + # and de-localization actions + script_path = os.path.join(_SCRIPT_DIR, script.name) + user_project = task_view.job_metadata['user-project'] or '' + + prepare_command = google_utils.PREPARE_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + mk_runtime_dirs=google_utils.make_runtime_dirs_command( + _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR + ), + script_var=google_utils.SCRIPT_VARNAME, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + script_path=script_path, + mk_io_dirs=google_utils.MK_IO_DIRS, + ) + # pylint: disable=line-too-long + + continuous_logging_env = google_batch_operations.build_environment( + self._get_logging_env( + task_resources.logging_path.uri, user_project, True + ) + ) + final_logging_env = google_batch_operations.build_environment( + self._get_logging_env( + task_resources.logging_path.uri, user_project, False + ) + ) + + # Build the list of runnables (aka actions) + runnables = [] + + runnables.append( + # logging + google_batch_operations.build_runnable( + run_in_background=True, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=continuous_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', continuous_logging_cmd], + ) + ) + + runnables.append( + # prepare + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', prepare_command], + ) + ) + + runnables.append( + # localization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.LOCALIZATION_LOOP, + ), + ], + ) + ) + + runnables.append( + # user-command + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=job_resources.image, + environment=None, + entrypoint='/usr/bin/env', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=[ + 'bash', + '-c', + google_utils.USER_CMD.format( + tmp_dir=_TMP_DIR, + working_dir=_WORKING_DIR, + user_script=script_path, + ), + ], + ) + ) + + runnables.append( + # delocalization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.DELOCALIZATION_LOOP, + ), + ], + ) + ) + + runnables.append( + # final_logging + google_batch_operations.build_runnable( + run_in_background=False, + always_run=True, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=final_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', logging_cmd], + ), + ) + + # Prepare the VM (resources) configuration. The InstancePolicy describes an + # instance type and resources attached to each VM. The AllocationPolicy + # describes when, where, and how compute resources should be allocated + # for the Job. + disk = google_batch_operations.build_persistent_disk( + size_gb=job_resources.disk_size, + disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, + ) + attached_disk = google_batch_operations.build_attached_disk( + disk=disk, device_name=google_utils.DATA_DISK_NAME + ) + instance_policy = google_batch_operations.build_instance_policy( + attached_disk + ) + ipt = google_batch_operations.build_instance_policy_or_template( + instance_policy + ) + + service_account = google_batch_operations.build_service_account( + service_account_email=job_resources.service_account) + + network_policy = google_batch_operations.build_network_policy( + network=job_resources.network, + subnetwork=job_resources.subnetwork, + no_external_ip_address=job_resources.use_private_address, + ) + + allocation_policy = google_batch_operations.build_allocation_policy( + ipts=[ipt], + service_account=service_account, + network_policy=network_policy + ) + + logs_policy = google_batch_operations.build_logs_policy( + batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH + ) + + compute_resource = build_compute_resource( + cpu_milli=job_resources.min_cores * 1000, + memory_mib=job_resources.min_ram * 1024, + boot_disk_mib=job_resources.boot_disk_size * 1024 + ) + + # Bring together the task definition(s) and build the Job request. + task_spec = google_batch_operations.build_task_spec( + runnables=runnables, volumes=[datadisk_volume], compute_resource=compute_resource + ) + task_group = google_batch_operations.build_task_group( + task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 + ) + + job = google_batch_operations.build_job( + [task_group], allocation_policy, labels, logs_policy + ) + + job_request = batch_v1.CreateJobRequest( + parent=f'projects/{self._project}/locations/{self._location}', + job=job, + job_id=job_id, + ) + # pylint: enable=line-too-long + return job_request + + def _submit_batch_job(self, request) -> str: + client = batch_v1.BatchServiceClient() + job_response = client.create_job(request=request) + op = GoogleBatchOperation(job_response) + print(f'Provider internal-id (operation): {job_response.name}') + return op.get_field('task-id') + + def _create_env_for_task( + self, task_view: job_model.JobDescriptor + ) -> Dict[str, str]: + job_params = task_view.job_params + task_params = task_view.task_descriptors[0].task_params + + # Set local variables for the core pipeline values + script = task_view.job_metadata['script'] + user_project = task_view.job_metadata['user-project'] or '' + + envs = job_params['envs'] | task_params['envs'] + inputs = job_params['inputs'] | task_params['inputs'] + outputs = job_params['outputs'] | task_params['outputs'] + mounts = job_params['mounts'] + + prepare_env = self._get_prepare_env( + script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT + ) + localization_env = self._get_localization_env( + inputs, user_project, _DATA_MOUNT_POINT + ) + user_environment = self._build_user_environment( + envs, inputs, outputs, mounts, _DATA_MOUNT_POINT + ) + delocalization_env = self._get_delocalization_env( + outputs, user_project, _DATA_MOUNT_POINT + ) + # This merges all the envs into one dict. Need to use this syntax because + # of python3.6. In python3.9 we'd prefer to use | operator. + all_env = { + **prepare_env, + **localization_env, + **user_environment, + **delocalization_env, + } + return all_env + + def submit_job( + self, + job_descriptor: job_model.JobDescriptor, + skip_if_output_present: bool, + ) -> Dict[str, any]: + # Validate task data and resources. + param_util.validate_submit_args_or_fail( + job_descriptor, + provider_name=_PROVIDER_NAME, + input_providers=_SUPPORTED_INPUT_PROVIDERS, + output_providers=_SUPPORTED_OUTPUT_PROVIDERS, + logging_providers=_SUPPORTED_LOGGING_PROVIDERS, + ) + + # Prepare and submit jobs. + launched_tasks = [] + requests = [] + job_id = job_descriptor.job_metadata['job-id'] + # Instead of creating one job per task, create one job with several tasks. + # We also need to create a list of environments per task. The length of this + # list determines how many tasks are in the job, and is specified in the + # TaskGroup's task_count field. + envs = [] + for task_view in job_model.task_view_generator(job_descriptor): + env = self._create_env_for_task(task_view) + envs.append(google_batch_operations.build_environment(env)) + + request = self._create_batch_request(job_descriptor, job_id, envs) + if self._dry_run: + requests.append(request) + else: + # task_id = client.create_job(request=request) + task_id = self._submit_batch_job(request) + launched_tasks.append(task_id) + # If this is a dry-run, emit all the pipeline request objects + if self._dry_run: + print( + json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) + ) + return { + 'job-id': job_id, + 'user-id': job_descriptor.job_metadata.get('user-id'), + 'task-id': [task_id for task_id in launched_tasks if task_id], + } + + def delete_jobs( + self, + user_ids, + job_ids, + task_ids, + labels, + create_time_min=None, + create_time_max=None, + ): + """Kills the operations associated with the specified job or job.task. Args: user_ids: List of user ids who "own" the job(s) to cancel. @@ -813,64 +834,64 @@ def delete_jobs( Returns: A list of tasks canceled and a list of error messages. """ - # Look up the job(s) - tasks = list( - self.lookup_job_tasks( - {'RUNNING'}, - user_ids=user_ids, - job_ids=job_ids, - task_ids=task_ids, - labels=labels, - create_time_min=create_time_min, - create_time_max=create_time_max, - ) - ) - - print('Found %d tasks to delete.' % len(tasks)) - return google_base.cancel( - self._batch_handler_def(), self._operations_cancel_api_def(), tasks - ) - - def lookup_job_tasks( - self, - statuses: Set[str], - user_ids=None, - job_ids=None, - job_names=None, - task_ids=None, - task_attempts=None, - labels=None, - create_time_min=None, - create_time_max=None, - max_tasks=0, - page_size=0, - ): - client = batch_v1.BatchServiceClient() - # TODO: Batch API has no 'done' filter like lifesciences API. - # Need to figure out how to filter for jobs that are completed. - empty_statuses = set() - ops_filter = self._build_query_filter( - empty_statuses, - user_ids, - job_ids, - job_names, - task_ids, - task_attempts, - labels, - create_time_min, - create_time_max, - ) - # Initialize request argument(s) - request = batch_v1.ListJobsRequest( - parent=f'projects/{self._project}/locations/{self._location}', - filter=ops_filter, - ) - - # Make the request - response = client.list_jobs(request=request) - for page in response: - yield GoogleBatchOperation(page) - - def get_tasks_completion_messages(self, tasks): - # TODO: This needs to return a list of error messages for each task - pass + # Look up the job(s) + tasks = list( + self.lookup_job_tasks( + {'RUNNING'}, + user_ids=user_ids, + job_ids=job_ids, + task_ids=task_ids, + labels=labels, + create_time_min=create_time_min, + create_time_max=create_time_max, + ) + ) + + print('Found %d tasks to delete.' % len(tasks)) + return google_base.cancel( + self._batch_handler_def(), self._operations_cancel_api_def(), tasks + ) + + def lookup_job_tasks( + self, + statuses: Set[str], + user_ids=None, + job_ids=None, + job_names=None, + task_ids=None, + task_attempts=None, + labels=None, + create_time_min=None, + create_time_max=None, + max_tasks=0, + page_size=0, + ): + client = batch_v1.BatchServiceClient() + # TODO: Batch API has no 'done' filter like lifesciences API. + # Need to figure out how to filter for jobs that are completed. + empty_statuses = set() + ops_filter = self._build_query_filter( + empty_statuses, + user_ids, + job_ids, + job_names, + task_ids, + task_attempts, + labels, + create_time_min, + create_time_max, + ) + # Initialize request argument(s) + request = batch_v1.ListJobsRequest( + parent=f'projects/{self._project}/locations/{self._location}', + filter=ops_filter, + ) + + # Make the request + response = client.list_jobs(request=request) + for page in response: + yield GoogleBatchOperation(page) + + def get_tasks_completion_messages(self, tasks): + # TODO: This needs to return a list of error messages for each task + pass diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 01f92c0..890c322 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -12,129 +12,164 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility routines for constructing a Google Batch API request.""" +import logging from typing import List, Optional, Dict +from google.cloud.batch_v1 import ServiceAccount, AllocationPolicy # pylint: disable=g-import-not-at-top try: - from google.cloud import batch_v1 + from google.cloud import batch_v1 except ImportError: - # TODO: Remove conditional import when batch library is available - from . import batch_dummy as batch_v1 + # TODO: Remove conditional import when batch library is available + from . import batch_dummy as batch_v1 + + # pylint: enable=g-import-not-at-top def label_filter(label_key: str, label_value: str) -> str: - """Return a valid label filter for operations.list().""" - return 'labels."{}" = "{}"'.format(label_key, label_value) + """Return a valid label filter for operations.list().""" + return 'labels."{}" = "{}"'.format(label_key, label_value) def get_label(op: batch_v1.types.Job, name: str) -> str: - """Return the value for the specified label.""" - return op.labels.get(name) + """Return the value for the specified label.""" + return op.labels.get(name) def get_environment(op: batch_v1.types.Job) -> Dict[str, str]: - # Currently Batch only supports task_groups of size 1 - task_group = op.task_groups[0] - env_dict = {} - for env in task_group.task_environments: - env_dict.update(env.variables) - return env_dict + # Currently Batch only supports task_groups of size 1 + task_group = op.task_groups[0] + env_dict = {} + for env in task_group.task_environments: + env_dict.update(env.variables) + return env_dict def is_done(op: batch_v1.types.Job) -> bool: - """Return whether the operation has been marked done.""" - return op.status.state in [ - batch_v1.types.job.JobStatus.State.SUCCEEDED, - batch_v1.types.job.JobStatus.State.FAILED, - ] + """Return whether the operation has been marked done.""" + return op.status.state in [ + batch_v1.types.job.JobStatus.State.SUCCEEDED, + batch_v1.types.job.JobStatus.State.FAILED, + ] def is_success(op: batch_v1.types.Job) -> bool: - """Return whether the operation has completed successfully.""" - return op.status.state == batch_v1.types.job.JobStatus.State.SUCCEEDED + """Return whether the operation has completed successfully.""" + return op.status.state == batch_v1.types.job.JobStatus.State.SUCCEEDED def is_canceled() -> bool: - """Return whether the operation was canceled by the user.""" - # TODO: Verify if the batch job has a canceled enum - return False + """Return whether the operation was canceled by the user.""" + # TODO: Verify if the batch job has a canceled enum + return False def is_failed(op: batch_v1.types.Job) -> bool: - """Return whether the operation has failed.""" - return op.status.state == batch_v1.types.job.JobStatus.State.FAILED + """Return whether the operation has failed.""" + return op.status.state == batch_v1.types.job.JobStatus.State.FAILED def _pad_timestamps(ts: str) -> str: - """Batch API removes trailing zeroes from the fractional part of seconds.""" - # ts looks like 2022-06-23T19:38:23.11506605Z - # Pad zeroes until the fractional part is 9 digits long - dt, fraction = ts.split('.') - fraction = fraction.rstrip('Z') - fraction = fraction.ljust(9, '0') - return f'{dt}.{fraction}Z' + """Batch API removes trailing zeroes from the fractional part of seconds.""" + # ts looks like 2022-06-23T19:38:23.11506605Z + # Pad zeroes until the fractional part is 9 digits long + dt, fraction = ts.split('.') + fraction = fraction.rstrip('Z') + fraction = fraction.ljust(9, '0') + return f'{dt}.{fraction}Z' def get_update_time(op: batch_v1.types.Job) -> Optional[str]: - """Return the update time string of the operation.""" - update_time = op.update_time - if update_time: - return _pad_timestamps(op.update_time.rfc3339()) - else: - return None + """Return the update time string of the operation.""" + update_time = op.update_time.ToDatetime() if op.update_time else None + if update_time: + return update_time.isoformat('T') + 'Z' # Representing the datetime object in rfc3339 format + else: + return None def get_create_time(op: batch_v1.types.Job) -> Optional[str]: - """Return the create time string of the operation.""" - create_time = op.create_time - if create_time: - return _pad_timestamps(op.create_time.rfc3339()) - else: - return None + """Return the create time string of the operation.""" + create_time = op.create_time.ToDatetime() if op.create_time else None + if create_time: + return create_time.isoformat('T') + 'Z' + else: + return None def get_status_events(op: batch_v1.types.Job): - return op.status.status_events + return op.status.status_events def build_job( - task_groups: List[batch_v1.types.TaskGroup], - allocation_policy: batch_v1.types.AllocationPolicy, - labels: Dict[str, str], - logs_policy: batch_v1.types.LogsPolicy, + task_groups: List[batch_v1.types.TaskGroup], + allocation_policy: batch_v1.types.AllocationPolicy, + labels: Dict[str, str], + logs_policy: batch_v1.types.LogsPolicy, ) -> batch_v1.types.Job: - job = batch_v1.Job() - job.task_groups = task_groups - job.allocation_policy = allocation_policy - job.labels = labels - job.logs_policy = logs_policy - return job + job = batch_v1.Job() + job.task_groups = task_groups + job.allocation_policy = allocation_policy + job.labels = labels + job.logs_policy = logs_policy + return job + + +def build_compute_resource(cpu_milli: int, memory_mib: int, boot_disk_mib: int) -> batch_v1.types.ComputeResource: + """Build a ComputeResource object for a Batch request. + + Args: + cpu_milli (int): Number of milliCPU units + memory_mib (int): Amount of memory in Mebibytes (MiB) + boot_disk_mib (int): The boot disk size in Mebibytes (MiB) + + Returns: + A ComputeResource object. + """ + compute_resource = batch_v1.ComputeResource( + cpu_milli=cpu_milli, + memory_mib=memory_mib, + boot_disk_mib=boot_disk_mib + ) + return compute_resource def build_task_spec( - runnables: List[batch_v1.types.task.Runnable], - volumes: List[batch_v1.types.Volume], + runnables: List[batch_v1.types.task.Runnable], + volumes: List[batch_v1.types.Volume], + compute_resource: batch_v1.types.ComputeResource, ) -> batch_v1.types.TaskSpec: - task_spec = batch_v1.TaskSpec() - task_spec.runnables = runnables - task_spec.volumes = volumes - return task_spec + """Build a TaskSpec object for a Batch request. + + Args: + runnables (List[Runnable]): List of Runnable objects + volumes (List[Volume]): List of Volume objects + compute_resource (ComputeResource): The compute resources to use + + Returns: + A TaskSpec object. + """ + task_spec = batch_v1.TaskSpec() + task_spec.runnables = runnables + task_spec.volumes = volumes + task_spec.compute_resource = compute_resource + return task_spec def build_environment(env_vars: Dict[str, str]): - environment = batch_v1.Environment() - environment.variables = env_vars - return environment + environment = batch_v1.Environment() + environment.variables = env_vars + return environment def build_task_group( - task_spec: batch_v1.types.TaskSpec, - task_environments: List[batch_v1.types.Environment], - task_count: int, - task_count_per_node: int, + task_spec: batch_v1.types.TaskSpec, + task_environments: List[batch_v1.types.Environment], + task_count: int, + task_count_per_node: int, ) -> batch_v1.types.TaskGroup: - """Build a TaskGroup object for a Batch request. + """Build a TaskGroup object for a Batch request. Args: task_spec (TaskSpec): TaskSpec object @@ -145,35 +180,35 @@ def build_task_group( Returns: A TaskGroup object. """ - task_group = batch_v1.TaskGroup() - task_group.task_spec = task_spec - task_group.task_environments = task_environments - task_group.task_count = task_count - task_group.task_count_per_node = task_count_per_node - return task_group + task_group = batch_v1.TaskGroup() + task_group.task_spec = task_spec + task_group.task_environments = task_environments + task_group.task_count = task_count + task_group.task_count_per_node = task_count_per_node + return task_group def build_container( - image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] + image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] ) -> batch_v1.types.task.Runnable.Container: - container = batch_v1.types.task.Runnable.Container() - container.image_uri = image_uri - container.entrypoint = entrypoint - container.commands = commands - container.volumes = volumes - return container + container = batch_v1.types.task.Runnable.Container() + container.image_uri = image_uri + container.entrypoint = entrypoint + container.commands = commands + container.volumes = volumes + return container def build_runnable( - run_in_background: bool, - always_run: bool, - environment: batch_v1.types.Environment, - image_uri: str, - entrypoint: str, - volumes: List[str], - commands: List[str], + run_in_background: bool, + always_run: bool, + environment: batch_v1.types.Environment, + image_uri: str, + entrypoint: str, + volumes: List[str], + commands: List[str], ) -> batch_v1.types.task.Runnable: - """Build a Runnable object for a Batch request. + """Build a Runnable object for a Batch request. Args: run_in_background (bool): True for the action to run in the background @@ -188,17 +223,17 @@ def build_runnable( Returns: An object representing a Runnable """ - container = build_container(image_uri, entrypoint, volumes, commands) - runnable = batch_v1.Runnable() - runnable.container = container - runnable.background = run_in_background - runnable.always_run = always_run - runnable.environment = environment - return runnable + container = build_container(image_uri, entrypoint, volumes, commands) + runnable = batch_v1.Runnable() + runnable.container = container + runnable.background = run_in_background + runnable.always_run = always_run + runnable.environment = environment + return runnable def build_volume(disk: str, path: str) -> batch_v1.types.Volume: - """Build a Volume object for a Batch request. + """Build a Volume object for a Batch request. Args: disk (str): Name of disk to mount, as specified in the resources section. @@ -207,59 +242,84 @@ def build_volume(disk: str, path: str) -> batch_v1.types.Volume: Returns: An object representing a Mount. """ - volume = batch_v1.Volume() - volume.device_name = disk - volume.mount_path = path - return volume + volume = batch_v1.Volume() + volume.device_name = disk + volume.mount_path = path + return volume + + +def build_network_policy(network: str, subnetwork: str, + no_external_ip_address: bool) -> batch_v1.types.job.AllocationPolicy.NetworkPolicy: + network_polycy = AllocationPolicy.NetworkPolicy( + network_interfaces=[ + AllocationPolicy.NetworkInterface( + network=network, + subnetwork=subnetwork, + no_external_ip_address=no_external_ip_address, + ) + ] + ) + return network_polycy + + +def build_service_account(service_account_email: str) -> batch_v1.ServiceAccount: + service_account = ServiceAccount( + email=service_account_email + ) + return service_account def build_allocation_policy( - ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], + ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], + service_account: batch_v1.ServiceAccount, + network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy, ) -> batch_v1.types.AllocationPolicy: - allocation_policy = batch_v1.AllocationPolicy() - allocation_policy.instances = ipts - return allocation_policy + allocation_policy = batch_v1.AllocationPolicy() + allocation_policy.instances = ipts + allocation_policy.service_account = service_account + allocation_policy.network = network_policy + return allocation_policy def build_instance_policy_or_template( - instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, + instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, ) -> batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate: - ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() - ipt.policy = instance_policy - return ipt + ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() + ipt.policy = instance_policy + return ipt def build_logs_policy( - destination: batch_v1.types.LogsPolicy.Destination, logs_path: str + destination: batch_v1.types.LogsPolicy.Destination, logs_path: str ) -> batch_v1.types.LogsPolicy: - logs_policy = batch_v1.LogsPolicy() - logs_policy.destination = destination - logs_policy.logs_path = logs_path + logs_policy = batch_v1.LogsPolicy() + logs_policy.destination = destination + logs_policy.logs_path = logs_path - return logs_policy + return logs_policy def build_instance_policy( - disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], + disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], ) -> batch_v1.types.AllocationPolicy.InstancePolicy: - instance_policy = batch_v1.AllocationPolicy.InstancePolicy() - instance_policy.disks = [disks] - return instance_policy + instance_policy = batch_v1.AllocationPolicy.InstancePolicy() + instance_policy.disks = [disks] + return instance_policy def build_attached_disk( - disk: batch_v1.types.AllocationPolicy.Disk, device_name: str + disk: batch_v1.types.AllocationPolicy.Disk, device_name: str ) -> batch_v1.types.AllocationPolicy.AttachedDisk: - attached_disk = batch_v1.AllocationPolicy.AttachedDisk() - attached_disk.new_disk = disk - attached_disk.device_name = device_name - return attached_disk + attached_disk = batch_v1.AllocationPolicy.AttachedDisk() + attached_disk.new_disk = disk + attached_disk.device_name = device_name + return attached_disk def build_persistent_disk( - size_gb: int, disk_type: str + size_gb: int, disk_type: str ) -> batch_v1.types.AllocationPolicy.Disk: - disk = batch_v1.AllocationPolicy.Disk() - disk.type = disk_type - disk.size_gb = size_gb - return disk + disk = batch_v1.AllocationPolicy.Disk() + disk.type = disk_type + disk.size_gb = size_gb + return disk From 31879f80b553e54f363b972f68165ab637e2b36d Mon Sep 17 00:00:00 2001 From: mccstan Date: Thu, 7 Mar 2024 17:44:07 +0100 Subject: [PATCH 02/10] Remove unwanted idea from commit history --- .gitignore | 2 ++ .idea/.gitignore | 8 -------- .idea/dsub.iml | 12 ------------ .idea/inspectionProfiles/profiles_settings.xml | 6 ------ .idea/misc.xml | 7 ------- .idea/modules.xml | 8 -------- .idea/vcs.xml | 6 ------ 7 files changed, 2 insertions(+), 47 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/dsub.iml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml diff --git a/.gitignore b/.gitignore index b46da6e..f75d3ac 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ *.py[cod] *$py.class +.idea + build/ dist/ dsub_libs/ diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/dsub.iml b/.idea/dsub.iml deleted file mode 100644 index 78caf75..0000000 --- a/.idea/dsub.iml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index ae54fb9..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index fb6592c..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From 8f6add5607aa8d114eccf683c4de4c12efe0c2d2 Mon Sep 17 00:00:00 2001 From: mccstan Date: Thu, 7 Mar 2024 17:44:23 +0100 Subject: [PATCH 03/10] Was able to submit jobs, with network configs, cpu, ram, accelerators, machine type --- dsub/providers/google_batch.py | 15 +++++++++++---- dsub/providers/google_batch_operations.py | 23 +++++++++++++++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index 853e7f2..8dcdbfa 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -22,13 +22,13 @@ import os import sys import textwrap -from typing import Dict, List, Set +from typing import Dict, List, Set, MutableSequence from . import base from . import google_base from . import google_batch_operations from . import google_utils -from .google_batch_operations import build_compute_resource +from .google_batch_operations import build_compute_resource, build_accelerators, build_instance_policy_or_template from ..lib import job_model from ..lib import param_util from ..lib import providers_util @@ -669,9 +669,16 @@ def _create_batch_request( attached_disk = google_batch_operations.build_attached_disk( disk=disk, device_name=google_utils.DATA_DISK_NAME ) + instance_policy = google_batch_operations.build_instance_policy( - attached_disk + disks=attached_disk, + machine_type=job_resources.machine_type, + accelerators=build_accelerators( + accelerator_type=job_resources.accelerator_type, + accelerator_count=job_resources.accelerator_count, + ) ) + ipt = google_batch_operations.build_instance_policy_or_template( instance_policy ) @@ -688,7 +695,7 @@ def _create_batch_request( allocation_policy = google_batch_operations.build_allocation_policy( ipts=[ipt], service_account=service_account, - network_policy=network_policy + network_policy=network_policy, ) logs_policy = google_batch_operations.build_logs_policy( diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 890c322..5afdac0 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -13,7 +13,7 @@ # limitations under the License. """Utility routines for constructing a Google Batch API request.""" import logging -from typing import List, Optional, Dict +from typing import List, Optional, Dict, MutableSequence from google.cloud.batch_v1 import ServiceAccount, AllocationPolicy # pylint: disable=g-import-not-at-top @@ -272,12 +272,13 @@ def build_service_account(service_account_email: str) -> batch_v1.ServiceAccount def build_allocation_policy( ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], service_account: batch_v1.ServiceAccount, - network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy, + network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy ) -> batch_v1.types.AllocationPolicy: allocation_policy = batch_v1.AllocationPolicy() allocation_policy.instances = ipts allocation_policy.service_account = service_account allocation_policy.network = network_policy + return allocation_policy @@ -301,9 +302,14 @@ def build_logs_policy( def build_instance_policy( disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], + machine_type: str, + accelerators: MutableSequence[batch_v1.types.AllocationPolicy.Accelerator] ) -> batch_v1.types.AllocationPolicy.InstancePolicy: instance_policy = batch_v1.AllocationPolicy.InstancePolicy() instance_policy.disks = [disks] + instance_policy.machine_type = machine_type + instance_policy.accelerators = accelerators + return instance_policy @@ -323,3 +329,16 @@ def build_persistent_disk( disk.type = disk_type disk.size_gb = size_gb return disk + + +def build_accelerators( + accelerator_type, + accelerator_count +) -> MutableSequence[batch_v1.types.AllocationPolicy.Accelerator]: + accelerators = [] + accelerator = batch_v1.AllocationPolicy.Accelerator() + accelerator.count = accelerator_count + accelerator.type = accelerator_type + accelerators.append(accelerator) + + return accelerators From bc2cd53e74e1aad06d8ee1156e0b6e7dcde51d9d Mon Sep 17 00:00:00 2001 From: mccstan Date: Fri, 8 Mar 2024 13:54:56 +0100 Subject: [PATCH 04/10] Fixed accelerators not working, missing volume mount and install drivers was not configured --- dsub/_dsub_version.py | 2 +- dsub/providers/google_batch.py | 16 +++++++++++++--- dsub/providers/google_batch_operations.py | 3 ++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/dsub/_dsub_version.py b/dsub/_dsub_version.py index 3feda53..9e7dad2 100644 --- a/dsub/_dsub_version.py +++ b/dsub/_dsub_version.py @@ -26,4 +26,4 @@ 0.1.3.dev0 -> 0.1.3 -> 0.1.4.dev0 -> ... """ -DSUB_VERSION = '0.4.10' +DSUB_VERSION = '0.4.11.dev0' diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index 8dcdbfa..c62bb44 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -28,7 +28,7 @@ from . import google_base from . import google_batch_operations from . import google_utils -from .google_batch_operations import build_compute_resource, build_accelerators, build_instance_policy_or_template +from .google_batch_operations import build_compute_resource, build_accelerators from ..lib import job_model from ..lib import param_util from ..lib import providers_util @@ -603,6 +603,14 @@ def _create_batch_request( ) ) + # user-command volumes + user_command_volumes = [f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'] + if job_resources.accelerator_type is not None: + user_command_volumes.extend([ + "/var/lib/nvidia/lib64:/usr/local/nvidia/lib64", + "/var/lib/nvidia/bin:/usr/local/nvidia/bin" + ]) + runnables.append( # user-command google_batch_operations.build_runnable( @@ -611,7 +619,7 @@ def _create_batch_request( image_uri=job_resources.image, environment=None, entrypoint='/usr/bin/env', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + volumes=user_command_volumes, commands=[ 'bash', '-c', @@ -680,7 +688,9 @@ def _create_batch_request( ) ipt = google_batch_operations.build_instance_policy_or_template( - instance_policy + instance_policy=instance_policy, + install_gpu_drivers=True if job_resources.accelerator_type is not None else False + ) service_account = google_batch_operations.build_service_account( diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 5afdac0..2b26906 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility routines for constructing a Google Batch API request.""" -import logging from typing import List, Optional, Dict, MutableSequence from google.cloud.batch_v1 import ServiceAccount, AllocationPolicy @@ -284,9 +283,11 @@ def build_allocation_policy( def build_instance_policy_or_template( instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, + install_gpu_drivers: bool ) -> batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate: ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() ipt.policy = instance_policy + ipt.install_gpu_drivers = install_gpu_drivers return ipt From 35922546a18a95aac20188563e9c66713e5e1a29 Mon Sep 17 00:00:00 2001 From: mccstan Date: Fri, 8 Mar 2024 15:44:30 +0100 Subject: [PATCH 05/10] Fix issue with accelerators object --- dsub/providers/google_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index c62bb44..91ad6e6 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -684,7 +684,7 @@ def _create_batch_request( accelerators=build_accelerators( accelerator_type=job_resources.accelerator_type, accelerator_count=job_resources.accelerator_count, - ) + ) if job_resources.accelerator_type is not None else None ) ipt = google_batch_operations.build_instance_policy_or_template( From 199f9cb5b816ea8fc18d4d8c3345d3c08d18c2d3 Mon Sep 17 00:00:00 2001 From: mccstan Date: Tue, 12 Mar 2024 11:31:20 +0100 Subject: [PATCH 06/10] Set 2 space indent for this project --- .gitignore | 1 + dsub/providers/google_batch.py | 1322 ++++++++++----------- dsub/providers/google_batch_operations.py | 426 +++---- 3 files changed, 875 insertions(+), 874 deletions(-) diff --git a/.gitignore b/.gitignore index f75d3ac..73ec146 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *$py.class .idea +.editorconfig build/ dist/ diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index 91ad6e6..e39c64e 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -35,10 +35,10 @@ # pylint: disable=g-import-not-at-top try: - from google.cloud import batch_v1 + from google.cloud import batch_v1 except ImportError: - # TODO: Remove conditional import when batch library is available - from . import batch_dummy as batch_v1 + # TODO: Remove conditional import when batch library is available + from . import batch_dummy as batch_v1 # pylint: enable=g-import-not-at-top _PROVIDER_NAME = 'google-batch' @@ -231,684 +231,684 @@ class GoogleBatchOperation(base.Task): - """Task wrapper around a Batch API Job object.""" - - def __init__(self, operation_data: batch_v1.types.Job): - self._op = operation_data - self._job_descriptor = self._try_op_to_job_descriptor() - - def raw_task_data(self): - return self._op - - def _try_op_to_job_descriptor(self): - # The _META_YAML_REPR field in the 'prepare' action enables reconstructing - # the original job descriptor. - # TODO: Currently, we set the environment across all runnables - # We really only want the env for the prepare action (runnable) here. - env = google_batch_operations.get_environment(self._op) - if not env: - return - - meta = env.get(google_utils.META_YAML_VARNAME) - if not meta: - return - - return job_model.JobDescriptor.from_yaml(ast.literal_eval(meta)) - - def get_field(self, field: str, default: str = None): - """Returns a value from the operation for a specific set of field names. - - This is the implementation of base.Task's abstract get_field method. See - base.py get_field for more details. - - Args: - field: a dsub-specific job metadata key - default: default value to return if field does not exist or is empty. - - Returns: - A text string for the field or a list for 'inputs'. - - Raises: - ValueError: if the field label is not supported by the operation - """ - value = None - if field == 'internal-id': - value = self._op.name - elif field == 'user-project': - if self._job_descriptor: - value = self._job_descriptor.job_metadata.get(field) - elif field in [ - 'job-id', - 'job-name', - 'task-id', - 'task-attempt', - 'user-id', - 'dsub-version', - ]: - value = google_batch_operations.get_label(self._op, field) - elif field == 'task-status': - value = self._operation_status() - elif field == 'logging': - if self._job_descriptor: - # The job_resources will contain the "--logging" value. - # The task_resources will contain the resolved logging path. - # Return the resolved logging path. - task_resources = self._job_descriptor.task_descriptors[0].task_resources - value = task_resources.logging_path - elif field in ['envs', 'labels']: - if self._job_descriptor: - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value = {item.name: item.value for item in items} - elif field in [ - 'inputs', - 'outputs', - 'input-recursives', - 'output-recursives', - ]: - if self._job_descriptor: - value = {} - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value.update({item.name: item.value for item in items}) - elif field == 'mounts': - if self._job_descriptor: - items = providers_util.get_job_and_task_param( - self._job_descriptor.job_params, - self._job_descriptor.task_descriptors[0].task_params, - field, - ) - value = {item.name: item.value for item in items} - elif field == 'provider': - return _PROVIDER_NAME - elif field == 'provider-attributes': - # TODO: This needs to return instance (VM) metadata - value = {} - elif field == 'events': - # TODO: This needs to return a list of events - value = [] - elif field == 'script-name': - if self._job_descriptor: - value = self._job_descriptor.job_metadata.get(field) - elif field == 'script': - value = self._try_op_to_script_body() - elif field == 'create-time' or field == 'start-time': - # TODO: Does Batch offer a start or end-time? - # Check http://shortn/_FPYmD1weUF - ds = google_batch_operations.get_create_time(self._op) - value = google_base.parse_rfc3339_utc_string(ds) - elif field == 'end-time' or field == 'last-update': - # TODO: Does Batch offer an end-time? - # Check http://shortn/_FPYmD1weUF - ds = google_batch_operations.get_update_time(self._op) - if ds: - value = google_base.parse_rfc3339_utc_string(ds) - elif field == 'status': - value = self._operation_status() - elif field == 'status-message': - value = self._operation_status_message() - elif field == 'status-detail': - value = self._operation_status_message() - else: - raise ValueError(f'Unsupported field: "{field}"') - - return value if value else default - - def _try_op_to_script_body(self): - # TODO: Currently, we set the environment across all runnables - # We really only want the env for the prepare action (runnable) here. - env = google_batch_operations.get_environment(self._op) - if env: - return ast.literal_eval(env.get(google_utils.SCRIPT_VARNAME)) - - def _operation_status(self): - """Returns the status of this operation. - - Raises: - ValueError: if the operation status cannot be determined. - - Returns: - A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). - """ - if not google_batch_operations.is_done(self._op): - return 'RUNNING' - if google_batch_operations.is_success(self._op): - return 'SUCCESS' - if google_batch_operations.is_canceled(): - return 'CANCELED' - if google_batch_operations.is_failed(self._op): - return 'FAILURE' - - raise ValueError( - 'Status for operation {} could not be determined'.format( - self._op['name'] - ) - ) - - def _operation_status_message(self): - # TODO: This is intended to grab as much detail as possible - # Currently, just grabbing the description field from the last status_event - status_events = google_batch_operations.get_status_events(self._op) - if status_events: - return status_events[-1].description - - -class GoogleBatchBatchHandler(object): - """Implement the HttpBatch interface to enable simple serial batches.""" - - def __init__(self, callback): - self._cancel_list = [] - self._response_handler = callback - - def add(self, cancel_fn, request_id): - self._cancel_list.append((request_id, cancel_fn)) - - def execute(self): - for request_id, cancel_fn in self._cancel_list: - response = None - exception = None - try: - response = cancel_fn.result() - except: # pylint: disable=bare-except - exception = sys.exc_info()[1] + """Task wrapper around a Batch API Job object.""" - self._response_handler(request_id, response, exception) + def __init__(self, operation_data: batch_v1.types.Job): + self._op = operation_data + self._job_descriptor = self._try_op_to_job_descriptor() + def raw_task_data(self): + return self._op -class GoogleBatchJobProvider(google_utils.GoogleJobProviderBase): - """dsub provider implementation managing Jobs on Google Cloud.""" - - def __init__( - self, dry_run: bool, project: str, location: str, credentials=None - ): - self._dry_run = dry_run - self._location = location - self._project = project - - def _batch_handler_def(self): - return GoogleBatchBatchHandler - - def _operations_cancel_api_def(self): - return batch_v1.BatchServiceClient().delete_job - - def _get_create_time_filters(self, create_time_min, create_time_max): - # TODO: Currently, Batch API does not support filtering by create t. - return [] - - def _get_logging_env(self, logging_uri, user_project, include_filter_script): - """Returns the environment for actions that copy logging files.""" - if not logging_uri.endswith('.log'): - raise ValueError('Logging URI must end in ".log": {}'.format(logging_uri)) - - logging_prefix = logging_uri[: -len('.log')] - env = { - 'LOGGING_PATH': '{}.log'.format(logging_prefix), - 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), - 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), - 'USER_PROJECT': user_project, - } - if include_filter_script: - env[_LOG_FILTER_VAR] = repr(_LOG_FILTER_PYTHON) - - return env - - def _create_batch_request( - self, - task_view: job_model.JobDescriptor, - job_id, - all_envs: List[batch_v1.types.Environment], - ): - job_metadata = task_view.job_metadata - job_params = task_view.job_params - job_resources = task_view.job_resources - task_metadata = task_view.task_descriptors[0].task_metadata - task_params = task_view.task_descriptors[0].task_params - task_resources = task_view.task_descriptors[0].task_resources - - # Set up VM-specific variables - datadisk_volume = google_batch_operations.build_volume( - disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT - ) - - # Set up the task labels - # pylint: disable=g-complex-comprehension - labels = { - label.name: label.value if label.value else '' - for label in google_base.build_pipeline_labels( - job_metadata, task_metadata - ) - | job_params['labels'] - | task_params['labels'] - } - # pylint: enable=g-complex-comprehension - - # Set local variables for the core pipeline values - script = task_view.job_metadata['script'] - - # Track 0-based runnable indexes for cross-task awareness - user_action = 3 - - continuous_logging_cmd = _CONTINUOUS_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, - log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), - log_interval=job_resources.log_interval or '60s', - ) - - logging_cmd = _FINAL_LOGGING_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - gsutil_cp_fn=google_utils.GSUTIL_CP_FN, - log_filter_var=_LOG_FILTER_VAR, - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - logging_dir=_LOGGING_DIR, - log_file_path=_LOG_FILE_PATH, - log_cp=_LOG_CP.format( - log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, - log_file_path=_LOG_FILE_PATH, - user_action=user_action, - ), - ) + def _try_op_to_job_descriptor(self): + # The _META_YAML_REPR field in the 'prepare' action enables reconstructing + # the original job descriptor. + # TODO: Currently, we set the environment across all runnables + # We really only want the env for the prepare action (runnable) here. + env = google_batch_operations.get_environment(self._op) + if not env: + return - # Set up command and environments for the prepare, localization, user, - # and de-localization actions - script_path = os.path.join(_SCRIPT_DIR, script.name) - user_project = task_view.job_metadata['user-project'] or '' + meta = env.get(google_utils.META_YAML_VARNAME) + if not meta: + return - prepare_command = google_utils.PREPARE_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - mk_runtime_dirs=google_utils.make_runtime_dirs_command( - _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR - ), - script_var=google_utils.SCRIPT_VARNAME, - python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, - script_path=script_path, - mk_io_dirs=google_utils.MK_IO_DIRS, - ) - # pylint: disable=line-too-long - - continuous_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, True - ) - ) - final_logging_env = google_batch_operations.build_environment( - self._get_logging_env( - task_resources.logging_path.uri, user_project, False - ) - ) - - # Build the list of runnables (aka actions) - runnables = [] - - runnables.append( - # logging - google_batch_operations.build_runnable( - run_in_background=True, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=continuous_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', continuous_logging_cmd], - ) - ) - - runnables.append( - # prepare - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', prepare_command], - ) - ) - - runnables.append( - # localization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.LOCALIZATION_LOOP, - ), - ], - ) - ) - - # user-command volumes - user_command_volumes = [f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'] - if job_resources.accelerator_type is not None: - user_command_volumes.extend([ - "/var/lib/nvidia/lib64:/usr/local/nvidia/lib64", - "/var/lib/nvidia/bin:/usr/local/nvidia/bin" - ]) - - runnables.append( - # user-command - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=job_resources.image, - environment=None, - entrypoint='/usr/bin/env', - volumes=user_command_volumes, - commands=[ - 'bash', - '-c', - google_utils.USER_CMD.format( - tmp_dir=_TMP_DIR, - working_dir=_WORKING_DIR, - user_script=script_path, - ), - ], - ) - ) - - runnables.append( - # delocalization - google_batch_operations.build_runnable( - run_in_background=False, - always_run=False, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=None, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], - commands=[ - '-c', - google_utils.LOCALIZATION_CMD.format( - log_msg_fn=google_utils.LOG_MSG_FN, - recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, - cp_fn=google_utils.GSUTIL_CP_FN, - cp_loop=google_utils.DELOCALIZATION_LOOP, - ), - ], - ) - ) + return job_model.JobDescriptor.from_yaml(ast.literal_eval(meta)) - runnables.append( - # final_logging - google_batch_operations.build_runnable( - run_in_background=False, - always_run=True, - image_uri=google_utils.CLOUD_SDK_IMAGE, - environment=final_logging_env, - entrypoint='/bin/bash', - volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], - commands=['-c', logging_cmd], - ), - ) + def get_field(self, field: str, default: str = None): + """Returns a value from the operation for a specific set of field names. - # Prepare the VM (resources) configuration. The InstancePolicy describes an - # instance type and resources attached to each VM. The AllocationPolicy - # describes when, where, and how compute resources should be allocated - # for the Job. - disk = google_batch_operations.build_persistent_disk( - size_gb=job_resources.disk_size, - disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, - ) - attached_disk = google_batch_operations.build_attached_disk( - disk=disk, device_name=google_utils.DATA_DISK_NAME - ) +This is the implementation of base.Task's abstract get_field method. See +base.py get_field for more details. - instance_policy = google_batch_operations.build_instance_policy( - disks=attached_disk, - machine_type=job_resources.machine_type, - accelerators=build_accelerators( - accelerator_type=job_resources.accelerator_type, - accelerator_count=job_resources.accelerator_count, - ) if job_resources.accelerator_type is not None else None - ) +Args: + field: a dsub-specific job metadata key + default: default value to return if field does not exist or is empty. - ipt = google_batch_operations.build_instance_policy_or_template( - instance_policy=instance_policy, - install_gpu_drivers=True if job_resources.accelerator_type is not None else False +Returns: + A text string for the field or a list for 'inputs'. - ) +Raises: + ValueError: if the field label is not supported by the operation +""" + value = None + if field == 'internal-id': + value = self._op.name + elif field == 'user-project': + if self._job_descriptor: + value = self._job_descriptor.job_metadata.get(field) + elif field in [ + 'job-id', + 'job-name', + 'task-id', + 'task-attempt', + 'user-id', + 'dsub-version', + ]: + value = google_batch_operations.get_label(self._op, field) + elif field == 'task-status': + value = self._operation_status() + elif field == 'logging': + if self._job_descriptor: + # The job_resources will contain the "--logging" value. + # The task_resources will contain the resolved logging path. + # Return the resolved logging path. + task_resources = self._job_descriptor.task_descriptors[0].task_resources + value = task_resources.logging_path + elif field in ['envs', 'labels']: + if self._job_descriptor: + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value = {item.name: item.value for item in items} + elif field in [ + 'inputs', + 'outputs', + 'input-recursives', + 'output-recursives', + ]: + if self._job_descriptor: + value = {} + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value.update({item.name: item.value for item in items}) + elif field == 'mounts': + if self._job_descriptor: + items = providers_util.get_job_and_task_param( + self._job_descriptor.job_params, + self._job_descriptor.task_descriptors[0].task_params, + field, + ) + value = {item.name: item.value for item in items} + elif field == 'provider': + return _PROVIDER_NAME + elif field == 'provider-attributes': + # TODO: This needs to return instance (VM) metadata + value = {} + elif field == 'events': + # TODO: This needs to return a list of events + value = [] + elif field == 'script-name': + if self._job_descriptor: + value = self._job_descriptor.job_metadata.get(field) + elif field == 'script': + value = self._try_op_to_script_body() + elif field == 'create-time' or field == 'start-time': + # TODO: Does Batch offer a start or end-time? + # Check http://shortn/_FPYmD1weUF + ds = google_batch_operations.get_create_time(self._op) + value = google_base.parse_rfc3339_utc_string(ds) + elif field == 'end-time' or field == 'last-update': + # TODO: Does Batch offer an end-time? + # Check http://shortn/_FPYmD1weUF + ds = google_batch_operations.get_update_time(self._op) + if ds: + value = google_base.parse_rfc3339_utc_string(ds) + elif field == 'status': + value = self._operation_status() + elif field == 'status-message': + value = self._operation_status_message() + elif field == 'status-detail': + value = self._operation_status_message() + else: + raise ValueError(f'Unsupported field: "{field}"') - service_account = google_batch_operations.build_service_account( - service_account_email=job_resources.service_account) + return value if value else default - network_policy = google_batch_operations.build_network_policy( - network=job_resources.network, - subnetwork=job_resources.subnetwork, - no_external_ip_address=job_resources.use_private_address, - ) + def _try_op_to_script_body(self): + # TODO: Currently, we set the environment across all runnables + # We really only want the env for the prepare action (runnable) here. + env = google_batch_operations.get_environment(self._op) + if env: + return ast.literal_eval(env.get(google_utils.SCRIPT_VARNAME)) - allocation_policy = google_batch_operations.build_allocation_policy( - ipts=[ipt], - service_account=service_account, - network_policy=network_policy, - ) + def _operation_status(self): + """Returns the status of this operation. - logs_policy = google_batch_operations.build_logs_policy( - batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH - ) +Raises: + ValueError: if the operation status cannot be determined. - compute_resource = build_compute_resource( - cpu_milli=job_resources.min_cores * 1000, - memory_mib=job_resources.min_ram * 1024, - boot_disk_mib=job_resources.boot_disk_size * 1024 - ) +Returns: + A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). +""" + if not google_batch_operations.is_done(self._op): + return 'RUNNING' + if google_batch_operations.is_success(self._op): + return 'SUCCESS' + if google_batch_operations.is_canceled(): + return 'CANCELED' + if google_batch_operations.is_failed(self._op): + return 'FAILURE' + + raise ValueError( + 'Status for operation {} could not be determined'.format( + self._op['name'] + ) + ) + + def _operation_status_message(self): + # TODO: This is intended to grab as much detail as possible + # Currently, just grabbing the description field from the last status_event + status_events = google_batch_operations.get_status_events(self._op) + if status_events: + return status_events[-1].description - # Bring together the task definition(s) and build the Job request. - task_spec = google_batch_operations.build_task_spec( - runnables=runnables, volumes=[datadisk_volume], compute_resource=compute_resource - ) - task_group = google_batch_operations.build_task_group( - task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 - ) - job = google_batch_operations.build_job( - [task_group], allocation_policy, labels, logs_policy - ) +class GoogleBatchBatchHandler(object): + """Implement the HttpBatch interface to enable simple serial batches.""" - job_request = batch_v1.CreateJobRequest( - parent=f'projects/{self._project}/locations/{self._location}', - job=job, - job_id=job_id, - ) - # pylint: enable=line-too-long - return job_request - - def _submit_batch_job(self, request) -> str: - client = batch_v1.BatchServiceClient() - job_response = client.create_job(request=request) - op = GoogleBatchOperation(job_response) - print(f'Provider internal-id (operation): {job_response.name}') - return op.get_field('task-id') - - def _create_env_for_task( - self, task_view: job_model.JobDescriptor - ) -> Dict[str, str]: - job_params = task_view.job_params - task_params = task_view.task_descriptors[0].task_params - - # Set local variables for the core pipeline values - script = task_view.job_metadata['script'] - user_project = task_view.job_metadata['user-project'] or '' - - envs = job_params['envs'] | task_params['envs'] - inputs = job_params['inputs'] | task_params['inputs'] - outputs = job_params['outputs'] | task_params['outputs'] - mounts = job_params['mounts'] - - prepare_env = self._get_prepare_env( - script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT - ) - localization_env = self._get_localization_env( - inputs, user_project, _DATA_MOUNT_POINT - ) - user_environment = self._build_user_environment( - envs, inputs, outputs, mounts, _DATA_MOUNT_POINT - ) - delocalization_env = self._get_delocalization_env( - outputs, user_project, _DATA_MOUNT_POINT - ) - # This merges all the envs into one dict. Need to use this syntax because - # of python3.6. In python3.9 we'd prefer to use | operator. - all_env = { - **prepare_env, - **localization_env, - **user_environment, - **delocalization_env, - } - return all_env - - def submit_job( - self, - job_descriptor: job_model.JobDescriptor, - skip_if_output_present: bool, - ) -> Dict[str, any]: - # Validate task data and resources. - param_util.validate_submit_args_or_fail( - job_descriptor, - provider_name=_PROVIDER_NAME, - input_providers=_SUPPORTED_INPUT_PROVIDERS, - output_providers=_SUPPORTED_OUTPUT_PROVIDERS, - logging_providers=_SUPPORTED_LOGGING_PROVIDERS, - ) + def __init__(self, callback): + self._cancel_list = [] + self._response_handler = callback - # Prepare and submit jobs. - launched_tasks = [] - requests = [] - job_id = job_descriptor.job_metadata['job-id'] - # Instead of creating one job per task, create one job with several tasks. - # We also need to create a list of environments per task. The length of this - # list determines how many tasks are in the job, and is specified in the - # TaskGroup's task_count field. - envs = [] - for task_view in job_model.task_view_generator(job_descriptor): - env = self._create_env_for_task(task_view) - envs.append(google_batch_operations.build_environment(env)) - - request = self._create_batch_request(job_descriptor, job_id, envs) - if self._dry_run: - requests.append(request) - else: - # task_id = client.create_job(request=request) - task_id = self._submit_batch_job(request) - launched_tasks.append(task_id) - # If this is a dry-run, emit all the pipeline request objects - if self._dry_run: - print( - json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) - ) - return { - 'job-id': job_id, - 'user-id': job_descriptor.job_metadata.get('user-id'), - 'task-id': [task_id for task_id in launched_tasks if task_id], - } - - def delete_jobs( - self, - user_ids, - job_ids, - task_ids, - labels, - create_time_min=None, - create_time_max=None, - ): - """Kills the operations associated with the specified job or job.task. - - Args: - user_ids: List of user ids who "own" the job(s) to cancel. - job_ids: List of job_ids to cancel. - task_ids: List of task-ids to cancel. - labels: List of LabelParam, each must match the job(s) to be canceled. - create_time_min: a timezone-aware datetime value for the earliest create - time of a task, inclusive. - create_time_max: a timezone-aware datetime value for the most recent - create time of a task, inclusive. - - Returns: - A list of tasks canceled and a list of error messages. - """ - # Look up the job(s) - tasks = list( - self.lookup_job_tasks( - {'RUNNING'}, - user_ids=user_ids, - job_ids=job_ids, - task_ids=task_ids, - labels=labels, - create_time_min=create_time_min, - create_time_max=create_time_max, - ) - ) + def add(self, cancel_fn, request_id): + self._cancel_list.append((request_id, cancel_fn)) - print('Found %d tasks to delete.' % len(tasks)) - return google_base.cancel( - self._batch_handler_def(), self._operations_cancel_api_def(), tasks - ) + def execute(self): + for request_id, cancel_fn in self._cancel_list: + response = None + exception = None + try: + response = cancel_fn.result() + except: # pylint: disable=bare-except + exception = sys.exc_info()[1] - def lookup_job_tasks( - self, - statuses: Set[str], - user_ids=None, - job_ids=None, - job_names=None, - task_ids=None, - task_attempts=None, - labels=None, - create_time_min=None, - create_time_max=None, - max_tasks=0, - page_size=0, - ): - client = batch_v1.BatchServiceClient() - # TODO: Batch API has no 'done' filter like lifesciences API. - # Need to figure out how to filter for jobs that are completed. - empty_statuses = set() - ops_filter = self._build_query_filter( - empty_statuses, - user_ids, - job_ids, - job_names, - task_ids, - task_attempts, - labels, - create_time_min, - create_time_max, - ) - # Initialize request argument(s) - request = batch_v1.ListJobsRequest( - parent=f'projects/{self._project}/locations/{self._location}', - filter=ops_filter, - ) + self._response_handler(request_id, response, exception) - # Make the request - response = client.list_jobs(request=request) - for page in response: - yield GoogleBatchOperation(page) - def get_tasks_completion_messages(self, tasks): - # TODO: This needs to return a list of error messages for each task - pass +class GoogleBatchJobProvider(google_utils.GoogleJobProviderBase): + """dsub provider implementation managing Jobs on Google Cloud.""" + + def __init__( + self, dry_run: bool, project: str, location: str, credentials=None + ): + self._dry_run = dry_run + self._location = location + self._project = project + + def _batch_handler_def(self): + return GoogleBatchBatchHandler + + def _operations_cancel_api_def(self): + return batch_v1.BatchServiceClient().delete_job + + def _get_create_time_filters(self, create_time_min, create_time_max): + # TODO: Currently, Batch API does not support filtering by create t. + return [] + + def _get_logging_env(self, logging_uri, user_project, include_filter_script): + """Returns the environment for actions that copy logging files.""" + if not logging_uri.endswith('.log'): + raise ValueError('Logging URI must end in ".log": {}'.format(logging_uri)) + + logging_prefix = logging_uri[: -len('.log')] + env = { + 'LOGGING_PATH': '{}.log'.format(logging_prefix), + 'STDOUT_PATH': '{}-stdout.log'.format(logging_prefix), + 'STDERR_PATH': '{}-stderr.log'.format(logging_prefix), + 'USER_PROJECT': user_project, + } + if include_filter_script: + env[_LOG_FILTER_VAR] = repr(_LOG_FILTER_PYTHON) + + return env + + def _create_batch_request( + self, + task_view: job_model.JobDescriptor, + job_id, + all_envs: List[batch_v1.types.Environment], + ): + job_metadata = task_view.job_metadata + job_params = task_view.job_params + job_resources = task_view.job_resources + task_metadata = task_view.task_descriptors[0].task_metadata + task_params = task_view.task_descriptors[0].task_params + task_resources = task_view.task_descriptors[0].task_resources + + # Set up VM-specific variables + datadisk_volume = google_batch_operations.build_volume( + disk=google_utils.DATA_DISK_NAME, path=_VOLUME_MOUNT_POINT + ) + + # Set up the task labels + # pylint: disable=g-complex-comprehension + labels = { + label.name: label.value if label.value else '' + for label in google_base.build_pipeline_labels( + job_metadata, task_metadata + ) + | job_params['labels'] + | task_params['labels'] + } + # pylint: enable=g-complex-comprehension + + # Set local variables for the core pipeline values + script = task_view.job_metadata['script'] + + # Track 0-based runnable indexes for cross-task awareness + user_action = 3 + + continuous_logging_cmd = _CONTINUOUS_LOGGING_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, + log_file_path=_LOG_FILE_PATH, + log_cp=_LOG_CP.format( + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + log_file_path=_LOG_FILE_PATH, + user_action=user_action, + ), + log_interval=job_resources.log_interval or '60s', + ) + + logging_cmd = _FINAL_LOGGING_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + gsutil_cp_fn=google_utils.GSUTIL_CP_FN, + log_filter_var=_LOG_FILTER_VAR, + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + logging_dir=_LOGGING_DIR, + log_file_path=_LOG_FILE_PATH, + log_cp=_LOG_CP.format( + log_filter_script_path=_LOG_FILTER_SCRIPT_PATH, + log_file_path=_LOG_FILE_PATH, + user_action=user_action, + ), + ) + + # Set up command and environments for the prepare, localization, user, + # and de-localization actions + script_path = os.path.join(_SCRIPT_DIR, script.name) + user_project = task_view.job_metadata['user-project'] or '' + + prepare_command = google_utils.PREPARE_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + mk_runtime_dirs=google_utils.make_runtime_dirs_command( + _SCRIPT_DIR, _TMP_DIR, _WORKING_DIR + ), + script_var=google_utils.SCRIPT_VARNAME, + python_decode_script=google_utils.PYTHON_DECODE_SCRIPT, + script_path=script_path, + mk_io_dirs=google_utils.MK_IO_DIRS, + ) + # pylint: disable=line-too-long + + continuous_logging_env = google_batch_operations.build_environment( + self._get_logging_env( + task_resources.logging_path.uri, user_project, True + ) + ) + final_logging_env = google_batch_operations.build_environment( + self._get_logging_env( + task_resources.logging_path.uri, user_project, False + ) + ) + + # Build the list of runnables (aka actions) + runnables = [] + + runnables.append( + # logging + google_batch_operations.build_runnable( + run_in_background=True, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=continuous_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', continuous_logging_cmd], + ) + ) + + runnables.append( + # prepare + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', prepare_command], + ) + ) + + runnables.append( + # localization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.LOCALIZATION_LOOP, + ), + ], + ) + ) + + # user-command volumes + user_command_volumes = [f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'] + if job_resources.accelerator_type is not None: + user_command_volumes.extend([ + "/var/lib/nvidia/lib64:/usr/local/nvidia/lib64", + "/var/lib/nvidia/bin:/usr/local/nvidia/bin" + ]) + + runnables.append( + # user-command + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=job_resources.image, + environment=None, + entrypoint='/usr/bin/env', + volumes=user_command_volumes, + commands=[ + 'bash', + '-c', + google_utils.USER_CMD.format( + tmp_dir=_TMP_DIR, + working_dir=_WORKING_DIR, + user_script=script_path, + ), + ], + ) + ) + + runnables.append( + # delocalization + google_batch_operations.build_runnable( + run_in_background=False, + always_run=False, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=None, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}:ro'], + commands=[ + '-c', + google_utils.LOCALIZATION_CMD.format( + log_msg_fn=google_utils.LOG_MSG_FN, + recursive_cp_fn=google_utils.GSUTIL_RSYNC_FN, + cp_fn=google_utils.GSUTIL_CP_FN, + cp_loop=google_utils.DELOCALIZATION_LOOP, + ), + ], + ) + ) + + runnables.append( + # final_logging + google_batch_operations.build_runnable( + run_in_background=False, + always_run=True, + image_uri=google_utils.CLOUD_SDK_IMAGE, + environment=final_logging_env, + entrypoint='/bin/bash', + volumes=[f'{_VOLUME_MOUNT_POINT}:{_DATA_MOUNT_POINT}'], + commands=['-c', logging_cmd], + ), + ) + + # Prepare the VM (resources) configuration. The InstancePolicy describes an + # instance type and resources attached to each VM. The AllocationPolicy + # describes when, where, and how compute resources should be allocated + # for the Job. + disk = google_batch_operations.build_persistent_disk( + size_gb=job_resources.disk_size, + disk_type=job_resources.disk_type or job_model.DEFAULT_DISK_TYPE, + ) + attached_disk = google_batch_operations.build_attached_disk( + disk=disk, device_name=google_utils.DATA_DISK_NAME + ) + + instance_policy = google_batch_operations.build_instance_policy( + disks=attached_disk, + machine_type=job_resources.machine_type, + accelerators=build_accelerators( + accelerator_type=job_resources.accelerator_type, + accelerator_count=job_resources.accelerator_count, + ) if job_resources.accelerator_type is not None else None + ) + + ipt = google_batch_operations.build_instance_policy_or_template( + instance_policy=instance_policy, + install_gpu_drivers=True if job_resources.accelerator_type is not None else False + + ) + + service_account = google_batch_operations.build_service_account( + service_account_email=job_resources.service_account) + + network_policy = google_batch_operations.build_network_policy( + network=job_resources.network, + subnetwork=job_resources.subnetwork, + no_external_ip_address=job_resources.use_private_address, + ) + + allocation_policy = google_batch_operations.build_allocation_policy( + ipts=[ipt], + service_account=service_account, + network_policy=network_policy, + ) + + logs_policy = google_batch_operations.build_logs_policy( + batch_v1.LogsPolicy.Destination.PATH, _BATCH_LOG_FILE_PATH + ) + + compute_resource = build_compute_resource( + cpu_milli=job_resources.min_cores * 1000, + memory_mib=job_resources.min_ram * 1024, + boot_disk_mib=job_resources.boot_disk_size * 1024 + ) + + # Bring together the task definition(s) and build the Job request. + task_spec = google_batch_operations.build_task_spec( + runnables=runnables, volumes=[datadisk_volume], compute_resource=compute_resource + ) + task_group = google_batch_operations.build_task_group( + task_spec, all_envs, task_count=len(all_envs), task_count_per_node=1 + ) + + job = google_batch_operations.build_job( + [task_group], allocation_policy, labels, logs_policy + ) + + job_request = batch_v1.CreateJobRequest( + parent=f'projects/{self._project}/locations/{self._location}', + job=job, + job_id=job_id, + ) + # pylint: enable=line-too-long + return job_request + + def _submit_batch_job(self, request) -> str: + client = batch_v1.BatchServiceClient() + job_response = client.create_job(request=request) + op = GoogleBatchOperation(job_response) + print(f'Provider internal-id (operation): {job_response.name}') + return op.get_field('task-id') + + def _create_env_for_task( + self, task_view: job_model.JobDescriptor + ) -> Dict[str, str]: + job_params = task_view.job_params + task_params = task_view.task_descriptors[0].task_params + + # Set local variables for the core pipeline values + script = task_view.job_metadata['script'] + user_project = task_view.job_metadata['user-project'] or '' + + envs = job_params['envs'] | task_params['envs'] + inputs = job_params['inputs'] | task_params['inputs'] + outputs = job_params['outputs'] | task_params['outputs'] + mounts = job_params['mounts'] + + prepare_env = self._get_prepare_env( + script, task_view, inputs, outputs, mounts, _DATA_MOUNT_POINT + ) + localization_env = self._get_localization_env( + inputs, user_project, _DATA_MOUNT_POINT + ) + user_environment = self._build_user_environment( + envs, inputs, outputs, mounts, _DATA_MOUNT_POINT + ) + delocalization_env = self._get_delocalization_env( + outputs, user_project, _DATA_MOUNT_POINT + ) + # This merges all the envs into one dict. Need to use this syntax because + # of python3.6. In python3.9 we'd prefer to use | operator. + all_env = { + **prepare_env, + **localization_env, + **user_environment, + **delocalization_env, + } + return all_env + + def submit_job( + self, + job_descriptor: job_model.JobDescriptor, + skip_if_output_present: bool, + ) -> Dict[str, any]: + # Validate task data and resources. + param_util.validate_submit_args_or_fail( + job_descriptor, + provider_name=_PROVIDER_NAME, + input_providers=_SUPPORTED_INPUT_PROVIDERS, + output_providers=_SUPPORTED_OUTPUT_PROVIDERS, + logging_providers=_SUPPORTED_LOGGING_PROVIDERS, + ) + + # Prepare and submit jobs. + launched_tasks = [] + requests = [] + job_id = job_descriptor.job_metadata['job-id'] + # Instead of creating one job per task, create one job with several tasks. + # We also need to create a list of environments per task. The length of this + # list determines how many tasks are in the job, and is specified in the + # TaskGroup's task_count field. + envs = [] + for task_view in job_model.task_view_generator(job_descriptor): + env = self._create_env_for_task(task_view) + envs.append(google_batch_operations.build_environment(env)) + + request = self._create_batch_request(job_descriptor, job_id, envs) + if self._dry_run: + requests.append(request) + else: + # task_id = client.create_job(request=request) + task_id = self._submit_batch_job(request) + launched_tasks.append(task_id) + # If this is a dry-run, emit all the pipeline request objects + if self._dry_run: + print( + json.dumps(requests, indent=2, sort_keys=True, separators=(',', ': ')) + ) + return { + 'job-id': job_id, + 'user-id': job_descriptor.job_metadata.get('user-id'), + 'task-id': [task_id for task_id in launched_tasks if task_id], + } + + def delete_jobs( + self, + user_ids, + job_ids, + task_ids, + labels, + create_time_min=None, + create_time_max=None, + ): + """Kills the operations associated with the specified job or job.task. + +Args: + user_ids: List of user ids who "own" the job(s) to cancel. + job_ids: List of job_ids to cancel. + task_ids: List of task-ids to cancel. + labels: List of LabelParam, each must match the job(s) to be canceled. + create_time_min: a timezone-aware datetime value for the earliest create + time of a task, inclusive. + create_time_max: a timezone-aware datetime value for the most recent + create time of a task, inclusive. + +Returns: + A list of tasks canceled and a list of error messages. +""" + # Look up the job(s) + tasks = list( + self.lookup_job_tasks( + {'RUNNING'}, + user_ids=user_ids, + job_ids=job_ids, + task_ids=task_ids, + labels=labels, + create_time_min=create_time_min, + create_time_max=create_time_max, + ) + ) + + print('Found %d tasks to delete.' % len(tasks)) + return google_base.cancel( + self._batch_handler_def(), self._operations_cancel_api_def(), tasks + ) + + def lookup_job_tasks( + self, + statuses: Set[str], + user_ids=None, + job_ids=None, + job_names=None, + task_ids=None, + task_attempts=None, + labels=None, + create_time_min=None, + create_time_max=None, + max_tasks=0, + page_size=0, + ): + client = batch_v1.BatchServiceClient() + # TODO: Batch API has no 'done' filter like lifesciences API. + # Need to figure out how to filter for jobs that are completed. + empty_statuses = set() + ops_filter = self._build_query_filter( + empty_statuses, + user_ids, + job_ids, + job_names, + task_ids, + task_attempts, + labels, + create_time_min, + create_time_max, + ) + # Initialize request argument(s) + request = batch_v1.ListJobsRequest( + parent=f'projects/{self._project}/locations/{self._location}', + filter=ops_filter, + ) + + # Make the request + response = client.list_jobs(request=request) + for page in response: + yield GoogleBatchOperation(page) + + def get_tasks_completion_messages(self, tasks): + # TODO: This needs to return a list of error messages for each task + pass diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 2b26906..f2b4641 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -17,329 +17,329 @@ # pylint: disable=g-import-not-at-top try: - from google.cloud import batch_v1 + from google.cloud import batch_v1 except ImportError: - # TODO: Remove conditional import when batch library is available - from . import batch_dummy as batch_v1 + # TODO: Remove conditional import when batch library is available + from . import batch_dummy as batch_v1 # pylint: enable=g-import-not-at-top def label_filter(label_key: str, label_value: str) -> str: - """Return a valid label filter for operations.list().""" - return 'labels."{}" = "{}"'.format(label_key, label_value) + """Return a valid label filter for operations.list().""" + return 'labels."{}" = "{}"'.format(label_key, label_value) def get_label(op: batch_v1.types.Job, name: str) -> str: - """Return the value for the specified label.""" - return op.labels.get(name) + """Return the value for the specified label.""" + return op.labels.get(name) def get_environment(op: batch_v1.types.Job) -> Dict[str, str]: - # Currently Batch only supports task_groups of size 1 - task_group = op.task_groups[0] - env_dict = {} - for env in task_group.task_environments: - env_dict.update(env.variables) - return env_dict + # Currently Batch only supports task_groups of size 1 + task_group = op.task_groups[0] + env_dict = {} + for env in task_group.task_environments: + env_dict.update(env.variables) + return env_dict def is_done(op: batch_v1.types.Job) -> bool: - """Return whether the operation has been marked done.""" - return op.status.state in [ - batch_v1.types.job.JobStatus.State.SUCCEEDED, - batch_v1.types.job.JobStatus.State.FAILED, - ] + """Return whether the operation has been marked done.""" + return op.status.state in [ + batch_v1.types.job.JobStatus.State.SUCCEEDED, + batch_v1.types.job.JobStatus.State.FAILED, + ] def is_success(op: batch_v1.types.Job) -> bool: - """Return whether the operation has completed successfully.""" - return op.status.state == batch_v1.types.job.JobStatus.State.SUCCEEDED + """Return whether the operation has completed successfully.""" + return op.status.state == batch_v1.types.job.JobStatus.State.SUCCEEDED def is_canceled() -> bool: - """Return whether the operation was canceled by the user.""" - # TODO: Verify if the batch job has a canceled enum - return False + """Return whether the operation was canceled by the user.""" + # TODO: Verify if the batch job has a canceled enum + return False def is_failed(op: batch_v1.types.Job) -> bool: - """Return whether the operation has failed.""" - return op.status.state == batch_v1.types.job.JobStatus.State.FAILED + """Return whether the operation has failed.""" + return op.status.state == batch_v1.types.job.JobStatus.State.FAILED def _pad_timestamps(ts: str) -> str: - """Batch API removes trailing zeroes from the fractional part of seconds.""" - # ts looks like 2022-06-23T19:38:23.11506605Z - # Pad zeroes until the fractional part is 9 digits long - dt, fraction = ts.split('.') - fraction = fraction.rstrip('Z') - fraction = fraction.ljust(9, '0') - return f'{dt}.{fraction}Z' + """Batch API removes trailing zeroes from the fractional part of seconds.""" + # ts looks like 2022-06-23T19:38:23.11506605Z + # Pad zeroes until the fractional part is 9 digits long + dt, fraction = ts.split('.') + fraction = fraction.rstrip('Z') + fraction = fraction.ljust(9, '0') + return f'{dt}.{fraction}Z' def get_update_time(op: batch_v1.types.Job) -> Optional[str]: - """Return the update time string of the operation.""" - update_time = op.update_time.ToDatetime() if op.update_time else None - if update_time: - return update_time.isoformat('T') + 'Z' # Representing the datetime object in rfc3339 format - else: - return None + """Return the update time string of the operation.""" + update_time = op.update_time.ToDatetime() if op.update_time else None + if update_time: + return update_time.isoformat('T') + 'Z' # Representing the datetime object in rfc3339 format + else: + return None def get_create_time(op: batch_v1.types.Job) -> Optional[str]: - """Return the create time string of the operation.""" - create_time = op.create_time.ToDatetime() if op.create_time else None - if create_time: - return create_time.isoformat('T') + 'Z' - else: - return None + """Return the create time string of the operation.""" + create_time = op.create_time.ToDatetime() if op.create_time else None + if create_time: + return create_time.isoformat('T') + 'Z' + else: + return None def get_status_events(op: batch_v1.types.Job): - return op.status.status_events + return op.status.status_events def build_job( - task_groups: List[batch_v1.types.TaskGroup], - allocation_policy: batch_v1.types.AllocationPolicy, - labels: Dict[str, str], - logs_policy: batch_v1.types.LogsPolicy, + task_groups: List[batch_v1.types.TaskGroup], + allocation_policy: batch_v1.types.AllocationPolicy, + labels: Dict[str, str], + logs_policy: batch_v1.types.LogsPolicy, ) -> batch_v1.types.Job: - job = batch_v1.Job() - job.task_groups = task_groups - job.allocation_policy = allocation_policy - job.labels = labels - job.logs_policy = logs_policy - return job + job = batch_v1.Job() + job.task_groups = task_groups + job.allocation_policy = allocation_policy + job.labels = labels + job.logs_policy = logs_policy + return job def build_compute_resource(cpu_milli: int, memory_mib: int, boot_disk_mib: int) -> batch_v1.types.ComputeResource: - """Build a ComputeResource object for a Batch request. + """Build a ComputeResource object for a Batch request. - Args: - cpu_milli (int): Number of milliCPU units - memory_mib (int): Amount of memory in Mebibytes (MiB) - boot_disk_mib (int): The boot disk size in Mebibytes (MiB) + Args: + cpu_milli (int): Number of milliCPU units + memory_mib (int): Amount of memory in Mebibytes (MiB) + boot_disk_mib (int): The boot disk size in Mebibytes (MiB) - Returns: - A ComputeResource object. - """ - compute_resource = batch_v1.ComputeResource( - cpu_milli=cpu_milli, - memory_mib=memory_mib, - boot_disk_mib=boot_disk_mib - ) - return compute_resource + Returns: + A ComputeResource object. + """ + compute_resource = batch_v1.ComputeResource( + cpu_milli=cpu_milli, + memory_mib=memory_mib, + boot_disk_mib=boot_disk_mib + ) + return compute_resource def build_task_spec( - runnables: List[batch_v1.types.task.Runnable], - volumes: List[batch_v1.types.Volume], - compute_resource: batch_v1.types.ComputeResource, + runnables: List[batch_v1.types.task.Runnable], + volumes: List[batch_v1.types.Volume], + compute_resource: batch_v1.types.ComputeResource, ) -> batch_v1.types.TaskSpec: - """Build a TaskSpec object for a Batch request. + """Build a TaskSpec object for a Batch request. - Args: - runnables (List[Runnable]): List of Runnable objects - volumes (List[Volume]): List of Volume objects - compute_resource (ComputeResource): The compute resources to use + Args: + runnables (List[Runnable]): List of Runnable objects + volumes (List[Volume]): List of Volume objects + compute_resource (ComputeResource): The compute resources to use - Returns: - A TaskSpec object. - """ - task_spec = batch_v1.TaskSpec() - task_spec.runnables = runnables - task_spec.volumes = volumes - task_spec.compute_resource = compute_resource - return task_spec + Returns: + A TaskSpec object. + """ + task_spec = batch_v1.TaskSpec() + task_spec.runnables = runnables + task_spec.volumes = volumes + task_spec.compute_resource = compute_resource + return task_spec def build_environment(env_vars: Dict[str, str]): - environment = batch_v1.Environment() - environment.variables = env_vars - return environment + environment = batch_v1.Environment() + environment.variables = env_vars + return environment def build_task_group( - task_spec: batch_v1.types.TaskSpec, - task_environments: List[batch_v1.types.Environment], - task_count: int, - task_count_per_node: int, + task_spec: batch_v1.types.TaskSpec, + task_environments: List[batch_v1.types.Environment], + task_count: int, + task_count_per_node: int, ) -> batch_v1.types.TaskGroup: - """Build a TaskGroup object for a Batch request. + """Build a TaskGroup object for a Batch request. - Args: - task_spec (TaskSpec): TaskSpec object - task_environments (List[Environment]): List of Environment objects - task_count (int): The number of total tasks in the job - task_count_per_node (int): The number of tasks to schedule on one VM +Args: + task_spec (TaskSpec): TaskSpec object + task_environments (List[Environment]): List of Environment objects + task_count (int): The number of total tasks in the job + task_count_per_node (int): The number of tasks to schedule on one VM - Returns: - A TaskGroup object. - """ - task_group = batch_v1.TaskGroup() - task_group.task_spec = task_spec - task_group.task_environments = task_environments - task_group.task_count = task_count - task_group.task_count_per_node = task_count_per_node - return task_group +Returns: + A TaskGroup object. +""" + task_group = batch_v1.TaskGroup() + task_group.task_spec = task_spec + task_group.task_environments = task_environments + task_group.task_count = task_count + task_group.task_count_per_node = task_count_per_node + return task_group def build_container( - image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] + image_uri: str, entrypoint: str, volumes: List[str], commands: List[str] ) -> batch_v1.types.task.Runnable.Container: - container = batch_v1.types.task.Runnable.Container() - container.image_uri = image_uri - container.entrypoint = entrypoint - container.commands = commands - container.volumes = volumes - return container + container = batch_v1.types.task.Runnable.Container() + container.image_uri = image_uri + container.entrypoint = entrypoint + container.commands = commands + container.volumes = volumes + return container def build_runnable( - run_in_background: bool, - always_run: bool, - environment: batch_v1.types.Environment, - image_uri: str, - entrypoint: str, - volumes: List[str], - commands: List[str], + run_in_background: bool, + always_run: bool, + environment: batch_v1.types.Environment, + image_uri: str, + entrypoint: str, + volumes: List[str], + commands: List[str], ) -> batch_v1.types.task.Runnable: - """Build a Runnable object for a Batch request. - - Args: - run_in_background (bool): True for the action to run in the background - always_run (bool): True for the action to run even in case of error from - prior actions - environment (Environment): Environment variables for action - image_uri (str): Docker image path - entrypoint (str): Docker image entrypoint path - volumes (List[str]): List of volume mounts (host_path:container_path) - commands (List[str]): Command arguments to pass to the entrypoint - - Returns: - An object representing a Runnable - """ - container = build_container(image_uri, entrypoint, volumes, commands) - runnable = batch_v1.Runnable() - runnable.container = container - runnable.background = run_in_background - runnable.always_run = always_run - runnable.environment = environment - return runnable + """Build a Runnable object for a Batch request. + +Args: + run_in_background (bool): True for the action to run in the background + always_run (bool): True for the action to run even in case of error from + prior actions + environment (Environment): Environment variables for action + image_uri (str): Docker image path + entrypoint (str): Docker image entrypoint path + volumes (List[str]): List of volume mounts (host_path:container_path) + commands (List[str]): Command arguments to pass to the entrypoint + +Returns: + An object representing a Runnable +""" + container = build_container(image_uri, entrypoint, volumes, commands) + runnable = batch_v1.Runnable() + runnable.container = container + runnable.background = run_in_background + runnable.always_run = always_run + runnable.environment = environment + return runnable def build_volume(disk: str, path: str) -> batch_v1.types.Volume: - """Build a Volume object for a Batch request. + """Build a Volume object for a Batch request. - Args: - disk (str): Name of disk to mount, as specified in the resources section. - path (str): Path to mount the disk at inside the container. +Args: + disk (str): Name of disk to mount, as specified in the resources section. + path (str): Path to mount the disk at inside the container. - Returns: - An object representing a Mount. - """ - volume = batch_v1.Volume() - volume.device_name = disk - volume.mount_path = path - return volume +Returns: + An object representing a Mount. +""" + volume = batch_v1.Volume() + volume.device_name = disk + volume.mount_path = path + return volume def build_network_policy(network: str, subnetwork: str, no_external_ip_address: bool) -> batch_v1.types.job.AllocationPolicy.NetworkPolicy: - network_polycy = AllocationPolicy.NetworkPolicy( - network_interfaces=[ - AllocationPolicy.NetworkInterface( - network=network, - subnetwork=subnetwork, - no_external_ip_address=no_external_ip_address, - ) - ] - ) - return network_polycy + network_polycy = AllocationPolicy.NetworkPolicy( + network_interfaces=[ + AllocationPolicy.NetworkInterface( + network=network, + subnetwork=subnetwork, + no_external_ip_address=no_external_ip_address, + ) + ] + ) + return network_polycy def build_service_account(service_account_email: str) -> batch_v1.ServiceAccount: - service_account = ServiceAccount( - email=service_account_email - ) - return service_account + service_account = ServiceAccount( + email=service_account_email + ) + return service_account def build_allocation_policy( - ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], - service_account: batch_v1.ServiceAccount, - network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy + ipts: List[batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate], + service_account: batch_v1.ServiceAccount, + network_policy: batch_v1.types.job.AllocationPolicy.NetworkPolicy ) -> batch_v1.types.AllocationPolicy: - allocation_policy = batch_v1.AllocationPolicy() - allocation_policy.instances = ipts - allocation_policy.service_account = service_account - allocation_policy.network = network_policy + allocation_policy = batch_v1.AllocationPolicy() + allocation_policy.instances = ipts + allocation_policy.service_account = service_account + allocation_policy.network = network_policy - return allocation_policy + return allocation_policy def build_instance_policy_or_template( - instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, - install_gpu_drivers: bool + instance_policy: batch_v1.types.AllocationPolicy.InstancePolicy, + install_gpu_drivers: bool ) -> batch_v1.types.AllocationPolicy.InstancePolicyOrTemplate: - ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() - ipt.policy = instance_policy - ipt.install_gpu_drivers = install_gpu_drivers - return ipt + ipt = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() + ipt.policy = instance_policy + ipt.install_gpu_drivers = install_gpu_drivers + return ipt def build_logs_policy( - destination: batch_v1.types.LogsPolicy.Destination, logs_path: str + destination: batch_v1.types.LogsPolicy.Destination, logs_path: str ) -> batch_v1.types.LogsPolicy: - logs_policy = batch_v1.LogsPolicy() - logs_policy.destination = destination - logs_policy.logs_path = logs_path + logs_policy = batch_v1.LogsPolicy() + logs_policy.destination = destination + logs_policy.logs_path = logs_path - return logs_policy + return logs_policy def build_instance_policy( - disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], - machine_type: str, - accelerators: MutableSequence[batch_v1.types.AllocationPolicy.Accelerator] + disks: List[batch_v1.types.AllocationPolicy.AttachedDisk], + machine_type: str, + accelerators: MutableSequence[batch_v1.types.AllocationPolicy.Accelerator] ) -> batch_v1.types.AllocationPolicy.InstancePolicy: - instance_policy = batch_v1.AllocationPolicy.InstancePolicy() - instance_policy.disks = [disks] - instance_policy.machine_type = machine_type - instance_policy.accelerators = accelerators + instance_policy = batch_v1.AllocationPolicy.InstancePolicy() + instance_policy.disks = [disks] + instance_policy.machine_type = machine_type + instance_policy.accelerators = accelerators - return instance_policy + return instance_policy def build_attached_disk( - disk: batch_v1.types.AllocationPolicy.Disk, device_name: str + disk: batch_v1.types.AllocationPolicy.Disk, device_name: str ) -> batch_v1.types.AllocationPolicy.AttachedDisk: - attached_disk = batch_v1.AllocationPolicy.AttachedDisk() - attached_disk.new_disk = disk - attached_disk.device_name = device_name - return attached_disk + attached_disk = batch_v1.AllocationPolicy.AttachedDisk() + attached_disk.new_disk = disk + attached_disk.device_name = device_name + return attached_disk def build_persistent_disk( - size_gb: int, disk_type: str + size_gb: int, disk_type: str ) -> batch_v1.types.AllocationPolicy.Disk: - disk = batch_v1.AllocationPolicy.Disk() - disk.type = disk_type - disk.size_gb = size_gb - return disk + disk = batch_v1.AllocationPolicy.Disk() + disk.type = disk_type + disk.size_gb = size_gb + return disk def build_accelerators( - accelerator_type, - accelerator_count + accelerator_type, + accelerator_count ) -> MutableSequence[batch_v1.types.AllocationPolicy.Accelerator]: - accelerators = [] - accelerator = batch_v1.AllocationPolicy.Accelerator() - accelerator.count = accelerator_count - accelerator.type = accelerator_type - accelerators.append(accelerator) + accelerators = [] + accelerator = batch_v1.AllocationPolicy.Accelerator() + accelerator.count = accelerator_count + accelerator.type = accelerator_type + accelerators.append(accelerator) - return accelerators + return accelerators From c37c0ea32a3a168e4255f399ea2461a13a0aba31 Mon Sep 17 00:00:00 2001 From: mccstan Date: Tue, 12 Mar 2024 11:37:52 +0100 Subject: [PATCH 07/10] rollack useless changes --- dsub/providers/google_batch_operations.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index f2b4641..1f752af 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -163,22 +163,22 @@ def build_environment(env_vars: Dict[str, str]): def build_task_group( - task_spec: batch_v1.types.TaskSpec, - task_environments: List[batch_v1.types.Environment], - task_count: int, - task_count_per_node: int, + task_spec: batch_v1.types.TaskSpec, + task_environments: List[batch_v1.types.Environment], + task_count: int, + task_count_per_node: int, ) -> batch_v1.types.TaskGroup: """Build a TaskGroup object for a Batch request. -Args: - task_spec (TaskSpec): TaskSpec object - task_environments (List[Environment]): List of Environment objects - task_count (int): The number of total tasks in the job - task_count_per_node (int): The number of tasks to schedule on one VM + Args: + task_spec (TaskSpec): TaskSpec object + task_environments (List[Environment]): List of Environment objects + task_count (int): The number of total tasks in the job + task_count_per_node (int): The number of tasks to schedule on one VM -Returns: - A TaskGroup object. -""" + Returns: + A TaskGroup object. + """ task_group = batch_v1.TaskGroup() task_group.task_spec = task_spec task_group.task_environments = task_environments From 064668c092b476383e4730d2eea3e14983a1fdda Mon Sep 17 00:00:00 2001 From: mccstan Date: Tue, 12 Mar 2024 11:40:47 +0100 Subject: [PATCH 08/10] Put back spaces to avoid changes --- dsub/providers/google_batch_operations.py | 46 +++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/dsub/providers/google_batch_operations.py b/dsub/providers/google_batch_operations.py index 1f752af..ab072d8 100644 --- a/dsub/providers/google_batch_operations.py +++ b/dsub/providers/google_batch_operations.py @@ -163,10 +163,10 @@ def build_environment(env_vars: Dict[str, str]): def build_task_group( - task_spec: batch_v1.types.TaskSpec, - task_environments: List[batch_v1.types.Environment], - task_count: int, - task_count_per_node: int, + task_spec: batch_v1.types.TaskSpec, + task_environments: List[batch_v1.types.Environment], + task_count: int, + task_count_per_node: int, ) -> batch_v1.types.TaskGroup: """Build a TaskGroup object for a Batch request. @@ -209,19 +209,19 @@ def build_runnable( ) -> batch_v1.types.task.Runnable: """Build a Runnable object for a Batch request. -Args: - run_in_background (bool): True for the action to run in the background - always_run (bool): True for the action to run even in case of error from - prior actions - environment (Environment): Environment variables for action - image_uri (str): Docker image path - entrypoint (str): Docker image entrypoint path - volumes (List[str]): List of volume mounts (host_path:container_path) - commands (List[str]): Command arguments to pass to the entrypoint - -Returns: - An object representing a Runnable -""" + Args: + run_in_background (bool): True for the action to run in the background + always_run (bool): True for the action to run even in case of error from + prior actions + environment (Environment): Environment variables for action + image_uri (str): Docker image path + entrypoint (str): Docker image entrypoint path + volumes (List[str]): List of volume mounts (host_path:container_path) + commands (List[str]): Command arguments to pass to the entrypoint + + Returns: + An object representing a Runnable + """ container = build_container(image_uri, entrypoint, volumes, commands) runnable = batch_v1.Runnable() runnable.container = container @@ -234,13 +234,13 @@ def build_runnable( def build_volume(disk: str, path: str) -> batch_v1.types.Volume: """Build a Volume object for a Batch request. -Args: - disk (str): Name of disk to mount, as specified in the resources section. - path (str): Path to mount the disk at inside the container. + Args: + disk (str): Name of disk to mount, as specified in the resources section. + path (str): Path to mount the disk at inside the container. -Returns: - An object representing a Mount. -""" + Returns: + An object representing a Mount. + """ volume = batch_v1.Volume() volume.device_name = disk volume.mount_path = path From 564e035d88156258d42a19d53cb42b949eccd443 Mon Sep 17 00:00:00 2001 From: mccstan Date: Tue, 12 Mar 2024 11:45:06 +0100 Subject: [PATCH 09/10] Put back spaces on doc strings start --- dsub/providers/google_batch.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index e39c64e..d62278f 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -258,19 +258,19 @@ def _try_op_to_job_descriptor(self): def get_field(self, field: str, default: str = None): """Returns a value from the operation for a specific set of field names. -This is the implementation of base.Task's abstract get_field method. See -base.py get_field for more details. + This is the implementation of base.Task's abstract get_field method. See + base.py get_field for more details. -Args: - field: a dsub-specific job metadata key - default: default value to return if field does not exist or is empty. + Args: + field: a dsub-specific job metadata key + default: default value to return if field does not exist or is empty. -Returns: - A text string for the field or a list for 'inputs'. + Returns: + A text string for the field or a list for 'inputs'. -Raises: - ValueError: if the field label is not supported by the operation -""" + Raises: + ValueError: if the field label is not supported by the operation + """ value = None if field == 'internal-id': value = self._op.name @@ -370,12 +370,12 @@ def _try_op_to_script_body(self): def _operation_status(self): """Returns the status of this operation. -Raises: - ValueError: if the operation status cannot be determined. + Raises: + ValueError: if the operation status cannot be determined. -Returns: - A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). -""" + Returns: + A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). + """ if not google_batch_operations.is_done(self._op): return 'RUNNING' if google_batch_operations.is_success(self._op): @@ -838,19 +838,19 @@ def delete_jobs( ): """Kills the operations associated with the specified job or job.task. -Args: - user_ids: List of user ids who "own" the job(s) to cancel. - job_ids: List of job_ids to cancel. - task_ids: List of task-ids to cancel. - labels: List of LabelParam, each must match the job(s) to be canceled. - create_time_min: a timezone-aware datetime value for the earliest create - time of a task, inclusive. - create_time_max: a timezone-aware datetime value for the most recent - create time of a task, inclusive. - -Returns: - A list of tasks canceled and a list of error messages. -""" + Args: + user_ids: List of user ids who "own" the job(s) to cancel. + job_ids: List of job_ids to cancel. + task_ids: List of task-ids to cancel. + labels: List of LabelParam, each must match the job(s) to be canceled. + create_time_min: a timezone-aware datetime value for the earliest create + time of a task, inclusive. + create_time_max: a timezone-aware datetime value for the most recent + create time of a task, inclusive. + + Returns: + A list of tasks canceled and a list of error messages. + """ # Look up the job(s) tasks = list( self.lookup_job_tasks( From 25369a206a35920d9e8eea37a1a48e0e0644e81b Mon Sep 17 00:00:00 2001 From: mccstan Date: Tue, 12 Mar 2024 11:51:05 +0100 Subject: [PATCH 10/10] Remove last changes --- dsub/providers/google_batch.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/dsub/providers/google_batch.py b/dsub/providers/google_batch.py index d62278f..5d79e7a 100644 --- a/dsub/providers/google_batch.py +++ b/dsub/providers/google_batch.py @@ -258,19 +258,19 @@ def _try_op_to_job_descriptor(self): def get_field(self, field: str, default: str = None): """Returns a value from the operation for a specific set of field names. - This is the implementation of base.Task's abstract get_field method. See - base.py get_field for more details. + This is the implementation of base.Task's abstract get_field method. See + base.py get_field for more details. - Args: - field: a dsub-specific job metadata key - default: default value to return if field does not exist or is empty. + Args: + field: a dsub-specific job metadata key + default: default value to return if field does not exist or is empty. - Returns: - A text string for the field or a list for 'inputs'. + Returns: + A text string for the field or a list for 'inputs'. - Raises: - ValueError: if the field label is not supported by the operation - """ + Raises: + ValueError: if the field label is not supported by the operation + """ value = None if field == 'internal-id': value = self._op.name @@ -370,12 +370,12 @@ def _try_op_to_script_body(self): def _operation_status(self): """Returns the status of this operation. - Raises: - ValueError: if the operation status cannot be determined. + Raises: + ValueError: if the operation status cannot be determined. - Returns: - A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). - """ + Returns: + A printable status string (RUNNING, SUCCESS, CANCELED or FAILURE). + """ if not google_batch_operations.is_done(self._op): return 'RUNNING' if google_batch_operations.is_success(self._op):