Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def attempt_workload_runs(self):
try:
ser_workload_key = json.dumps(workload_key._asdict())
except AttributeError:
# Callback workloads use string id.
ser_workload_key = workload_key
# Callback workloads use CallbackKey (or legacy string id); both have a
# str() representation that round-trips through JSON.
ser_workload_key = str(workload_key)

payload = {
"task_key": ser_workload_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch.utils import BatchJobWorkloadKey


from airflow.providers.amazon.aws.executors.batch.boto_schema import (
Expand Down Expand Up @@ -402,9 +403,7 @@ def _describe_jobs(self, job_ids) -> list[BatchJob]:
all_jobs.extend(describe_workloads_response["jobs"])
return all_jobs

def execute_async(
self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None
):
def execute_async(self, key: BatchJobWorkloadKey, command: CommandType, queue=None, executor_config=None):
"""Save the workload to be executed in the next sync using Boto3's RunTask API."""
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def set_env_vars():
@pytest.fixture
def mock_airflow_key():
def _key():
key_mock = mock.Mock()
key_mock = mock.Mock(spec=TaskInstanceKey)
# Use a "random" value (memory id of the mock obj) so each key serializes uniquely
key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)})
return key_mock
Expand Down Expand Up @@ -180,10 +180,12 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock
def test_task_sdk_callback(self, mock_executor):
"""Test task sdk callback execution end-to-end."""
from airflow.executors.workloads import ExecuteCallback
from airflow.models.callback import CallbackKey

callback_id = "callback_123"
callback_id = CallbackKey("callback_123")

workload = mock.Mock(spec=ExecuteCallback)
workload.key = callback_id
workload.callback = mock.Mock()
workload.callback.key = callback_id
workload.callback.data = {}
Expand Down Expand Up @@ -212,7 +214,7 @@ def test_task_sdk_callback(self, mock_executor):
mock_executor.attempt_workload_runs()
mock_executor.lambda_client.invoke.assert_called_once()
payload = json.loads(mock_executor.lambda_client.invoke.call_args.kwargs["Payload"])
assert payload["task_key"] == callback_id
assert payload["task_key"] == str(callback_id)
assert payload["command"] == [
"python",
"-m",
Expand All @@ -223,7 +225,7 @@ def test_task_sdk_callback(self, mock_executor):

# Callback is stored in running workloads.
assert len(mock_executor.running_workloads) == 1
assert callback_id in mock_executor.running_workloads
assert str(callback_id) in mock_executor.running_workloads

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+")
def test_task_sdk_callback_with_queue(self, mock_airflow_key, mock_executor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def mock_executor(set_env_vars) -> AwsBatchExecutor:

@pytest.fixture(autouse=True)
def mock_airflow_key():
return mock.Mock(spec=list)
return mock.Mock(spec=TaskInstanceKey)


@pytest.fixture(autouse=True)
Expand All @@ -108,7 +108,7 @@ def _setup_method(self):
self.collection = BatchJobCollection()
# Add first task
self.first_job_id = "001"
self.first_airflow_key = mock.Mock(spec=tuple)
self.first_airflow_key = mock.Mock(spec=TaskInstanceKey)
self.collection.add_job(
job_id=self.first_job_id,
airflow_workload_key=self.first_airflow_key,
Expand All @@ -119,7 +119,7 @@ def _setup_method(self):
)
# Add second task
self.second_job_id = "002"
self.second_airflow_key = mock.Mock(spec=tuple)
self.second_airflow_key = mock.Mock(spec=TaskInstanceKey)
self.collection.add_job(
job_id=self.second_job_id,
airflow_workload_key=self.second_airflow_key,
Expand Down Expand Up @@ -190,7 +190,7 @@ class TestAwsBatchExecutor:

def test_execute(self, mock_executor):
"""Test execution from end-to-end"""
airflow_key = mock.Mock(spec=tuple)
airflow_key = mock.Mock(spec=TaskInstanceKey)
airflow_cmd = ["1", "2"]

mock_executor.batch.submit_job.return_value = {"jobId": MOCK_JOB_ID, "jobName": "some-job-name"}
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor):

def test_attempt_submit_jobs_failure(self, mock_executor):
mock_executor.batch.submit_job.side_effect = NoCredentialsError()
mock_executor.execute_async("airflow_key", "airflow_cmd")
mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), "airflow_cmd")
assert len(mock_executor.pending_jobs) == 1
with pytest.raises(NoCredentialsError, match="Unable to locate credentials"):
mock_executor.attempt_submit_jobs()
Expand All @@ -501,7 +501,10 @@ def test_attempt_submit_jobs_failure(self, mock_executor):
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_task_retry_on_api_failure(self, _, mock_executor, caplog):
"""Test API failure retries"""
airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
airflow_keys = [
TaskInstanceKey("dag", "task1", "run", 1, -1),
TaskInstanceKey("dag", "task2", "run", 1, -1),
]
airflow_cmds = [["1", "2"], ["3", "4"]]

