Skip to content

Commit

Permalink
Add new flows pagination endpoint (#13846)
Browse files Browse the repository at this point in the history
  • Loading branch information
pleek91 authored Jun 7, 2024
1 parent 5879df3 commit 239701f
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 13 deletions.
16 changes: 5 additions & 11 deletions src/prefect/server/api/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
status,
)
from fastapi.responses import ORJSONResponse, PlainTextResponse
from pydantic import BaseModel
from pydantic_extra_types.pendulum_dt import DateTime
from sqlalchemy.exc import IntegrityError

Expand All @@ -39,7 +38,10 @@
from prefect.server.orchestration import dependencies as orchestration_dependencies
from prefect.server.orchestration.policies import BaseOrchestrationPolicy
from prefect.server.schemas.graph import Graph
from prefect.server.schemas.responses import OrchestrationResult
from prefect.server.schemas.responses import (
FlowRunPaginationResponse,
OrchestrationResult,
)
from prefect.server.utilities.server import PrefectRouter
from prefect.utilities import schema_tools

Expand Down Expand Up @@ -703,14 +705,6 @@ async def delete_flow_run_input(
)


class FlowRunPaginationResponse(BaseModel):
results: list[schemas.responses.FlowRunResponse]
count: int
limit: int
pages: int
page: int


