diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4b38d403a3112..80f46592710e7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1681,6 +1681,7 @@ repos:
^airflow-core/src/airflow/models/baseoperator\.py$|
^airflow-core/src/airflow/models/connection\.py$|
^airflow-core/src/airflow/models/dag\.py$|
+ ^airflow-core/src/airflow/models/deadline\.py$|
^airflow-core/src/airflow/models/dagbag\.py$|
^airflow-core/src/airflow/models/dagrun\.py$|
^airflow-core/src/airflow/models/mappedoperator\.py$|
diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256
index 07e93bb4ceafa..8bfb32a3713df 100644
--- a/airflow-core/docs/img/airflow_erd.sha256
+++ b/airflow-core/docs/img/airflow_erd.sha256
@@ -1 +1 @@
-efbae2f1de68413e5a6f671a306e748581fe454b81e25eeb2927567f11ebd59c
\ No newline at end of file
+625f362919679fe85b2bfb3f9e053261248ac6ec7a974eee51012e55a8105b94
\ No newline at end of file
diff --git a/airflow-core/docs/img/airflow_erd.svg b/airflow-core/docs/img/airflow_erd.svg
index 537b1e5f6e3fd..e6776c5f1538d 100644
--- a/airflow-core/docs/img/airflow_erd.svg
+++ b/airflow-core/docs/img/airflow_erd.svg
@@ -453,24 +453,24 @@
asset_trigger
-
-asset_trigger
-
-asset_id
-
- [INTEGER]
- NOT NULL
-
-trigger_id
-
- [INTEGER]
- NOT NULL
+
+asset_trigger
+
+asset_id
+
+ [INTEGER]
+ NOT NULL
+
+trigger_id
+
+ [INTEGER]
+ NOT NULL
asset--asset_trigger
-
-0..N
+
+0..N
1
@@ -692,25 +692,25 @@
dagrun_asset_event
-
-dagrun_asset_event
-
-dag_run_id
-
- [INTEGER]
- NOT NULL
-
-event_id
-
- [INTEGER]
- NOT NULL
+
+dagrun_asset_event
+
+dag_run_id
+
+ [INTEGER]
+ NOT NULL
+
+event_id
+
+ [INTEGER]
+ NOT NULL
asset_event--dagrun_asset_event
-
-0..N
-1
+
+0..N
+1
@@ -745,9 +745,9 @@
trigger--asset_trigger
-
-0..N
-1
+
+0..N
+1
@@ -922,50 +922,46 @@
deadline
-
-deadline
-
-id
-
- [UUID]
- NOT NULL
-
-callback
-
- [VARCHAR(500)]
- NOT NULL
-
-callback_kwargs
-
- [JSON]
-
-callback_state
-
- [VARCHAR(20)]
-
-dag_id
-
- [VARCHAR(250)]
-
-dagrun_id
-
- [INTEGER]
-
-deadline_time
-
- [TIMESTAMP]
- NOT NULL
-
-trigger_id
-
- [INTEGER]
+
+deadline
+
+id
+
+ [UUID]
+ NOT NULL
+
+callback
+
+ [JSON]
+ NOT NULL
+
+callback_state
+
+ [VARCHAR(20)]
+
+dag_id
+
+ [VARCHAR(250)]
+
+dagrun_id
+
+ [INTEGER]
+
+deadline_time
+
+ [TIMESTAMP]
+ NOT NULL
+
+trigger_id
+
+ [INTEGER]
trigger--deadline
-
-0..N
-{0,1}
+
+0..N
+{0,1}
@@ -1643,35 +1639,35 @@
{0,1}
-
+
dag--dag_schedule_asset_alias_reference
0..N
1
-
+
dag--dag_schedule_asset_reference
0..N
1
-
+
dag--task_outlet_asset_reference
0..N
1
-
+
dag--task_inlet_asset_reference
0..N
1
-
+
dag--asset_dag_run_queue
0..N
@@ -1680,12 +1676,58 @@
dag--deadline
-
-0..N
+
+0..N
{0,1}
-
+
+dag_version
+
+dag_version
+
+id
+
+ [UUID]
+ NOT NULL
+
+bundle_name
+
+ [VARCHAR(250)]
+
+bundle_version
+
+ [VARCHAR(250)]
+
+created_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+last_updated
+
+ [TIMESTAMP]
+ NOT NULL
+
+version_number
+
+ [INTEGER]
+ NOT NULL
+
+
+
+dag--dag_version
+
+0..N
+1
+
+
+
dag_schedule_asset_name_reference
dag_schedule_asset_name_reference
@@ -1706,14 +1748,14 @@
NOT NULL
-
+
dag--dag_schedule_asset_name_reference
0..N
1
-
+
dag_schedule_asset_uri_reference
dag_schedule_asset_uri_reference
@@ -1734,58 +1776,12 @@
NOT NULL
-
+
dag--dag_schedule_asset_uri_reference
0..N
1
-
-
-dag_version
-
-dag_version
-
-id
-
- [UUID]
- NOT NULL
-
-bundle_name
-
- [VARCHAR(250)]
-
-bundle_version
-
- [VARCHAR(250)]
-
-created_at
-
- [TIMESTAMP]
- NOT NULL
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-last_updated
-
- [TIMESTAMP]
- NOT NULL
-
-version_number
-
- [INTEGER]
- NOT NULL
-
-
-
-dag--dag_version
-
-0..N
-1
-
dag_tag
@@ -2131,9 +2127,9 @@
dag_run--dagrun_asset_event
-
-0..N
-1
+
+0..N
+1
@@ -2152,8 +2148,8 @@
dag_run--deadline
-
-0..N
+
+0..N
{0,1}
diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst
index cf9dc37553978..a27474fdead69 100644
--- a/airflow-core/docs/migrations-ref.rst
+++ b/airflow-core/docs/migrations-ref.rst
@@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
-| ``3bda03debd04`` (head) | ``f56f68b9e02f`` | ``3.1.0`` | Add url template and template params to DagBundleModel. |
+| ``808787349f22`` (head) | ``3bda03debd04`` | ``3.1.0`` | Modify deadline's callback schema. |
++-------------------------+------------------+-------------------+--------------------------------------------------------------+
+| ``3bda03debd04`` | ``f56f68b9e02f`` | ``3.1.0`` | Add url template and template params to DagBundleModel. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``f56f68b9e02f`` | ``09fa89ba1710`` | ``3.1.0`` | Add callback_state to deadline. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
diff --git a/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py b/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py
new file mode 100644
index 0000000000000..0d5431d663f52
--- /dev/null
+++ b/airflow-core/src/airflow/migrations/versions/0081_3_1_0_modify_deadline_callback_schema.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Modify deadline's callback schema.
+
+Revision ID: 808787349f22
+Revises: 3bda03debd04
+Create Date: 2025-07-31 19:35:53.150465
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+import sqlalchemy_jsonfield
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "808787349f22"
+down_revision = "3bda03debd04"
+branch_labels = None
+depends_on = None
+airflow_version = "3.1.0"
+
+
+def upgrade():
+ """Replace deadline table's string callback and JSON callback_kwargs with JSON callback."""
+ with op.batch_alter_table("deadline", schema=None) as batch_op:
+ batch_op.drop_column("callback")
+ batch_op.drop_column("callback_kwargs")
+ batch_op.add_column(sa.Column("callback", sqlalchemy_jsonfield.jsonfield.JSONField(), nullable=False))
+
+
+def downgrade():
+ """Replace deadline table's JSON callback with string callback and JSON callback_kwargs."""
+ with op.batch_alter_table("deadline", schema=None) as batch_op:
+ batch_op.drop_column("callback")
+ batch_op.add_column(
+ sa.Column("callback_kwargs", sqlalchemy_jsonfield.jsonfield.JSONField(), nullable=True)
+ )
+ batch_op.add_column(sa.Column("callback", sa.String(length=500), nullable=False))
diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py
index 5d68ad226d208..d9cbb94c744bc 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -1605,7 +1605,6 @@ def create_dagrun(
run_id=run_id,
),
callback=self.deadline.callback,
- callback_kwargs=self.deadline.callback_kwargs or {},
dag_id=self.dag_id,
dagrun_id=orm_dagrun.id,
)
diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py
index 2a0924f91d29f..c52c1fb785672 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -21,7 +21,8 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
-from typing import TYPE_CHECKING, Any
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, cast
import sqlalchemy_jsonfield
import uuid6
@@ -33,6 +34,7 @@
from airflow._shared.timezones import timezone
from airflow.models import Trigger
from airflow.models.base import Base, StringID
+from airflow.serialization.serde import deserialize, serialize
from airflow.settings import json
from airflow.triggers.deadline import PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -42,6 +44,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session
+ from airflow.sdk.definitions.deadline import Callback
from airflow.triggers.base import TriggerEvent
@@ -76,7 +79,11 @@ def __get__(self, instance, cls=None):
class DeadlineCallbackState(str, Enum):
- """All possible states of deadline callbacks."""
+ """
+ All possible states of deadline callbacks once the deadline is missed.
+
+ `None` state implies that the deadline is pending (`deadline_time` hasn't passed yet).
+ """
QUEUED = "queued"
SUCCESS = "success"
@@ -96,10 +103,8 @@ class Deadline(Base):
# The time after which the Deadline has passed and the callback should be triggered.
deadline_time = Column(UtcDateTime, nullable=False)
- # The Callback to be called when the Deadline has passed.
- callback = Column(String(500), nullable=False)
- # Serialized kwargs to pass to the callback.
- callback_kwargs = Column(sqlalchemy_jsonfield.JSONField(json=json))
+ # The (serialized) callback to be called when the Deadline has passed.
+ _callback = Column("callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False)
# The state of the deadline callback
callback_state = Column(String(20))
@@ -114,15 +119,13 @@ class Deadline(Base):
def __init__(
self,
deadline_time: datetime,
- callback: str,
- callback_kwargs: dict | None = None,
+ callback: Callback,
dag_id: str | None = None,
dagrun_id: int | None = None,
):
super().__init__()
self.deadline_time = deadline_time
- self.callback = callback
- self.callback_kwargs = callback_kwargs
+ self._callback = serialize(callback)
self.dag_id = dag_id
self.dagrun_id = dagrun_id
@@ -136,11 +139,10 @@ def _determine_resource() -> tuple[str, str]:
return "Unknown", ""
resource_type, resource_details = _determine_resource()
- callback_kwargs = json.dumps(self.callback_kwargs) if self.callback_kwargs else ""
return (
f"[{resource_type} Deadline] {resource_details} needed by "
- f"{self.deadline_time} or run: {self.callback}({callback_kwargs})"
+ f"{self.deadline_time} or run: {self.callback.path}({self.callback.kwargs or ''})"
)
@classmethod
@@ -193,18 +195,30 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) ->
return deleted_count
+ @cached_property
+ def callback(self) -> Callback:
+ return cast("Callback", deserialize(self._callback))
+
def handle_miss(self, session: Session):
- """Handle a missed deadline by creating a trigger to run the callback."""
- # TODO: check to see if the callback is meant to run in triggerer or executor. For now, the code below assumes it's for the triggerer
- callback_trigger = DeadlineCallbackTrigger(
- callback_path=self.callback,
- callback_kwargs=self.callback_kwargs,
- )
+ """Handle a missed deadline by running the callback in the appropriate host and updating the `callback_state`."""
+ from airflow.sdk.definitions.deadline import AsyncCallback, SyncCallback
+
+ if isinstance(self.callback, AsyncCallback):
+ callback_trigger = DeadlineCallbackTrigger(
+ callback_path=self.callback.path,
+ callback_kwargs=self.callback.kwargs,
+ )
+ trigger_orm = Trigger.from_object(callback_trigger)
+ session.add(trigger_orm)
+ session.flush()
+ self.trigger = trigger_orm
+
+ elif isinstance(self.callback, SyncCallback):
+ raise NotImplementedError("SyncCallback is currently not supported")
+
+ else:
+ raise TypeError("Unknown Callback type")
- trigger_orm = Trigger.from_object(callback_trigger)
- session.add(trigger_orm)
- session.flush()
- self.trigger_id = trigger_orm.id
self.callback_state = DeadlineCallbackState.QUEUED
session.add(self)
diff --git a/airflow-core/src/airflow/triggers/deadline.py b/airflow-core/src/airflow/triggers/deadline.py
index d04910dd8b8ef..4229695854d10 100644
--- a/airflow-core/src/airflow/triggers/deadline.py
+++ b/airflow-core/src/airflow/triggers/deadline.py
@@ -22,7 +22,7 @@
from typing import Any
from airflow.triggers.base import BaseTrigger, TriggerEvent
-from airflow.utils.module_loading import import_string
+from airflow.utils.module_loading import import_string, qualname
log = logging.getLogger(__name__)
@@ -40,8 +40,8 @@ def __init__(self, callback_path: str, callback_kwargs: dict[str, Any] | None =
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
- f"{type(self).__module__}.{type(self).__qualname__}",
- {"callback_path": self.callback_path, "callback_kwargs": self.callback_kwargs},
+ qualname(self),
+ {attr: getattr(self, attr) for attr in ("callback_path", "callback_kwargs")},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py
index 4cf633360dab6..311198e5f0b95 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -93,7 +93,7 @@ class MappedClassProtocol(Protocol):
"2.10.3": "5f2621c13b39",
"3.0.0": "29ce7909c52b",
"3.0.3": "fe199e1abd77",
- "3.1.0": "3bda03debd04",
+ "3.1.0": "808787349f22",
}
diff --git a/airflow-core/src/airflow/utils/module_loading.py b/airflow-core/src/airflow/utils/module_loading.py
index 49f9db05c9df8..028b5d31103ec 100644
--- a/airflow-core/src/airflow/utils/module_loading.py
+++ b/airflow-core/src/airflow/utils/module_loading.py
@@ -53,6 +53,7 @@ def import_string(dotted_path: str):
Raise ImportError if the import failed.
"""
+ # TODO: Add support for nested classes. Currently, it only works for top-level classes.
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError:
diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py
index 7f2f4b01394f5..55fb1a3cf27d9 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -70,7 +70,7 @@
from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
-from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
+from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference
from airflow.sdk.definitions.param import Param
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.simple import (
@@ -112,6 +112,11 @@
repo_root = Path(__file__).parents[2]
+async def empty_callback_for_deadline():
+ """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
+ pass
+
+
@pytest.fixture
def clear_dags():
clear_db_dags()
@@ -2024,7 +2029,7 @@ def test_dagrun_deadline(self, reference_type, reference_column, dag_maker, sess
deadline=DeadlineAlert(
reference=reference_type,
interval=interval,
- callback=print,
+ callback=AsyncCallback(empty_callback_for_deadline),
),
) as dag:
...
diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py
index ced147148d70c..24c6cf87ab156 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -44,7 +44,7 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
from airflow.sdk import BaseOperator, setup, task, task_group, teardown
-from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
+from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.stats import Stats
from airflow.triggers.base import StartTriggerArgs
@@ -69,7 +69,7 @@
DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
-def test_callback_for_deadline():
+async def empty_callback_for_deadline():
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
pass
@@ -1294,7 +1294,7 @@ def on_success_callable(context):
deadline=DeadlineAlert(
reference=DeadlineReference.FIXED_DATETIME(future_date),
interval=datetime.timedelta(hours=1),
- callback=test_callback_for_deadline,
+ callback=AsyncCallback(empty_callback_for_deadline),
),
) as dag:
...
diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py
index ef656f7455ae8..d1a7949951f69 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import json
from datetime import datetime, timedelta
from unittest import mock
@@ -28,7 +27,7 @@
from airflow.models import DagRun, Trigger
from airflow.models.deadline import Deadline, DeadlineCallbackState, ReferenceModels, _fetch_from_db
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk.definitions.deadline import DeadlineReference
+from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineReference, SyncCallback
from airflow.triggers.base import TriggerEvent
from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY
from airflow.utils.state import DagRunState
@@ -55,6 +54,8 @@ async def callback_for_deadline():
TEST_CALLBACK_PATH = f"{__name__}.{callback_for_deadline.__name__}"
TEST_CALLBACK_KWARGS = {"arg1": "value1"}
+TEST_ASYNC_CALLBACK = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS)
+TEST_SYNC_CALLBACK = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS)
def _clean_db():
@@ -90,8 +91,7 @@ def test_add_deadline(self, dagrun, session):
assert session.query(Deadline).count() == 0
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_ASYNC_CALLBACK,
dag_id=DAG_ID,
dagrun_id=dagrun.id,
)
@@ -106,7 +106,6 @@ def test_add_deadline(self, dagrun, session):
assert result.dagrun_id == deadline_orm.dagrun_id
assert result.deadline_time == deadline_orm.deadline_time
assert result.callback == deadline_orm.callback
- assert result.callback_kwargs == deadline_orm.callback_kwargs
@pytest.mark.parametrize(
"conditions",
@@ -145,23 +144,20 @@ def test_prune_deadlines(self, mock_session, conditions):
def test_orm(self):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_ASYNC_CALLBACK,
dag_id=DAG_ID,
dagrun_id=RUN_ID,
)
assert deadline_orm.deadline_time == DEFAULT_DATE
- assert deadline_orm.callback == TEST_CALLBACK_PATH
- assert deadline_orm.callback_kwargs == TEST_CALLBACK_KWARGS
+ assert deadline_orm.callback == TEST_ASYNC_CALLBACK
assert deadline_orm.dag_id == DAG_ID
assert deadline_orm.dagrun_id == RUN_ID
def test_repr_with_callback_kwargs(self):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_ASYNC_CALLBACK,
dag_id=DAG_ID,
dagrun_id=RUN_ID,
)
@@ -169,18 +165,18 @@ def test_repr_with_callback_kwargs(self):
assert (
repr(deadline_orm)
== f"[DagRun Deadline] Dag: {deadline_orm.dag_id} Run: {deadline_orm.dagrun_id} needed by "
- f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}({json.dumps(deadline_orm.callback_kwargs)})"
+ f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})"
)
def test_repr_without_callback_kwargs(self):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
+ callback=AsyncCallback(TEST_CALLBACK_PATH),
dag_id=DAG_ID,
dagrun_id=RUN_ID,
)
- assert deadline_orm.callback_kwargs is None
+ assert deadline_orm.callback.kwargs is None
assert (
repr(deadline_orm)
== f"[DagRun Deadline] Dag: {deadline_orm.dag_id} Run: {deadline_orm.dagrun_id} needed by "
@@ -191,8 +187,7 @@ def test_repr_without_callback_kwargs(self):
def test_handle_miss_async_callback(self, dagrun, session):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_ASYNC_CALLBACK,
dag_id=DAG_ID,
dagrun_id=dagrun.id,
)
@@ -209,6 +204,22 @@ def test_handle_miss_async_callback(self, dagrun, session):
assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
+ @pytest.mark.db_test
+ def test_handle_miss_sync_callback(self, dagrun, session):
+ deadline_orm = Deadline(
+ deadline_time=DEFAULT_DATE,
+ callback=TEST_SYNC_CALLBACK,
+ dag_id=DAG_ID,
+ dagrun_id=dagrun.id,
+ )
+ session.add(deadline_orm)
+ session.flush()
+
+ with pytest.raises(NotImplementedError):
+ deadline_orm.handle_miss(session=session)
+ session.flush()
+ assert deadline_orm.trigger_id is None
+
@pytest.mark.db_test
@pytest.mark.parametrize(
"event, none_trigger_expected",
@@ -238,7 +249,7 @@ def test_handle_miss_async_callback(self, dagrun, session):
def test_handle_callback_event(self, dagrun, session, event, none_trigger_expected):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
- callback=TEST_CALLBACK_PATH,
+ callback=TEST_ASYNC_CALLBACK,
dag_id=DAG_ID,
dagrun_id=dagrun.id,
)
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index d54cb3eac799a..cf56a42211fa3 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -58,7 +58,12 @@
AssetUniqueKey,
AssetWatcher,
)
-from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineAlertFields, DeadlineReference
+from airflow.sdk.definitions.deadline import (
+ AsyncCallback,
+ DeadlineAlert,
+ DeadlineAlertFields,
+ DeadlineReference,
+)
from airflow.sdk.definitions.decorators import task
from airflow.sdk.definitions.param import Param
from airflow.sdk.definitions.taskgroup import TaskGroup
@@ -76,7 +81,7 @@
DAG_ID = "dag_id_1"
-TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline"
+TEST_CALLBACK_PATH = f"{__name__}.empty_callback_for_deadline"
TEST_CALLBACK_KWARGS = {"arg1": "value1"}
REFERENCE_TYPES = [
@@ -86,7 +91,7 @@
]
-def test_callback_for_deadline():
+async def empty_callback_for_deadline():
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
pass
@@ -364,7 +369,7 @@ def __len__(self) -> int:
DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(),
- callback="fake_callable",
+ callback=AsyncCallback("fake_callable"),
),
None,
None,
@@ -410,8 +415,7 @@ def __len__(self) -> int:
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback="valid.callback.path",
- callback_kwargs={"arg1": "value1"},
+ callback=AsyncCallback("valid.callback.path", kwargs={"arg1": "value1"}),
),
DAT.DEADLINE_ALERT,
equals,
@@ -445,8 +449,7 @@ def test_serialize_deserialize_deadline_alert(reference):
original = DeadlineAlert(
reference=reference,
interval=timedelta(hours=1),
- callback=test_callback_for_deadline,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=AsyncCallback(empty_callback_for_deadline, kwargs=TEST_CALLBACK_KWARGS),
)
serialized = original.serialize_deadline_alert()
@@ -456,9 +459,7 @@ def test_serialize_deserialize_deadline_alert(reference):
deserialized = DeadlineAlert.deserialize_deadline_alert(serialized)
assert deserialized.reference.serialize_reference() == reference.serialize_reference()
assert deserialized.interval == original.interval
- assert deserialized.callback_kwargs == original.callback_kwargs
- assert isinstance(deserialized.callback, str)
- assert deserialized.callback == TEST_CALLBACK_PATH
+ assert deserialized.callback == original.callback
@pytest.mark.parametrize(
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py
index 05d73e1ae4a99..8e1e67ae08118 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -16,12 +16,16 @@
# under the License.
from __future__ import annotations
+import inspect
import logging
+from abc import ABC
from collections.abc import Callable
from datetime import datetime, timedelta
+from typing import Any, cast
from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
+from airflow.serialization.serde import deserialize, serialize
from airflow.utils.module_loading import import_string, is_valid_dotpath
logger = logging.getLogger(__name__)
@@ -38,7 +42,6 @@ class DeadlineAlertFields:
REFERENCE = "reference"
INTERVAL = "interval"
CALLBACK = "callback"
- CALLBACK_KWARGS = "callback_kwargs"
class DeadlineAlert:
@@ -48,13 +51,11 @@ def __init__(
self,
reference: DeadlineReferenceType,
interval: timedelta,
- callback: Callable | str,
- callback_kwargs: dict | None = None,
+ callback: Callback,
):
self.reference = reference
self.interval = interval
- self.callback_kwargs = callback_kwargs or {}
- self.callback = self.get_callback_path(callback)
+ self.callback = callback
def __eq__(self, other: object) -> bool:
if not isinstance(other, DeadlineAlert):
@@ -63,7 +64,6 @@ def __eq__(self, other: object) -> bool:
isinstance(self.reference, type(other.reference))
and self.interval == other.interval
and self.callback == other.callback
- and self.callback_kwargs == other.callback_kwargs
)
def __hash__(self) -> int:
@@ -72,14 +72,64 @@ def __hash__(self) -> int:
type(self.reference).__name__,
self.interval,
self.callback,
- tuple(sorted(self.callback_kwargs.items())) if self.callback_kwargs else None,
)
)
- @staticmethod
- def get_callback_path(_callback: str | Callable) -> str:
+ def serialize_deadline_alert(self):
+ """Return the data in a format that BaseSerialization can handle."""
+ return {
+ Encoding.TYPE: DAT.DEADLINE_ALERT,
+ Encoding.VAR: {
+ DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(),
+ DeadlineAlertFields.INTERVAL: self.interval.total_seconds(),
+ DeadlineAlertFields.CALLBACK: serialize(self.callback),
+ },
+ }
+
+ @classmethod
+ def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert:
+ """Deserialize a DeadlineAlert from serialized data."""
+ data = encoded_data.get(Encoding.VAR, encoded_data)
+
+ reference_data = data[DeadlineAlertFields.REFERENCE]
+ reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD]
+
+ reference_class = ReferenceModels.get_reference_class(reference_type)
+ reference = reference_class.deserialize_reference(reference_data)
+
+ return cls(
+ reference=reference,
+ interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
+ callback=cast("Callback", deserialize(data[DeadlineAlertFields.CALLBACK])),
+ )
+
+
+class Callback(ABC):
+ """
+ Base class for Deadline Alert callbacks.
+
+ Callbacks are used to execute custom logic when a deadline is missed.
+
+ The `callback_callable` can be a Python callable type or a string containing the path to the callable that
+ can be used to import the callable. It must be a top-level callable in a module present on the host where
+ it will run.
+
+ It will be called with Airflow context and specified kwargs when a deadline is missed.
+ """
+
+ path: str
+ kwargs: dict | None
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None = None):
+ self.path = self.get_callback_path(callback_callable)
+ self.kwargs = kwargs
+
+ @classmethod
+ def get_callback_path(cls, _callback: str | Callable) -> str:
"""Convert callback to a string path that can be used to import it later."""
if callable(_callback):
+ cls.verify_callable(_callback)
+
# TODO: This implementation doesn't support using a lambda function as a callback.
# We should consider that in the future, but the addition is non-trivial.
# Get the reference path to the callable in the form `airflow.models.deadline.get_from_db`
@@ -96,6 +146,9 @@ def get_callback_path(_callback: str | Callable) -> str:
if not callable(callback):
# The input is a string which can be imported, but is not callable.
raise AttributeError(f"Provided callback {callback} is not callable.")
+
+ cls.verify_callable(callback)
+
except ImportError as e:
# Logging here instead of failing because it is possible that the code for the callable
# exists somewhere other than on the DAG processor. We are making a best effort to validate,
@@ -108,35 +161,80 @@ def get_callback_path(_callback: str | Callable) -> str:
return stripped_callback
- def serialize_deadline_alert(self):
- """Return the data in a format that BaseSerialization can handle."""
- return {
- Encoding.TYPE: DAT.DEADLINE_ALERT,
- Encoding.VAR: {
- DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(),
- DeadlineAlertFields.INTERVAL: self.interval.total_seconds(),
- DeadlineAlertFields.CALLBACK: self.callback, # Already stored as a string path
- DeadlineAlertFields.CALLBACK_KWARGS: self.callback_kwargs,
- },
- }
+ @classmethod
+ def verify_callable(cls, callback: Callable):
+ """For additional verification of the callable during initialization in subclasses."""
+ pass # No verification needed in the base class
@classmethod
- def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert:
- """Deserialize a DeadlineAlert from serialized data."""
- data = encoded_data.get(Encoding.VAR, encoded_data)
+ def deserialize(cls, data: dict, version):
+ path = data.pop("path")
+ return cls(callback_callable=path, **data)
- reference_data = data[DeadlineAlertFields.REFERENCE]
- reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD]
+ @classmethod
+ def serialized_fields(cls) -> tuple[str, ...]:
+ return ("path", "kwargs")
- reference_class = ReferenceModels.get_reference_class(reference_type)
- reference = reference_class.deserialize_reference(reference_data)
+ def serialize(self) -> dict[str, Any]:
+ return {f: getattr(self, f) for f in self.serialized_fields()}
- return cls(
- reference=reference,
- interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
- callback=data[DeadlineAlertFields.CALLBACK], # Keep as string path
- callback_kwargs=data[DeadlineAlertFields.CALLBACK_KWARGS],
- )
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return NotImplemented
+ return self.serialize() == other.serialize()
+
+ def __hash__(self):
+ serialized = self.serialize()
+ hashable_items = []
+ for k, v in serialized.items():
+ if isinstance(v, dict) and v:
+ hashable_items.append((k, tuple(sorted(v.items()))))
+ else:
+ hashable_items.append((k, v))
+ return hash(tuple(sorted(hashable_items)))
+
+
+class AsyncCallback(Callback):
+ """
+ Asynchronous callback that runs in the triggerer.
+
+ The `callback_callable` can be a Python callable type or a string containing the path to the callable that
+ can be used to import the callable. It must be a top-level awaitable callable in a module present on the
+ triggerer.
+
+ It will be called with Airflow context and specified kwargs when a deadline is missed.
+ """
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None = None):
+ super().__init__(callback_callable=callback_callable, kwargs=kwargs)
+
+ @classmethod
+ def verify_callable(cls, callback: Callable):
+ if not (inspect.iscoroutinefunction(callback) or hasattr(callback, "__await__")):
+ raise AttributeError(f"Provided callback {callback} is not awaitable.")
+
+
+class SyncCallback(Callback):
+ """
+ Synchronous callback that runs in the specified or default executor.
+
+ The `callback_callable` can be a Python callable type or a string containing the path to the callable that
+ can be used to import the callable. It must be a top-level callable in a module present on the executor.
+
+ It will be called with Airflow context and specified kwargs when a deadline is missed.
+ """
+
+ executor: str | None
+
+ def __init__(
+ self, callback_callable: Callable | str, kwargs: dict | None = None, executor: str | None = None
+ ):
+ super().__init__(callback_callable=callback_callable, kwargs=kwargs)
+ self.executor = executor
+
+ @classmethod
+ def serialized_fields(cls) -> tuple[str, ...]:
+ return super().serialized_fields() + ("executor",)
class DeadlineReference:
diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py
index b62ce1c754e68..761a86ae0a9a2 100644
--- a/task-sdk/tests/task_sdk/definitions/test_deadline.py
+++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py
@@ -17,10 +17,19 @@
from __future__ import annotations
from datetime import datetime, timedelta
+from typing import cast
import pytest
-from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
+from airflow.sdk.definitions.deadline import (
+ AsyncCallback,
+ Callback,
+ DeadlineAlert,
+ DeadlineReference,
+ SyncCallback,
+)
+from airflow.serialization.serde import deserialize, serialize
+from airflow.utils.module_loading import qualname
UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path"
@@ -28,9 +37,6 @@
RUN_ID = 1
DEFAULT_DATE = datetime(2025, 6, 26)
-TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline"
-TEST_CALLBACK_KWARGS = {"arg1": "value1"}
-
REFERENCE_TYPES = [
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
@@ -38,47 +44,22 @@
]
-def test_callback_for_deadline():
+async def empty_async_callback_for_deadline_tests():
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
pass
-class TestDeadlineAlert:
- @pytest.mark.parametrize(
- "callback_value, expected_path",
- [
- pytest.param(test_callback_for_deadline, TEST_CALLBACK_PATH, id="valid_callable"),
- pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_PATH, id="valid_path_string"),
- pytest.param(lambda x: x, None, id="lambda_function"),
- pytest.param(TEST_CALLBACK_PATH + " ", TEST_CALLBACK_PATH, id="path_with_whitespace"),
- pytest.param(UNIMPORTABLE_DOT_PATH, UNIMPORTABLE_DOT_PATH, id="valid_format_not_importable"),
- ],
- )
- def test_get_callback_path_happy_cases(self, callback_value, expected_path):
- path = DeadlineAlert.get_callback_path(callback_value)
- if expected_path is None:
- assert path.endswith("")
- else:
- assert path == expected_path
+def empty_sync_callback_for_deadline_tests():
+ """Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
+ pass
- @pytest.mark.parametrize(
- "callback_value, error_type",
- [
- pytest.param(42, ImportError, id="not_a_string"),
- pytest.param("", ImportError, id="empty_string"),
- pytest.param("os.path", AttributeError, id="non_callable_module"),
- ],
- )
- def test_get_callback_path_error_cases(self, callback_value, error_type):
- expected_message = ""
- if error_type is ImportError:
- expected_message = "doesn't look like a valid dot path."
- elif error_type is AttributeError:
- expected_message = "is not callable."
- with pytest.raises(error_type, match=expected_message):
- DeadlineAlert.get_callback_path(callback_value)
+TEST_CALLBACK_PATH = qualname(empty_async_callback_for_deadline_tests)
+TEST_CALLBACK_KWARGS = {"arg1": "value1"}
+TEST_DEADLINE_CALLBACK = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS)
+
+class TestDeadlineAlert:
@pytest.mark.parametrize(
"test_alert, should_equal",
[
@@ -86,8 +67,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_DEADLINE_CALLBACK,
),
True,
id="same_alert",
@@ -96,8 +76,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_DEADLINE_CALLBACK,
),
False,
id="different_reference",
@@ -106,8 +85,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=2),
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_DEADLINE_CALLBACK,
),
False,
id="different_interval",
@@ -116,8 +94,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback="other.callback",
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=AsyncCallback(UNIMPORTABLE_DOT_PATH, kwargs=TEST_CALLBACK_KWARGS),
),
False,
id="different_callback",
@@ -126,8 +103,7 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback=TEST_CALLBACK_PATH,
- callback_kwargs={"arg2": "value2"},
+ callback=AsyncCallback(TEST_CALLBACK_PATH, kwargs={"arg2": "value2"}),
),
False,
id="different_kwargs",
@@ -139,8 +115,7 @@ def test_deadline_alert_equality(self, test_alert, should_equal):
base_alert = DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback=TEST_CALLBACK_PATH,
- callback_kwargs=TEST_CALLBACK_KWARGS,
+ callback=TEST_DEADLINE_CALLBACK,
)
assert (base_alert == test_alert) == should_equal
@@ -153,14 +128,12 @@ def test_deadline_alert_hash(self):
alert1 = DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=std_interval,
- callback=std_callback,
- callback_kwargs=std_kwargs,
+ callback=AsyncCallback(std_callback, kwargs=std_kwargs),
)
alert2 = DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=std_interval,
- callback=std_callback,
- callback_kwargs=std_kwargs,
+ callback=AsyncCallback(std_callback, kwargs=std_kwargs),
)
assert hash(alert1) == hash(alert1)
@@ -174,19 +147,195 @@ def test_deadline_alert_in_set(self):
alert1 = DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=std_interval,
- callback=std_callback,
- callback_kwargs=std_kwargs,
+ callback=AsyncCallback(std_callback, kwargs=std_kwargs),
)
alert2 = DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=std_interval,
- callback=std_callback,
- callback_kwargs=std_kwargs,
+ callback=AsyncCallback(std_callback, kwargs=std_kwargs),
)
alert_set = {alert1, alert2}
assert len(alert_set) == 1
+class TestCallback:
+ @pytest.mark.parametrize(
+ "callback_callable, expected_path",
+ [
+ pytest.param(
+ empty_sync_callback_for_deadline_tests,
+ qualname(empty_sync_callback_for_deadline_tests),
+ id="valid_sync_callable",
+ ),
+ pytest.param(
+ empty_async_callback_for_deadline_tests,
+ qualname(empty_async_callback_for_deadline_tests),
+ id="valid_async_callable",
+ ),
+ pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_PATH, id="valid_path_string"),
+ pytest.param(lambda x: x, None, id="lambda_function"),
+ pytest.param(TEST_CALLBACK_PATH + " ", TEST_CALLBACK_PATH, id="path_with_whitespace"),
+ pytest.param(UNIMPORTABLE_DOT_PATH, UNIMPORTABLE_DOT_PATH, id="valid_format_not_importable"),
+ ],
+ )
+ def test_get_callback_path_happy_cases(self, callback_callable, expected_path):
+ path = Callback.get_callback_path(callback_callable)
+ if expected_path is None:
+ assert path.endswith("")
+ else:
+ assert path == expected_path
+
+ @pytest.mark.parametrize(
+ "callback_callable, error_type",
+ [
+ pytest.param(42, ImportError, id="not_a_string"),
+ pytest.param("", ImportError, id="empty_string"),
+ pytest.param("os.path", AttributeError, id="non_callable_module"),
+ ],
+ )
+ def test_get_callback_path_error_cases(self, callback_callable, error_type):
+ expected_message = ""
+ if error_type is ImportError:
+ expected_message = "doesn't look like a valid dot path."
+ elif error_type is AttributeError:
+ expected_message = "is not callable."
+
+ with pytest.raises(error_type, match=expected_message):
+ Callback.get_callback_path(callback_callable)
+
+ @pytest.mark.parametrize(
+ "callback1_args, callback2_args, should_equal",
+ [
+ pytest.param(
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ True,
+ id="identical",
+ ),
+ pytest.param(
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS),
+ False,
+ id="different_path",
+ ),
+ pytest.param(
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (TEST_CALLBACK_PATH, {"other": "kwargs"}),
+ False,
+ id="different_kwargs",
+ ),
+ pytest.param((TEST_CALLBACK_PATH, None), (TEST_CALLBACK_PATH, None), True, id="both_no_kwargs"),
+ ],
+ )
+ def test_callback_equality(self, callback1_args, callback2_args, should_equal):
+ callback1 = AsyncCallback(*callback1_args)
+ callback2 = AsyncCallback(*callback2_args)
+ assert (callback1 == callback2) == should_equal
+
+ @pytest.mark.parametrize(
+ "callback_class, args1, args2, should_be_same_hash",
+ [
+ pytest.param(
+ AsyncCallback,
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ True,
+ id="async_identical",
+ ),
+ pytest.param(
+ SyncCallback,
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ True,
+ id="sync_identical",
+ ),
+ pytest.param(
+ AsyncCallback,
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS),
+ False,
+ id="async_different_path",
+ ),
+ pytest.param(
+ SyncCallback,
+ (TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ (TEST_CALLBACK_PATH, {"other": "kwargs"}),
+ False,
+ id="sync_different_kwargs",
+ ),
+ pytest.param(
+ AsyncCallback,
+ (TEST_CALLBACK_PATH, None),
+ (TEST_CALLBACK_PATH, None),
+ True,
+ id="async_no_kwargs",
+ ),
+ ],
+ )
+ def test_callback_hash_and_set_behavior(self, callback_class, args1, args2, should_be_same_hash):
+ callback1 = callback_class(*args1)
+ callback2 = callback_class(*args2)
+ assert (hash(callback1) == hash(callback2)) == should_be_same_hash
+
+
+class TestAsyncCallback:
+ @pytest.mark.parametrize(
+ "callback_callable, kwargs, expected_path",
+ [
+ pytest.param(
+ empty_async_callback_for_deadline_tests,
+ TEST_CALLBACK_KWARGS,
+ TEST_CALLBACK_PATH,
+ id="callable",
+ ),
+ pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS, TEST_CALLBACK_PATH, id="string_path"),
+ pytest.param(
+ UNIMPORTABLE_DOT_PATH, TEST_CALLBACK_KWARGS, UNIMPORTABLE_DOT_PATH, id="unimportable_path"
+ ),
+ ],
+ )
+ def test_init(self, callback_callable, kwargs, expected_path):
+ callback = AsyncCallback(callback_callable, kwargs=kwargs)
+ assert callback.path == expected_path
+ assert callback.kwargs == kwargs
+ assert isinstance(callback, Callback)
+
+ def test_init_error(self):
+ with pytest.raises(AttributeError, match="is not awaitable."):
+ AsyncCallback(empty_sync_callback_for_deadline_tests)
+
+ def test_serialize_deserialize(self):
+ callback = AsyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS)
+ serialized = serialize(callback)
+ deserialized = cast("Callback", deserialize(serialized.copy()))
+ assert callback == deserialized
+
+
+class TestSyncCallback:
+ @pytest.mark.parametrize(
+ "callback_callable, executor",
+ [
+ pytest.param(empty_sync_callback_for_deadline_tests, "remote", id="with_executor"),
+ pytest.param(empty_sync_callback_for_deadline_tests, None, id="without_executor"),
+ pytest.param(qualname(empty_sync_callback_for_deadline_tests), None, id="importable_path"),
+ pytest.param(UNIMPORTABLE_DOT_PATH, None, id="unimportable_path"),
+ ],
+ )
+ def test_init(self, callback_callable, executor):
+ callback = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS, executor=executor)
+
+ assert callback.path == TEST_CALLBACK_PATH
+ assert callback.kwargs == TEST_CALLBACK_KWARGS
+ assert callback.executor == executor
+ assert isinstance(callback, Callback)
+
+ def test_serialize_deserialize(self):
+ callback = SyncCallback(TEST_CALLBACK_PATH, kwargs=TEST_CALLBACK_KWARGS, executor="local")
+ serialized = serialize(callback)
+ deserialized = cast("Callback", deserialize(serialized.copy()))
+ assert callback == deserialized
+
+
# While DeadlineReference lives in the SDK package, the unit tests to confirm it
# works need database access so they live in the models/test_deadline.py module.