Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new flows pagination endpoint #13846

Merged
merged 5 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions src/prefect/server/api/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
status,
)
from fastapi.responses import ORJSONResponse, PlainTextResponse
from pydantic import BaseModel
from pydantic_extra_types.pendulum_dt import DateTime
from sqlalchemy.exc import IntegrityError

Expand All @@ -39,7 +38,10 @@
from prefect.server.orchestration import dependencies as orchestration_dependencies
from prefect.server.orchestration.policies import BaseOrchestrationPolicy
from prefect.server.schemas.graph import Graph
from prefect.server.schemas.responses import OrchestrationResult
from prefect.server.schemas.responses import (
FlowRunPaginationResponse,
OrchestrationResult,
)
from prefect.server.utilities.server import PrefectRouter
from prefect.utilities import schema_tools

Expand Down Expand Up @@ -703,14 +705,6 @@ async def delete_flow_run_input(
)


class FlowRunPaginationResponse(BaseModel):
results: list[schemas.responses.FlowRunResponse]
count: int
limit: int
pages: int
page: int


@router.post("/paginate", response_class=ORJSONResponse)
async def read_flow_runs(
sort: schemas.sorting.FlowRunSort = Body(schemas.sorting.FlowRunSort.ID_DESC),
Expand All @@ -723,7 +717,7 @@ async def read_flow_runs(
work_pools: Optional[schemas.filters.WorkPoolFilter] = None,
work_pool_queues: Optional[schemas.filters.WorkQueueFilter] = None,
db: PrefectDBInterface = Depends(provide_database_interface),
) -> List[schemas.responses.FlowRunResponse]:
) -> FlowRunPaginationResponse:
"""
Pagination query for flow runs.
"""
Expand Down
51 changes: 50 additions & 1 deletion src/prefect/server/api/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Routes for interacting with flow objects.
"""

from typing import List
from typing import List, Optional
from uuid import UUID

import pendulum
Expand All @@ -14,6 +14,7 @@
import prefect.server.schemas as schemas
from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.schemas.responses import FlowPaginationResponse
from prefect.server.utilities.server import PrefectRouter

router = PrefectRouter(prefix="/flows", tags=["Flows"])
Expand Down Expand Up @@ -160,3 +161,51 @@ async def delete_flow(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Flow not found"
)


@router.post("/paginate")
async def paginate_flows(
limit: int = dependencies.LimitBody(),
page: int = Body(1, ge=1),
flows: Optional[schemas.filters.FlowFilter] = None,
flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
task_runs: Optional[schemas.filters.TaskRunFilter] = None,
deployments: Optional[schemas.filters.DeploymentFilter] = None,
work_pools: Optional[schemas.filters.WorkPoolFilter] = None,
sort: schemas.sorting.FlowSort = Body(schemas.sorting.FlowSort.NAME_ASC),
db: PrefectDBInterface = Depends(provide_database_interface),
) -> FlowPaginationResponse:
"""
Pagination query for flows.
"""
offset = (page - 1) * limit

async with db.session_context() as session:
results = await models.flows.read_flows(
session=session,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
work_pool_filter=work_pools,
sort=sort,
offset=offset,
limit=limit,
)

count = await models.flows.count_flows(
session=session,
flow_filter=flows,
flow_run_filter=flow_runs,
task_run_filter=task_runs,
deployment_filter=deployments,
work_pool_filter=work_pools,
)

return FlowPaginationResponse(
results=results,
count=count,
limit=limit,
pages=(count + limit - 1) // limit,
page=page,
)
18 changes: 17 additions & 1 deletion src/prefect/server/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from uuid import UUID

import pendulum
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_extra_types.pendulum_dt import DateTime
from typing_extensions import Literal, Self

Expand Down Expand Up @@ -571,3 +571,19 @@ class GlobalConcurrencyLimitResponse(ORMBaseModel):
default=2.0,
description="The decay rate for active slots when used as a rate limit.",
)


class FlowPaginationResponse(BaseModel):
results: list[schemas.core.Flow]
count: int
limit: int
pages: int
page: int


class FlowRunPaginationResponse(BaseModel):
results: list[FlowRunResponse]
count: int
limit: int
pages: int
page: int
241 changes: 241 additions & 0 deletions tests/server/orchestration/api/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,244 @@ async def test_delete_flow(self, client):
async def test_delete_flow_returns_404_if_does_not_exist(self, client):
response = await client.delete(f"/flows/{uuid4()}")
assert response.status_code == status.HTTP_404_NOT_FOUND


class TestPaginateFlows:
@pytest.fixture
async def flows(self, client):
await client.post("/flows/", json={"name": "my-flow-1"})
await client.post("/flows/", json={"name": "my-flow-2"})

@pytest.mark.usefixtures("flows")
async def test_paginate_flows(self, client):
response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK

json = response.json()

assert len(json["results"]) == 2
assert json["page"] == 1
assert json["pages"] == 1
assert json["count"] == 2

@pytest.mark.usefixtures("flows")
async def test_paginate_flows_applies_limit(self, client):
response = await client.post("/flows/paginate", json=dict(limit=1))
assert response.status_code == status.HTTP_200_OK

json = response.json()

assert len(json["results"]) == 1
assert json["page"] == 1
assert json["pages"] == 2
assert json["count"] == 2

async def test_paginate_flows_applies_flow_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
await session.commit()

flow_filter = dict(
flows=schemas.filters.FlowFilter(
name=schemas.filters.FlowFilterName(any_=["my-flow-1"])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_flow_run_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
flow_run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.actions.FlowRunCreate(flow_id=flow_1.id),
)
await session.commit()

flow_filter = dict(
flow_runs=schemas.filters.FlowRunFilter(
id=schemas.filters.FlowRunFilterId(any_=[flow_run_1.id])
).model_dump(mode="json")
)

response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_task_run_filter(self, client, session):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
flow_run_1 = await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.actions.FlowRunCreate(flow_id=flow_1.id),
)
task_run_1 = await models.task_runs.create_task_run(
session=session,
task_run=schemas.actions.TaskRunCreate(
flow_run_id=flow_run_1.id,
task_key="my-key",
dynamic_key="0",
),
)
await session.commit()

flow_filter = dict(
task_runs=schemas.filters.TaskRunFilter(
id=schemas.filters.TaskRunFilterId(any_=[task_run_1.id])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=flow_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_work_pool(self, client, session, work_pool):
flow_1 = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="my-flow-1", tags=["db", "blue"]),
)
await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="my-flow-2", tags=["db"])
)
await session.commit()

response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK
assert len(response.json()["results"]) == 2

# work queue
work_queue = await models.workers.create_work_queue(
session=session,
work_pool_id=work_pool.id,
work_queue=schemas.actions.WorkQueueCreate(name="test-queue"), # type: ignore
)
# deployment
await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="My Deployment X",
manifest_path="file.json",
flow_id=flow_1.id,
is_schedule_active=True,
work_queue_id=work_queue.id,
),
)
await session.commit()

work_pool_filter = dict(
work_pools=schemas.filters.WorkPoolFilter(
id=schemas.filters.WorkPoolFilterId(any_=[work_pool.id])
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=work_pool_filter)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == flow_1.id

async def test_paginate_flows_applies_deployment_is_null(self, client, session):
undeployed_flow = await models.flows.create_flow(
session=session,
flow=schemas.core.Flow(name="undeployment_flow"),
)
deployed_flow = await models.flows.create_flow(
session=session, flow=schemas.core.Flow(name="deployed_flow")
)
await session.commit()

await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="Mr. Deployment",
manifest_path="file.json",
flow_id=deployed_flow.id,
),
)
await session.commit()

deployment_filter_isnull = dict(
flows=schemas.filters.FlowFilter(
deployment=schemas.filters.FlowFilterDeployment(is_null_=True)
).model_dump(mode="json")
)
response = await client.post("/flows/paginate", json=deployment_filter_isnull)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == undeployed_flow.id

deployment_filter_not_isnull = dict(
flows=schemas.filters.FlowFilter(
deployment=schemas.filters.FlowFilterDeployment(is_null_=False)
).model_dump(mode="json")
)
response = await client.post(
"/flows/paginate", json=deployment_filter_not_isnull
)
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert UUID(results[0]["id"]) == deployed_flow.id

async def test_paginate_flows_page(self, flows, client):
# right now this works because flows are ordered by name
# by default, when ordering is actually implemented, this test
# should be re-written
response = await client.post("/flows/paginate", json=dict(page=2, limit=1))
assert response.status_code == status.HTTP_200_OK

results = response.json()["results"]

assert len(results) == 1
assert results[0]["name"] == "my-flow-2"

async def test_paginate_flows_sort(self, flows, client):
response = await client.post(
"/flows/paginate", json=dict(sort=schemas.sorting.FlowSort.NAME_ASC)
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["results"][0]["name"] == "my-flow-1"

response_desc = await client.post(
"/flows/paginate", json=dict(sort=schemas.sorting.FlowSort.NAME_DESC)
)
assert response_desc.status_code == status.HTTP_200_OK
assert response_desc.json()["results"][0]["name"] == "my-flow-2"

async def test_read_flows_returns_empty_list(self, client):
response = await client.post("/flows/paginate")
assert response.status_code == status.HTTP_200_OK
assert response.json()["results"] == []
Loading