Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@

from __future__ import annotations

from datetime import datetime, timezone
import json
from datetime import date, datetime, timedelta, timezone
from textwrap import dedent
from zoneinfo import ZoneInfo

import sqlalchemy as sa
from alembic import context, op
from sqlalchemy import column, select, table

from airflow.serialization.serde import deserialize
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime

# revision identifiers, used by Alembic.
Expand All @@ -45,16 +46,132 @@
airflow_version = "3.2.0"

BATCH_SIZE = 1000
_CALLBACK_TYPE_TRIGGERER = "triggerer"
_CALLBACK_FETCH_METHOD_IMPORT_PATH = "import_path"
_CALLBACK_STATE_FAILED = "failed"
_CALLBACK_STATE_PENDING = "pending"
_CALLBACK_STATE_SUCCESS = "success"
_CALLBACK_TERMINAL_STATES = frozenset({_CALLBACK_STATE_FAILED, _CALLBACK_STATE_SUCCESS})
_CALLBACK_METRICS_PREFIX = "deadline_alerts"
_SERDE_CLASSNAME = "__classname__"
_SERDE_VERSION = "__version__"
_SERDE_DATA = "__data__"
_SERDE_DATETIME_TIMESTAMP = "timestamp"
_SERDE_DATETIME_TIMEZONE = "tz"
_SERDE_TUPLE_TYPES = {
"builtins.frozenset": frozenset,
"builtins.set": set,
"builtins.tuple": tuple,
}
_SERDE_DATETIME_TYPES = {
"datetime.datetime",
"pendulum.datetime.DateTime",
}
_SERDE_DATE_TYPES = {
"datetime.date",
"pendulum.date.Date",
}
_SERDE_TIMEDELTA_TYPES = {
"datetime.timedelta",
}
_SERDE_TIMEZONE_TYPES = {
"pendulum.tz.timezone.FixedTimezone",
"pendulum.tz.timezone.Timezone",
"zoneinfo.ZoneInfo",
}


def _deserialize_task_sdk_value(value):
"""Deserialize a minimal subset of Task SDK serde values used in callback kwargs."""
if value is None or isinstance(value, bool | float | int | str):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Airflow requires Python >= 3.10, and isinstance(value, bool | float | int | str) works fine from 3.10 onward.

return value

if isinstance(value, list):
return [_deserialize_task_sdk_value(item) for item in value]

if not isinstance(value, dict):
return value

if _SERDE_CLASSNAME not in value or _SERDE_VERSION not in value:
return {key: _deserialize_task_sdk_value(item) for key, item in value.items()}

classname = value[_SERDE_CLASSNAME]
data = _deserialize_task_sdk_value(value.get(_SERDE_DATA))

if classname in _SERDE_TUPLE_TYPES:
return _SERDE_TUPLE_TYPES[classname](data)

if classname in _SERDE_TIMEZONE_TYPES:
return _deserialize_task_sdk_timezone(data)

if classname in _SERDE_DATETIME_TYPES:
if not isinstance(data, dict) or _SERDE_DATETIME_TIMESTAMP not in data:
raise ValueError(f"Unsupported datetime serde payload: {value!r}")
tz = None
if _SERDE_DATETIME_TIMEZONE in data:
tz = _deserialize_task_sdk_timezone(data[_SERDE_DATETIME_TIMEZONE])
return datetime.fromtimestamp(float(data[_SERDE_DATETIME_TIMESTAMP]), tz=tz)

if classname in _SERDE_DATE_TYPES:
if not isinstance(data, str):
raise ValueError(f"Unsupported date serde payload: {value!r}")
return date.fromisoformat(data)

if classname in _SERDE_TIMEDELTA_TYPES:
return timedelta(seconds=float(data))

raise ValueError(f"Unsupported Task SDK serde value in callback kwargs: {classname}")


def _deserialize_task_sdk_timezone(value):
"""Deserialize timezone values produced by Task SDK serde."""
if value in (None, "UTC"):
return timezone.utc if value == "UTC" else None

if isinstance(value, int):
return timezone(timedelta(seconds=value))

if isinstance(value, str):
return ZoneInfo(value)

if isinstance(value, list) and len(value) == 3:
data, classname, _version = value
if classname in _SERDE_TIMEZONE_TYPES:
return _deserialize_task_sdk_timezone(data)

Comment on lines +131 to +141
raise ValueError(f"Unsupported timezone serde payload: {value!r}")


def _extract_callback_data(deadline_callback, dag_id):
"""Convert the stored deadline callback payload into the new callback.data structure."""
callback = deadline_callback if isinstance(deadline_callback, dict) else json.loads(deadline_callback)
callback_data = callback.get(_SERDE_DATA)
if not isinstance(callback_data, dict):
raise ValueError(f"Deadline callback is missing {_SERDE_DATA}: {callback!r}")

callback_path = callback_data.get("path")
if not callback_path:
raise ValueError(f"Deadline callback is missing path: {callback!r}")

