Skip to content

Commit

Permalink
refactor(api_connexion): merge multiple SQLs into one SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Feb 19, 2024
1 parent 941fb30 commit beeba70
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def _generate_queued_event_where_clause(
if dataset_id is not None:
where_clause.append(DatasetDagRunQueue.dataset_id == dataset_id)
if uri is not None:
where_clause.append(DatasetModel.uri == uri)
where_clause.append(
DatasetDagRunQueue.dataset_id.in_(
select(DatasetModel.id).where(DatasetModel.uri == uri),
),
)
if before is not None:
where_clause.append(DatasetDagRunQueue.created_at < format_datetime(before))
if permitted_dag_ids is not None:
Expand Down Expand Up @@ -185,9 +189,10 @@ def delete_dag_dataset_queued_event(
*, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete a queued Dataset event for a DAG."""
dataset_id = session.scalars(select(DatasetModel.id).where(DatasetModel.uri == uri)).one_or_none()
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, dataset_id=dataset_id, before=before)
delete_stmt = delete(DatasetDagRunQueue).where(*where_clause)
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, uri=uri, before=before)
delete_stmt = (
delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch")
)
result = session.execute(delete_stmt)
if result.rowcount > 0:
return NoContent, HTTPStatus.NO_CONTENT
Expand Down Expand Up @@ -281,13 +286,14 @@ def delete_dataset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued Dataset events for a Dataset."""
dataset_id = session.scalars(select(DatasetModel.id).where(DatasetModel.uri == uri)).one_or_none()
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
where_clause = _generate_queued_event_where_clause(
dataset_id=dataset_id, before=before, permitted_dag_ids=permitted_dag_ids
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
delete_stmt = (
delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch")
)

delete_stmt = delete(DatasetDagRunQueue).where(*where_clause)
result = session.execute(delete_stmt)
if result.rowcount > 0:
return NoContent, HTTPStatus.NO_CONTENT
Expand Down

0 comments on commit beeba70

Please sign in to comment.