mock_executor.execute_async(airflow_keys[0], airflow_cmds[0])
Expand Down Expand Up @@ -575,7 +578,7 @@ def test_sync_running_jobs_no_jobs(self, mock_executor, caplog):
assert "No active Airflow workloads, skipping sync" in caplog.messages[0]

def test_sync_client_error(self, mock_executor, caplog):
mock_executor.execute_async("airflow_key", "airflow_cmd")
mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), "airflow_cmd")
assert len(mock_executor.pending_jobs) == 1
mock_resp = {
"Error": {
Expand Down Expand Up @@ -1053,7 +1056,7 @@ def test_submit_job_kwargs_exec_config_overrides(
)
os.environ[submit_job_kwargs_env_key] = json.dumps(submit_job_kwargs)

mock_ti_key = mock.Mock(spec=tuple)
mock_ti_key = mock.Mock(spec=TaskInstanceKey)
command = ["command"]

executor = AwsBatchExecutor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def mock_task(arn=ARN1, state=State.RUNNING):
@pytest.fixture(autouse=True)
def mock_airflow_key():
def _key():
return mock.Mock(spec=tuple)
return mock.Mock(spec=TaskInstanceKey)

return _key

Expand Down Expand Up @@ -519,7 +519,7 @@ def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_c
"failures": [],
}
mock_executor.ecs.run_task.side_effect = [run_task_exception, run_task_exception, run_task_success]
mock_executor.execute_async(mock_airflow_key, mock_cmd)
mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), mock_cmd)
expected_retry_count = 2

# Fail 2 times
Expand Down Expand Up @@ -669,7 +669,7 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor):
mock_executor.ecs.run_task.call_args_list.clear()

# queue new task
airflow_keys[1] = mock.Mock(spec=tuple)
airflow_keys[1] = mock.Mock(spec=TaskInstanceKey)
airflow_commands[1] = _generate_mock_cmd()
mock_executor.execute_async(airflow_keys[1], airflow_commands[1])

Expand Down Expand Up @@ -710,7 +710,10 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog
Test API failure retries.
"""
mock_executor.max_run_task_attempts = "2"
airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
airflow_keys = [
TaskInstanceKey("dag", "task1", "run", 1, -1),
TaskInstanceKey("dag", "task2", "run", 1, -1),
]
airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()]

mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
Expand Down Expand Up @@ -955,7 +958,7 @@ def test_failed_sync_api_exception(self, mock_executor, caplog):
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor, mock_cmd):
"""Test what happens when ECS sync fails for certain tasks repeatedly."""
airflow_key = "test-key"
airflow_key = TaskInstanceKey("dag", "task", "run", 1, -1)
mock_executor.execute_async(airflow_key, mock_cmd)
assert len(mock_executor.pending_tasks) == 1

Expand Down Expand Up @@ -1148,7 +1151,14 @@ def _mock_sync(
@staticmethod
def _add_mock_task(executor: AwsEcsExecutor, arn: str, state=TaskInstanceState.RUNNING):
task = mock_task(arn, state)
executor.active_workers.add_task(task, mock.Mock(spec=tuple), mock_queue, mock_cmd, mock_config, 1) # type:ignore[arg-type]
executor.active_workers.add_task(
task,
mock.Mock(spec=TaskInstanceKey),
mock_queue, # type:ignore[arg-type]
mock_cmd, # type:ignore[arg-type]
mock_config, # type:ignore[arg-type]
1,
)

def _sync_mock_with_call_counts(self, sync_func: Callable):
"""Mock won't work here, because we actually want to call the 'sync' func."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def test_process_workloads_routes_execute_callback(mock_send_workloads, callback
executor = celery_executor.CeleryExecutor()
executor._process_workloads([workload])

mock_send_workloads.assert_called_once_with([(callback_id, workload, expected_queue, None)])
mock_send_workloads.assert_called_once_with([(workload.callback.key, workload, expected_queue, None)])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, this could just be workload.key since the ExecuteCallback.key property returns self.callback.key. But thanks for fixing this. I'm curious why my PR was green but caused all these issues. Sorry about that.



@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="execute_workload is only used for Airflow 3+")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def test_invalid_executor_config(self, mock_get_kube_client, mock_kubernetes_job
try:
assert executor.event_buffer == {}
executor.execute_async(
key=("dag", "task", timezone.utcnow(), 1),
key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
queue=None,
command=["airflow", "tasks", "run", "true", "some_parameter"],
executor_config=k8s.V1Pod(
Expand Down
Loading