Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@

from __future__ import annotations

import json
from textwrap import dedent

import sqlalchemy as sa
from alembic import op
from alembic import context, op

from airflow.configuration import conf

# revision identifiers, used by Alembic.
revision = "808787349f22"
Expand All @@ -38,17 +43,199 @@
airflow_version = "3.1.0"


_ASYNC_CALLBACK_CLASSNAME = "airflow.sdk.definitions.deadline.AsyncCallback"
# Maximum length of the callback VARCHAR column in the pre-0080 schema.
_CALLBACK_MAX_LEN = 500


def upgrade():
Comment thread
vatsrahul1001 marked this conversation as resolved.
"""Replace deadline table's string callback and JSON callback_kwargs with JSON callback."""
if context.is_offline_mode():
print(
dedent("""
------------
-- WARNING: Unable to migrate the data in the deadline table
-- while in offline mode! All rows in the deadline table will
-- be deleted.
------------
""")
)
op.execute("DELETE FROM deadline")
with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.drop_column("callback")
batch_op.drop_column("callback_kwargs")
batch_op.add_column(sa.Column("callback", sa.JSON(), nullable=False))
return

conn = op.get_bind()
batch_size = conf.getint("database", "migration_batch_size", fallback=1000)

# Add the destination column alongside the existing ones so we can migrate
# in batches without loading the whole table into memory at once.
with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.add_column(sa.Column("callback_new", sa.JSON(), nullable=True))

deadline_read = sa.table(
"deadline",
sa.column("id"),
sa.column("callback"),
sa.column("callback_kwargs", sa.JSON()),
sa.column("callback_new", sa.JSON()),
)
deadline_write = sa.table(
"deadline",
sa.column("id"),
sa.column("callback_new", sa.JSON()),
)

while True:
rows = conn.execute(
sa.select(
deadline_read.c.id,
deadline_read.c.callback,
deadline_read.c.callback_kwargs,
)
.where(deadline_read.c.callback_new.is_(None))
.limit(batch_size)
).fetchall()

if not rows:
break

batch = []
for row in rows:
path = row[1] or ""
kwargs = row[2]
if isinstance(kwargs, str):
kwargs = json.loads(kwargs) if kwargs else {}
if not isinstance(kwargs, dict):
kwargs = {}
batch.append(
{
"row_id": row[0],
"new_callback": {
"__data__": {"path": path, "kwargs": kwargs},
"__classname__": _ASYNC_CALLBACK_CLASSNAME,
"__version__": 0,
},
}
)

conn.execute(
sa.update(deadline_write)
.where(deadline_write.c.id == sa.bindparam("row_id"))
.values(callback_new=sa.bindparam("new_callback")),
batch,
)

if len(rows) < batch_size:
break

with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.drop_column("callback")
batch_op.drop_column("callback_kwargs")
batch_op.add_column(sa.Column("callback", sa.JSON(), nullable=False))
batch_op.alter_column(
"callback_new",
new_column_name="callback",
existing_type=sa.JSON(),
nullable=False,
)


def downgrade():
"""Replace deadline table's JSON callback with string callback and JSON callback_kwargs."""
if context.is_offline_mode():
print(
dedent("""
------------
-- WARNING: Unable to migrate the data in the deadline table
-- while in offline mode! All rows in the deadline table will
-- be deleted.
------------
""")
)
op.execute("DELETE FROM deadline")
with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.drop_column("callback")
batch_op.add_column(sa.Column("callback_kwargs", sa.JSON(), nullable=True))
batch_op.add_column(sa.Column("callback", sa.String(length=500), nullable=False))
return

conn = op.get_bind()
batch_size = conf.getint("database", "migration_batch_size", fallback=1000)

# Add the restored columns alongside the existing JSON callback so we can
# back-fill in batches before dropping the JSON column.
with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.drop_column("callback")
batch_op.add_column(sa.Column("callback_old", sa.String(length=500), nullable=True))
batch_op.add_column(sa.Column("callback_kwargs", sa.JSON(), nullable=True))
batch_op.add_column(sa.Column("callback", sa.String(length=500), nullable=False))

deadline_read = sa.table(
"deadline",
sa.column("id"),
sa.column("callback", sa.JSON()),
sa.column("callback_old", sa.String(500)),
)
deadline_write = sa.table(
"deadline",
sa.column("id"),
sa.column("callback_old", sa.String(500)),
sa.column("callback_kwargs", sa.JSON()),
)

