Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: API for task run counts by state #12244

Merged
merged 8 commits into from
Mar 14, 2024
19 changes: 19 additions & 0 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,22 @@ async def read_dashboard_task_run_counts(
buckets[index].failed = row.failed_count

return buckets


@router.post("/count")
async def read_task_run_counts_by_state(
flows: Optional[schemas.filters.FlowFilter] = None,
flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
task_runs: Optional[schemas.filters.TaskRunFilter] = None,
deployments: Optional[schemas.filters.DeploymentFilter] = None,
db: PrefectDBInterface = Depends(provide_database_interface),
) -> schemas.states.CountByState:
async with db.session_context(begin_transaction=False) as session:
return await models.task_runs.count_task_runs_by_state(
session=session,
db=db,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
)
49 changes: 49 additions & 0 deletions src/prefect/server/models/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import contextlib
from typing import Optional
from uuid import UUID

import pendulum
Expand Down Expand Up @@ -334,6 +335,54 @@ async def count_task_runs(
return result.scalar()


async def count_task_runs_by_state(
session: AsyncSession,
db: PrefectDBInterface,
flow_filter: Optional[schemas.filters.FlowFilter] = None,
flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None,
task_run_filter: Optional[schemas.filters.TaskRunFilter] = None,
deployment_filter: Optional[schemas.filters.DeploymentFilter] = None,
) -> schemas.states.CountByState:
"""
Count task runs by state.

Args:
session: a database session
flow_filter: only count task runs whose flows match these filters
flow_run_filter: only count task runs whose flow runs match these filters
task_run_filter: only count task runs that match these filters
deployment_filter: only count task runs whose deployments match these filters
Returns:
schemas.states.CountByState: count of task runs by state
"""

base_query = (
select(
db.TaskRun.state_type,
sa.func.count(sa.text("*")).label("count"),
)
.select_from(db.TaskRun)
.group_by(db.TaskRun.state_type)
)

query = await _apply_task_run_filters(
base_query,
flow_filter=flow_filter,
flow_run_filter=flow_run_filter,
task_run_filter=task_run_filter,
deployment_filter=deployment_filter,
)

result = await session.execute(query)

counts = schemas.states.CountByState()

for row in result:
setattr(counts, row.state_type, row.count)

return counts


@inject_db
async def delete_task_run(
session: sa.orm.Session, task_run_id: UUID, db: PrefectDBInterface
Expand Down
19 changes: 19 additions & 0 deletions src/prefect/server/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ class StateType(AutoEnum):
CANCELLING = AutoEnum.auto()


class CountByState(PrefectBaseModel):
COMPLETED: int = Field(default=0)
PENDING: int = Field(default=0)
RUNNING: int = Field(default=0)
FAILED: int = Field(default=0)
CANCELLED: int = Field(default=0)
CRASHED: int = Field(default=0)
PAUSED: int = Field(default=0)
CANCELLING: int = Field(default=0)
SCHEDULED: int = Field(default=0)

@validator("*")
@classmethod
def check_key(cls, value, field):
if field.name not in StateType.__members__:
raise ValueError(f"{field.name} is not a valid StateType")
return value


TERMINAL_STATES = {
StateType.COMPLETED,
StateType.CANCELLED,
Expand Down
190 changes: 189 additions & 1 deletion tests/server/api/ui/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from prefect.server import models
from prefect.server.api.ui.task_runs import TaskRunCount
from prefect.server.schemas import core, filters, states
from prefect.server.schemas import actions, core, filters, states
from prefect.server.utilities.schemas import DateTimeTZ


Expand Down Expand Up @@ -128,3 +128,191 @@ async def test_returns_completed_and_failed_counts(
TaskRunCount(completed=2, failed=3),
TaskRunCount(completed=2, failed=2),
]


class TestReadTaskRunCountsByState:
@pytest.fixture
def url(self) -> str:
return "/ui/task_runs/count"

@pytest.fixture
async def create_flow_runs(
self,
session: AsyncSession,
flow,
):
run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Completed(),
),
)

run_2 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Failed(),
),
)

