Skip to content

Commit

Permalink
Refactor sqlalchemy queries to 2.0 style (Part 1) (#31569)
Browse files Browse the repository at this point in the history
* Refactor sqlalchemy queries to 2.0 style (Part 1)

This commit updates the sqlalchemy queries to adopt the 2.0 query style,
which is compatible with version 1.4. The changes involve updating the engine
with the future=True flag to indicate the execution of queries using the 2.0 style.
As a result, textual SQL statements are wrapped with the text function.

In addition, queries that previously used delete and update operations have been
modified to utilize the new delete/update construct. Furthermore, all queries within the
jobs/ and api/ directories have been thoroughly updated to employ the new style queries.

Please note that this commit intentionally stops at this point to ensure ease of review for the pull request.
The only test change is the addition of the future flag to the create_engine function.

* Apply suggestions from code review

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* Use session.scalar()

* Fix typing and remove unique()

* fixup! Apply suggestions from code review

* revert session.scalar()

* use session.scalars where possible

* Use session.scalar() with limit

* Fix pre-commit

* fixup! Use session.scalar() with limit

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
ephraimbuddy and uranusjr committed May 30, 2023
1 parent 0cbc0dc commit 0f1cef2
Show file tree
Hide file tree
Showing 39 changed files with 443 additions and 362 deletions.
47 changes: 26 additions & 21 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import logging

from sqlalchemy import and_, or_
from sqlalchemy import and_, delete, or_, select
from sqlalchemy.orm import Session

from airflow import models
Expand All @@ -47,25 +47,28 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
:return count of deleted dags
"""
log.info("Deleting DAG: %s", dag_id)
running_tis = (
session.query(models.TaskInstance.state)
.filter(models.TaskInstance.dag_id == dag_id)
.filter(models.TaskInstance.state == State.RUNNING)
.first()
running_tis = session.scalar(
select(models.TaskInstance.state)
.where(models.TaskInstance.dag_id == dag_id)
.where(models.TaskInstance.state == State.RUNNING)
.limit(1)
)
if running_tis:
raise AirflowException("TaskInstances still running")
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id).limit(1))
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# deleting a DAG should also delete all of its subdags
dags_to_delete_query = session.query(DagModel.dag_id).filter(
or_(
DagModel.dag_id == dag_id,
and_(DagModel.dag_id.like(f"{dag_id}.%"), DagModel.is_subdag),
dags_to_delete_query = session.execute(
select(DagModel.dag_id).where(
or_(
DagModel.dag_id == dag_id,
and_(DagModel.dag_id.like(f"{dag_id}.%"), DagModel.is_subdag),
)
)
)

dags_to_delete = [dag_id for dag_id, in dags_to_delete_query]

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
Expand All @@ -79,22 +82,24 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
if hasattr(model, "dag_id"):
if keep_records_in_log and model.__name__ == "Log":
continue
count += (
session.query(model)
.filter(model.dag_id.in_(dags_to_delete))
.delete(synchronize_session="fetch")
)
count += session.execute(
delete(model)
.where(model.dag_id.in_(dags_to_delete))
.execution_options(synchronize_session="fetch")
).rowcount
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)
count += session.execute(
delete(model).where(model.dag_id == parent_dag_id, model.task_id == task_id)
).rowcount

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session="fetch"
session.execute(
delete(models.ImportError)
.where(models.ImportError.filename == dag.fileloc)
.execution_options(synchronize_session="fetch")
)

return count
7 changes: 4 additions & 3 deletions airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from deprecated import deprecated
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.exceptions import AirflowBadRequest, PoolNotFound
Expand All @@ -33,7 +34,7 @@ def get_pool(name, session: Session = NEW_SESSION):
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")

pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
raise PoolNotFound(f"Pool '{name}' doesn't exist")

Expand Down Expand Up @@ -65,7 +66,7 @@ def create_pool(name, slots, description, session: Session = NEW_SESSION):
raise AirflowBadRequest(f"Pool name can't be more than {pool_name_length} characters")

session.expire_on_commit = False
pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
session.add(pool)
Expand All @@ -88,7 +89,7 @@ def delete_pool(name, session: Session = NEW_SESSION):
if name == Pool.DEFAULT_POOL_NAME:
raise AirflowBadRequest(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")

pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
raise PoolNotFound(f"Pool '{name}' doesn't exist")

Expand Down
56 changes: 30 additions & 26 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple

from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.orm import Session as SASession, lazyload

from airflow.models.dag import DAG
Expand Down Expand Up @@ -148,10 +148,10 @@ def set_state(
qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)

if commit:
tis_altered = qry_dag.with_for_update().all()
tis_altered = session.scalars(qry_dag.with_for_update()).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.with_for_update().all()
tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
for task_instance in tis_altered:
# The try_number was decremented when setting to up_for_reschedule and deferred.
# Increment it back when changing the state again
Expand All @@ -160,10 +160,10 @@ def set_state(
task_instance.set_state(state, session=session)
session.flush()
else:
tis_altered = qry_dag.all()
tis_altered = session.scalars(qry_dag).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.all()
tis_altered += session.scalars(qry_sub_dag).all()
return tis_altered


Expand All @@ -175,9 +175,9 @@ def all_subdag_tasks_query(
):
"""Get *all* tasks of the sub dags."""
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
select(TaskInstance)
.where(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_sub_dag

Expand All @@ -190,13 +190,13 @@ def get_all_dag_task_query(
run_ids: Iterable[str],
):
"""Get all tasks of the main dag that will be affected by a state change."""
qry_dag = session.query(TaskInstance).filter(
qry_dag = select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id.in_(run_ids),
TaskInstance.ti_selector_condition(task_ids),
)

qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
lazyload(TaskInstance.dag_run)
)
return qry_dag
Expand Down Expand Up @@ -324,11 +324,8 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess
"""Return DAG executions' run_ids."""
last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session)
current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
first_dagrun = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id)
.order_by(DagRun.execution_date.asc())
.first()
first_dagrun = session.scalar(
select(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date.asc()).limit(1)
)