while True:
rows = conn.execute(
sa.select(deadline_read.c.id, deadline_read.c.callback)
.where(deadline_read.c.callback_old.is_(None))
.limit(batch_size)
).fetchall()

if not rows:
break

batch = []
for row in rows:
cb = row[1]
if cb is None:
path, kwargs = "", {}
else:
if isinstance(cb, str):
cb = json.loads(cb)
cb_inner = cb.get("__data__", cb) if isinstance(cb, dict) else {}
path = cb_inner.get("path", "")
if len(path) > _CALLBACK_MAX_LEN:
print(
f"WARNING: callback path for deadline {row[0]} exceeds "
f"{_CALLBACK_MAX_LEN} chars and will be truncated."
)
path = path[:_CALLBACK_MAX_LEN]
kwargs = cb_inner.get("kwargs", {})
if not isinstance(kwargs, dict):
print(
f"WARNING: kwargs for deadline {row[0]} is not a dict "
f"(type={type(kwargs).__name__}); resetting to empty dict."
)
kwargs = {}
batch.append({"row_id": row[0], "old_callback": path, "old_kwargs": kwargs})

conn.execute(
sa.update(deadline_write)
.where(deadline_write.c.id == sa.bindparam("row_id"))
.values(
callback_old=sa.bindparam("old_callback"),
callback_kwargs=sa.bindparam("old_kwargs"),
),
batch,
)

if len(rows) < batch_size:
break

with op.batch_alter_table("deadline", schema=None) as batch_op:
batch_op.drop_column("callback")
batch_op.alter_column(
"callback_old",
new_column_name="callback",
existing_type=sa.String(500),
nullable=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _upgrade_postgresql(conn, batch_size):
d.id AS deadline_id,
gen_random_uuid() AS callback_id,
COALESCE(dr.dag_id, '') AS dag_id,
d.callback::jsonb->'__data__'->>'path' AS cb_path,
d.callback::jsonb->'__data__'->'kwargs' AS cb_kwargs,
COALESCE(d.callback::jsonb->'__data__'->>'path', '') AS cb_path,
COALESCE(NULLIF(d.callback::jsonb->'__data__'->'kwargs', 'null'::jsonb), '{}'::jsonb) AS cb_kwargs,
CASE
WHEN d.callback_state IN (:state_success, :state_failed) THEN d.callback_state
ELSE :state_pending
Expand Down Expand Up @@ -177,6 +177,7 @@ def _upgrade_mysql_sqlite(conn, batch_size):
)

batch_num = 0
null_callback_count = 0
while True:
batch_num += 1
batch = conn.execute(
Expand All @@ -199,11 +200,23 @@ def _upgrade_mysql_sqlite(conn, batch_size):

for row in batch:
callback_id = uuid6.uuid7()
cb = row.callback if isinstance(row.callback, dict) else json.loads(row.callback)
cb_inner = cb.get("__data__", cb)
raw_cb = row.callback
if raw_cb is None:
null_callback_count += 1
cb = {}
elif isinstance(raw_cb, dict):
cb = raw_cb
else:
cb = json.loads(raw_cb)
cb_inner = cb.get("__data__", cb) if isinstance(cb, dict) else {}
if not isinstance(cb_inner, dict):
cb_inner = {}
kwargs = cb_inner.get("kwargs", {})
if not isinstance(kwargs, dict):
kwargs = {}
cb_data = {
"path": cb_inner.get("path", ""),
"kwargs": cb_inner.get("kwargs", {}),
"path": cb_inner.get("path", "") or "",
"kwargs": kwargs,
"prefix": _CALLBACK_METRICS_PREFIX,
"dag_id": row.dag_id or "",
}
Expand Down Expand Up @@ -237,6 +250,12 @@ def _upgrade_mysql_sqlite(conn, batch_size):
)
print(f"Migrated {len(batch)} deadline records in batch {batch_num}")

if null_callback_count:
print(
f"WARNING: {null_callback_count} deadline rows had NULL callback "
"(legacy 0080 data); migrated with empty envelope."
)


def upgrade():
"""Replace Deadline table's inline callback fields with callback_id foreign key."""
Expand Down
Loading
Loading