return {
"path": callback_path,
"kwargs": _deserialize_task_sdk_value(callback_data.get("kwargs", {})),
"prefix": _CALLBACK_METRICS_PREFIX,
"dag_id": dag_id,
}


def _get_migrated_callback_state(callback_state):
"""Return callback state and missed flag for the migrated callback row."""
if callback_state in _CALLBACK_TERMINAL_STATES:
return callback_state, True
return _CALLBACK_STATE_PENDING, False


def upgrade():
"""Replace Deadline table's inline callback fields with callback_id foreign key."""
import uuid6

from airflow.models.base import StringID
from airflow.models.callback import CallbackFetchMethod, CallbackState, CallbackType
from airflow.models.deadline import CALLBACK_METRICS_PREFIX

timestamp = datetime.now(timezone.utc)

def migrate_batch(conn, deadline_table, callback_table, batch):
Expand All @@ -65,31 +182,17 @@ def migrate_batch(conn, deadline_table, callback_table, batch):
try:
callback_id = uuid6.uuid7()

# Transform serialized callback to the new representation
callback_data = deserialize(deadline.callback).serialize() | {
"prefix": CALLBACK_METRICS_PREFIX,
"dag_id": deadline.dag_id,
}

if deadline.callback_state and deadline.callback_state in {
CallbackState.FAILED,
CallbackState.SUCCESS,
}:
deadline_missed = True
callback_state = deadline.callback_state
else:
# Mark the deadlines in non-terminal states as not missed so the scheduler handles them
deadline_missed = False
callback_state = CallbackState.PENDING
callback_data = _extract_callback_data(deadline.callback, deadline.dag_id)
callback_state, deadline_missed = _get_migrated_callback_state(deadline.callback_state)

callback_inserts.append(
{
"id": callback_id,
"type": CallbackType.TRIGGERER, # Past versions only support triggerer callbacks
"fetch_method": CallbackFetchMethod.IMPORT_PATH, # Past versions only support import_path
"type": _CALLBACK_TYPE_TRIGGERER,
"fetch_method": _CALLBACK_FETCH_METHOD_IMPORT_PATH,
"data": callback_data,
"state": callback_state,
"priority_weight": 1, # Default priority weight
"priority_weight": 1,
"created_at": timestamp,
}
)
Expand Down Expand Up @@ -136,7 +239,7 @@ def migrate_all_data():
dag_run_table = table(
"dag_run",
column("id", sa.Integer()),
column("dag_id", StringID()),
column("dag_id", sa.String()),
)

callback_table = table(
Expand All @@ -152,26 +255,31 @@ def migrate_all_data():

conn = op.get_bind()
batch_num = 0
last_seen_id = None
while True:
batch_num += 1
batch = conn.execute(
batch_query = (
select(
deadline_table.c.id,
deadline_table.c.dagrun_id,
deadline_table.c.deadline_time,
deadline_table.c.callback,
deadline_table.c.callback_state,
dag_run_table.c.dag_id,
)
.join(dag_run_table, deadline_table.c.dagrun_id == dag_run_table.c.id)
.where(deadline_table.c.callback_id.is_(None)) # Only get rows that haven't been migrated yet
.where(deadline_table.c.callback_id.is_(None))
.order_by(deadline_table.c.id)
.limit(BATCH_SIZE)
).fetchall()
)
if last_seen_id is not None:
batch_query = batch_query.where(deadline_table.c.id > last_seen_id)

batch = conn.execute(batch_query).fetchall()

if not batch:
break

migrate_batch(conn, deadline_table, callback_table, batch)
last_seen_id = batch[-1].id
print(f"Migrated {len(batch)} deadline records in batch {batch_num}")

# Add new columns (temporarily nullable until data has been migrated)
Expand Down Expand Up @@ -199,7 +307,6 @@ def migrate_all_data():

def downgrade():
"""Restore Deadline table's inline callback fields from callback_id foreign key."""
from airflow.models.callback import CallbackState

def migrate_batch(conn, deadline_table, callback_table, batch):
deadline_updates = []
Expand All @@ -215,14 +322,14 @@ def migrate_batch(conn, deadline_table, callback_table, batch):
# from airflow.sdk.definitions.deadline import AsyncCallback
# callback_serialized = serialize(AsyncCallback.deserialize(filtered_data, 0))
callback_serialized = {
"__data__": filtered_cb_data,
"__classname__": "airflow.sdk.definitions.deadline.AsyncCallback",
"__version__": 0,
_SERDE_DATA: filtered_cb_data,
_SERDE_CLASSNAME: "airflow.sdk.definitions.deadline.AsyncCallback",
_SERDE_VERSION: 0,
}

# Mark the deadline as not handled if its callback is not in a terminal state so that the
# scheduler handles it appropriately
if row.callback_state in {CallbackState.SUCCESS, CallbackState.FAILED}:
if row.callback_state in {_CALLBACK_STATE_SUCCESS, _CALLBACK_STATE_FAILED}:
callback_state = row.callback_state
else:
callback_state = None
Expand Down
Loading
Loading