run_3 = await models.flow_runs.create_flow_run(
session=session,
flow_run=actions.FlowRunCreate(
flow_id=flow.id,
state=states.Pending(),
),
)

await session.commit()

return [run_1, run_2, run_3]

@pytest.fixture
async def create_task_runs(
self,
session: AsyncSession,
flow_run,
create_flow_runs,
):
task_runs_per_flow_run = 27
now = cast(DateTimeTZ, pendulum.datetime(2023, 6, 1, 18, tz="UTC"))
serinamarie marked this conversation as resolved.
Show resolved Hide resolved

for flow_run in create_flow_runs:
for i in range(task_runs_per_flow_run):
# This means that each flow run should have task runs with the following states:
# 9 completed, 9 failed, 3 scheduled, 3 running, 2 cancelled, 1 crashed, 1 paused, 1 cancelling, 1 pending
if i < 9:
state_type = states.StateType.COMPLETED
state_name = "Completed"
elif i < 15:
state_type = states.StateType.FAILED
state_name = "Failed"
elif i < 18:
state_type = states.StateType.SCHEDULED
state_name = "Scheduled"
elif i < 21:
state_type = states.StateType.RUNNING
state_name = "Running"
elif i < 23:
state_type = states.StateType.CANCELLED
state_name = "Cancelled"
elif i < 24:
state_type = states.StateType.CRASHED
state_name = "Crashed"
elif i < 25:
state_type = states.StateType.PAUSED
state_name = "Paused"
elif i < 26:
state_type = states.StateType.CANCELLING
state_name = "Cancelling"
else:
state_type = states.StateType.PENDING
state_name = "Pending"

await models.task_runs.create_task_run(
session=session,
task_run=core.TaskRun(
flow_run_id=flow_run.id,
task_key=f"task-{i}",
dynamic_key=str(i),
state_type=state_type,
state_name=state_name,
start_time=now,
end_time=now,
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why the tests are failing, is that the case locally?

Do we need to commit these, e.g.
await session.commit()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oo maybe that's the missing piece


await session.commit()

async def test_returns_all_state_types(
self,
url: str,
client: AsyncClient,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()

assert set(counts.keys()) == set(states.StateType.__members__.keys())

async def test_none(
self,
url: str,
client: AsyncClient,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()
assert counts == {
"COMPLETED": 0,
"FAILED": 0,
"PENDING": 0,
"RUNNING": 0,
"CANCELLED": 0,
"CRASHED": 0,
"PAUSED": 0,
"CANCELLING": 0,
"SCHEDULED": 0,
}

async def test_returns_counts(
self,
url: str,
client: AsyncClient,
create_task_runs,
):
response = await client.post(url)
assert response.status_code == 200

counts = response.json()

assert counts == {
"COMPLETED": 9 * 3,
"FAILED": 6 * 3,
"PENDING": 1 * 3,
"RUNNING": 3 * 3,
"CANCELLED": 2 * 3,
"CRASHED": 1 * 3,
"PAUSED": 1 * 3,
"CANCELLING": 1 * 3,
"SCHEDULED": 3 * 3,
}

async def test_returns_counts_with_filter(
self,
url: str,
client: AsyncClient,
create_task_runs,
):
response = await client.post(
url,
json={
"flow_runs": filters.FlowRunFilter(
state=filters.FlowRunFilterState(
type=filters.FlowRunFilterStateType(
any_=[states.StateType.COMPLETED]
)
)
).dict(json_compatible=True)
},
)

assert response.status_code == 200

counts = response.json()

assert counts == {
"COMPLETED": 9,
"FAILED": 6,
"PENDING": 1,
"RUNNING": 3,
"CANCELLED": 2,
"CRASHED": 1,
"PAUSED": 1,
"CANCELLING": 1,
"SCHEDULED": 3,
}
Loading