Skip to content

Commit

Permalink
WIP: Airflow: fix undefined poll_interval in Deferrable Operator (#2975)
Browse files Browse the repository at this point in the history
* Airflow: handle poll_interval attr in ArmadaJobCompleteTrigger

Fix incomplete handling of 'poll_interval' attribute in
ArmadaJobCompleteTrigger, used by the Armada Deferrable Operator for
Airflow.

Signed-off-by: Rich Scott <richscott@sent.com>

* Airflow - add unit test for armada deferrable operator

Run much of the same tests for the deferrable operator as for the
regular operator, plus test serialization.  Also, update interval
signifier in examples. A full test of the deferrable operator that
verifies the trigger handling is still needed.

Signed-off-by: Rich Scott <richscott@sent.com>

---------

Signed-off-by: Rich Scott <richscott@sent.com>
  • Loading branch information
richscott committed Sep 18, 2023
1 parent babce23 commit 8eb7201
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 8 deletions.
28 changes: 26 additions & 2 deletions docs/python_airflow_operator.md
Expand Up @@ -239,9 +239,27 @@ Reports the result of the job and returns.



#### serialize()
Get a serialized version of this object.


* **Returns**

A dict of keyword arguments used when instantiating



* **Return type**

dict


this object.


#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ )

### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name)
### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30)
Bases: `BaseTrigger`

An airflow trigger that monitors the job state of an armada job.
Expand Down Expand Up @@ -269,6 +287,9 @@ Triggers when the job is complete.
belongs.


* **poll_interval** (*int*) – How often to poll jobservice to get status.



* **Returns**

Expand Down Expand Up @@ -664,7 +685,7 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED



### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, time_out_for_failure=7200)
### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200)
Poll JobService cache asyncronously until you get a terminated event.

A terminated event is SUCCEEDED, FAILED or CANCELLED
Expand All @@ -689,6 +710,9 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED
It is optional only for testing


* **poll_interval** (*int*) – How often to poll jobservice to get status.


* **time_out_for_failure** (*int*) – The amount of time a job
can be in job_id_not_found
before we decide it was a invalid job
Expand Down
39 changes: 37 additions & 2 deletions third_party/airflow/armada/operators/armada_deferrable.py
Expand Up @@ -103,6 +103,25 @@ def __init__(
self.lookout_url_template = lookout_url_template
self.poll_interval = poll_interval

def serialize(self) -> dict:
"""
Get a serialized version of this object.
:return: A dict of keyword arguments used when instantiating
this object.
"""

return {
"task_id": self.task_id,
"name": self.name,
"armada_channel_args": self.armada_channel_args.serialize(),
"job_service_channel_args": self.job_service_channel_args.serialize(),
"armada_queue": self.armada_queue,
"job_request_items": self.job_request_items,
"lookout_url_template": self.lookout_url_template,
"poll_interval": self.poll_interval,
}

def execute(self, context) -> None:
"""
Executes the Armada Operator. Only meant to be called by airflow.
Expand Down Expand Up @@ -156,6 +175,7 @@ def execute(self, context) -> None:
armada_queue=self.armada_queue,
job_set_id=context["run_id"],
airflow_task_name=self.name,
poll_interval=self.poll_interval,
),
method_name="resume_job_complete",
kwargs={"job_id": job_id},
Expand Down Expand Up @@ -216,6 +236,7 @@ class ArmadaJobCompleteTrigger(BaseTrigger):
:param job_set_id: The ID of the job set.
:param airflow_task_name: Name of the airflow task to which this trigger
belongs.
:param poll_interval: How often to poll jobservice to get status.
:return: An armada job complete trigger instance.
"""

Expand All @@ -226,13 +247,15 @@ def __init__(
armada_queue: str,
job_set_id: str,
airflow_task_name: str,
poll_interval: int = 30,
) -> None:
super().__init__()
self.job_id = job_id
self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args)
self.armada_queue = armada_queue
self.job_set_id = job_set_id
self.airflow_task_name = airflow_task_name
self.poll_interval = poll_interval

def serialize(self) -> tuple:
return (
Expand All @@ -243,9 +266,21 @@ def serialize(self) -> tuple:
"armada_queue": self.armada_queue,
"job_set_id": self.job_set_id,
"airflow_task_name": self.airflow_task_name,
"poll_interval": self.poll_interval,
},
)

def __eq__(self, o):
return (
self.task_id == o.task_id
and self.job_id == o.job_id
and self.job_service_channel_args == o.job_service_channel_args
and self.armada_queue == o.armada_queue
and self.job_set_id == o.job_set_id
and self.airflow_task_name == o.airflow_task_name
and self.poll_interval == o.poll_interval
)