if last_dagrun is None:
Expand Down Expand Up @@ -361,7 +358,9 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
:param state: target state
:param session: database session
"""
dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
dag_run = session.execute(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
Expand Down Expand Up @@ -464,12 +463,15 @@ def set_dag_run_state_to_failed(

# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
tis = session.scalars(
select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
)

task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

tasks = []
Expand All @@ -480,11 +482,13 @@ def set_dag_run_state_to_failed(
tasks.append(task)

# Mark non-finished tasks as SKIPPED.
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
tis = session.scalars(
select(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
)

tis = [ti for ti in tis]
Expand Down
7 changes: 5 additions & 2 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flask import g
from flask_login import current_user
from marshmallow import ValidationError
from sqlalchemy import or_
from sqlalchemy import delete, or_
from sqlalchemy.orm import Query, Session

from airflow.api.common.mark_tasks import (
Expand Down Expand Up @@ -74,7 +74,10 @@
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a DAG Run."""
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
deleted_count = session.execute(
delete(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)
).rowcount
if deleted_count == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, HTTPStatus.NO_CONTENT

Expand Down
5 changes: 3 additions & 2 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from flask import Response
from marshmallow import ValidationError
from sqlalchemy import func
from sqlalchemy import delete, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

Expand All @@ -41,7 +41,8 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons
"""Delete a pool."""
if pool_name == "default_pool":
raise BadRequest(detail="Default Pool can't be deleted")
affected_count = session.query(Pool).filter(Pool.pool == pool_name).delete()
affected_count = session.execute(delete(Pool).where(Pool.pool == pool_name)).rowcount

if affected_count == 0:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
return Response(status=HTTPStatus.NO_CONTENT)
Expand Down
3 changes: 2 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings

from graphviz.dot import Dot
from sqlalchemy import delete
from sqlalchemy.orm import Session

from airflow import settings
Expand Down Expand Up @@ -507,7 +508,7 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
@cli_utils.action_cli
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
"""Serialize a DAG instance."""
session.query(SerializedDagModel).delete(synchronize_session=False)
session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False))

if not args.clear_only:
dagbag = DagBag(process_subdir(args.subdir))
Expand Down
10 changes: 6 additions & 4 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing import TYPE_CHECKING, Iterable, Iterator

from setproctitle import setproctitle
from sqlalchemy import exc, func, or_
from sqlalchemy import delete, exc, func, or_
from sqlalchemy.orm.session import Session

from airflow import settings
Expand Down Expand Up @@ -610,9 +610,11 @@ def update_import_errors(
# Clear the errors of the processed files
# that no longer have errors
for dagbag_file in files_without_error:
session.query(errors.ImportError).filter(
errors.ImportError.filename.startswith(dagbag_file)
).delete(synchronize_session="fetch")
session.execute(
delete(errors.ImportError)
.where(errors.ImportError.filename.startswith(dagbag_file))
.execution_options(synchronize_session="fetch")
)

# files that still have errors
existing_import_error_files = [x.filename for x in session.query(errors.ImportError.filename).all()]
Expand Down

0 comments on commit 0f1cef2

Please sign in to comment.