diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py index 8f51c5c9cfb9..86f4b37763c9 100644 --- a/src/prefect/server/api/ui/task_runs.py +++ b/src/prefect/server/api/ui/task_runs.py @@ -187,3 +187,22 @@ async def read_dashboard_task_run_counts( buckets[index].failed = row.failed_count return buckets + + +@router.post("/count") +async def read_task_run_counts_by_state( + flows: Optional[schemas.filters.FlowFilter] = None, + flow_runs: Optional[schemas.filters.FlowRunFilter] = None, + task_runs: Optional[schemas.filters.TaskRunFilter] = None, + deployments: Optional[schemas.filters.DeploymentFilter] = None, + db: PrefectDBInterface = Depends(provide_database_interface), +) -> schemas.states.CountByState: + async with db.session_context(begin_transaction=False) as session: + return await models.task_runs.count_task_runs_by_state( + session=session, + db=db, + flow_filter=flows, + flow_run_filter=flow_runs, + task_run_filter=task_runs, + deployment_filter=deployments, + ) diff --git a/src/prefect/server/models/task_runs.py b/src/prefect/server/models/task_runs.py index c6956c682fbc..8e282c9b0448 100644 --- a/src/prefect/server/models/task_runs.py +++ b/src/prefect/server/models/task_runs.py @@ -4,6 +4,7 @@ """ import contextlib +from typing import Optional from uuid import UUID import pendulum @@ -334,6 +335,54 @@ async def count_task_runs( return result.scalar() +async def count_task_runs_by_state( + session: AsyncSession, + db: PrefectDBInterface, + flow_filter: Optional[schemas.filters.FlowFilter] = None, + flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, + task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, + deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, +) -> schemas.states.CountByState: + """ + Count task runs by state. + + Args: + session: a database session + flow_filter: only count task runs whose flows match these filters + flow_run_filter: only count task runs whose flow runs match these filters + task_run_filter: only count task runs that match these filters + deployment_filter: only count task runs whose deployments match these filters + Returns: + schemas.states.CountByState: count of task runs by state + """ + + base_query = ( + select( + db.TaskRun.state_type, + sa.func.count(sa.text("*")).label("count"), + ) + .select_from(db.TaskRun) + .group_by(db.TaskRun.state_type) + ) + + query = await _apply_task_run_filters( + base_query, + flow_filter=flow_filter, + flow_run_filter=flow_run_filter, + task_run_filter=task_run_filter, + deployment_filter=deployment_filter, + ) + + result = await session.execute(query) + + counts = schemas.states.CountByState() + + for row in result: + setattr(counts, row.state_type, row.count) + + return counts + + @inject_db async def delete_task_run( session: sa.orm.Session, task_run_id: UUID, db: PrefectDBInterface diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index a3526537ad0b..a5c074d3b1bc 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -43,6 +43,25 @@ class StateType(AutoEnum): CANCELLING = AutoEnum.auto() +class CountByState(PrefectBaseModel): + COMPLETED: int = Field(default=0) + PENDING: int = Field(default=0) + RUNNING: int = Field(default=0) + FAILED: int = Field(default=0) + CANCELLED: int = Field(default=0) + CRASHED: int = Field(default=0) + PAUSED: int = Field(default=0) + CANCELLING: int = Field(default=0) + SCHEDULED: int = Field(default=0) + + @validator("*") + @classmethod + def check_key(cls, value, field): + if field.name not in StateType.__members__: + raise ValueError(f"{field.name} is not a valid StateType") + return value + + TERMINAL_STATES = { StateType.COMPLETED, StateType.CANCELLED, diff --git a/tests/server/api/ui/test_task_runs.py b/tests/server/api/ui/test_task_runs.py index 97dba731d0e6..7e48ac64c574 100644 --- a/tests/server/api/ui/test_task_runs.py +++ b/tests/server/api/ui/test_task_runs.py @@ -7,7 +7,7 @@ from prefect.server import models from prefect.server.api.ui.task_runs import TaskRunCount -from prefect.server.schemas import core, filters, states +from prefect.server.schemas import actions, core, filters, states from prefect.server.utilities.schemas import DateTimeTZ @@ -128,3 +128,191 @@ async def test_returns_completed_and_failed_counts( TaskRunCount(completed=2, failed=3), TaskRunCount(completed=2, failed=2), ] + + +class TestReadTaskRunCountsByState: + @pytest.fixture + def url(self) -> str: + return "/ui/task_runs/count" + + @pytest.fixture + async def create_flow_runs( + self, + session: AsyncSession, + flow, + ): + run_1 = await models.flow_runs.create_flow_run( + session=session, + flow_run=actions.FlowRunCreate( + flow_id=flow.id, + state=states.Completed(), + ), + ) + + run_2 = await models.flow_runs.create_flow_run( + session=session, + flow_run=actions.FlowRunCreate( + flow_id=flow.id, + state=states.Failed(), + ), + ) + + run_3 = await models.flow_runs.create_flow_run( + session=session, + flow_run=actions.FlowRunCreate( + flow_id=flow.id, + state=states.Pending(), + ), + ) + + await session.commit() + + return [run_1, run_2, run_3] + + @pytest.fixture + async def create_task_runs( + self, + session: AsyncSession, + flow_run, + create_flow_runs, + ): + task_runs_per_flow_run = 27 + now = cast(DateTimeTZ, pendulum.datetime(2023, 6, 1, 18, tz="UTC")) + + for flow_run in create_flow_runs: + for i in range(task_runs_per_flow_run): + # This means that each flow run should have task runs with the following states: + # 9 completed, 9 failed, 3 scheduled, 3 running, 2 cancelled, 1 crashed, 1 paused, 1 cancelling, 1 pending + if i < 9: + state_type = states.StateType.COMPLETED + state_name = "Completed" + elif i < 15: + state_type = states.StateType.FAILED + state_name = "Failed" + elif i < 18: + state_type = states.StateType.SCHEDULED + state_name = "Scheduled" + elif i < 21: + state_type = states.StateType.RUNNING + state_name = "Running" + elif i < 23: + state_type = states.StateType.CANCELLED + state_name = "Cancelled" + elif i < 24: + state_type = states.StateType.CRASHED + state_name = "Crashed" + elif i < 25: + state_type = states.StateType.PAUSED + state_name = "Paused" + elif i < 26: + state_type = states.StateType.CANCELLING + state_name = "Cancelling" + else: + state_type = states.StateType.PENDING + state_name = "Pending" + + await models.task_runs.create_task_run( + session=session, + task_run=core.TaskRun( + flow_run_id=flow_run.id, + task_key=f"task-{i}", + dynamic_key=str(i), + state_type=state_type, + state_name=state_name, + start_time=now, + end_time=now, + ), + ) + + await session.commit() + + async def test_returns_all_state_types( + self, + url: str, + client: AsyncClient, + ): + response = await client.post(url) + assert response.status_code == 200 + + counts = response.json() + + assert set(counts.keys()) == set(states.StateType.__members__.keys()) + + async def test_none( + self, + url: str, + client: AsyncClient, + ): + response = await client.post(url) + assert response.status_code == 200 + + counts = response.json() + assert counts == { + "COMPLETED": 0, + "FAILED": 0, + "PENDING": 0, + "RUNNING": 0, + "CANCELLED": 0, + "CRASHED": 0, + "PAUSED": 0, + "CANCELLING": 0, + "SCHEDULED": 0, + } + + async def test_returns_counts( + self, + url: str, + client: AsyncClient, + create_task_runs, + ): + response = await client.post(url) + assert response.status_code == 200 + + counts = response.json() + + assert counts == { + "COMPLETED": 9 * 3, + "FAILED": 6 * 3, + "PENDING": 1 * 3, + "RUNNING": 3 * 3, + "CANCELLED": 2 * 3, + "CRASHED": 1 * 3, + "PAUSED": 1 * 3, + "CANCELLING": 1 * 3, + "SCHEDULED": 3 * 3, + } + + async def test_returns_counts_with_filter( + self, + url: str, + client: AsyncClient, + create_task_runs, + ): + response = await client.post( + url, + json={ + "flow_runs": filters.FlowRunFilter( + state=filters.FlowRunFilterState( + type=filters.FlowRunFilterStateType( + any_=[states.StateType.COMPLETED] + ) + ) + ).dict(json_compatible=True) + }, + ) + + assert response.status_code == 200 + + counts = response.json() + + assert counts == { + "COMPLETED": 9, + "FAILED": 6, + "PENDING": 1, + "RUNNING": 3, + "CANCELLED": 2, + "CRASHED": 1, + "PAUSED": 1, + "CANCELLING": 1, + "SCHEDULED": 3, + }