diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index c74a464751d..048667a2562 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -239,9 +239,27 @@ Reports the result of the job and returns. +#### serialize() +Get a serialized version of this object. + + +* **Returns** + + A dict of keyword arguments used when instantiating + + + +* **Return type** + + dict + + +this object. + + #### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) -### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name) +### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30) Bases: `BaseTrigger` An airflow trigger that monitors the job state of an armada job. @@ -269,6 +287,9 @@ Triggers when the job is complete. belongs. + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **Returns** @@ -664,7 +685,7 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED -### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, time_out_for_failure=7200) +### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200) Poll JobService cache asyncronously until you get a terminated event. A terminated event is SUCCEEDED, FAILED or CANCELLED @@ -689,6 +710,9 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED It is optional only for testing + * **poll_interval** (*int*) – How often to poll jobservice to get status. + + * **time_out_for_failure** (*int*) – The amount of time a job can be in job_id_not_found before we decide it was a invalid job diff --git a/third_party/airflow/armada/operators/armada_deferrable.py b/third_party/airflow/armada/operators/armada_deferrable.py index 2f53a702228..f7aa1413637 100644 --- a/third_party/airflow/armada/operators/armada_deferrable.py +++ b/third_party/airflow/armada/operators/armada_deferrable.py @@ -103,6 +103,25 @@ def __init__( self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval + def serialize(self) -> dict: + """ + Get a serialized version of this object. + + :return: A dict of keyword arguments used when instantiating + this object. + """ + + return { + "task_id": self.task_id, + "name": self.name, + "armada_channel_args": self.armada_channel_args.serialize(), + "job_service_channel_args": self.job_service_channel_args.serialize(), + "armada_queue": self.armada_queue, + "job_request_items": self.job_request_items, + "lookout_url_template": self.lookout_url_template, + "poll_interval": self.poll_interval, + } + def execute(self, context) -> None: """ Executes the Armada Operator. Only meant to be called by airflow. @@ -156,6 +175,7 @@ def execute(self, context) -> None: armada_queue=self.armada_queue, job_set_id=context["run_id"], airflow_task_name=self.name, + poll_interval=self.poll_interval, ), method_name="resume_job_complete", kwargs={"job_id": job_id}, @@ -216,6 +236,7 @@ class ArmadaJobCompleteTrigger(BaseTrigger): :param job_set_id: The ID of the job set. :param airflow_task_name: Name of the airflow task to which this trigger belongs. + :param poll_interval: How often to poll jobservice to get status. :return: An armada job complete trigger instance. """ @@ -226,6 +247,7 @@ def __init__( armada_queue: str, job_set_id: str, airflow_task_name: str, + poll_interval: int = 30, ) -> None: super().__init__() self.job_id = job_id @@ -233,6 +255,7 @@ def __init__( self.armada_queue = armada_queue self.job_set_id = job_set_id self.airflow_task_name = airflow_task_name + self.poll_interval = poll_interval def serialize(self) -> tuple: return ( @@ -243,9 +266,21 @@ def serialize(self) -> tuple: "armada_queue": self.armada_queue, "job_set_id": self.job_set_id, "airflow_task_name": self.airflow_task_name, + "poll_interval": self.poll_interval, }, ) + def __eq__(self, o): + return ( + self.task_id == o.task_id + and self.job_id == o.job_id + and self.job_service_channel_args == o.job_service_channel_args + and self.armada_queue == o.armada_queue + and self.job_set_id == o.job_set_id + and self.airflow_task_name == o.airflow_task_name + and self.poll_interval == o.poll_interval + ) + async def run(self): """ Runs the trigger. Meant to be called by an airflow triggerer process. @@ -255,12 +290,12 @@ async def run(self): ) job_state, job_message = await search_for_job_complete_async( - job_service_client=job_service_client, armada_queue=self.armada_queue, job_set_id=self.job_set_id, airflow_task_name=self.airflow_task_name, job_id=self.job_id, - poll_interval=self.poll_interval, + job_service_client=job_service_client, log=self.log, + poll_interval=self.poll_interval, ) yield TriggerEvent({"job_state": job_state, "job_message": job_message}) diff --git a/third_party/airflow/armada/operators/utils.py b/third_party/airflow/armada/operators/utils.py index e3c68beb321..1ab7fa35d04 100644 --- a/third_party/airflow/armada/operators/utils.py +++ b/third_party/airflow/armada/operators/utils.py @@ -217,6 +217,7 @@ async def search_for_job_complete_async( job_id: str, job_service_client: JobServiceAsyncIOClient, log, + poll_interval: int, time_out_for_failure: int = 7200, ) -> Tuple[JobState, str]: """ @@ -231,6 +232,7 @@ async def search_for_job_complete_async( :param job_id: The name of the job id that armada assigns to it :param job_service_client: A JobServiceClient that is used for polling. It is optional only for testing + :param poll_interval: How often to poll jobservice to get status. :param time_out_for_failure: The amount of time a job can be in job_id_not_found before we decide it was a invalid job @@ -251,7 +253,7 @@ async def search_for_job_complete_async( job_state = job_state_from_pb(job_status_return.state) log.debug(f"Got job state '{job_state.name}' for job {job_id}") - await asyncio.sleep(3) + await asyncio.sleep(poll_interval) if job_state == JobState.SUCCEEDED: job_message = f"Armada {airflow_task_name}:{job_id} succeeded" diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index f1196307227..dc64cdc76b2 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -57,7 +57,7 @@ def submit_sleep_job(): with DAG( dag_id="big_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) as dag: diff --git a/third_party/airflow/tests/unit/test_airflow_operator_mock.py b/third_party/airflow/tests/unit/test_airflow_operator_mock.py index 4634e644795..1ab2d37ced1 100644 --- a/third_party/airflow/tests/unit/test_airflow_operator_mock.py +++ b/third_party/airflow/tests/unit/test_airflow_operator_mock.py @@ -170,7 +170,7 @@ def test_annotate_job_request_items(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) @@ -204,7 +204,7 @@ def test_parameterize_armada_operator(): dag = DAG( dag_id="hello_armada", start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule_interval="@daily", + schedule="@daily", catchup=False, default_args={"retries": 2}, ) diff --git a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py new file mode 100644 index 00000000000..0f156ed177e --- /dev/null +++ b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py @@ -0,0 +1,171 @@ +import copy + +import pytest + +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) +from armada.operators.armada_deferrable import ArmadaDeferrableOperator +from armada.operators.grpc import CredentialsCallback + + +def test_serialize_armada_deferrable(): + grpc_chan_args = { + "target": "localhost:443", + "credentials_callback_args": { + "module_name": "channel_test", + "function_name": "get_credentials", + "function_kwargs": { + "example_arg": "test", + }, + }, + } + + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="busybox", + args=["sleep", "10s"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + ), + ) + ], + ) + + job_requests = [ + submit_pb2.JobSubmitRequestItem( + priority=1, + pod_spec=pod, + namespace="personal-anonymous", + annotations={"armadaproject.io/hello": "world"}, + ) + ] + + source = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test task", + armada_channel_args=grpc_chan_args, + job_service_channel_args=grpc_chan_args, + armada_queue="test-queue", + job_request_items=job_requests, + lookout_url_template="https://lookout.test.domain/", + poll_interval=5, + ) + + serialized = source.serialize() + assert serialized["name"] == source.name + + reconstituted = ArmadaDeferrableOperator(**serialized) + assert reconstituted == source + + +get_lookout_url_test_cases = [ + ( + "http://localhost:8089/jobs?job_id=", + "test_id", + "http://localhost:8089/jobs?job_id=test_id", + ), + ( + "https://lookout.armada.domain/jobs?job_id=", + "test_id", + "https://lookout.armada.domain/jobs?job_id=test_id", + ), + ("", "test_id", ""), + (None, "test_id", ""), +] + + +@pytest.mark.parametrize( + "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases +) +def test_get_lookout_url(lookout_url_template, job_id, expected_url): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template=lookout_url_template, + ) + + assert operator._get_lookout_url(job_id) == expected_url + + +def test_deepcopy_operator(): + armada_channel_args = {"target": "127.0.0.1:50051"} + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def test_deepcopy_operator_with_grpc_credentials_callback(): + armada_channel_args = { + "target": "127.0.0.1:50051", + "credentials_callback_args": { + "module_name": "tests.unit.test_armada_operator", + "function_name": "__example_test_callback", + "function_kwargs": { + "test_arg": "fake_arg", + }, + }, + } + job_service_channel_args = {"target": "127.0.0.1:60003"} + + operator = ArmadaDeferrableOperator( + task_id="test_task_id", + name="test_task", + armada_channel_args=armada_channel_args, + job_service_channel_args=job_service_channel_args, + armada_queue="test_queue", + job_request_items=[], + lookout_url_template="http://localhost:8089/jobs?job_id=", + ) + + try: + copy.deepcopy(operator) + except Exception as e: + assert False, f"{e}" + + +def __example_test_callback(foo=None): + return f"fake_cred {foo}" + + +def test_credentials_callback(): + callback = CredentialsCallback( + module_name="test_armada_operator", + function_name="__example_test_callback", + function_kwargs={"foo": "bar"}, + ) + + result = callback.call() + assert result == "fake_cred bar" diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py index 83cc3e220aa..a842fa994d3 100644 --- a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py +++ b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py @@ -71,6 +71,7 @@ async def test_failed_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.FAILED assert ( @@ -89,6 +90,7 @@ async def test_successful_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded" @@ -104,6 +106,7 @@ async def test_cancelled_event(js_aio_client): job_service_client=js_aio_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.CANCELLED assert job_complete[1] == "Armada test:test_cancelled cancelled" @@ -119,6 +122,7 @@ async def test_job_id_not_found(js_aio_client): time_out_for_failure=5, job_service_client=js_aio_client, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.JOB_ID_NOT_FOUND assert ( @@ -142,6 +146,7 @@ async def test_error_retry(js_aio_retry_client): job_service_client=js_aio_retry_client, time_out_for_failure=5, log=logging.getLogger(), + poll_interval=1, ) assert job_complete[0] == JobState.SUCCEEDED assert job_complete[1] == "Armada test:test_succeeded succeeded"