diff --git a/.travis.yml b/.travis.yml index d29b04f..3b2d42a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ python: - "3.7" - "3.8" install: - - pip install apache-airflow boto3 pylint isort + - pip install apache-airflow boto3 pylint isort marshmallow env: - AIRFLOW__BATCH__REGION=us-west-1 AIRFLOW__BATCH__JOB_NAME=some-job-name AIRFLOW__BATCH__JOB_QUEUE=some-job-queue AIRFLOW__BATCH__JOB_DEFINITION=some-job-def AIRFLOW__ECS_FARGATE__REGION=us-west-1 AIRFLOW__ECS_FARGATE__CLUSTER=some-cluster AIRFLOW__ECS_FARGATE__CONTAINER_NAME=some-container-name AIRFLOW__ECS_FARGATE__TASK_DEFINITION=some-task-def AIRFLOW__ECS_FARGATE__LAUNCH_TYPE=FARGATE AIRFLOW__ECS_FARGATE__PLATFORM_VERSION=LATEST AIRFLOW__ECS_FARGATE__ASSIGN_PUBLIC_IP=DISABLED AIRFLOW__ECS_FARGATE__SECURITY_GROUPS=SG1,SG2 AIRFLOW__ECS_FARGATE__SUBNETS=SUB1,SUB2 script: diff --git a/airflow_aws_executors/batch_executor.py b/airflow_aws_executors/batch_executor.py index d343ac9..e1582eb 100644 --- a/airflow_aws_executors/batch_executor.py +++ b/airflow_aws_executors/batch_executor.py @@ -9,7 +9,7 @@ from airflow.executors.base_executor import BaseExecutor from airflow.utils.module_loading import import_string from airflow.utils.state import State -from marshmallow import Schema, fields, post_load +from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load CommandType = List[str] TaskInstanceKeyType = Tuple[Any] @@ -105,16 +105,17 @@ def _describe_tasks(self, job_ids) -> List[BatchJob]: for i in range((len(job_ids) // max_batch_size) + 1): batched_job_ids = job_ids[i * max_batch_size: (i + 1) * max_batch_size] boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids) - describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks) - if describe_tasks_response.errors: + try: + describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks) + except ValidationError as err: self.log.error('Batch DescribeJobs API Response: %s', boto_describe_tasks) raise BatchError( 'DescribeJobs API call does not match expected JSON shape. ' 'Are you sure that the correct version of Boto3 is installed? {}'.format( - describe_tasks_response.errors + err ) ) - all_jobs.extend(describe_tasks_response.data['jobs']) + all_jobs.extend(describe_tasks_response['jobs']) return all_jobs def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None): @@ -135,16 +136,17 @@ def _submit_job(self, cmd: CommandType, exec_config: ExecutorConfigType) -> str: submit_job_api['containerOverrides'].update(exec_config) submit_job_api['containerOverrides']['command'] = cmd boto_run_task = self.batch.submit_job(**submit_job_api) - submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task) - if submit_job_response.errors: - self.log.error('Batch SubmitJob Response: %s', submit_job_response) + try: + submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task) + except ValidationError as err: + self.log.error('Batch SubmitJob Response: %s', err) raise BatchError( 'RunTask API call does not match expected JSON shape. ' 'Are you sure that the correct version of Boto3 is installed? {}'.format( - submit_job_response.errors + err ) ) - return submit_job_response.data['job_id'] + return submit_job_response['job_id'] def end(self, heartbeat_interval=10): """ @@ -213,29 +215,38 @@ def __len__(self): class BatchSubmitJobResponseSchema(Schema): """API Response for SubmitJob""" # The unique identifier for the job. - job_id = fields.String(load_from='jobId', required=True) + job_id = fields.String(data_key='jobId', required=True) + + class Meta: + unknown = EXCLUDE class BatchJobDetailSchema(Schema): """API Response for Describe Jobs""" # The unique identifier for the job. - job_id = fields.String(load_from='jobId', required=True) + job_id = fields.String(data_key='jobId', required=True) # The current status for the job: 'SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED' status = fields.String(required=True) # A short, human-readable string to provide additional details about the current status of the job. - status_reason = fields.String(load_from='statusReason') + status_reason = fields.String(data_key='statusReason') @post_load def make_job(self, data, **kwargs): - """Overwrites marshmallow data property to return an instance of BatchJob instead of a dictionary""" + """Overwrites marshmallow load() to return an instance of BatchJob instead of a dictionary""" return BatchJob(**data) + class Meta: + unknown = EXCLUDE + class BatchDescribeJobsResponseSchema(Schema): """API Response for Describe Jobs""" # The list of jobs jobs = fields.List(fields.Nested(BatchJobDetailSchema), required=True) + class Meta: + unknown = EXCLUDE + class BatchError(Exception): """Thrown when something unexpected has occurred within the AWS Batch ecosystem""" diff --git a/airflow_aws_executors/conf.py b/airflow_aws_executors/conf.py index b6b0a53..6737eb3 100644 --- a/airflow_aws_executors/conf.py +++ b/airflow_aws_executors/conf.py @@ -25,7 +25,8 @@ from airflow.configuration import conf -def has_option(section, config_name): +def has_option(section, config_name) -> bool: + """Returns True if configuration has a section and an option""" if conf.has_option(section, config_name): config_val = conf.get(section, config_name) return config_val is not None and config_val != '' diff --git a/airflow_aws_executors/ecs_fargate_executor.py b/airflow_aws_executors/ecs_fargate_executor.py index 15c3f39..c339ef3 100644 --- a/airflow_aws_executors/ecs_fargate_executor.py +++ b/airflow_aws_executors/ecs_fargate_executor.py @@ -10,14 +10,14 @@ from airflow.executors.base_executor import BaseExecutor from airflow.utils.module_loading import import_string from airflow.utils.state import State -from marshmallow import Schema, fields, post_load +from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load CommandType = List[str] TaskInstanceKeyType = Tuple[Any] ExecutorConfigFunctionType = Callable[[CommandType], dict] -EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'executor_config')) +EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'queue', 'executor_config')) ExecutorConfigType = Dict[str, Any] -EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'config')) +EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'queue', 'config')) class EcsFargateTask: @@ -147,17 +147,18 @@ def __describe_tasks(self, task_arns): for i in range((len(task_arns) // self.DESCRIBE_TASKS_BATCH_SIZE) + 1): batched_task_arns = task_arns[i * self.DESCRIBE_TASKS_BATCH_SIZE: (i + 1) * self.DESCRIBE_TASKS_BATCH_SIZE] boto_describe_tasks = self.ecs.describe_tasks(tasks=batched_task_arns, cluster=self.cluster) - describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks) - if describe_tasks_response.errors: + try: + describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks) + except ValidationError as err: self.log.error('ECS DescribeTask Response: %s', boto_describe_tasks) raise EcsFargateError( 'DescribeTasks API call does not match expected JSON shape. ' 'Are you sure that the correct version of Boto3 is installed? {}'.format( - describe_tasks_response.errors + err ) ) - all_task_descriptions['tasks'].extend(describe_tasks_response.data['tasks']) - all_task_descriptions['failures'].extend(describe_tasks_response.data['failures']) + all_task_descriptions['tasks'].extend(describe_tasks_response['tasks']) + all_task_descriptions['failures'].extend(describe_tasks_response['failures']) return all_task_descriptions def __handle_failed_task(self, task_arn: str, reason: str): @@ -166,14 +167,14 @@ def __handle_failed_task(self, task_arn: str, reason: str): ECS/Fargate Cloud. If an API failure occurs the task is simply rescheduled. """ task_key = self.active_workers.arn_to_key[task_arn] - task_cmd, exec_info = self.active_workers.info_by_key(task_key) + task_cmd, queue, exec_info = self.active_workers.info_by_key(task_key) failure_count = self.active_workers.failure_count_by_key(task_key) if failure_count < self.__class__.MAX_FAILURE_CHECKS: self.log.warning('Task %s has failed due to %s. ' 'Failure %s out of %s occurred on %s. Rescheduling.', task_key, reason, failure_count, self.__class__.MAX_FAILURE_CHECKS, task_arn) self.active_workers.increment_failure_count(task_key) - self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, exec_info)) + self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, queue, exec_info)) else: self.log.error('Task %s has failed a maximum of %s times. Marking as failed', task_key, failure_count) @@ -192,8 +193,8 @@ def attempt_task_runs(self): failure_reasons = defaultdict(int) for _ in range(queue_len): ecs_task = self.pending_tasks.popleft() - task_key, cmd, exec_config = ecs_task - run_task_response = self.__run_task(cmd, exec_config) + task_key, cmd, queue, exec_config = ecs_task + run_task_response = self._run_task(task_key, cmd, queue, exec_config) if run_task_response['failures']: for f in run_task_response['failures']: failure_reasons[f['reason']] += 1 @@ -203,39 +204,53 @@ def attempt_task_runs(self): raise EcsFargateError('No failures and no tasks provided in response. This should never happen.') else: task = run_task_response['tasks'][0] - self.active_workers.add_task(task, task_key, cmd, exec_config) + self.active_workers.add_task(task, task_key, queue, cmd, exec_config) if failure_reasons: self.log.debug('Pending tasks failed to launch for the following reasons: %s. Will retry later.', dict(failure_reasons)) - def __run_task(self, cmd: CommandType, exec_config: ExecutorConfigType): + def _run_task(self, task_id: TaskInstanceKeyType, cmd: CommandType, queue: str, exec_config: ExecutorConfigType): """ + This function is the actual attempt to run a queued-up airflow task. Not to be confused with + execute_async() which inserts tasks into the queue. The command and executor config will be placed in the container-override section of the JSON request, before calling Boto3's "run_task" function. """ - run_task_api = deepcopy(self.run_task_kwargs) - container_override = self.get_container(run_task_api['overrides']['containerOverrides']) - container_override['command'] = cmd - container_override.update(exec_config) + run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config) boto_run_task = self.ecs.run_task(**run_task_api) - run_task_response = BotoRunTaskSchema().load(boto_run_task) - if run_task_response.errors: - self.log.error('ECS RunTask Response: %s', run_task_response) + try: + run_task_response = BotoRunTaskSchema().load(boto_run_task) + except ValidationError as err: + self.log.error('ECS RunTask Response: %s', err) raise EcsFargateError( 'RunTask API call does not match expected JSON shape. ' 'Are you sure that the correct version of Boto3 is installed? {}'.format( - run_task_response.errors + err ) ) - return run_task_response.data + return run_task_response + + def _run_task_kwargs(self, task_id: TaskInstanceKeyType, cmd: CommandType, + queue: str, exec_config: ExecutorConfigType) -> dict: + """ + This modifies the standard kwargs to be specific to this task by overriding the airflow command and updating + the container overrides. + + One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client. + """ + run_task_api = deepcopy(self.run_task_kwargs) + container_override = self.get_container(run_task_api['overrides']['containerOverrides']) + container_override['command'] = cmd + container_override.update(exec_config) + return run_task_api def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None): """ - Save the task to be executed in the next sync using Boto3's RunTask API + Save the task to be executed in the next sync by inserting the commands into a queue. """ if executor_config and ('name' in executor_config or 'command' in executor_config): raise ValueError('Executor Config should never override "name" or "command"') - self.pending_tasks.append(EcsFargateQueuedTask(key, command, executor_config or {})) + self.pending_tasks.append(EcsFargateQueuedTask(key, command, queue, executor_config or {})) def end(self, heartbeat_interval=10): """ @@ -298,14 +313,14 @@ def __init__(self): self.key_to_failure_counts: Dict[TaskInstanceKeyType, int] = defaultdict(int) self.key_to_task_info: Dict[TaskInstanceKeyType, EcsFargateTaskInfo] = {} - def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, airflow_cmd: CommandType, - exec_config: ExecutorConfigType): + def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, queue: str, + airflow_cmd: CommandType, exec_config: ExecutorConfigType): """Adds a task to the collection""" arn = task.task_arn self.tasks[arn] = task self.key_to_arn[airflow_task_key] = arn self.arn_to_key[arn] = airflow_task_key - self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, exec_config) + self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, queue, exec_config) def update_task(self, task: EcsFargateTask): """Updates the state of the given task based on task ARN""" @@ -366,28 +381,34 @@ class BotoContainerSchema(Schema): Botocore Serialization Object for ECS 'Container' shape. Note that there are many more parameters, but the executor only needs the members listed below. """ - exit_code = fields.Integer(load_from='exitCode') - last_status = fields.String(load_from='lastStatus') + exit_code = fields.Integer(data_key='exitCode') + last_status = fields.String(data_key='lastStatus') name = fields.String(required=True) + class Meta: + unknown = EXCLUDE + class BotoTaskSchema(Schema): """ Botocore Serialization Object for ECS 'Task' shape. Note that there are many more parameters, but the executor only needs the members listed below. """ - task_arn = fields.String(load_from='taskArn', required=True) - last_status = fields.String(load_from='lastStatus', required=True) - desired_status = fields.String(load_from='desiredStatus', required=True) + task_arn = fields.String(data_key='taskArn', required=True) + last_status = fields.String(data_key='lastStatus', required=True) + desired_status = fields.String(data_key='desiredStatus', required=True) containers = fields.List(fields.Nested(BotoContainerSchema), required=True) - started_at = fields.Field(load_from='startedAt') - stopped_reason = fields.String(load_from='stoppedReason') + started_at = fields.Field(data_key='startedAt') + stopped_reason = fields.String(data_key='stoppedReason') @post_load def make_task(self, data, **kwargs): - """Overwrites marshmallow .data property to return an instance of EcsFargateTask instead of a dictionary""" + """Overwrites marshmallow load() to return an instance of EcsFargateTask instead of a dictionary""" return EcsFargateTask(**data) + class Meta: + unknown = EXCLUDE + class BotoFailureSchema(Schema): """ @@ -396,6 +417,9 @@ class BotoFailureSchema(Schema): arn = fields.String() reason = fields.String() + class Meta: + unknown = EXCLUDE + class BotoRunTaskSchema(Schema): """ @@ -404,6 +428,9 @@ class BotoRunTaskSchema(Schema): tasks = fields.List(fields.Nested(BotoTaskSchema), required=True) failures = fields.List(fields.Nested(BotoFailureSchema), required=True) + class Meta: + unknown = EXCLUDE + class BotoDescribeTasksSchema(Schema): """ @@ -412,6 +439,9 @@ class BotoDescribeTasksSchema(Schema): tasks = fields.List(fields.Nested(BotoTaskSchema), required=True) failures = fields.List(fields.Nested(BotoFailureSchema), required=True) + class Meta: + unknown = EXCLUDE + class EcsFargateError(Exception): """Thrown when something unexpected has occurred within the AWS ECS/Fargate ecosystem""" diff --git a/setup.py b/setup.py index a05f7fc..6ea6b9d 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name="airflow-aws-executors", - version="1.0.0", + version="1.1.0", description=description, long_description=long_description, long_description_content_type="text/markdown", @@ -29,5 +29,5 @@ ], packages=["airflow_aws_executors"], include_package_data=True, - install_requires=["boto3", "apache-airflow>=1.10.5"] + install_requires=["boto3", "apache-airflow>=1.10.5", "marshmallow>=3"] ) diff --git a/tests/test_batch_executor.py b/tests/test_batch_executor.py index 63cbe37..78373ae 100644 --- a/tests/test_batch_executor.py +++ b/tests/test_batch_executor.py @@ -112,8 +112,7 @@ def test_sync(self, success_mock, fail_mock): # sanity check that container's status code is mocked to success loaded_batch_job = BatchJobDetailSchema().load(after_sync_reponse) - self.assertFalse(loaded_batch_job.errors, msg='Mocked message is not like defined schema') - self.assertEqual(State.SUCCESS, loaded_batch_job.data.get_job_state()) + self.assertEqual(State.SUCCESS, loaded_batch_job.get_job_state()) self.executor.sync() @@ -135,7 +134,7 @@ def test_failed_sync(self, success_mock, fail_mock): # set container's status code to failure & sanity-check after_sync_reponse['status'] = 'FAILED' - self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_reponse).data.get_job_state()) + self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_reponse).get_job_state()) self.executor.sync() # ensure that run_task is called correctly as defined by Botocore docs @@ -152,7 +151,7 @@ def test_terminate(self): """Test that executor can shut everything down; forcing all tasks to unnaturally exit""" mocked_job_json = self.__mock_sync() mocked_job_json['status'] = 'FAILED' - self.assertEqual(State.FAILED, BatchJobDetailSchema().load(mocked_job_json).data.get_job_state()) + self.assertEqual(State.FAILED, BatchJobDetailSchema().load(mocked_job_json).get_job_state()) self.executor.terminate() @@ -190,13 +189,6 @@ def setUp(self) -> None: def __set_mocked_executor(self): """Mock ECS such that there's nothing wrong with anything""" from airflow.configuration import conf - - if not conf.has_section('batch'): - conf.add_section('batch') - conf.set('batch', 'region', 'us-west-1') - conf.set('batch', 'job_name', 'some-job-name') - conf.set('batch', 'job_queue', 'some-job-queue') - conf.set('batch', 'job_definition', 'some-job-def') executor = AwsBatchExecutor() executor.start() diff --git a/tests/test_ecs_fargate_executor.py b/tests/test_ecs_fargate_executor.py index 2d6ed32..e79d5f8 100644 --- a/tests/test_ecs_fargate_executor.py +++ b/tests/test_ecs_fargate_executor.py @@ -20,6 +20,7 @@ def test_get_and_add(self): self.assertEqual(self.collection['001'], self.first_task) self.assertEqual(self.collection.task_by_key(self.first_airflow_key), self.first_task) self.assertEqual(self.collection.info_by_key(self.first_airflow_key).cmd, self.first_airflow_cmd) + self.assertEqual(self.collection.info_by_key(self.first_airflow_key).queue, self.first_airflow_queue) self.assertEqual(self.collection.info_by_key(self.first_airflow_key).config, self.first_airflow_exec_config) # Check basic get for second task @@ -27,6 +28,7 @@ def test_get_and_add(self): self.assertEqual(self.collection['002'], self.second_task) self.assertEqual(self.collection.task_by_key(self.second_airflow_key), self.second_task) self.assertEqual(self.collection.info_by_key(self.second_airflow_key).cmd, self.second_airflow_cmd) + self.assertEqual(self.collection.info_by_key(self.second_airflow_key).queue, self.second_airflow_queue) self.assertEqual(self.collection.info_by_key(self.second_airflow_key).config, self.second_airflow_exec_config) def test_list(self): @@ -72,9 +74,10 @@ def setUp(self): self.first_task.task_arn = '001' self.first_airflow_key = mock.Mock(spec=tuple) self.first_airflow_cmd = mock.Mock(spec=list) + self.first_airflow_queue = mock.Mock(spec=str) self.first_airflow_exec_config = mock.Mock(spec=dict) self.collection.add_task( - self.first_task, self.first_airflow_key, + self.first_task, self.first_airflow_key, self.first_airflow_queue, self.first_airflow_cmd, self.first_airflow_exec_config ) # Add second task @@ -82,9 +85,10 @@ def setUp(self): self.second_task.task_arn = '002' self.second_airflow_key = mock.Mock(spec=tuple) self.second_airflow_cmd = mock.Mock(spec=list) + self.second_airflow_queue = mock.Mock(spec=str) self.second_airflow_exec_config = mock.Mock(spec=dict) self.collection.add_task( - self.second_task, self.second_airflow_key, + self.second_task, self.second_airflow_key, self.second_airflow_queue, self.second_airflow_cmd, self.second_airflow_exec_config ) @@ -207,8 +211,7 @@ def test_sync(self, success_mock, fail_mock): """Test synch from end-to-end""" after_fargate_json = self.__mock_sync() loaded_fargate_json = BotoTaskSchema().load(after_fargate_json) - self.assertFalse(loaded_fargate_json.errors, msg='Mocked message is not like defined schema') - self.assertEqual(State.SUCCESS, loaded_fargate_json.data.get_task_state()) + self.assertEqual(State.SUCCESS, loaded_fargate_json.get_task_state()) self.executor.sync_running_tasks() @@ -230,7 +233,7 @@ def test_failed_sync(self, success_mock, fail_mock): # set container's exit code to failure after_fargate_json['containers'][0]['exitCode'] = 100 - self.assertEqual(State.FAILED, BotoTaskSchema().load(after_fargate_json).data.get_task_state()) + self.assertEqual(State.FAILED, BotoTaskSchema().load(after_fargate_json).get_task_state()) self.executor.sync() # ensure that run_task is called correctly as defined by Botocore docs @@ -281,7 +284,7 @@ def test_terminate(self): """Test that executor can shut everything down; forcing all tasks to unnaturally exit""" after_fargate_task = self.__mock_sync() after_fargate_task['containers'][0]['exitCode'] = 100 - self.assertEqual(State.FAILED, BotoTaskSchema().load(after_fargate_task).data.get_task_state()) + self.assertEqual(State.FAILED, BotoTaskSchema().load(after_fargate_task).get_task_state()) self.executor.terminate() @@ -316,15 +319,6 @@ def setUp(self) -> None: def __set_mocked_executor(self): """Mock ECS such that there's nothing wrong with anything""" - from airflow.configuration import conf - - if not conf.has_section('ecs_fargate'): - conf.add_section('ecs_fargate') - conf.set('ecs_fargate', 'region', 'us-west-1') - conf.set('ecs_fargate', 'cluster', 'some-ecs-cluster') - conf.set('ecs_fargate', 'task_definition', 'some-ecs-task-definition') - conf.set('ecs_fargate', 'container_name', 'some-ecs-container') - conf.set('ecs_fargate', 'launch_type', 'FARGATE') executor = AwsEcsFargateExecutor() executor.start() @@ -350,8 +344,10 @@ def __mock_sync(self): airflow_cmd = mock.Mock(spec=list) airflow_key = mock.Mock(spec=tuple) + airflow_queue = mock.Mock(spec=str) airflow_exec_conf = mock.Mock(spec=dict) - self.executor.active_workers.add_task(before_fargate_task, airflow_key, airflow_cmd, airflow_exec_conf) + self.executor.active_workers.add_task(before_fargate_task, airflow_key, airflow_queue, + airflow_cmd, airflow_exec_conf) after_task_json = { 'taskArn': 'ABC',