Skip to content

Commit

Permalink
Add "queuedEvent" endpoint to get/delete DatasetDagRunQueue (#37176)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Feb 20, 2024
1 parent de4502f commit 16d2671
Show file tree
Hide file tree
Showing 7 changed files with 1,026 additions and 4 deletions.
183 changes: 180 additions & 3 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,31 @@
# under the License.
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING

from sqlalchemy import func, select
from connexion import NoContent
from sqlalchemy import delete, func, select
from sqlalchemy.orm import joinedload, subqueryload

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dataset_schema import (
DatasetCollection,
DatasetEventCollection,
QueuedEvent,
QueuedEventCollection,
dataset_collection_schema,
dataset_event_collection_schema,
dataset_schema,
queued_event_collection_schema,
queued_event_schema,
)
from airflow.models.dataset import DatasetEvent, DatasetModel
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -124,3 +131,173 @@ def get_dataset_events(
return dataset_event_collection_schema.dump(
DatasetEventCollection(dataset_events=events, total_entries=total_entries)
)


def _generate_queued_event_where_clause(
*,
dag_id: str | None = None,
dataset_id: int | None = None,
uri: str | None = None,
before: str | None = None,
permitted_dag_ids: set[str] | None = None,
) -> list:
"""Get DatasetDagRunQueue where clause."""
where_clause = []
if dag_id is not None:
where_clause.append(DatasetDagRunQueue.target_dag_id == dag_id)
if dataset_id is not None:
where_clause.append(DatasetDagRunQueue.dataset_id == dataset_id)
if uri is not None:
where_clause.append(
DatasetDagRunQueue.dataset_id.in_(
select(DatasetModel.id).where(DatasetModel.uri == uri),
),
)
if before is not None:
where_clause.append(DatasetDagRunQueue.created_at < format_datetime(before))
if permitted_dag_ids is not None:
where_clause.append(DatasetDagRunQueue.target_dag_id.in_(permitted_dag_ids))
return where_clause


@security.requires_access_dataset("GET")
@security.requires_access_dag("GET")
@provide_session
def get_dag_dataset_queued_event(
*, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get a queued Dataset event for a DAG."""
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, uri=uri, before=before)
ddrq = session.scalar(
select(DatasetDagRunQueue)
.join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id)
.where(*where_clause)
)
if ddrq is None:
raise NotFound(
"Queue event not found",
detail=f"Queue event with dag_id: `{dag_id}` and dataset uri: `{uri}` was not found",
)
queued_event = {"created_at": ddrq.created_at, "dag_id": dag_id, "uri": uri}
return queued_event_schema.dump(queued_event)


@security.requires_access_dataset("DELETE")
@security.requires_access_dag("GET")
@provide_session
def delete_dag_dataset_queued_event(
*, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete a queued Dataset event for a DAG."""
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, uri=uri, before=before)
delete_stmt = (
delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch")
)
result = session.execute(delete_stmt)
if result.rowcount > 0:
return NoContent, HTTPStatus.NO_CONTENT
raise NotFound(
"Queue event not found",
detail=f"Queue event with dag_id: `{dag_id}` and dataset uri: `{uri}` was not found",
)


@security.requires_access_dataset("GET")
@security.requires_access_dag("GET")
@provide_session
def get_dag_dataset_queued_events(
*, dag_id: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get queued Dataset events for a DAG."""
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before)
query = (
select(DatasetDagRunQueue, DatasetModel.uri)
.join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id)
.where(*where_clause)
)
result = session.execute(query).all()
total_entries = get_query_count(query, session=session)
if not result:
raise NotFound(
"Queue event not found",
detail=f"Queue event with dag_id: `{dag_id}` was not found",
)
queued_events = [
QueuedEvent(created_at=ddrq.created_at, dag_id=ddrq.target_dag_id, uri=uri) for ddrq, uri in result
]
return queued_event_collection_schema.dump(
QueuedEventCollection(queued_events=queued_events, total_entries=total_entries)
)


@security.requires_access_dataset("DELETE")
@security.requires_access_dag("GET")
@provide_session
def delete_dag_dataset_queued_events(
*, dag_id: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued Dataset events for a DAG."""
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before)
delete_stmt = delete(DatasetDagRunQueue).where(*where_clause)
result = session.execute(delete_stmt)
if result.rowcount > 0:
return NoContent, HTTPStatus.NO_CONTENT

raise NotFound(
"Queue event not found",
detail=f"Queue event with dag_id: `{dag_id}` was not found",
)


@security.requires_access_dataset("GET")
@provide_session
def get_dataset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get queued Dataset events for a Dataset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
query = (
select(DatasetDagRunQueue, DatasetModel.uri)
.join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id)
.where(*where_clause)
)
total_entries = get_query_count(query, session=session)
result = session.execute(query).all()
if total_entries > 0:
queued_events = [
QueuedEvent(created_at=ddrq.created_at, dag_id=ddrq.target_dag_id, uri=uri)
for ddrq, uri in result
]
return queued_event_collection_schema.dump(
QueuedEventCollection(queued_events=queued_events, total_entries=total_entries)
)
raise NotFound(
"Queue event not found",
detail=f"Queue event with dataset uri: `{uri}` was not found",
)


@security.requires_access_dataset("DELETE")
@provide_session
def delete_dataset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued Dataset events for a Dataset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
delete_stmt = (
delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch")
)

result = session.execute(delete_stmt)
if result.rowcount > 0:
return NoContent, HTTPStatus.NO_CONTENT
raise NotFound(
"Queue event not found",
detail=f"Queue event with dataset uri: `{uri}` was not found",
)

0 comments on commit 16d2671

Please sign in to comment.