@router.post("/paginate", response_class=ORJSONResponse)
async def read_flow_runs(
sort: schemas.sorting.FlowRunSort = Body(schemas.sorting.FlowRunSort.ID_DESC),
Expand All @@ -723,7 +717,7 @@ async def read_flow_runs(
work_pools: Optional[schemas.filters.WorkPoolFilter] = None,
work_pool_queues: Optional[schemas.filters.WorkQueueFilter] = None,
db: PrefectDBInterface = Depends(provide_database_interface),
) -> List[schemas.responses.FlowRunResponse]:
) -> FlowRunPaginationResponse:
"""
Pagination query for flow runs.
"""
Expand Down
51 changes: 50 additions & 1 deletion src/prefect/server/api/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Routes for interacting with flow objects.
"""

from typing import List
from typing import List, Optional
from uuid import UUID

import pendulum
Expand All @@ -14,6 +14,7 @@
import prefect.server.schemas as schemas
from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.schemas.responses import FlowPaginationResponse
from prefect.server.utilities.server import PrefectRouter

router = PrefectRouter(prefix="/flows", tags=["Flows"])
Expand Down Expand Up @@ -160,3 +161,51 @@ async def delete_flow(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found"
)


@router.post("/paginate")
async def paginate_flows(
limit: int = dependencies.LimitBody(),
page: int = Body(1, ge=1),
flows: Optional[schemas.filters.FlowFilter] = None,
flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
task_runs: Optional[schemas.filters.TaskRunFilter] = None,
deployments: Optional[schemas.filters.DeploymentFilter] = None,
work_pools: Optional[schemas.filters.WorkPoolFilter] = None,
sort: schemas.sorting.FlowSort = Body(schemas.sorting.FlowSort.NAME_ASC),
db: PrefectDBInterface = Depends(provide_database_interface),
) -> FlowPaginationResponse:
"""
Pagination query for flows.
"""
offset = (page - 1) * limit

async with db.session_context() as session:
results = await models.flows.read_flows(
session=session,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
work_pool_filter=work_pools,
sort=sort,
offset=offset,
limit=limit,
)

count = await models.flows.count_flows(
session=session,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
work_pool_filter=work_pools,
)

return FlowPaginationResponse(
results=results,
count=count,
limit=limit,
pages=(count + limit - 1) // limit,
page=page,
)
18 changes: 17 additions & 1 deletion src/prefect/server/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from uuid import UUID

import pendulum
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_extra_types.pendulum_dt import DateTime
from typing_extensions import Literal, Self

Expand Down Expand Up @@ -571,3 +571,19 @@ class GlobalConcurrencyLimitResponse(ORMBaseModel):
default=2.0,
description="The decay rate for active slots when used as a rate limit.",
)


class FlowPaginationResponse(BaseModel):
results: list[schemas.core.Flow]
count: int
limit: int
pages: int
page: int


class FlowRunPaginationResponse(BaseModel):
results: list[FlowRunResponse]
count: int
limit: int
pages: int
page: int
241 changes: 241 additions & 0 deletions tests/server/orchestration/api/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,244 @@ async def test_delete_flow(self, client):
async def test_delete_flow_returns_404_if_does_not_exist(self, client):
response = await client.delete(f"/flows/{uuid4()}")
assert response.status_code == status.HTTP_404_NOT_FOUND


class TestPaginateFlows:
@pytest.fixture
async def flows(self, client):
await client.post("/flows/", json={"name": "my-flow-1"})
await client.post("/flows/", json={"name": "my-flow-2"})

@pytest.mark.usefixtures("flows")
async def test_paginate_flows(self, client):
response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK

json = response.json()

assert len(json["results"]) == 2
assert json["page"] == 1
assert json["pages"] == 1
assert json["count"] == 2

@pytest.mark.usefixtures("flows")
async def test_paginate_flows_applies_limit(self, client):
response = await client.post("/flows/paginate", json=dict(limit=1))
assert response.status_code == status.HTTP_200_OK

json = response.json()

assert len(json["results"]) == 1
assert json["page"] == 1
assert json["pages"] == 2
assert json["count"] == 2

async def test_paginate_flows_applies_flow_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
await session.commit()

flow_filter = dict(
flows=schemas.filters.FlowFilter(
name=schemas.filters.FlowFilterName(any_=["my-flow-1"])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_flow_run_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
flow_run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.actions.FlowRunCreate(flow_id=flow_1.id),
)
await session.commit()

flow_filter = dict(
flow_runs=schemas.filters.FlowRunFilter(
id=schemas.filters.FlowRunFilterId(any_=[flow_run_1.id])
).model_dump(mode="json")
)

response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_task_run_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
flow_run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.actions.FlowRunCreate(flow_id=flow_1.id),
)
task_run_1 = await models.task_runs.create_task_run(
session=session,
task_run=schemas.actions.TaskRunCreate(
flow_run_id=flow_run_1.id,
task_key="my-key",
dynamic_key="0",
),
)
await session.commit()

flow_filter = dict(
task_runs=schemas.filters.TaskRunFilter(
id=schemas.filters.TaskRunFilterId(any_=[task_run_1.id])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_work_pool(self, client, session, work_pool):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
await session.commit()

response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["results"]) == 2

# work queue
work_queue = await models.workers.create_work_queue(
session=session,
work_pool_id=work_pool.id,
work_queue=schemas.actions.WorkQueueCreate(name="test-queue"), # type: ignore
)
# deployment
await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="My Deployment X",
manifest_path="file.json",
flow_id=flow_1.id,
is_schedule_active=True,
work_queue_id=work_queue.id,
),
)
await session.commit()

work_pool_filter = dict(
work_pools=schemas.filters.WorkPoolFilter(
id=schemas.filters.WorkPoolFilterId(any_=[work_pool.id])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=work_pool_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_deployment_is_null(self, client, session):
undeployed_flow = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="undeployment_flow"),
)
deployed_flow = await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="deployed_flow")
)
await session.commit()

await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="Mr. Deployment",
manifest_path="file.json",
flow_id=deployed_flow.id,
),
)
await session.commit()

deployment_filter_isnull = dict(
flows=schemas.filters.FlowFilter(
deployment=schemas.filters.FlowFilterDeployment(is_null_=True)
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=deployment_filter_isnull)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == undeployed_flow.id

deployment_filter_not_isnull = dict(
flows=schemas.filters.FlowFilter(
deployment=schemas.filters.FlowFilterDeployment(is_null_=False)
).model_dump(mode="json")
)
response = await client.post(
"/flows/paginate", json=deployment_filter_not_isnull
)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == deployed_flow.id

async def test_paginate_flows_page(self, flows, client):
# right now this works because flows are ordered by name
# by default, when ordering is actually implemented, this test
# should be re-written
response = await client.post("/flows/paginate", json=dict(page=2, limit=1))
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert results[0]["name"] == "my-flow-2"

async def test_paginate_flows_sort(self, flows, client):
response = await client.post(
"/flows/paginate", json=dict(sort=schemas.sorting.FlowSort.NAME_ASC)
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["results"][0]["name"] == "my-flow-1"

response_desc = await client.post(
"/flows/paginate", json=dict(sort=schemas.sorting.FlowSort.NAME_DESC)
)
assert response_desc.status_code == status.HTTP_200_OK
assert response_desc.json()["results"][0]["name"] == "my-flow-2"

async def test_read_flows_returns_empty_list(self, client):
response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK
assert response.json()["results"] == []

0 comments on commit 239701f

Please sign in to comment.