async def run(self):
"""
Runs the trigger. Meant to be called by an airflow triggerer process.
Expand All @@ -255,12 +290,12 @@ async def run(self):
)

job_state, job_message = await search_for_job_complete_async(
job_service_client=job_service_client,
armada_queue=self.armada_queue,
job_set_id=self.job_set_id,
airflow_task_name=self.airflow_task_name,
job_id=self.job_id,
poll_interval=self.poll_interval,
job_service_client=job_service_client,
log=self.log,
poll_interval=self.poll_interval,
)
yield TriggerEvent({"job_state": job_state, "job_message": job_message})
4 changes: 3 additions & 1 deletion third_party/airflow/armada/operators/utils.py
Expand Up @@ -217,6 +217,7 @@ async def search_for_job_complete_async(
job_id: str,
job_service_client: JobServiceAsyncIOClient,
log,
poll_interval: int,
time_out_for_failure: int = 7200,
) -> Tuple[JobState, str]:
"""
Expand All @@ -231,6 +232,7 @@ async def search_for_job_complete_async(
:param job_id: The name of the job id that armada assigns to it
:param job_service_client: A JobServiceClient that is used for polling.
It is optional only for testing
:param poll_interval: How often to poll jobservice to get status.
:param time_out_for_failure: The amount of time a job
can be in job_id_not_found
before we decide it was a invalid job
Expand All @@ -251,7 +253,7 @@ async def search_for_job_complete_async(
job_state = job_state_from_pb(job_status_return.state)
log.debug(f"Got job state '{job_state.name}' for job {job_id}")

await asyncio.sleep(3)
await asyncio.sleep(poll_interval)

if job_state == JobState.SUCCEEDED:
job_message = f"Armada {airflow_task_name}:{job_id} succeeded"
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/examples/big_armada.py
Expand Up @@ -57,7 +57,7 @@ def submit_sleep_job():
with DAG(
dag_id="big_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
) as dag:
Expand Down
4 changes: 2 additions & 2 deletions third_party/airflow/tests/unit/test_airflow_operator_mock.py
Expand Up @@ -170,7 +170,7 @@ def test_annotate_job_request_items():
dag = DAG(
dag_id="hello_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_parameterize_armada_operator():
dag = DAG(
dag_id="hello_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
)
Expand Down
171 changes: 171 additions & 0 deletions third_party/airflow/tests/unit/test_armada_deferrable_operator.py
@@ -0,0 +1,171 @@
import copy

import pytest

from armada_client.armada import submit_pb2
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
from armada_client.k8s.io.apimachinery.pkg.api.resource import (
generated_pb2 as api_resource,
)
from armada.operators.armada_deferrable import ArmadaDeferrableOperator
from armada.operators.grpc import CredentialsCallback


def test_serialize_armada_deferrable():
grpc_chan_args = {
"target": "localhost:443",
"credentials_callback_args": {
"module_name": "channel_test",
"function_name": "get_credentials",
"function_kwargs": {
"example_arg": "test",
},
},
}

pod = core_v1.PodSpec(
containers=[
core_v1.Container(
name="sleep",
image="busybox",
args=["sleep", "10s"],
securityContext=core_v1.SecurityContext(runAsUser=1000),
resources=core_v1.ResourceRequirements(
requests={
"cpu": api_resource.Quantity(string="120m"),
"memory": api_resource.Quantity(string="510Mi"),
},
limits={
"cpu": api_resource.Quantity(string="120m"),
"memory": api_resource.Quantity(string="510Mi"),
},
),
)
],
)

job_requests = [
submit_pb2.JobSubmitRequestItem(
priority=1,
pod_spec=pod,
namespace="personal-anonymous",
annotations={"armadaproject.io/hello": "world"},
)
]

source = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test task",
armada_channel_args=grpc_chan_args,
job_service_channel_args=grpc_chan_args,
armada_queue="test-queue",
job_request_items=job_requests,
lookout_url_template="https://lookout.test.domain/",
poll_interval=5,
)

serialized = source.serialize()
assert serialized["name"] == source.name

reconstituted = ArmadaDeferrableOperator(**serialized)
assert reconstituted == source


get_lookout_url_test_cases = [
(
"http://localhost:8089/jobs?job_id=<job_id>",
"test_id",
"http://localhost:8089/jobs?job_id=test_id",
),
(
"https://lookout.armada.domain/jobs?job_id=<job_id>",
"test_id",
"https://lookout.armada.domain/jobs?job_id=test_id",
),
("", "test_id", ""),
(None, "test_id", ""),
]


@pytest.mark.parametrize(
"lookout_url_template, job_id, expected_url", get_lookout_url_test_cases
)
def test_get_lookout_url(lookout_url_template, job_id, expected_url):
armada_channel_args = {"target": "127.0.0.1:50051"}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template=lookout_url_template,
)

assert operator._get_lookout_url(job_id) == expected_url


def test_deepcopy_operator():
armada_channel_args = {"target": "127.0.0.1:50051"}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template="http://localhost:8089/jobs?job_id=<job_id>",
)

try:
copy.deepcopy(operator)
except Exception as e:
assert False, f"{e}"


def test_deepcopy_operator_with_grpc_credentials_callback():
armada_channel_args = {
"target": "127.0.0.1:50051",
"credentials_callback_args": {
"module_name": "tests.unit.test_armada_operator",
"function_name": "__example_test_callback",
"function_kwargs": {
"test_arg": "fake_arg",
},
},
}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template="http://localhost:8089/jobs?job_id=<job_id>",
)

try:
copy.deepcopy(operator)
except Exception as e:
assert False, f"{e}"


def __example_test_callback(foo=None):
return f"fake_cred {foo}"


def test_credentials_callback():
callback = CredentialsCallback(
module_name="test_armada_operator",
function_name="__example_test_callback",
function_kwargs={"foo": "bar"},
)

result = callback.call()
assert result == "fake_cred bar"

0 comments on commit 8eb7201

Please sign in to comment.