From 8eb7201d20d28d74e892cd7489366e9db0c31855 Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Mon, 18 Sep 2023 10:24:31 -0600 Subject: [PATCH] WIP: Airflow: fix undefined poll_interval in Deferrable Operator (#2975) * Airflow: handle poll_interval attr in ArmadaJobCompleteTrigger Fix incomplete handling of 'poll_interval' attribute in ArmadaJobCompleteTrigger, used by the Armada Deferrable Operator for Airflow. Signed-off-by: Rich Scott * Airflow - add unit test for armada deferrable operator Run much of the same tests for the deferrable operator as for the regular operator, plus test serialization. Also, update interval signifier in examples. A full test of the deferrable operator that verifies the trigger handling is still needed. Signed-off-by: Rich Scott --------- Signed-off-by: Rich Scott --- docs/python_airflow_operator.md | 28 ++- .../armada/operators/armada_deferrable.py | 39 +++- third_party/airflow/armada/operators/utils.py | 4 +- third_party/airflow/examples/big_armada.py | 2 +- .../tests/unit/test_airflow_operator_mock.py | 4 +- .../unit/test_armada_deferrable_operator.py | 171 ++++++++++++++++++ .../test_search_for_job_complete_asyncio.py | 5 + 7 files changed, 245 insertions(+), 8 deletions(-) create mode 100644 third_party/airflow/tests/unit/test_armada_deferrable_operator.py diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index c74a464751d..048667a2562 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -239,9 +239,27 @@ Reports the result of the job and returns. +#### serialize() +Get a serialized version of this object. + + +* **Returns** + + A dict of keyword arguments used when instantiating + + + +* **Return type** + + dict + + +this object. + + #### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) -### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name) +### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30) Bases: `BaseTrigger` An airflow trigger that monitors the job state of an armada job. @@ -269,6 +287,9 @@ Triggers when the job is complete. belongs. + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **Returns** @@ -664,7 +685,7 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED -### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, time_out_for_failure=7200) +### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200) Poll JobService cache asyncronously until you get a terminated event. A terminated event is SUCCEEDED, FAILED or CANCELLED @@ -689,6 +710,9 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED It is optional only for testing + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **time_out_for_failure** (*int*) – The amount of time a job can be in job_id_not_found before we decide it was a invalid job diff --git a/third_party/airflow/armada/operators/armada_deferrable.py b/third_party/airflow/armada/operators/armada_deferrable.py index 2f53a702228..f7aa1413637 100644 --- a/third_party/airflow/armada/operators/armada_deferrable.py +++ b/third_party/airflow/armada/operators/armada_deferrable.py @@ -103,6 +103,25 @@ def __init__( self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval + def serialize(self) -> dict: + """ + Get a serialized version of this object. + + :return: A dict of keyword arguments used when instantiating + this object. + """ + + return { + "task_id": self.task_id, + "name": self.name, + "armada_channel_args": self.armada_channel_args.serialize(), + "job_service_channel_args": self.job_service_channel_args.serialize(), + "armada_queue": self.armada_queue, + "job_request_items": self.job_request_items, + "lookout_url_template": self.lookout_url_template, + "poll_interval": self.poll_interval, + } + def execute(self, context) -> None: """ Executes the Armada Operator. Only meant to be called by airflow. @@ -156,6 +175,7 @@ def execute(self, context) -> None: armada_queue=self.armada_queue, job_set_id=context["run_id"], airflow_task_name=self.name, + poll_interval=self.poll_interval, ), method_name="resume_job_complete", kwargs={"job_id": job_id}, @@ -216,6 +236,7 @@ class ArmadaJobCompleteTrigger(BaseTrigger): :param job_set_id: The ID of the job set. :param airflow_task_name: Name of the airflow task to which this trigger belongs. + :param poll_interval: How often to poll jobservice to get status. :return: An armada job complete trigger instance. """ @@ -226,6 +247,7 @@ def __init__( armada_queue: str, job_set_id: str, airflow_task_name: str, + poll_interval: int = 30, ) -> None: super().__init__() self.job_id = job_id @@ -233,6 +255,7 @@ def __init__( self.armada_queue = armada_queue self.job_set_id = job_set_id self.airflow_task_name = airflow_task_name + self.poll_interval = poll_interval def serialize(self) -> tuple: return ( @@ -243,9 +266,21 @@ def serialize(self) -> tuple: "armada_queue": self.armada_queue, "job_set_id": self.job_set_id, "airflow_task_name": self.airflow_task_name, + "poll_interval": self.poll_interval, }, ) + def __eq__(self, o): + return ( + self.task_id == o.task_id + and self.job_id == o.job_id + and self.job_service_channel_args == o.job_service_channel_args + and self.armada_queue == o.armada_queue + and self.job_set_id == o.job_set_id + and self.airflow_task_name == o.airflow_task_name + and self.poll_interval == o.poll_interval + ) + async def run(self): """ Runs the trigger. Meant to be called by an airflow triggerer process. @@ -255,12 +290,12 @@ async def run(self): ) job_state, job_message = await search_for_job_complete_async( - job_service_client=job_service_client, armada_queue=self.armada_queue, job_set_id=self.job_set_id, airflow_task_name=self.airflow_task_name, job_id=self.job_id, - poll_interval=self.poll_interval, + job_service_client=job_service_client, log=self.log, + poll_interval=self.poll_interval, ) yield TriggerEvent({"job_state": job_state, "job_message": job_message}) diff --git a/third_party/airflow/armada/operators/utils.py b/third_party/airflow/armada/operators/utils.py index e3c68beb321..1ab7fa35d04 100644 --- a/third_party/airflow/armada/operators/utils.py +++ b/third_party/airflow/armada/operators/utils.py @@ -217,6 +217,7 @@ async def search_for_job_complete_async( job_id: str, job_service_client: JobServiceAsyncIOClient, log, + poll_interval: int, time_out_for_failure: int = 7200, ) -> Tuple[JobState, str]: """ @@ -231,6 +232,7 @@ async def search_for_job_complete_async( :param job_id: The name of the job id that armada assigns to it :param job_service_client: A JobServiceClient that is used for polling. It is optional only for testing + :param poll_interval: How often to poll jobservice to get status. :param time_out_for_failure: The amount of time a job can be in job_id_not_found before we decide it was a invalid job @@ -251,7 +253,7 @@ async def search_for_job_complete_async( job_state = job_state_from_pb(job_status_return.state) log.debug(f"Got job state '{job_state.name}' for job {job_id}") - await asyncio.sleep(3) + await asyncio.sleep(poll_interval) if job_state == JobState.SUCCEEDED: job_message = f"Armada {airflow_task_name}:{job_id} succeeded" diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index f1196307227..dc64cdc76b2 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -57,7 +57,7 @@ def submit_sleep_job(): with DAG( dag_id="big_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) as dag: diff --git a/third_party/airflow/tests/unit/test_airflow_operator_mock.py b/third_party/airflow/tests/unit/test_airflow_operator_mock.py index 4634e644795..1ab2d37ced1 100644 --- a/third_party/airflow/tests/unit/test_airflow_operator_mock.py +++ b/third_party/airflow/tests/unit/test_airflow_operator_mock.py @@ -170,7 +170,7 @@ def test_annotate_job_request_items(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) @@ -204,7 +204,7 @@ def test_parameterize_armada_operator(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) diff --git a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py new file mode 100644 index 00000000000..0f156ed177e --- /dev/null +++ b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py @@ -0,0 +1,171 @@ +import copy + +import pytest + +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) +from armada.operators.armada_deferrable import ArmadaDeferrableOperator +from armada.operators.grpc import CredentialsCallback + + +def test_serialize_armada_deferrable(): + grpc_chan_args = { + "target": "localhost:443", + "credentials_callback_args": { + "module_name": "channel_test", + "function_name": "get_credentials", + "function_kwargs": { + "example_arg": "test", + }, + }, + } + + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="busybox", + args=["sleep", "10s"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + ), + ) + ], + ) + + job_requests = [ + submit_pb2.JobSubmitRequestItem( + priority=1, + pod_spec=pod, + namespace="personal-anonymous", + annotations={"armadaproject.io/hello": "world"}, + ) + ] + + source = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test task", + armada_channel_args=grpc_chan_args, + job_service_channel_args=grpc_chan_args, + armada_queue="test-queue", + job_request_items=job_requests, + lookout_url_template="https://lookout.test.domain/", + poll_interval=5, + ) + + serialized = source.serialize() + assert serialized["name"] == source.name + + reconstituted = ArmadaDeferrableOperator(**serialized) + assert reconstituted == source + + +get_lookout_url_test_cases = [ + ( + "http://localhost:8089/jobs?job_id=", + "test_id", + "http://localhost:8089/jobs?job_id=test_id", + ), + ( + "https://lookout.armada.domain/jobs?job_id=", + "test_id", + "https://lookout.armada.domain/jobs?job_id=test_id", + ), + ("", "test_id", ""), + (None, "test_id", ""), +] + + +@pytest.mark.parametrize( + "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases +) +def test_get_lookout_url(lookout_url_template, job_id, expected_url): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template=lookout_url_template, + ) + + assert operator._get_lookout_url(job_id) == expected_url + + +def test_deepcopy_operator(): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def test_deepcopy_operator_with_grpc_credentials_callback(): + armada_channel_args = { + "target": "127.0.0.1:50051", + "credentials_callback_args": { + "module_name": "tests.unit.test_armada_operator", + "function_name": "__example_test_callback", + "function_kwargs": { + "test_arg": "fake_arg", + }, + }, + } + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def __example_test_callback(foo=None): + return f"fake_cred {foo}" + + +def test_credentials_callback(): + callback = CredentialsCallback( + module_name="test_armada_operator", + function_name="__example_test_callback", + function_kwargs={"foo": "bar"}, + ) + + result = callback.call() + assert result == "fake_cred bar" diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py index 83cc3e220aa..a842fa994d3 100644 --- a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py +++ b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py @@ -71,6 +71,7 @@ async def test_failed_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.FAILED assert ( @@ -89,6 +90,7 @@ async def test_successful_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded" @@ -104,6 +106,7 @@ async def test_cancelled_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.CANCELLED assert job_complete[1] == "Armada test:test_cancelled cancelled" @@ -119,6 +122,7 @@ async def test_job_id_not_found(js_aio_client): time_out_for_failure=5, job_service_client=js_aio_client, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.JOB_ID_NOT_FOUND assert ( @@ -142,6 +146,7 @@ async def test_error_retry(js_aio_retry_client): job_service_client=js_aio_retry_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded"