Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 133 additions & 89 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import json
import logging
import typing
from typing import Any, Optional
from typing import Any, Final, Optional

import sqlalchemy as sql
from sqlalchemy import orm

from . import backend_types_sql as bts
from . import component_structures as structures
from . import errors

if typing.TYPE_CHECKING:
from cloud_pipelines.orchestration.storage_providers import (
Expand All @@ -26,10 +33,9 @@ def _get_current_time() -> datetime.datetime:
return datetime.datetime.now(tz=datetime.timezone.utc)


from . import component_structures as structures
from . import backend_types_sql as bts
from . import errors
from .errors import ItemNotFoundError
_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset"
_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter"
_DEFAULT_PAGE_SIZE: Final[int] = 10
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no indication that these constants belong to the PipelineRunsApiService.
If these are only for PipelineRuns then this should be indicate in the name.
Preference: Move these constance to the PipelineRunsApiService_Sql class.



# ==== PipelineJobService
Expand Down Expand Up @@ -65,10 +71,6 @@ class ListPipelineJobsResponse:
next_page_token: str | None = None


import sqlalchemy as sql
from sqlalchemy import orm


class PipelineRunsApiService_Sql:
PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"

Expand Down Expand Up @@ -116,7 +118,7 @@ def create(
def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for switching to module imports.
(AFAIR, the reason for this class import was that the error classes were originally declared in this module, but were then extracted into the errors module.)

return PipelineRunResponse.from_db(pipeline_run)

def terminate(
Expand All @@ -128,7 +130,7 @@ def terminate(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (terminated_by != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be terminated by {terminated_by}"
Expand Down Expand Up @@ -166,98 +168,86 @@ def list(
*,
session: orm.Session,
page_token: str | None = None,
# page_size: int = 10,
filter: str | None = None,
current_user: str | None = None,
include_pipeline_names: bool = False,
include_execution_stats: bool = False,
) -> ListPipelineJobsResponse:
page_token_dict = _decode_page_token(page_token)
OFFSET_KEY = "offset"
offset = page_token_dict.get(OFFSET_KEY, 0)
page_size = 10

FILTER_KEY = "filter"
if page_token:
filter = page_token_dict.get(FILTER_KEY, None)
where_clauses = []
parsed_filter = _parse_filter(filter) if filter else {}
for key, value in parsed_filter.items():
if key == "_text":
raise NotImplementedError("Text search is not implemented yet.")
elif key == "created_by":
if value == "me":
if current_user is None:
# raise ApiServiceError(
# f"The `created_by:me` filter requires `current_user`."
# )
current_user = ""
value = current_user
# TODO: Maybe make this a bit more robust.
# We need to change the filter since it goes into the next_page_token.
filter = filter.replace(
"created_by:me", f"created_by:{current_user}"
)
if value:
where_clauses.append(bts.PipelineRun.created_by == value)
else:
where_clauses.append(bts.PipelineRun.created_by == None)
else:
raise NotImplementedError(f"Unsupported filter {filter}.")
filter_value, offset = _resolve_filter_value(
filter=filter,
page_token=page_token,
)
where_clauses, next_page_filter_value = _build_filter_where_clauses(
filter_value=filter_value,
current_user=current_user,
)

pipeline_runs = list(
session.scalars(
sql.select(bts.PipelineRun)
.where(*where_clauses)
.order_by(bts.PipelineRun.created_at.desc())
.offset(offset)
.limit(page_size)
.limit(_DEFAULT_PAGE_SIZE)
).all()
)
next_page_offset = offset + page_size
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
next_page_offset = offset + _DEFAULT_PAGE_SIZE
next_page_token_dict = {
_PAGE_TOKEN_OFFSET_KEY: next_page_offset,
_PAGE_TOKEN_FILTER_KEY: next_page_filter_value,
}
next_page_token = _encode_page_token(next_page_token_dict)
if len(pipeline_runs) < page_size:
if len(pipeline_runs) < _DEFAULT_PAGE_SIZE:
next_page_token = None

def create_pipeline_run_response(
pipeline_run: bts.PipelineRun,
) -> PipelineRunResponse:
response = PipelineRunResponse.from_db(pipeline_run)
if include_pipeline_names:
pipeline_name = None
extra_data = pipeline_run.extra_data or {}
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
else:
execution_node = session.get(
bts.ExecutionNode, pipeline_run.root_execution_id
)
if execution_node:
task_spec = structures.TaskSpec.from_json_dict(
execution_node.task_spec
)
component_spec = task_spec.component_ref.spec
if component_spec:
pipeline_name = component_spec.name
response.pipeline_name = pipeline_name
if include_execution_stats:
execution_status_stats = self._calculate_execution_status_stats(
session=session, root_execution_id=pipeline_run.root_execution_id
)
response.execution_status_stats = {
status.value: count
for status, count in execution_status_stats.items()
}
return response

return ListPipelineJobsResponse(
pipeline_runs=[
create_pipeline_run_response(pipeline_run)
self._create_pipeline_run_response(
session=session,
pipeline_run=pipeline_run,
include_pipeline_names=include_pipeline_names,
include_execution_stats=include_execution_stats,
)
for pipeline_run in pipeline_runs
],
next_page_token=next_page_token,
)

def _create_pipeline_run_response(
self,
*,
session: orm.Session,
pipeline_run: bts.PipelineRun,
include_pipeline_names: bool,
include_execution_stats: bool,
) -> PipelineRunResponse:
response = PipelineRunResponse.from_db(pipeline_run)
if include_pipeline_names:
pipeline_name = None
extra_data = pipeline_run.extra_data or {}
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
else:
execution_node = session.get(
bts.ExecutionNode, pipeline_run.root_execution_id
)
if execution_node:
task_spec = structures.TaskSpec.from_json_dict(
execution_node.task_spec
)
component_spec = task_spec.component_ref.spec
if component_spec:
pipeline_name = component_spec.name
response.pipeline_name = pipeline_name
if include_execution_stats:
execution_status_stats = self._calculate_execution_status_stats(
session=session, root_execution_id=pipeline_run.root_execution_id
)
response.execution_status_stats = {
status.value: count for status, count in execution_status_stats.items()
}
return response

def _calculate_execution_status_stats(
self, session: orm.Session, root_execution_id: bts.IdType
) -> dict[bts.ContainerExecutionStatus, int]:
Expand Down Expand Up @@ -316,7 +306,7 @@ def set_annotation(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (user_name != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
Expand All @@ -338,7 +328,7 @@ def delete_annotation(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (user_name != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
Expand All @@ -349,6 +339,58 @@ def delete_annotation(
session.commit()


def _resolve_filter_value(
*,
filter: str | None,
page_token: str | None,
) -> tuple[str | None, int]:
"""Decode page_token and return the effective (filter_value, offset).

If a page_token is present, its stored filter takes precedence over the
raw filter parameter (the token carries the resolved filter forward across pages).
"""
page_token_dict = _decode_page_token(page_token)
offset = page_token_dict.get(_PAGE_TOKEN_OFFSET_KEY, 0)
if page_token:
filter = page_token_dict.get(_PAGE_TOKEN_FILTER_KEY, None)
return filter, offset


def _build_filter_where_clauses(
*,
filter_value: str | None,
current_user: str | None,
) -> tuple[list[sql.ColumnElement], str | None]:
"""Parse a filter string into SQLAlchemy WHERE clauses.

Returns (where_clauses, next_page_filter_value). The second value is the
filter string with shorthand values resolved (e.g. "created_by:me" becomes
"created_by:alice@example.com") so it can be embedded in the next page token.
"""
where_clauses: list[sql.ColumnElement] = []
parsed_filter = _parse_filter(filter_value) if filter_value else {}
for key, value in parsed_filter.items():
if key == "_text":
raise NotImplementedError("Text search is not implemented yet.")
elif key == "created_by":
if value == "me":
if current_user is None:
current_user = ""
value = current_user
# TODO: Maybe make this a bit more robust.
# We need to change the filter since it goes into the next_page_token.
filter_value = filter_value.replace(
"created_by:me", f"created_by:{current_user}"
)
if value:
where_clauses.append(bts.PipelineRun.created_by == value)
else:
where_clauses.append(bts.PipelineRun.created_by == None)
else:
raise NotImplementedError(f"Unsupported filter {filter_value}.")
return where_clauses, filter_value


def _decode_page_token(page_token: str) -> dict[str, Any]:
return json.loads(base64.b64decode(page_token)) if page_token else {}

Expand Down Expand Up @@ -524,7 +566,7 @@ class ExecutionNodesApiService_Sql:
def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse:
execution_node = session.get(bts.ExecutionNode, id)
if execution_node is None:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")

parent_pipeline_run_id = session.scalar(
sql.select(bts.PipelineRun.id).where(
Expand Down Expand Up @@ -676,7 +718,7 @@ def get_container_execution_state(
) -> GetContainerExecutionStateResponse:
execution = session.get(bts.ExecutionNode, id)
if not execution:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
container_execution = execution.container_execution
if not container_execution:
raise RuntimeError(
Expand All @@ -696,7 +738,7 @@ def get_artifacts(
if not session.scalar(
sql.select(sql.exists().where(bts.ExecutionNode.id == id))
):
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")

input_artifact_links = session.scalars(
sql.select(bts.InputArtifactLink)
Expand Down Expand Up @@ -742,7 +784,7 @@ def get_container_execution_log(
) -> GetContainerExecutionLogResponse:
execution = session.get(bts.ExecutionNode, id)
if not execution:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
container_execution = execution.container_execution
execution_extra_data = execution.extra_data or {}
system_error_exception_full = execution_extra_data.get(
Expand Down Expand Up @@ -829,7 +871,9 @@ def stream_container_execution_log(
) -> typing.Iterator[str]:
execution = session.get(bts.ExecutionNode, execution_id)
if not execution:
raise ItemNotFoundError(f"Execution with {execution_id=} does not exist.")
raise errors.ItemNotFoundError(
f"Execution with {execution_id=} does not exist."
)
container_execution = execution.container_execution
if not container_execution:
raise ApiServiceError(
Expand Down Expand Up @@ -970,7 +1014,7 @@ class ArtifactNodesApiService_Sql:
def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse:
artifact_node = session.get(bts.ArtifactNode, id)
if artifact_node is None:
raise ItemNotFoundError(f"Artifact with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Artifact with {id=} does not exist.")
artifact_data = artifact_node.artifact_data
result = GetArtifactInfoResponse(id=artifact_node.id)
if artifact_data:
Expand All @@ -986,7 +1030,7 @@ def get_signed_artifact_url(
.where(bts.ArtifactNode.id == id)
)
if not artifact_data:
raise ItemNotFoundError(f"Artifact node with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Artifact node with {id=} does not exist.")
if not artifact_data.uri:
raise ValueError(f"Artifact node with {id=} does not have artifact URI.")
if artifact_data.is_dir:
Expand Down
Loading