Skip to content

Commit

Permalink
Feature: API for task run counts by state (#12244)
Browse files Browse the repository at this point in the history
  • Loading branch information
znicholasbrown committed Mar 14, 2024
1 parent 7cf1528 commit 7728e51
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,22 @@ async def read_dashboard_task_run_counts(
buckets[index].failed = row.failed_count

return buckets


@router.post("/count")
async def read_task_run_counts_by_state(
flows: Optional[schemas.filters.FlowFilter] = None,
flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
task_runs: Optional[schemas.filters.TaskRunFilter] = None,
deployments: Optional[schemas.filters.DeploymentFilter] = None,
db: PrefectDBInterface = Depends(provide_database_interface),
) -> schemas.states.CountByState:
async with db.session_context(begin_transaction=False) as session:
return await models.task_runs.count_task_runs_by_state(
session=session,
db=db,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
)
49 changes: 49 additions & 0 deletions src/prefect/server/models/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import contextlib
from typing import Optional
from uuid import UUID

import pendulum
Expand Down Expand Up @@ -334,6 +335,54 @@ async def count_task_runs(
return result.scalar()


async def count_task_runs_by_state(
session: AsyncSession,
db: PrefectDBInterface,
flow_filter: Optional[schemas.filters.FlowFilter] = None,
flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
) -> schemas.states.CountByState:
"""
Count task runs by state.
Args:
session: a database session
flow_filter: only count task runs whose flows match these filters
flow_run_filter: only count task runs whose flow runs match these filters
task_run_filter: only count task runs that match these filters
deployment_filter: only count task runs whose deployments match these filters
Returns:
schemas.states.CountByState: count of task runs by state
"""

base_query = (
select(
db.TaskRun.state_type,
sa.func.count(sa.text("*")).label("count"),
)
.select_from(db.TaskRun)
.group_by(db.TaskRun.state_type)
)

query = await _apply_task_run_filters(
base_query,
flow_filter=flow_filter,
flow_run_filter=flow_run_filter,
task_run_filter=task_run_filter,
deployment_filter=deployment_filter,
)

result = await session.execute(query)

counts = schemas.states.CountByState()

for row in result:
setattr(counts, row.state_type, row.count)

return counts


@inject_db
async def delete_task_run(
session: sa.orm.Session, task_run_id: UUID, db: PrefectDBInterface
Expand Down
19 changes: 19 additions & 0 deletions src/prefect/server/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ class StateType(AutoEnum):
CANCELLING = AutoEnum.auto()


class CountByState(PrefectBaseModel):
COMPLETED: int = Field(default=0)
PENDING: int = Field(default=0)
RUNNING: int = Field(default=0)
FAILED: int = Field(default=0)
CANCELLED: int = Field(default=0)
CRASHED: int = Field(default=0)
PAUSED: int = Field(default=0)
CANCELLING: int = Field(default=0)
SCHEDULED: int = Field(default=0)

@validator("*")
@classmethod
def check_key(cls, value, field):
if field.name not in StateType.__members__:
raise ValueError(f"{field.name} is not a valid StateType")
return value


TERMINAL_STATES = {
StateType.COMPLETED,
StateType.CANCELLED,
Expand Down
190 changes: 189 additions & 1 deletion tests/server/api/ui/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from prefect.server import models
from prefect.server.api.ui.task_runs import TaskRunCount
from prefect.server.schemas import core, filters, states
from prefect.server.schemas import actions, core, filters, states
from prefect.server.utilities.schemas import DateTimeTZ


Expand Down Expand Up @@ -128,3 +128,191 @@ async def test_returns_completed_and_failed_counts(
TaskRunCount(completed=2, failed=3),
TaskRunCount(completed=2, failed=2),
]


class TestReadTaskRunCountsByState:
@pytest.fixture
def url(self) -> str:
return "/ui/task_runs/count"

@pytest.fixture
async def create_flow_runs(
self,
session: AsyncSession,
flow,
):
run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Completed(),
),
)

run_2 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Failed(),
),
)

run_3 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Pending(),
),
)

await session.commit()

return [run_1, run_2, run_3]

@pytest.fixture
async def create_task_runs(
self,
session: AsyncSession,
flow_run,
create_flow_runs,
):
task_runs_per_flow_run = 27
now = cast(DateTimeTZ, pendulum.datetime(2023, 6, 1, 18, tz="UTC"))

for flow_run in create_flow_runs:
for i in range(task_runs_per_flow_run):
# This means that each flow run should have task runs with the following states:
# 9 completed, 9 failed, 3 scheduled, 3 running, 2 cancelled, 1 crashed, 1 paused, 1 cancelling, 1 pending
if i < 9:
state_type = states.StateType.COMPLETED
state_name = "Completed"
elif i < 15:
state_type = states.StateType.FAILED
state_name = "Failed"
elif i < 18:
state_type = states.StateType.SCHEDULED
state_name = "Scheduled"
elif i < 21:
state_type = states.StateType.RUNNING
state_name = "Running"
elif i < 23:
state_type = states.StateType.CANCELLED
state_name = "Cancelled"
elif i < 24:
state_type = states.StateType.CRASHED
state_name = "Crashed"
elif i < 25:
state_type = states.StateType.PAUSED
state_name = "Paused"
elif i < 26:
state_type = states.StateType.CANCELLING
state_name = "Cancelling"
else:
state_type = states.StateType.PENDING
state_name = "Pending"

await models.task_runs.create_task_run(
session=session,
task_run=core.TaskRun(
flow_run_id=flow_run.id,
task_key=f"task-{i}",
dynamic_key=str(i),
state_type=state_type,
state_name=state_name,
start_time=now,
end_time=now,
),
)

await session.commit()

async def test_returns_all_state_types(
self,
url: str,
client: AsyncClient,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()

assert set(counts.keys()) == set(states.StateType.__members__.keys())

async def test_none(
self,
url: str,
client: AsyncClient,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()
assert counts == {
"COMPLETED": 0,
"FAILED": 0,
"PENDING": 0,
"RUNNING": 0,
"CANCELLED": 0,
"CRASHED": 0,
"PAUSED": 0,
"CANCELLING": 0,
"SCHEDULED": 0,
}

async def test_returns_counts(
self,
url: str,
client: AsyncClient,
create_task_runs,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()

assert counts == {
"COMPLETED": 9 * 3,
"FAILED": 6 * 3,
"PENDING": 1 * 3,
"RUNNING": 3 * 3,
"CANCELLED": 2 * 3,
"CRASHED": 1 * 3,
"PAUSED": 1 * 3,
"CANCELLING": 1 * 3,
"SCHEDULED": 3 * 3,
}

async def test_returns_counts_with_filter(
self,
url: str,
client: AsyncClient,
create_task_runs,
):
response = await client.post(
url,
json={
"flow_runs": filters.FlowRunFilter(
state=filters.FlowRunFilterState(
type=filters.FlowRunFilterStateType(
any_=[states.StateType.COMPLETED]
)
)
).dict(json_compatible=True)
},
)

assert response.status_code == 200

counts = response.json()

assert counts == {
"COMPLETED": 9,
"FAILED": 6,
"PENDING": 1,
"RUNNING": 3,
"CANCELLED": 2,
"CRASHED": 1,
"PAUSED": 1,
"CANCELLING": 1,
"SCHEDULED": 3,
}

0 comments on commit 7728e51

Please sign in to comment.