Skip to content

Commit

Permalink
SQL query improvements in utils/db.py (#32518)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Jul 11, 2023
1 parent 3a8da4b commit 3f6ac2f
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from typing import TYPE_CHECKING, Callable, Generator, Iterable

from sqlalchemy import Table, and_, column, delete, exc, func, inspect, or_, select, table, text, tuple_
from sqlalchemy.orm.session import Session

import airflow
from airflow import settings
Expand All @@ -45,9 +44,10 @@
if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from airflow.models.base import Base
from airflow.models.connection import Connection

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,9 +90,9 @@ def _format_airflow_moved_table_name(source_table, version, category):


@provide_session
def merge_conn(conn, session: Session = NEW_SESSION):
def merge_conn(conn: Connection, session: Session = NEW_SESSION):
"""Add new Connection."""
if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)):
session.add(conn)
session.commit()

Expand Down Expand Up @@ -957,20 +957,20 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
from airflow.models.connection import Connection

dups = []
try:
dups = session.execute(
dups = session.scalars(
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
).all()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
return
if dups:
yield (
"Seems you have non unique conn_id in connection table.\n"
"You have to manage those duplicate connections "
"before upgrading the database.\n"
f"Duplicated conn_id: {[dup.conn_id for dup in dups]}"
f"Duplicated conn_id: {dups}"
)


Expand Down Expand Up @@ -1057,11 +1057,11 @@ def check_task_fail_for_duplicates(session):
:param uniqueness: uniqueness constraint to evaluate against
:param session: session of the sqlalchemy
"""
minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
try:
subquery = session.execute(
select(minimal_table_obj, func.count().label("dupe_count"))
.group_by(*[text(x) for x in uniqueness])
.group_by(*(text(x) for x in uniqueness))
.having(func.count() > text("1"))
.subquery()
)
Expand Down Expand Up @@ -1100,20 +1100,20 @@ def check_conn_type_null(session: Session) -> Iterable[str]:
"""
from airflow.models.connection import Connection

n_nulls = []
try:
n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
# fallback if tables hasn't been created yet
session.rollback()
return

if n_nulls:
yield (
"The conn_type column in the connection "
"table must contain content.\n"
"Make sure you don't have null "
"in the conn_type column.\n"
f"Null conn_type conn_id: {list(n_nulls)}"
f"Null conn_type conn_id: {n_nulls}"
)


Expand Down Expand Up @@ -1265,7 +1265,7 @@ def _dangling_against_dag_run(session, source_table, dag_run):
)

return (
select(*[c.label(c.name) for c in source_table.c])
select(*(c.label(c.name) for c in source_table.c))
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.where(dag_run.c.dag_id.is_(None))
)
Expand Down Expand Up @@ -1306,9 +1306,9 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
)

return (
select(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
select(*(c.label(c.name) for c in source_table.c))
.outerjoin(dag_run, dr_join_cond)
.outerjoin(task_instance, ti_join_cond)
.where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
)

Expand All @@ -1335,9 +1335,9 @@ def _move_duplicate_data_to_new_table(
dialect_name = bind.dialect.name

query = (
select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns))
.select_from(source_table)
.join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness]))
.join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness)))
)

_create_table_as(
Expand All @@ -1353,7 +1353,7 @@ def _move_duplicate_data_to_new_table(

metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
where_clause = and_(*[getattr(source_table.c, x) == getattr(target_table.c, x) for x in uniqueness])
where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness))

if dialect_name == "sqlite":
subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
Expand Down Expand Up @@ -1410,7 +1410,7 @@ class BadReferenceConfig:
(TaskFail, "2.3", missing_ti_config),
(XCom, "2.3", missing_ti_config),
]
metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)
metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session)

if (
not metadata.tables
Expand Down

0 comments on commit 3f6ac2f

Please sign in to comment.