diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py index d4b3bc30a5f4f..72455165e83e2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -345,8 +345,9 @@ def attempt_workload_runs(self): try: ser_workload_key = json.dumps(workload_key._asdict()) except AttributeError: - # Callback workloads use string id. - ser_workload_key = workload_key + # Callback workloads use CallbackKey (or legacy string id); both have a + # str() representation that round-trips through JSON. + ser_workload_key = str(workload_key) payload = { "task_key": ser_workload_key, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index 835dfe2e1c4d7..e04464883c0aa 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -42,6 +42,7 @@ from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.providers.amazon.aws.executors.batch.utils import BatchJobWorkloadKey from airflow.providers.amazon.aws.executors.batch.boto_schema import ( @@ -402,9 +403,7 @@ def _describe_jobs(self, job_ids) -> list[BatchJob]: all_jobs.extend(describe_workloads_response["jobs"]) return all_jobs - def execute_async( - self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None - ): + def execute_async(self, key: BatchJobWorkloadKey, command: CommandType, queue=None, executor_config=None): """Save the workload to be executed in the next sync using Boto3's RunTask API.""" if executor_config and "command" in executor_config: raise ValueError('Executor Config should never override "command"') diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py index d98ceaa5a89dc..284865285f164 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -64,7 +64,7 @@ def set_env_vars(): @pytest.fixture def mock_airflow_key(): def _key(): - key_mock = mock.Mock() + key_mock = mock.Mock(spec=TaskInstanceKey) # Use a "random" value (memory id of the mock obj) so each key serializes uniquely key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)}) return key_mock @@ -180,10 +180,12 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock def test_task_sdk_callback(self, mock_executor): """Test task sdk callback execution end-to-end.""" from airflow.executors.workloads import ExecuteCallback + from airflow.models.callback import CallbackKey - callback_id = "callback_123" + callback_id = CallbackKey("callback_123") workload = mock.Mock(spec=ExecuteCallback) + workload.key = callback_id workload.callback = mock.Mock() workload.callback.key = callback_id workload.callback.data = {} @@ -212,7 +214,7 @@ def test_task_sdk_callback(self, mock_executor): mock_executor.attempt_workload_runs() mock_executor.lambda_client.invoke.assert_called_once() payload = json.loads(mock_executor.lambda_client.invoke.call_args.kwargs["Payload"]) - assert payload["task_key"] == callback_id + assert payload["task_key"] == str(callback_id) assert payload["command"] == [ "python", "-m", @@ -223,7 +225,7 @@ def test_task_sdk_callback(self, mock_executor): # Callback is stored in running workloads. assert len(mock_executor.running_workloads) == 1 - assert callback_id in mock_executor.running_workloads + assert str(callback_id) in mock_executor.running_workloads @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") def test_task_sdk_callback_with_queue(self, mock_airflow_key, mock_executor): diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 00c3d622dcb5e..4695ca1d47afa 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -88,7 +88,7 @@ def mock_executor(set_env_vars) -> AwsBatchExecutor: @pytest.fixture(autouse=True) def mock_airflow_key(): - return mock.Mock(spec=list) + return mock.Mock(spec=TaskInstanceKey) @pytest.fixture(autouse=True) @@ -108,7 +108,7 @@ def _setup_method(self): self.collection = BatchJobCollection() # Add first task self.first_job_id = "001" - self.first_airflow_key = mock.Mock(spec=tuple) + self.first_airflow_key = mock.Mock(spec=TaskInstanceKey) self.collection.add_job( job_id=self.first_job_id, airflow_workload_key=self.first_airflow_key, @@ -119,7 +119,7 @@ def _setup_method(self): ) # Add second task self.second_job_id = "002" - self.second_airflow_key = mock.Mock(spec=tuple) + self.second_airflow_key = mock.Mock(spec=TaskInstanceKey) self.collection.add_job( job_id=self.second_job_id, airflow_workload_key=self.second_airflow_key, @@ -190,7 +190,7 @@ class TestAwsBatchExecutor: def test_execute(self, mock_executor): """Test execution from end-to-end""" - airflow_key = mock.Mock(spec=tuple) + airflow_key = mock.Mock(spec=TaskInstanceKey) airflow_cmd = ["1", "2"] mock_executor.batch.submit_job.return_value = {"jobId": MOCK_JOB_ID, "jobName": "some-job-name"} @@ -480,7 +480,7 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor): def test_attempt_submit_jobs_failure(self, mock_executor): mock_executor.batch.submit_job.side_effect = NoCredentialsError() - mock_executor.execute_async("airflow_key", "airflow_cmd") + mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), "airflow_cmd") assert len(mock_executor.pending_jobs) == 1 with pytest.raises(NoCredentialsError, match="Unable to locate credentials"): mock_executor.attempt_submit_jobs() @@ -501,7 +501,10 @@ def test_attempt_submit_jobs_failure(self, mock_executor): @mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_task_retry_on_api_failure(self, _, mock_executor, caplog): """Test API failure retries""" - airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"] + airflow_keys = [ + TaskInstanceKey("dag", "task1", "run", 1, -1), + TaskInstanceKey("dag", "task2", "run", 1, -1), + ] airflow_cmds = [["1", "2"], ["3", "4"]] mock_executor.execute_async(airflow_keys[0], airflow_cmds[0]) @@ -575,7 +578,7 @@ def test_sync_running_jobs_no_jobs(self, mock_executor, caplog): assert "No active Airflow workloads, skipping sync" in caplog.messages[0] def test_sync_client_error(self, mock_executor, caplog): - mock_executor.execute_async("airflow_key", "airflow_cmd") + mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), "airflow_cmd") assert len(mock_executor.pending_jobs) == 1 mock_resp = { "Error": { @@ -1053,7 +1056,7 @@ def test_submit_job_kwargs_exec_config_overrides( ) os.environ[submit_job_kwargs_env_key] = json.dumps(submit_job_kwargs) - mock_ti_key = mock.Mock(spec=tuple) + mock_ti_key = mock.Mock(spec=TaskInstanceKey) command = ["command"] executor = AwsBatchExecutor() diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index f57cd13d28aab..ca2d1e255b272 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -110,7 +110,7 @@ def mock_task(arn=ARN1, state=State.RUNNING): @pytest.fixture(autouse=True) def mock_airflow_key(): def _key(): - return mock.Mock(spec=tuple) + return mock.Mock(spec=TaskInstanceKey) return _key @@ -519,7 +519,7 @@ def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_c "failures": [], } mock_executor.ecs.run_task.side_effect = [run_task_exception, run_task_exception, run_task_success] - mock_executor.execute_async(mock_airflow_key, mock_cmd) + mock_executor.execute_async(mock.Mock(spec=TaskInstanceKey), mock_cmd) expected_retry_count = 2 # Fail 2 times @@ -669,7 +669,7 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): mock_executor.ecs.run_task.call_args_list.clear() # queue new task - airflow_keys[1] = mock.Mock(spec=tuple) + airflow_keys[1] = mock.Mock(spec=TaskInstanceKey) airflow_commands[1] = _generate_mock_cmd() mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) @@ -710,7 +710,10 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog Test API failure retries. """ mock_executor.max_run_task_attempts = "2" - airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"] + airflow_keys = [ + TaskInstanceKey("dag", "task1", "run", 1, -1), + TaskInstanceKey("dag", "task2", "run", 1, -1), + ] airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()] mock_executor.execute_async(airflow_keys[0], airflow_commands[0]) @@ -955,7 +958,7 @@ def test_failed_sync_api_exception(self, mock_executor, caplog): @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor, mock_cmd): """Test what happens when ECS sync fails for certain tasks repeatedly.""" - airflow_key = "test-key" + airflow_key = TaskInstanceKey("dag", "task", "run", 1, -1) mock_executor.execute_async(airflow_key, mock_cmd) assert len(mock_executor.pending_tasks) == 1 @@ -1148,7 +1151,14 @@ def _mock_sync( @staticmethod def _add_mock_task(executor: AwsEcsExecutor, arn: str, state=TaskInstanceState.RUNNING): task = mock_task(arn, state) - executor.active_workers.add_task(task, mock.Mock(spec=tuple), mock_queue, mock_cmd, mock_config, 1) # type:ignore[arg-type] + executor.active_workers.add_task( + task, + mock.Mock(spec=TaskInstanceKey), + mock_queue, # type:ignore[arg-type] + mock_cmd, # type:ignore[arg-type] + mock_config, # type:ignore[arg-type] + 1, + ) def _sync_mock_with_call_counts(self, sync_func: Callable): """Mock won't work here, because we actually want to call the 'sync' func.""" diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 3dcd32e84ef0b..c11ea80a5baaf 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -919,7 +919,7 @@ def test_process_workloads_routes_execute_callback(mock_send_workloads, callback executor = celery_executor.CeleryExecutor() executor._process_workloads([workload]) - mock_send_workloads.assert_called_once_with([(callback_id, workload, expected_queue, None)]) + mock_send_workloads.assert_called_once_with([(workload.callback.key, workload, expected_queue, None)]) @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="execute_workload is only used for Airflow 3+") diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py index 9269c4debd05f..2930eb7f2c42a 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -808,7 +808,7 @@ def test_invalid_executor_config(self, mock_get_kube_client, mock_kubernetes_job try: assert executor.event_buffer == {} executor.execute_async( - key=("dag", "task", timezone.utcnow(), 1), + key=TaskInstanceKey("dag", "task", "run_id", 1, -1), queue=None, command=["airflow", "tasks", "run", "true", "some_parameter"], executor_config=k8s.V1Pod(