Skip to content

Commit

Permalink
Merge branch 'master' into fix-autoupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
Sharpz7 authored Sep 18, 2023
2 parents 79d2d1d + 291ef41 commit 5b79fd1
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 9 deletions.
28 changes: 26 additions & 2 deletions docs/python_airflow_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,27 @@ Reports the result of the job and returns.



#### serialize()
Get a serialized version of this object.


* **Returns**

A dict of keyword arguments used when instantiating



* **Return type**

dict


this object.


#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ )

### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name)
### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30)
Bases: `BaseTrigger`

An airflow trigger that monitors the job state of an armada job.
Expand Down Expand Up @@ -269,6 +287,9 @@ Triggers when the job is complete.
belongs.


* **poll_interval** (*int*) – How often to poll jobservice to get status.



* **Returns**

Expand Down Expand Up @@ -664,7 +685,7 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED



### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, time_out_for_failure=7200)
### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200)
Poll JobService cache asyncronously until you get a terminated event.

A terminated event is SUCCEEDED, FAILED or CANCELLED
Expand All @@ -689,6 +710,9 @@ A terminated event is SUCCEEDED, FAILED or CANCELLED
It is optional only for testing


* **poll_interval** (*int*) – How often to poll jobservice to get status.


* **time_out_for_failure** (*int*) – The amount of time a job
can be in job_id_not_found
before we decide it was a invalid job
Expand Down
39 changes: 37 additions & 2 deletions third_party/airflow/armada/operators/armada_deferrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ def __init__(
self.lookout_url_template = lookout_url_template
self.poll_interval = poll_interval

def serialize(self) -> dict:
"""
Get a serialized version of this object.
:return: A dict of keyword arguments used when instantiating
this object.
"""

return {
"task_id": self.task_id,
"name": self.name,
"armada_channel_args": self.armada_channel_args.serialize(),
"job_service_channel_args": self.job_service_channel_args.serialize(),
"armada_queue": self.armada_queue,
"job_request_items": self.job_request_items,
"lookout_url_template": self.lookout_url_template,
"poll_interval": self.poll_interval,
}

def execute(self, context) -> None:
"""
Executes the Armada Operator. Only meant to be called by airflow.
Expand Down Expand Up @@ -156,6 +175,7 @@ def execute(self, context) -> None:
armada_queue=self.armada_queue,
job_set_id=context["run_id"],
airflow_task_name=self.name,
poll_interval=self.poll_interval,
),
method_name="resume_job_complete",
kwargs={"job_id": job_id},
Expand Down Expand Up @@ -216,6 +236,7 @@ class ArmadaJobCompleteTrigger(BaseTrigger):
:param job_set_id: The ID of the job set.
:param airflow_task_name: Name of the airflow task to which this trigger
belongs.
:param poll_interval: How often to poll jobservice to get status.
:return: An armada job complete trigger instance.
"""

Expand All @@ -226,13 +247,15 @@ def __init__(
armada_queue: str,
job_set_id: str,
airflow_task_name: str,
poll_interval: int = 30,
) -> None:
super().__init__()
self.job_id = job_id
self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args)
self.armada_queue = armada_queue
self.job_set_id = job_set_id
self.airflow_task_name = airflow_task_name
self.poll_interval = poll_interval

def serialize(self) -> tuple:
return (
Expand All @@ -243,9 +266,21 @@ def serialize(self) -> tuple:
"armada_queue": self.armada_queue,
"job_set_id": self.job_set_id,
"airflow_task_name": self.airflow_task_name,
"poll_interval": self.poll_interval,
},
)

def __eq__(self, o):
return (
self.task_id == o.task_id
and self.job_id == o.job_id
and self.job_service_channel_args == o.job_service_channel_args
and self.armada_queue == o.armada_queue
and self.job_set_id == o.job_set_id
and self.airflow_task_name == o.airflow_task_name
and self.poll_interval == o.poll_interval
)

