Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -448,7 +450,7 @@ def write_dag(
)
)

if result.rowcount == 0:
if getattr(result, "rowcount", 0) == 0:
# No rows updated - serialized DAG doesn't exist
return False
# The dag_version and dag_code may not have changed, still we should
Expand Down Expand Up @@ -491,7 +493,7 @@ def latest_item_select_object(cls, dag_id):
@provide_session
def get_latest_serialized_dags(
cls, *, dag_ids: list[str], session: Session = NEW_SESSION
) -> list[SerializedDagModel]:
) -> Sequence[SerializedDagModel]:
"""
Get the latest serialized dags of given DAGs.

Expand Down Expand Up @@ -613,7 +615,8 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[

:param session: ORM Session
"""
load_json: Callable | None
load_json: Callable
data_col_to_select: ColumnElement[Any] | InstrumentedAttribute[bytes | None]
if COMPRESS_SERIALIZED_DAGS is False:
dialect = get_dialect_name(session)
if dialect in ["sqlite", "mysql"]:
Expand All @@ -625,10 +628,10 @@ def load_json(deps_data):
# Use #> operator which works for both JSON and JSONB types
# Returns the JSON sub-object at the specified path
data_col_to_select = cls._data.op("#>")(literal('{"dag","dag_dependencies"}'))
load_json = None
load_json = lambda x: x
else:
data_col_to_select = func.json_extract_path(cls._data, "dag", "dag_dependencies")
load_json = None
load_json = lambda x: x
else:
data_col_to_select = cls._data_compressed

Expand All @@ -648,11 +651,7 @@ def load_json(deps_data):
.join(cls.dag_model)
.where(~DagModel.is_stale)
)
iterator = (
[(dag_id, load_json(deps_data)) for dag_id, deps_data in query]
if load_json is not None
else query.all()
)
resolver = _DagDependenciesResolver(dag_id_dependencies=iterator, session=session)
dag_depdendencies = [(str(dag_id), load_json(deps_data)) for dag_id, deps_data in query]
resolver = _DagDependenciesResolver(dag_id_dependencies=dag_depdendencies, session=session)
dag_depdendencies_by_dag = resolver.resolve()
return dag_depdendencies_by_dag