diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b38d403a3112..80f46592710e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1681,6 +1681,7 @@ repos: ^airflow-core/src/airflow/models/baseoperator\.py$| ^airflow-core/src/airflow/models/connection\.py$| ^airflow-core/src/airflow/models/dag\.py$| + ^airflow-core/src/airflow/models/deadline\.py$| ^airflow-core/src/airflow/models/dagbag\.py$| ^airflow-core/src/airflow/models/dagrun\.py$| ^airflow-core/src/airflow/models/mappedoperator\.py$| diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256 index 07e93bb4ceafa..8bfb32a3713df 100644 --- a/airflow-core/docs/img/airflow_erd.sha256 +++ b/airflow-core/docs/img/airflow_erd.sha256 @@ -1 +1 @@ -efbae2f1de68413e5a6f671a306e748581fe454b81e25eeb2927567f11ebd59c \ No newline at end of file +625f362919679fe85b2bfb3f9e053261248ac6ec7a974eee51012e55a8105b94 \ No newline at end of file diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg index 537b1e5f6e3fd..e6776c5f1538d 100644 --- a/airflow-core/docs/img/airflow_erd.svg +++ b/airflow-core/docs/img/airflow_erd.svg @@ -453,24 +453,24 @@ asset_trigger - -asset_trigger - -asset_id - - [INTEGER] - NOT NULL - -trigger_id - - [INTEGER] - NOT NULL + +asset_trigger + +asset_id + + [INTEGER] + NOT NULL + +trigger_id + + [INTEGER] + NOT NULL asset--asset_trigger - -0..N + +0..N 1 @@ -692,25 +692,25 @@ dagrun_asset_event - -dagrun_asset_event - -dag_run_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dagrun_asset_event + +dag_run_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL asset_event--dagrun_asset_event - -0..N -1 + +0..N +1 @@ -745,9 +745,9 @@ trigger--asset_trigger - -0..N -1 + +0..N +1 @@ -922,50 +922,46 @@ deadline - -deadline - -id - - [UUID] - NOT NULL - -callback - - [VARCHAR(500)] - NOT NULL - -callback_kwargs - - [JSON] - -callback_state - - [VARCHAR(20)] - -dag_id - - [VARCHAR(250)] - -dagrun_id - - [INTEGER] - -deadline_time - - [TIMESTAMP] - NOT NULL - -trigger_id - - [INTEGER] + +deadline + +id + + [UUID] + NOT NULL + +callback + + [JSON] + NOT NULL + +callback_state + + [VARCHAR(20)] + +dag_id + + [VARCHAR(250)] + +dagrun_id + + [INTEGER] + +deadline_time + + [TIMESTAMP] + NOT NULL + +trigger_id + + [INTEGER] trigger--deadline - -0..N -{0,1} + +0..N +{0,1} @@ -1643,35 +1639,35 @@ {0,1} - + dag--dag_schedule_asset_alias_reference 0..N 1 - + dag--dag_schedule_asset_reference 0..N 1 - + dag--task_outlet_asset_reference 0..N 1 - + dag--task_inlet_asset_reference 0..N 1 - + dag--asset_dag_run_queue 0..N @@ -1680,12 +1676,58 @@ dag--deadline - -0..N + +0..N {0,1} - + +dag_version + +dag_version + +id + + [UUID] + NOT NULL + +bundle_name + + [VARCHAR(250)] + +bundle_version + + [VARCHAR(250)] + +created_at + + [TIMESTAMP] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +last_updated + + [TIMESTAMP] + NOT NULL + +version_number + + [INTEGER] + NOT NULL + + + +dag--dag_version + +0..N +1 + + + dag_schedule_asset_name_reference dag_schedule_asset_name_reference @@ -1706,14 +1748,14 @@ NOT NULL - + dag--dag_schedule_asset_name_reference 0..N 1 - + dag_schedule_asset_uri_reference dag_schedule_asset_uri_reference @@ -1734,58 +1776,12 @@ NOT NULL - + dag--dag_schedule_asset_uri_reference 0..N 1 - - -dag_version - -dag_version - -id - - [UUID] - NOT NULL - -bundle_name - - [VARCHAR(250)] - -bundle_version - - [VARCHAR(250)] - -created_at - - [TIMESTAMP] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -last_updated - - [TIMESTAMP] - NOT NULL - -version_number - - [INTEGER] - NOT NULL - - - -dag--dag_version - -0..N -1 - dag_tag @@ -2131,9 +2127,9 @@ dag_run--dagrun_asset_event - -0..N -1 + +0..N +1 @@ -2152,8 +2148,8 @@ dag_run--deadline - -0..N + +0..N {0,1} diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index cf9dc37553978..a27474fdead69 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``3bda03debd04`` (head) | ``f56f68b9e02f`` | ``3.1.0`` | Add url template and template params to DagBundleModel. | +| ``808787349f22`` (head) | ``3bda03debd04`` | ``3.1.0`` | Modify deadline's callback schema. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``3bda03debd04`` | ``f56f68b9e02f`` | ``3.1.0`` | Add url template and template params to DagBundleModel. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``f56f68b9e02f`` | ``09fa89ba1710`` | ``3.1.0`` | Add callback_state to deadline. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py b/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py new file mode 100644 index 0000000000000..0d5431d663f52 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Modify deadline's callback schema. + +Revision ID: 808787349f22 +Revises: 3bda03debd04 +Create Date: 2025-07-31 19:35:53.150465 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +import sqlalchemy_jsonfield +from alembic import op + +# revision identifiers, used by Alembic. +revision = "808787349f22" +down_revision = "3bda03debd04" +branch_labels = None +depends_on = None +airflow_version = "3.1.0" + + +def upgrade(): + """Replace deadline table's string callback and JSON callback_kwargs with JSON callback.""" + with op.batch_alter_table("deadline", schema=None) as batch_op: + batch_op.drop_column("callback") + batch_op.drop_column("callback_kwargs") + batch_op.add_column(sa.Column("callback", sqlalchemy_jsonfield.jsonfield.JSONField(), nullable=False)) + + +def downgrade(): + """Replace deadline table's JSON callback with string callback and JSON callback_kwargs.""" + with op.batch_alter_table("deadline", schema=None) as batch_op: + batch_op.drop_column("callback") + batch_op.add_column( + sa.Column("callback_kwargs", sqlalchemy_jsonfield.jsonfield.JSONField(), nullable=True) + ) + batch_op.add_column(sa.Column("callback", sa.String(length=500), nullable=False)) diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 5d68ad226d208..d9cbb94c744bc 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -1605,7 +1605,6 @@ def create_dagrun( run_id=run_id, ), callback=self.deadline.callback, - callback_kwargs=self.deadline.callback_kwargs or {}, dag_id=self.dag_id, dagrun_id=orm_dagrun.id, ) diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 2a0924f91d29f..c52c1fb785672 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -21,7 +21,8 @@ from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum -from typing import TYPE_CHECKING, Any +from functools import cached_property +from typing import TYPE_CHECKING, Any, cast import sqlalchemy_jsonfield import uuid6 @@ -33,6 +34,7 @@ from airflow._shared.timezones import timezone from airflow.models import Trigger from airflow.models.base import Base, StringID +from airflow.serialization.serde import deserialize, serialize from airflow.settings import json from airflow.triggers.deadline import PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger from airflow.utils.log.logging_mixin import LoggingMixin @@ -42,6 +44,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.sdk.definitions.deadline import Callback from airflow.triggers.base import TriggerEvent @@ -76,7 +79,11 @@ def __get__(self, instance, cls=None): class DeadlineCallbackState(str, Enum): - """All possible states of deadline callbacks.""" + """ + All possible states of deadline callbacks once the deadline is missed. + + `None` state implies that the deadline is pending (`deadline_time` hasn't passed yet). + """ QUEUED = "queued" SUCCESS = "success" @@ -96,10 +103,8 @@ class Deadline(Base): # The time after which the Deadline has passed and the callback should be triggered. deadline_time = Column(UtcDateTime, nullable=False) - # The Callback to be called when the Deadline has passed. - callback = Column(String(500), nullable=False) - # Serialized kwargs to pass to the callback. - callback_kwargs = Column(sqlalchemy_jsonfield.JSONField(json=json)) + # The (serialized) callback to be called when the Deadline has passed. + _callback = Column("callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False) # The state of the deadline callback callback_state = Column(String(20)) @@ -114,15 +119,13 @@ class Deadline(Base): def __init__( self, deadline_time: datetime, - callback: str, - callback_kwargs: dict | None = None, + callback: Callback, dag_id: str | None = None, dagrun_id: int | None = None, ): super().__init__() self.deadline_time = deadline_time - self.callback = callback - self.callback_kwargs = callback_kwargs + self._callback = serialize(callback) self.dag_id = dag_id self.dagrun_id = dagrun_id @@ -136,11 +139,10 @@ def _determine_resource() -> tuple[str, str]: return "Unknown", "" resource_type, resource_details = _determine_resource() - callback_kwargs = json.dumps(self.callback_kwargs) if self.callback_kwargs else "" return ( f"[{resource_type} Deadline] {resource_details} needed by " - f"{self.deadline_time} or run: {self.callback}({callback_kwargs})" + f"{self.deadline_time} or run: {self.callback.path}({self.callback.kwargs or ''})" ) @classmethod @@ -193,18 +195,30 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) -> return deleted_count + @cached_property + def callback(self) -> Callback: + return cast("Callback", deserialize(self._callback)) + def handle_miss(self, session: Session): - """Handle a missed deadline by creating a trigger to run the callback.""" - # TODO: check to see if the callback is meant to run in triggerer or executor. For now, the code below assumes it's for the triggerer - callback_trigger = DeadlineCallbackTrigger( - callback_path=self.callback, - callback_kwargs=self.callback_kwargs, - ) + """Handle a missed deadline by running the callback in the appropriate host and updating the `callback_state`.""" + from airflow.sdk.definitions.deadline import AsyncCallback, SyncCallback + + if isinstance(self.callback, AsyncCallback): + callback_trigger = DeadlineCallbackTrigger( + callback_path=self.callback.path, + callback_kwargs=self.callback.kwargs, + ) + trigger_orm = Trigger.from_object(callback_trigger) + session.add(trigger_orm) + session.flush() + self.trigger = trigger_orm + + elif isinstance(self.callback, SyncCallback): + raise NotImplementedError("SyncCallback is currently not supported") + + else: + raise TypeError("Unknown Callback type") - trigger_orm = Trigger.from_object(callback_trigger) - session.add(trigger_orm) - session.flush() - self.trigger_id = trigger_orm.id self.callback_state = DeadlineCallbackState.QUEUED session.add(self) diff --git a/airflow-core/src/airflow/triggers/deadline.py b/airflow-core/src/airflow/triggers/deadline.py index d04910dd8b8ef..4229695854d10 100644 --- a/airflow-core/src/airflow/triggers/deadline.py +++ b/airflow-core/src/airflow/triggers/deadline.py @@ -22,7 +22,7 @@ from typing import Any from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.module_loading import import_string +from airflow.utils.module_loading import import_string, qualname log = logging.getLogger(__name__) @@ -40,8 +40,8 @@ def __init__(self, callback_path: str, callback_kwargs: dict[str, Any] | None = def serialize(self) -> tuple[str, dict[str, Any]]: return ( - f"{type(self).__module__}.{type(self).__qualname__}", - {"callback_path": self.callback_path, "callback_kwargs": self.callback_kwargs}, + qualname(self), + {attr: getattr(self, attr) for attr in ("callback_path", "callback_kwargs")}, ) async def run(self) -> AsyncIterator[TriggerEvent]: diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 4cf633360dab6..311198e5f0b95 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -93,7 +93,7 @@ class MappedClassProtocol(Protocol): "2.10.3": "5f2621c13b39", "3.0.0": "29ce7909c52b", "3.0.3": "fe199e1abd77", - "3.1.0": "3bda03debd04", + "3.1.0": "808787349f22", } diff --git a/airflow-core/src/airflow/utils/module_loading.py b/airflow-core/src/airflow/utils/module_loading.py index 49f9db05c9df8..028b5d31103ec 100644 --- a/airflow-core/src/airflow/utils/module_loading.py +++ b/airflow-core/src/airflow/utils/module_loading.py @@ -53,6 +53,7 @@ def import_string(dotted_path: str): Raise ImportError if the import failed. """ + # TODO: Add support for nested classes. Currently, it only works for top-level classes. try: module_path, class_name = dotted_path.rsplit(".", 1) except ValueError: diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index 7f2f4b01394f5..55fb1a3cf27d9 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -70,7 +70,7 @@ from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference from airflow.sdk.definitions.param import Param from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -112,6 +112,11 @@ repo_root = Path(__file__).parents[2] +async def empty_callback_for_deadline(): + """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly.""" + pass + + @pytest.fixture def clear_dags(): clear_db_dags() @@ -2024,7 +2029,7 @@ def test_dagrun_deadline(self, reference_type, reference_column, dag_maker, sess deadline=DeadlineAlert( reference=reference_type, interval=interval, - callback=print, + callback=AsyncCallback(empty_callback_for_deadline), ), ) as dag: ... diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index ced147148d70c..24c6cf87ab156 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -44,7 +44,7 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator from airflow.sdk import BaseOperator, setup, task, task_group, teardown -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats from airflow.triggers.base import StartTriggerArgs @@ -69,7 +69,7 @@ DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE) -def test_callback_for_deadline(): +async def empty_callback_for_deadline(): """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly.""" pass @@ -1294,7 +1294,7 @@ def on_success_callable(context): deadline=DeadlineAlert( reference=DeadlineReference.FIXED_DATETIME(future_date), interval=datetime.timedelta(hours=1), - callback=test_callback_for_deadline, + callback=AsyncCallback(empty_callback_for_deadline), ), ) as dag: ... diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index ef656f7455ae8..d1a7949951f69 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json from datetime import datetime, timedelta from unittest import mock @@ -28,7 +27,7 @@ from airflow.models import DagRun, Trigger from airflow.models.deadline import Deadline, DeadlineCallbackState, ReferenceModels, _fetch_from_db from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.definitions.deadline import DeadlineReference +from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineReference, SyncCallback from airflow.triggers.base import TriggerEvent from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY from airflow.utils.state import DagRunState @@ -55,6 +54,8 @@ async def callback_for_deadline(): TEST_CALLBACK_PATH = f"{__name__}.{callback_for_deadline.__name__}" TEST_CALLBACK_KWARGS = {"arg1": "value1"} +TEST_ASYNC_CALLBACK = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS) +TEST_SYNC_CALLBACK = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS) def _clean_db(): @@ -90,8 +91,7 @@ def test_add_deadline(self, dagrun, session): assert session.query(Deadline).count() == 0 deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_ASYNC_CALLBACK, dag_id=DAG_ID, dagrun_id=dagrun.id, ) @@ -106,7 +106,6 @@ def test_add_deadline(self, dagrun, session): assert result.dagrun_id == deadline_orm.dagrun_id assert result.deadline_time == deadline_orm.deadline_time assert result.callback == deadline_orm.callback - assert result.callback_kwargs == deadline_orm.callback_kwargs @pytest.mark.parametrize( "conditions", @@ -145,23 +144,20 @@ def test_prune_deadlines(self, mock_session, conditions): def test_orm(self): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_ASYNC_CALLBACK, dag_id=DAG_ID, dagrun_id=RUN_ID, ) assert deadline_orm.deadline_time == DEFAULT_DATE - assert deadline_orm.callback == TEST_CALLBACK_PATH - assert deadline_orm.callback_kwargs == TEST_CALLBACK_KWARGS + assert deadline_orm.callback == TEST_ASYNC_CALLBACK assert deadline_orm.dag_id == DAG_ID assert deadline_orm.dagrun_id == RUN_ID def test_repr_with_callback_kwargs(self): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_ASYNC_CALLBACK, dag_id=DAG_ID, dagrun_id=RUN_ID, ) @@ -169,18 +165,18 @@ def test_repr_with_callback_kwargs(self): assert ( repr(deadline_orm) == f"[DagRun Deadline] Dag: {deadline_orm.dag_id} Run: {deadline_orm.dagrun_id} needed by " - f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}({json.dumps(deadline_orm.callback_kwargs)})" + f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})" ) def test_repr_without_callback_kwargs(self): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, + callback=AsyncCallback(TEST_CALLBACK_PATH), dag_id=DAG_ID, dagrun_id=RUN_ID, ) - assert deadline_orm.callback_kwargs is None + assert deadline_orm.callback.kwargs is None assert ( repr(deadline_orm) == f"[DagRun Deadline] Dag: {deadline_orm.dag_id} Run: {deadline_orm.dagrun_id} needed by " @@ -191,8 +187,7 @@ def test_repr_without_callback_kwargs(self): def test_handle_miss_async_callback(self, dagrun, session): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_ASYNC_CALLBACK, dag_id=DAG_ID, dagrun_id=dagrun.id, ) @@ -209,6 +204,22 @@ def test_handle_miss_async_callback(self, dagrun, session): assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS + @pytest.mark.db_test + def test_handle_miss_sync_callback(self, dagrun, session): + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=TEST_SYNC_CALLBACK, + dag_id=DAG_ID, + dagrun_id=dagrun.id, + ) + session.add(deadline_orm) + session.flush() + + with pytest.raises(NotImplementedError): + deadline_orm.handle_miss(session=session) + session.flush() + assert deadline_orm.trigger_id is None + @pytest.mark.db_test @pytest.mark.parametrize( "event, none_trigger_expected", @@ -238,7 +249,7 @@ def test_handle_miss_async_callback(self, dagrun, session): def test_handle_callback_event(self, dagrun, session, event, none_trigger_expected): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, - callback=TEST_CALLBACK_PATH, + callback=TEST_ASYNC_CALLBACK, dag_id=DAG_ID, dagrun_id=dagrun.id, ) diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index d54cb3eac799a..cf56a42211fa3 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -58,7 +58,12 @@ AssetUniqueKey, AssetWatcher, ) -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineAlertFields, DeadlineReference +from airflow.sdk.definitions.deadline import ( + AsyncCallback, + DeadlineAlert, + DeadlineAlertFields, + DeadlineReference, +) from airflow.sdk.definitions.decorators import task from airflow.sdk.definitions.param import Param from airflow.sdk.definitions.taskgroup import TaskGroup @@ -76,7 +81,7 @@ DAG_ID = "dag_id_1" -TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline" +TEST_CALLBACK_PATH = f"{__name__}.empty_callback_for_deadline" TEST_CALLBACK_KWARGS = {"arg1": "value1"} REFERENCE_TYPES = [ @@ -86,7 +91,7 @@ ] -def test_callback_for_deadline(): +async def empty_callback_for_deadline(): """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly.""" pass @@ -364,7 +369,7 @@ def __len__(self) -> int: DeadlineAlert( reference=DeadlineReference.DAGRUN_LOGICAL_DATE, interval=timedelta(), - callback="fake_callable", + callback=AsyncCallback("fake_callable"), ), None, None, @@ -410,8 +415,7 @@ def __len__(self) -> int: DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback="valid.callback.path", - callback_kwargs={"arg1": "value1"}, + callback=AsyncCallback("valid.callback.path", kwargs={"arg1": "value1"}), ), DAT.DEADLINE_ALERT, equals, @@ -445,8 +449,7 @@ def test_serialize_deserialize_deadline_alert(reference): original = DeadlineAlert( reference=reference, interval=timedelta(hours=1), - callback=test_callback_for_deadline, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=AsyncCallback(empty_callback_for_deadline, kwargs=TEST_CALLBACK_KWARGS), ) serialized = original.serialize_deadline_alert() @@ -456,9 +459,7 @@ def test_serialize_deserialize_deadline_alert(reference): deserialized = DeadlineAlert.deserialize_deadline_alert(serialized) assert deserialized.reference.serialize_reference() == reference.serialize_reference() assert deserialized.interval == original.interval - assert deserialized.callback_kwargs == original.callback_kwargs - assert isinstance(deserialized.callback, str) - assert deserialized.callback == TEST_CALLBACK_PATH + assert deserialized.callback == original.callback @pytest.mark.parametrize( diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index 05d73e1ae4a99..8e1e67ae08118 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -16,12 +16,16 @@ # under the License. from __future__ import annotations +import inspect import logging +from abc import ABC from collections.abc import Callable from datetime import datetime, timedelta +from typing import Any, cast from airflow.models.deadline import DeadlineReferenceType, ReferenceModels from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.serde import deserialize, serialize from airflow.utils.module_loading import import_string, is_valid_dotpath logger = logging.getLogger(__name__) @@ -38,7 +42,6 @@ class DeadlineAlertFields: REFERENCE = "reference" INTERVAL = "interval" CALLBACK = "callback" - CALLBACK_KWARGS = "callback_kwargs" class DeadlineAlert: @@ -48,13 +51,11 @@ def __init__( self, reference: DeadlineReferenceType, interval: timedelta, - callback: Callable | str, - callback_kwargs: dict | None = None, + callback: Callback, ): self.reference = reference self.interval = interval - self.callback_kwargs = callback_kwargs or {} - self.callback = self.get_callback_path(callback) + self.callback = callback def __eq__(self, other: object) -> bool: if not isinstance(other, DeadlineAlert): @@ -63,7 +64,6 @@ def __eq__(self, other: object) -> bool: isinstance(self.reference, type(other.reference)) and self.interval == other.interval and self.callback == other.callback - and self.callback_kwargs == other.callback_kwargs ) def __hash__(self) -> int: @@ -72,14 +72,64 @@ def __hash__(self) -> int: type(self.reference).__name__, self.interval, self.callback, - tuple(sorted(self.callback_kwargs.items())) if self.callback_kwargs else None, ) ) - @staticmethod - def get_callback_path(_callback: str | Callable) -> str: + def serialize_deadline_alert(self): + """Return the data in a format that BaseSerialization can handle.""" + return { + Encoding.TYPE: DAT.DEADLINE_ALERT, + Encoding.VAR: { + DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(), + DeadlineAlertFields.INTERVAL: self.interval.total_seconds(), + DeadlineAlertFields.CALLBACK: serialize(self.callback), + }, + } + + @classmethod + def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert: + """Deserialize a DeadlineAlert from serialized data.""" + data = encoded_data.get(Encoding.VAR, encoded_data) + + reference_data = data[DeadlineAlertFields.REFERENCE] + reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD] + + reference_class = ReferenceModels.get_reference_class(reference_type) + reference = reference_class.deserialize_reference(reference_data) + + return cls( + reference=reference, + interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]), + callback=cast("Callback", deserialize(data[DeadlineAlertFields.CALLBACK])), + ) + + +class Callback(ABC): + """ + Base class for Deadline Alert callbacks. + + Callbacks are used to execute custom logic when a deadline is missed. + + The `callback_callable` can be a Python callable type or a string containing the path to the callable that + can be used to import the callable. It must be a top-level callable in a module present on the host where + it will run. + + It will be called with Airflow context and specified kwargs when a deadline is missed. + """ + + path: str + kwargs: dict | None + + def __init__(self, callback_callable: Callable | str, kwargs: dict | None = None): + self.path = self.get_callback_path(callback_callable) + self.kwargs = kwargs + + @classmethod + def get_callback_path(cls, _callback: str | Callable) -> str: """Convert callback to a string path that can be used to import it later.""" if callable(_callback): + cls.verify_callable(_callback) + # TODO: This implementation doesn't support using a lambda function as a callback. # We should consider that in the future, but the addition is non-trivial. # Get the reference path to the callable in the form `airflow.models.deadline.get_from_db` @@ -96,6 +146,9 @@ def get_callback_path(_callback: str | Callable) -> str: if not callable(callback): # The input is a string which can be imported, but is not callable. raise AttributeError(f"Provided callback {callback} is not callable.") + + cls.verify_callable(callback) + except ImportError as e: # Logging here instead of failing because it is possible that the code for the callable # exists somewhere other than on the DAG processor. We are making a best effort to validate, @@ -108,35 +161,80 @@ def get_callback_path(_callback: str | Callable) -> str: return stripped_callback - def serialize_deadline_alert(self): - """Return the data in a format that BaseSerialization can handle.""" - return { - Encoding.TYPE: DAT.DEADLINE_ALERT, - Encoding.VAR: { - DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(), - DeadlineAlertFields.INTERVAL: self.interval.total_seconds(), - DeadlineAlertFields.CALLBACK: self.callback, # Already stored as a string path - DeadlineAlertFields.CALLBACK_KWARGS: self.callback_kwargs, - }, - } + @classmethod + def verify_callable(cls, callback: Callable): + """For additional verification of the callable during initialization in subclasses.""" + pass # No verification needed in the base class @classmethod - def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert: - """Deserialize a DeadlineAlert from serialized data.""" - data = encoded_data.get(Encoding.VAR, encoded_data) + def deserialize(cls, data: dict, version): + path = data.pop("path") + return cls(callback_callable=path, **data) - reference_data = data[DeadlineAlertFields.REFERENCE] - reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD] + @classmethod + def serialized_fields(cls) -> tuple[str, ...]: + return ("path", "kwargs") - reference_class = ReferenceModels.get_reference_class(reference_type) - reference = reference_class.deserialize_reference(reference_data) + def serialize(self) -> dict[str, Any]: + return {f: getattr(self, f) for f in self.serialized_fields()} - return cls( - reference=reference, - interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]), - callback=data[DeadlineAlertFields.CALLBACK], # Keep as string path - callback_kwargs=data[DeadlineAlertFields.CALLBACK_KWARGS], - ) + def __eq__(self, other): + if type(self) is not type(other): + return NotImplemented + return self.serialize() == other.serialize() + + def __hash__(self): + serialized = self.serialize() + hashable_items = [] + for k, v in serialized.items(): + if isinstance(v, dict) and v: + hashable_items.append((k, tuple(sorted(v.items())))) + else: + hashable_items.append((k, v)) + return hash(tuple(sorted(hashable_items))) + + +class AsyncCallback(Callback): + """ + Asynchronous callback that runs in the triggerer. + + The `callback_callable` can be a Python callable type or a string containing the path to the callable that + can be used to import the callable. It must be a top-level awaitable callable in a module present on the + triggerer. + + It will be called with Airflow context and specified kwargs when a deadline is missed. + """ + + def __init__(self, callback_callable: Callable | str, kwargs: dict | None = None): + super().__init__(callback_callable=callback_callable, kwargs=kwargs) + + @classmethod + def verify_callable(cls, callback: Callable): + if not (inspect.iscoroutinefunction(callback) or hasattr(callback, "__await__")): + raise AttributeError(f"Provided callback {callback} is not awaitable.") + + +class SyncCallback(Callback): + """ + Synchronous callback that runs in the specified or default executor. + + The `callback_callable` can be a Python callable type or a string containing the path to the callable that + can be used to import the callable. It must be a top-level callable in a module present on the executor. + + It will be called with Airflow context and specified kwargs when a deadline is missed. + """ + + executor: str | None + + def __init__( + self, callback_callable: Callable | str, kwargs: dict | None = None, executor: str | None = None + ): + super().__init__(callback_callable=callback_callable, kwargs=kwargs) + self.executor = executor + + @classmethod + def serialized_fields(cls) -> tuple[str, ...]: + return super().serialized_fields() + ("executor",) class DeadlineReference: diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py index b62ce1c754e68..761a86ae0a9a2 100644 --- a/task-sdk/tests/task_sdk/definitions/test_deadline.py +++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py @@ -17,10 +17,19 @@ from __future__ import annotations from datetime import datetime, timedelta +from typing import cast import pytest -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import ( + AsyncCallback, + Callback, + DeadlineAlert, + DeadlineReference, + SyncCallback, +) +from airflow.serialization.serde import deserialize, serialize +from airflow.utils.module_loading import qualname UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path" @@ -28,9 +37,6 @@ RUN_ID = 1 DEFAULT_DATE = datetime(2025, 6, 26) -TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline" -TEST_CALLBACK_KWARGS = {"arg1": "value1"} - REFERENCE_TYPES = [ pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"), pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"), @@ -38,47 +44,22 @@ ] -def test_callback_for_deadline(): +async def empty_async_callback_for_deadline_tests(): """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly.""" pass -class TestDeadlineAlert: - @pytest.mark.parametrize( - "callback_value, expected_path", - [ - pytest.param(test_callback_for_deadline, TEST_CALLBACK_PATH, id="valid_callable"), - pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_PATH, id="valid_path_string"), - pytest.param(lambda x: x, None, id="lambda_function"), - pytest.param(TEST_CALLBACK_PATH + " ", TEST_CALLBACK_PATH, id="path_with_whitespace"), - pytest.param(UNIMPORTABLE_DOT_PATH, UNIMPORTABLE_DOT_PATH, id="valid_format_not_importable"), - ], - ) - def test_get_callback_path_happy_cases(self, callback_value, expected_path): - path = DeadlineAlert.get_callback_path(callback_value) - if expected_path is None: - assert path.endswith("") - else: - assert path == expected_path +def empty_sync_callback_for_deadline_tests(): + """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly.""" + pass - @pytest.mark.parametrize( - "callback_value, error_type", - [ - pytest.param(42, ImportError, id="not_a_string"), - pytest.param("", ImportError, id="empty_string"), - pytest.param("os.path", AttributeError, id="non_callable_module"), - ], - ) - def test_get_callback_path_error_cases(self, callback_value, error_type): - expected_message = "" - if error_type is ImportError: - expected_message = "doesn't look like a valid dot path." - elif error_type is AttributeError: - expected_message = "is not callable." - with pytest.raises(error_type, match=expected_message): - DeadlineAlert.get_callback_path(callback_value) +TEST_CALLBACK_PATH = qualname(empty_async_callback_for_deadline_tests) +TEST_CALLBACK_KWARGS = {"arg1": "value1"} +TEST_DEADLINE_CALLBACK = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS) + +class TestDeadlineAlert: @pytest.mark.parametrize( "test_alert, should_equal", [ @@ -86,8 +67,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_DEADLINE_CALLBACK, ), True, id="same_alert", @@ -96,8 +76,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type): DeadlineAlert( reference=DeadlineReference.DAGRUN_LOGICAL_DATE, interval=timedelta(hours=1), - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_DEADLINE_CALLBACK, ), False, id="different_reference", @@ -106,8 +85,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=2), - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_DEADLINE_CALLBACK, ), False, id="different_interval", @@ -116,8 +94,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback="other.callback", - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=AsyncCallback(UNIMPORTABLE_DOT_PATH, kwargs=TEST_CALLBACK_KWARGS), ), False, id="different_callback", @@ -126,8 +103,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback=TEST_CALLBACK_PATH, - callback_kwargs={"arg2": "value2"}, + callback=AsyncCallback(TEST_CALLBACK_PATH, kwargs={"arg2": "value2"}), ), False, id="different_kwargs", @@ -139,8 +115,7 @@ def test_deadline_alert_equality(self, test_alert, should_equal): base_alert = DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback=TEST_CALLBACK_PATH, - callback_kwargs=TEST_CALLBACK_KWARGS, + callback=TEST_DEADLINE_CALLBACK, ) assert (base_alert == test_alert) == should_equal @@ -153,14 +128,12 @@ def test_deadline_alert_hash(self): alert1 = DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=std_interval, - callback=std_callback, - callback_kwargs=std_kwargs, + callback=AsyncCallback(std_callback, kwargs=std_kwargs), ) alert2 = DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=std_interval, - callback=std_callback, - callback_kwargs=std_kwargs, + callback=AsyncCallback(std_callback, kwargs=std_kwargs), ) assert hash(alert1) == hash(alert1) @@ -174,19 +147,195 @@ def test_deadline_alert_in_set(self): alert1 = DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=std_interval, - callback=std_callback, - callback_kwargs=std_kwargs, + callback=AsyncCallback(std_callback, kwargs=std_kwargs), ) alert2 = DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=std_interval, - callback=std_callback, - callback_kwargs=std_kwargs, + callback=AsyncCallback(std_callback, kwargs=std_kwargs), ) alert_set = {alert1, alert2} assert len(alert_set) == 1 +class TestCallback: + @pytest.mark.parametrize( + "callback_callable, expected_path", + [ + pytest.param( + empty_sync_callback_for_deadline_tests, + qualname(empty_sync_callback_for_deadline_tests), + id="valid_sync_callable", + ), + pytest.param( + empty_async_callback_for_deadline_tests, + qualname(empty_async_callback_for_deadline_tests), + id="valid_async_callable", + ), + pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_PATH, id="valid_path_string"), + pytest.param(lambda x: x, None, id="lambda_function"), + pytest.param(TEST_CALLBACK_PATH + " ", TEST_CALLBACK_PATH, id="path_with_whitespace"), + pytest.param(UNIMPORTABLE_DOT_PATH, UNIMPORTABLE_DOT_PATH, id="valid_format_not_importable"), + ], + ) + def test_get_callback_path_happy_cases(self, callback_callable, expected_path): + path = Callback.get_callback_path(callback_callable) + if expected_path is None: + assert path.endswith("") + else: + assert path == expected_path + + @pytest.mark.parametrize( + "callback_callable, error_type", + [ + pytest.param(42, ImportError, id="not_a_string"), + pytest.param("", ImportError, id="empty_string"), + pytest.param("os.path", AttributeError, id="non_callable_module"), + ], + ) + def test_get_callback_path_error_cases(self, callback_callable, error_type): + expected_message = "" + if error_type is ImportError: + expected_message = "doesn't look like a valid dot path." + elif error_type is AttributeError: + expected_message = "is not callable." + + with pytest.raises(error_type, match=expected_message): + Callback.get_callback_path(callback_callable) + + @pytest.mark.parametrize( + "callback1_args, callback2_args, should_equal", + [ + pytest.param( + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + True, + id="identical", + ), + pytest.param( + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS), + False, + id="different_path", + ), + pytest.param( + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (TEST_CALLBACK_PATH, {"other": "kwargs"}), + False, + id="different_kwargs", + ), + pytest.param((TEST_CALLBACK_PATH, None), (TEST_CALLBACK_PATH, None), True, id="both_no_kwargs"), + ], + ) + def test_callback_equality(self, callback1_args, callback2_args, should_equal): + callback1 = AsyncCallback(*callback1_args) + callback2 = AsyncCallback(*callback2_args) + assert (callback1 == callback2) == should_equal + + @pytest.mark.parametrize( + "callback_class, args1, args2, should_be_same_hash", + [ + pytest.param( + AsyncCallback, + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + True, + id="async_identical", + ), + pytest.param( + SyncCallback, + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + True, + id="sync_identical", + ), + pytest.param( + AsyncCallback, + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS), + False, + id="async_different_path", + ), + pytest.param( + SyncCallback, + (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + (TEST_CALLBACK_PATH, {"other": "kwargs"}), + False, + id="sync_different_kwargs", + ), + pytest.param( + AsyncCallback, + (TEST_CALLBACK_PATH, None), + (TEST_CALLBACK_PATH, None), + True, + id="async_no_kwargs", + ), + ], + ) + def test_callback_hash_and_set_behavior(self, callback_class, args1, args2, should_be_same_hash): + callback1 = callback_class(*args1) + callback2 = callback_class(*args2) + assert (hash(callback1) == hash(callback2)) == should_be_same_hash + + +class TestAsyncCallback: + @pytest.mark.parametrize( + "callback_callable, kwargs, expected_path", + [ + pytest.param( + empty_async_callback_for_deadline_tests, + TEST_CALLBACK_KWARGS, + TEST_CALLBACK_PATH, + id="callable", + ), + pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS, TEST_CALLBACK_PATH, id="string_path"), + pytest.param( + UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS, UNIMPORTABLE_DOT_PATH, id="unimportable_path" + ), + ], + ) + def test_init(self, callback_callable, kwargs, expected_path): + callback = AsyncCallback(callback_callable, kwargs=kwargs) + assert callback.path == expected_path + assert callback.kwargs == kwargs + assert isinstance(callback, Callback) + + def test_init_error(self): + with pytest.raises(AttributeError, match="is not awaitable."): + AsyncCallback(empty_sync_callback_for_deadline_tests) + + def test_serialize_deserialize(self): + callback = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS) + serialized = serialize(callback) + deserialized = cast("Callback", deserialize(serialized.copy())) + assert callback == deserialized + + +class TestSyncCallback: + @pytest.mark.parametrize( + "callback_callable, executor", + [ + pytest.param(empty_sync_callback_for_deadline_tests, "remote", id="with_executor"), + pytest.param(empty_sync_callback_for_deadline_tests, None, id="without_executor"), + pytest.param(qualname(empty_sync_callback_for_deadline_tests), None, id="importable_path"), + pytest.param(UNIMPORTABLE_DOT_PATH, None, id="unimportable_path"), + ], + ) + def test_init(self, callback_callable, executor): + callback = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS, executor=executor) + + assert callback.path == TEST_CALLBACK_PATH + assert callback.kwargs == TEST_CALLBACK_KWARGS + assert callback.executor == executor + assert isinstance(callback, Callback) + + def test_serialize_deserialize(self): + callback = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS, executor="local") + serialized = serialize(callback) + deserialized = cast("Callback", deserialize(serialized.copy())) + assert callback == deserialized + + # While DeadlineReference lives in the SDK package, the unit tests to confirm it # works need database access so they live in the models/test_deadline.py module.