async def run(self):
"""
Runs the trigger. Meant to be called by an airflow triggerer process.
Expand All @@ -255,12 +290,12 @@ async def run(self):
)

job_state, job_message = await search_for_job_complete_async(
job_service_client=job_service_client,
armada_queue=self.armada_queue,
job_set_id=self.job_set_id,
airflow_task_name=self.airflow_task_name,
job_id=self.job_id,
poll_interval=self.poll_interval,
job_service_client=job_service_client,
log=self.log,
poll_interval=self.poll_interval,
)
yield TriggerEvent({"job_state": job_state, "job_message": job_message})
4 changes: 3 additions & 1 deletion third_party/airflow/armada/operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ async def search_for_job_complete_async(
job_id: str,
job_service_client: JobServiceAsyncIOClient,
log,
poll_interval: int,
time_out_for_failure: int = 7200,
) -> Tuple[JobState, str]:
"""
Expand All @@ -231,6 +232,7 @@ async def search_for_job_complete_async(
:param job_id: The name of the job id that armada assigns to it
:param job_service_client: A JobServiceClient that is used for polling.
It is optional only for testing
:param poll_interval: How often to poll jobservice to get status.
:param time_out_for_failure: The amount of time a job
can be in job_id_not_found
before we decide it was a invalid job
Expand All @@ -251,7 +253,7 @@ async def search_for_job_complete_async(
job_state = job_state_from_pb(job_status_return.state)
log.debug(f"Got job state '{job_state.name}' for job {job_id}")

await asyncio.sleep(3)
await asyncio.sleep(poll_interval)

if job_state == JobState.SUCCEEDED:
job_message = f"Armada {airflow_task_name}:{job_id} succeeded"
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/examples/big_armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def submit_sleep_job():
with DAG(
dag_id="big_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
) as dag:
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "armada_airflow"
version = "0.5.5"
version = "0.5.6"
description = "Armada Airflow Operator"
requires-python = ">=3.7"
# Note(JayF): This dependency value is not suitable for release. Whatever
Expand Down
4 changes: 2 additions & 2 deletions third_party/airflow/tests/unit/test_airflow_operator_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_annotate_job_request_items():
dag = DAG(
dag_id="hello_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_parameterize_armada_operator():
dag = DAG(
dag_id="hello_armada",
start_date=pendulum.datetime(2016, 1, 1, tz="UTC"),
schedule_interval="@daily",
schedule="@daily",
catchup=False,
default_args={"retries": 2},
)
Expand Down
171 changes: 171 additions & 0 deletions third_party/airflow/tests/unit/test_armada_deferrable_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import copy

import pytest

from armada_client.armada import submit_pb2
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
from armada_client.k8s.io.apimachinery.pkg.api.resource import (
generated_pb2 as api_resource,
)
from armada.operators.armada_deferrable import ArmadaDeferrableOperator
from armada.operators.grpc import CredentialsCallback


def test_serialize_armada_deferrable():
grpc_chan_args = {
"target": "localhost:443",
"credentials_callback_args": {
"module_name": "channel_test",
"function_name": "get_credentials",
"function_kwargs": {
"example_arg": "test",
},
},
}

pod = core_v1.PodSpec(
containers=[
core_v1.Container(
name="sleep",
image="busybox",
args=["sleep", "10s"],
securityContext=core_v1.SecurityContext(runAsUser=1000),
resources=core_v1.ResourceRequirements(
requests={
"cpu": api_resource.Quantity(string="120m"),
"memory": api_resource.Quantity(string="510Mi"),
},
limits={
"cpu": api_resource.Quantity(string="120m"),
"memory": api_resource.Quantity(string="510Mi"),
},
),
)
],
)

job_requests = [
submit_pb2.JobSubmitRequestItem(
priority=1,
pod_spec=pod,
namespace="personal-anonymous",
annotations={"armadaproject.io/hello": "world"},
)
]

source = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test task",
armada_channel_args=grpc_chan_args,
job_service_channel_args=grpc_chan_args,
armada_queue="test-queue",
job_request_items=job_requests,
lookout_url_template="https://lookout.test.domain/",
poll_interval=5,
)

serialized = source.serialize()
assert serialized["name"] == source.name

reconstituted = ArmadaDeferrableOperator(**serialized)
assert reconstituted == source


get_lookout_url_test_cases = [
(
"http://localhost:8089/jobs?job_id=<job_id>",
"test_id",
"http://localhost:8089/jobs?job_id=test_id",
),
(
"https://lookout.armada.domain/jobs?job_id=<job_id>",
"test_id",
"https://lookout.armada.domain/jobs?job_id=test_id",
),
("", "test_id", ""),
(None, "test_id", ""),
]


@pytest.mark.parametrize(
"lookout_url_template, job_id, expected_url", get_lookout_url_test_cases
)
def test_get_lookout_url(lookout_url_template, job_id, expected_url):
armada_channel_args = {"target": "127.0.0.1:50051"}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template=lookout_url_template,
)

assert operator._get_lookout_url(job_id) == expected_url


def test_deepcopy_operator():
armada_channel_args = {"target": "127.0.0.1:50051"}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template="http://localhost:8089/jobs?job_id=<job_id>",
)

try:
copy.deepcopy(operator)
except Exception as e:
assert False, f"{e}"


def test_deepcopy_operator_with_grpc_credentials_callback():
armada_channel_args = {
"target": "127.0.0.1:50051",
"credentials_callback_args": {
"module_name": "tests.unit.test_armada_operator",
"function_name": "__example_test_callback",
"function_kwargs": {
"test_arg": "fake_arg",
},
},
}
job_service_channel_args = {"target": "127.0.0.1:60003"}

operator = ArmadaDeferrableOperator(
task_id="test_task_id",
name="test_task",
armada_channel_args=armada_channel_args,
job_service_channel_args=job_service_channel_args,
armada_queue="test_queue",
job_request_items=[],
lookout_url_template="http://localhost:8089/jobs?job_id=<job_id>",
)

try:
copy.deepcopy(operator)
except Exception as e:
assert False, f"{e}"


def __example_test_callback(foo=None):
return f"fake_cred {foo}"


def test_credentials_callback():
callback = CredentialsCallback(
module_name="test_armada_operator",
function_name="__example_test_callback",
function_kwargs={"foo": "bar"},
)

result = callback.call()
assert result == "fake_cred bar"
Loading

0 comments on commit 5b79fd1

Please sign in to comment.