From 9fc3134fb3547fdc14cc4f8a655b618e19e482ab Mon Sep 17 00:00:00 2001 From: Yue Chao Qin Date: Mon, 23 Feb 2026 22:39:58 -0800 Subject: [PATCH] refactor: Search pipeline run using Pydantic query parameter --- cloud_pipelines_backend/api_server_sql.py | 234 ++++++---- tests/test_api_server_sql.py | 492 +++++++++++++++++++++- 2 files changed, 628 insertions(+), 98 deletions(-) diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 9341db4..13b6c35 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -4,7 +4,14 @@ import json import logging import typing -from typing import Any, Optional +from typing import Any, Final, Optional + +import sqlalchemy as sql +from sqlalchemy import orm + +from . import backend_types_sql as bts +from . import component_structures as structures +from . import errors if typing.TYPE_CHECKING: from cloud_pipelines.orchestration.storage_providers import ( @@ -26,12 +33,6 @@ def _get_current_time() -> datetime.datetime: return datetime.datetime.now(tz=datetime.timezone.utc) -from . import component_structures as structures -from . import backend_types_sql as bts -from . import errors -from .errors import ItemNotFoundError - - # ==== PipelineJobService @dataclasses.dataclass(kw_only=True) class PipelineRunResponse: @@ -65,12 +66,11 @@ class ListPipelineJobsResponse: next_page_token: str | None = None -import sqlalchemy as sql -from sqlalchemy import orm - - class PipelineRunsApiService_Sql: - PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" + _PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" + _PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset" + _PAGE_TOKEN_FILTER_KEY: Final[str] = "filter" + _DEFAULT_PAGE_SIZE: Final[int] = 10 def create( self, @@ -104,7 +104,7 @@ def create( annotations=annotations, created_by=created_by, extra_data={ - self.PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name, + self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name, }, ) session.add(pipeline_run) @@ -116,7 +116,7 @@ def create( def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse: pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: - raise ItemNotFoundError(f"Pipeline run {id} not found.") + raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") return PipelineRunResponse.from_db(pipeline_run) def terminate( @@ -128,7 +128,7 @@ def terminate( ): pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: - raise ItemNotFoundError(f"Pipeline run {id} not found.") + raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") if not skip_user_check and (terminated_by != pipeline_run.created_by): raise errors.PermissionError( f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be terminated by {terminated_by}" @@ -166,98 +166,86 @@ def list( *, session: orm.Session, page_token: str | None = None, - # page_size: int = 10, filter: str | None = None, current_user: str | None = None, include_pipeline_names: bool = False, include_execution_stats: bool = False, ) -> ListPipelineJobsResponse: - page_token_dict = _decode_page_token(page_token) - OFFSET_KEY = "offset" - offset = page_token_dict.get(OFFSET_KEY, 0) - page_size = 10 - - FILTER_KEY = "filter" - if page_token: - filter = page_token_dict.get(FILTER_KEY, None) - where_clauses = [] - parsed_filter = _parse_filter(filter) if filter else {} - for key, value in parsed_filter.items(): - if key == "_text": - raise NotImplementedError("Text search is not implemented yet.") - elif key == "created_by": - if value == "me": - if current_user is None: - # raise ApiServiceError( - # f"The `created_by:me` filter requires `current_user`." - # ) - current_user = "" - value = current_user - # TODO: Maybe make this a bit more robust. - # We need to change the filter since it goes into the next_page_token. - filter = filter.replace( - "created_by:me", f"created_by:{current_user}" - ) - if value: - where_clauses.append(bts.PipelineRun.created_by == value) - else: - where_clauses.append(bts.PipelineRun.created_by == None) - else: - raise NotImplementedError(f"Unsupported filter {filter}.") + filter_value, offset = _resolve_filter_value( + filter=filter, + page_token=page_token, + ) + where_clauses, next_page_filter_value = _build_filter_where_clauses( + filter_value=filter_value, + current_user=current_user, + ) + pipeline_runs = list( session.scalars( sql.select(bts.PipelineRun) .where(*where_clauses) .order_by(bts.PipelineRun.created_at.desc()) .offset(offset) - .limit(page_size) + .limit(self._DEFAULT_PAGE_SIZE) ).all() ) - next_page_offset = offset + page_size - next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter} + next_page_offset = offset + self._DEFAULT_PAGE_SIZE + next_page_token_dict = { + self._PAGE_TOKEN_OFFSET_KEY: next_page_offset, + self._PAGE_TOKEN_FILTER_KEY: next_page_filter_value, + } next_page_token = _encode_page_token(next_page_token_dict) - if len(pipeline_runs) < page_size: + if len(pipeline_runs) < self._DEFAULT_PAGE_SIZE: next_page_token = None - def create_pipeline_run_response( - pipeline_run: bts.PipelineRun, - ) -> PipelineRunResponse: - response = PipelineRunResponse.from_db(pipeline_run) - if include_pipeline_names: - pipeline_name = None - extra_data = pipeline_run.extra_data or {} - if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data: - pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY] - else: - execution_node = session.get( - bts.ExecutionNode, pipeline_run.root_execution_id - ) - if execution_node: - task_spec = structures.TaskSpec.from_json_dict( - execution_node.task_spec - ) - component_spec = task_spec.component_ref.spec - if component_spec: - pipeline_name = component_spec.name - response.pipeline_name = pipeline_name - if include_execution_stats: - execution_status_stats = self._calculate_execution_status_stats( - session=session, root_execution_id=pipeline_run.root_execution_id - ) - response.execution_status_stats = { - status.value: count - for status, count in execution_status_stats.items() - } - return response - return ListPipelineJobsResponse( pipeline_runs=[ - create_pipeline_run_response(pipeline_run) + self._create_pipeline_run_response( + session=session, + pipeline_run=pipeline_run, + include_pipeline_names=include_pipeline_names, + include_execution_stats=include_execution_stats, + ) for pipeline_run in pipeline_runs ], next_page_token=next_page_token, ) + def _create_pipeline_run_response( + self, + *, + session: orm.Session, + pipeline_run: bts.PipelineRun, + include_pipeline_names: bool, + include_execution_stats: bool, + ) -> PipelineRunResponse: + response = PipelineRunResponse.from_db(pipeline_run) + if include_pipeline_names: + pipeline_name = None + extra_data = pipeline_run.extra_data or {} + if self._PIPELINE_NAME_EXTRA_DATA_KEY in extra_data: + pipeline_name = extra_data[self._PIPELINE_NAME_EXTRA_DATA_KEY] + else: + execution_node = session.get( + bts.ExecutionNode, pipeline_run.root_execution_id + ) + if execution_node: + task_spec = structures.TaskSpec.from_json_dict( + execution_node.task_spec + ) + component_spec = task_spec.component_ref.spec + if component_spec: + pipeline_name = component_spec.name + response.pipeline_name = pipeline_name + if include_execution_stats: + execution_status_stats = self._calculate_execution_status_stats( + session=session, root_execution_id=pipeline_run.root_execution_id + ) + response.execution_status_stats = { + status.value: count for status, count in execution_status_stats.items() + } + return response + def _calculate_execution_status_stats( self, session: orm.Session, root_execution_id: bts.IdType ) -> dict[bts.ContainerExecutionStatus, int]: @@ -316,7 +304,7 @@ def set_annotation( ): pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: - raise ItemNotFoundError(f"Pipeline run {id} not found.") + raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") if not skip_user_check and (user_name != pipeline_run.created_by): raise errors.PermissionError( f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}" @@ -338,7 +326,7 @@ def delete_annotation( ): pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: - raise ItemNotFoundError(f"Pipeline run {id} not found.") + raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") if not skip_user_check and (user_name != pipeline_run.created_by): raise errors.PermissionError( f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}" @@ -349,6 +337,64 @@ def delete_annotation( session.commit() +def _resolve_filter_value( + *, + filter: str | None, + page_token: str | None, +) -> tuple[str | None, int]: + """Decode page_token and return the effective (filter_value, offset). + + If a page_token is present, its stored filter takes precedence over the + raw filter parameter (the token carries the resolved filter forward across pages). + """ + page_token_dict = _decode_page_token(page_token) + offset = page_token_dict.get( + PipelineRunsApiService_Sql._PAGE_TOKEN_OFFSET_KEY, + 0, + ) + if page_token: + filter = page_token_dict.get( + PipelineRunsApiService_Sql._PAGE_TOKEN_FILTER_KEY, + None, + ) + return filter, offset + + +def _build_filter_where_clauses( + *, + filter_value: str | None, + current_user: str | None, +) -> tuple[list[sql.ColumnElement], str | None]: + """Parse a filter string into SQLAlchemy WHERE clauses. + + Returns (where_clauses, next_page_filter_value). The second value is the + filter string with shorthand values resolved (e.g. "created_by:me" becomes + "created_by:alice@example.com") so it can be embedded in the next page token. + """ + where_clauses: list[sql.ColumnElement] = [] + parsed_filter = _parse_filter(filter_value) if filter_value else {} + for key, value in parsed_filter.items(): + if key == "_text": + raise NotImplementedError("Text search is not implemented yet.") + elif key == "created_by": + if value == "me": + if current_user is None: + current_user = "" + value = current_user + # TODO: Maybe make this a bit more robust. + # We need to change the filter since it goes into the next_page_token. + filter_value = filter_value.replace( + "created_by:me", f"created_by:{current_user}" + ) + if value: + where_clauses.append(bts.PipelineRun.created_by == value) + else: + where_clauses.append(bts.PipelineRun.created_by == None) + else: + raise NotImplementedError(f"Unsupported filter {filter_value}.") + return where_clauses, filter_value + + def _decode_page_token(page_token: str) -> dict[str, Any]: return json.loads(base64.b64decode(page_token)) if page_token else {} @@ -524,7 +570,7 @@ class ExecutionNodesApiService_Sql: def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse: execution_node = session.get(bts.ExecutionNode, id) if execution_node is None: - raise ItemNotFoundError(f"Execution with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.") parent_pipeline_run_id = session.scalar( sql.select(bts.PipelineRun.id).where( @@ -676,7 +722,7 @@ def get_container_execution_state( ) -> GetContainerExecutionStateResponse: execution = session.get(bts.ExecutionNode, id) if not execution: - raise ItemNotFoundError(f"Execution with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.") container_execution = execution.container_execution if not container_execution: raise RuntimeError( @@ -696,7 +742,7 @@ def get_artifacts( if not session.scalar( sql.select(sql.exists().where(bts.ExecutionNode.id == id)) ): - raise ItemNotFoundError(f"Execution with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.") input_artifact_links = session.scalars( sql.select(bts.InputArtifactLink) @@ -742,7 +788,7 @@ def get_container_execution_log( ) -> GetContainerExecutionLogResponse: execution = session.get(bts.ExecutionNode, id) if not execution: - raise ItemNotFoundError(f"Execution with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.") container_execution = execution.container_execution execution_extra_data = execution.extra_data or {} system_error_exception_full = execution_extra_data.get( @@ -829,7 +875,9 @@ def stream_container_execution_log( ) -> typing.Iterator[str]: execution = session.get(bts.ExecutionNode, execution_id) if not execution: - raise ItemNotFoundError(f"Execution with {execution_id=} does not exist.") + raise errors.ItemNotFoundError( + f"Execution with {execution_id=} does not exist." + ) container_execution = execution.container_execution if not container_execution: raise ApiServiceError( @@ -970,7 +1018,7 @@ class ArtifactNodesApiService_Sql: def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse: artifact_node = session.get(bts.ArtifactNode, id) if artifact_node is None: - raise ItemNotFoundError(f"Artifact with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Artifact with {id=} does not exist.") artifact_data = artifact_node.artifact_data result = GetArtifactInfoResponse(id=artifact_node.id) if artifact_data: @@ -986,7 +1034,7 @@ def get_signed_artifact_url( .where(bts.ArtifactNode.id == id) ) if not artifact_data: - raise ItemNotFoundError(f"Artifact node with {id=} does not exist.") + raise errors.ItemNotFoundError(f"Artifact node with {id=} does not exist.") if not artifact_data.uri: raise ValueError(f"Artifact node with {id=} does not have artifact URI.") if artifact_data.is_dir: diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index c0f26a7..eab8c95 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -1,17 +1,22 @@ +import pytest +from sqlalchemy import orm + from cloud_pipelines_backend import backend_types_sql as bts -from cloud_pipelines_backend.api_server_sql import ExecutionStatusSummary +from cloud_pipelines_backend import component_structures as structures +from cloud_pipelines_backend import api_server_sql +from cloud_pipelines_backend import database_ops class TestExecutionStatusSummary: def test_initial_state(self): - summary = ExecutionStatusSummary() + summary = api_server_sql.ExecutionStatusSummary() assert summary.total_executions == 0 assert summary.ended_executions == 0 assert summary.has_ended is False def test_accumulate_all_ended_statuses(self): """Add each ended status with 2^i count for robust uniqueness.""" - summary = ExecutionStatusSummary() + summary = api_server_sql.ExecutionStatusSummary() ended_statuses = sorted(bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value) expected_total = 0 expected_ended = 0 @@ -26,7 +31,7 @@ def test_accumulate_all_ended_statuses(self): def test_accumulate_all_in_progress_statuses(self): """Add each in-progress status with 2^i count for robust uniqueness.""" - summary = ExecutionStatusSummary() + summary = api_server_sql.ExecutionStatusSummary() in_progress_statuses = sorted( set(bts.ContainerExecutionStatus) - bts.CONTAINER_STATUSES_ENDED, key=lambda s: s.value, @@ -42,7 +47,7 @@ def test_accumulate_all_in_progress_statuses(self): def test_accumulate_all_statuses(self): """Add every status with 2^i count. Summary math must be exact.""" - summary = ExecutionStatusSummary() + summary = api_server_sql.ExecutionStatusSummary() all_statuses = sorted(bts.ContainerExecutionStatus, key=lambda s: s.value) expected_total = 0 expected_ended = 0 @@ -55,3 +60,480 @@ def test_accumulate_all_statuses(self): assert summary.total_executions == expected_total assert summary.ended_executions == expected_ended assert summary.has_ended == (expected_ended == expected_total) + + +def _make_task_spec(pipeline_name: str = "test-pipeline") -> structures.TaskSpec: + return structures.TaskSpec( + component_ref=structures.ComponentReference( + spec=structures.ComponentSpec( + name=pipeline_name, + implementation=structures.ContainerImplementation( + container=structures.ContainerSpec(image="test-image:latest"), + ), + ), + ), + ) + + +@pytest.fixture() +def session_factory(): + engine = database_ops.create_db_engine(database_uri="sqlite://") + bts._TableBase.metadata.create_all(engine) + return orm.sessionmaker(engine) + + +@pytest.fixture() +def db_session(session_factory): + with session_factory() as session: + yield session + + +@pytest.fixture() +def service(): + return api_server_sql.PipelineRunsApiService_Sql() + + +def _create_run(session_factory, service, **kwargs): + """Create a pipeline run using a fresh session (mirrors production per-request sessions).""" + with session_factory() as session: + return service.create(session, **kwargs) + + +class TestPipelineRunServiceList: + def test_list_empty(self, session_factory, service): + with session_factory() as session: + result = service.list( + session=session, + ) + assert result.pipeline_runs == [] + assert result.next_page_token is None + + def test_list_returns_pipeline_runs(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec("pipeline-a")) + _create_run(session_factory, service, root_task=_make_task_spec("pipeline-b")) + + with session_factory() as session: + result = service.list( + session=session, + ) + assert len(result.pipeline_runs) == 2 + + def test_list_with_execution_stats(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec()) + + with session_factory() as session: + result = service.list( + session=session, + include_execution_stats=True, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].execution_status_stats is not None + + def test_list_filter_created_by(self, session_factory, service): + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user2", + ) + + with session_factory() as session: + result = service.list( + session=session, + filter="created_by:user1", + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].created_by == "user1" + + def test_list_filter_created_by_empty(self, session_factory, service): + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by=None, + ) + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + + with session_factory() as session: + result = service.list( + session=session, + filter="created_by:", + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].created_by is None + + def test_list_pagination(self, session_factory, service): + for i in range(12): + _create_run( + session_factory, + service, + root_task=_make_task_spec(f"pipeline-{i}"), + ) + + with session_factory() as session: + page1 = service.list( + session=session, + ) + assert len(page1.pipeline_runs) == 10 + assert page1.next_page_token is not None + + with session_factory() as session: + page2 = service.list( + session=session, + page_token=page1.next_page_token, + ) + assert len(page2.pipeline_runs) == 2 + assert page2.next_page_token is None + + def test_list_filter_unsupported(self, session_factory, service): + with session_factory() as session: + with pytest.raises(NotImplementedError, match="Unsupported filter"): + service.list( + session=session, + filter="unknown_key:value", + ) + + def test_list_with_pipeline_names(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec("my-pipeline")) + + with session_factory() as session: + result = service.list( + session=session, + include_pipeline_names=True, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].pipeline_name == "my-pipeline" + + def test_list_filter_created_by_me(self, session_factory, service): + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice@example.com", + ) + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="bob@example.com", + ) + + with session_factory() as session: + result = service.list( + session=session, + current_user="alice@example.com", + filter="created_by:me", + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].created_by == "alice@example.com" + + +class TestCreatePipelineRunResponse: + def test_base_response(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + response = service._create_pipeline_run_response( + session=session, + pipeline_run=db_run, + include_pipeline_names=False, + include_execution_stats=False, + ) + assert response.id == run.id + assert response.pipeline_name is None + assert response.execution_status_stats is None + + def test_pipeline_name_from_task_spec(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("my-pipeline"), + ) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + response = service._create_pipeline_run_response( + session=session, + pipeline_run=db_run, + include_pipeline_names=True, + include_execution_stats=False, + ) + assert response.pipeline_name == "my-pipeline" + + def test_pipeline_name_from_extra_data(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("spec-name"), + ) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + db_run.extra_data = {"pipeline_name": "cached-name"} + session.commit() + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + response = service._create_pipeline_run_response( + session=session, + pipeline_run=db_run, + include_pipeline_names=True, + include_execution_stats=False, + ) + assert response.pipeline_name == "cached-name" + + def test_pipeline_name_none_when_no_execution_node(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec("some-name"), + ) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + db_run.root_execution_id = "nonexistent-id" + db_run.extra_data = {} + session.commit() + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + response = service._create_pipeline_run_response( + session=session, + pipeline_run=db_run, + include_pipeline_names=True, + include_execution_stats=False, + ) + assert response.pipeline_name is None + + def test_with_execution_stats(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + response = service._create_pipeline_run_response( + session=session, + pipeline_run=db_run, + include_pipeline_names=False, + include_execution_stats=True, + ) + assert response.execution_status_stats is not None + + +class TestPipelineRunServiceCreate: + def test_create_returns_pipeline_run(self, session_factory, service): + result = _create_run( + session_factory, service, root_task=_make_task_spec("my-pipeline") + ) + assert result.id is not None + assert result.root_execution_id is not None + assert result.created_at is not None + + def test_create_with_created_by(self, session_factory, service): + result = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1@example.com", + ) + assert result.created_by == "user1@example.com" + + def test_create_with_annotations(self, session_factory, service): + annotations = {"team": "ml-ops", "project": "search"} + result = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + annotations=annotations, + ) + assert result.annotations == annotations + + def test_create_without_created_by(self, session_factory, service): + result = _create_run(session_factory, service, root_task=_make_task_spec()) + assert result.created_by is None + + +class TestPipelineRunAnnotationCrud: + def test_set_annotation(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="ml-ops", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations == {"team": "ml-ops"} + + def test_set_annotation_overwrites(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="old-value", + user_name="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="new-value", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations == {"team": "new-value"} + + def test_delete_annotation(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + service.set_annotation( + session=session, + id=run.id, + key="team", + value="ml-ops", + user_name="user1", + ) + with session_factory() as session: + service.delete_annotation( + session=session, + id=run.id, + key="team", + user_name="user1", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations == {} + + def test_list_annotations_empty(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations == {} + + +class TestResolveFilterValue: + def test_no_page_token_no_filter(self): + filter_value, offset = api_server_sql._resolve_filter_value( + filter=None, page_token=None + ) + assert filter_value is None + assert offset == 0 + + def test_no_page_token_with_filter(self): + filter_value, offset = api_server_sql._resolve_filter_value( + filter="created_by:alice", + page_token=None, + ) + assert filter_value == "created_by:alice" + assert offset == 0 + + def test_page_token_overrides_filter(self): + token = api_server_sql._encode_page_token( + {"offset": 20, "filter": "created_by:bob"} + ) + filter_value, offset = api_server_sql._resolve_filter_value( + filter="created_by:alice", + page_token=token, + ) + assert filter_value == "created_by:bob" + assert offset == 20 + + def test_page_token_without_filter_key(self): + token = api_server_sql._encode_page_token({"offset": 10}) + filter_value, offset = api_server_sql._resolve_filter_value( + filter="created_by:alice", + page_token=token, + ) + assert filter_value is None + assert offset == 10 + + def test_page_token_without_offset_key(self): + token = api_server_sql._encode_page_token({"filter": "created_by:bob"}) + filter_value, offset = api_server_sql._resolve_filter_value( + filter=None, + page_token=token, + ) + assert filter_value == "created_by:bob" + assert offset == 0 + + +class TestBuildFilterWhereClauses: + def test_no_filter(self): + clauses, next_filter = api_server_sql._build_filter_where_clauses( + filter_value=None, + current_user=None, + ) + assert clauses == [] + assert next_filter is None + + def test_created_by_literal(self): + clauses, next_filter = api_server_sql._build_filter_where_clauses( + filter_value="created_by:alice", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:alice" + + def test_created_by_me_resolves(self): + clauses, next_filter = api_server_sql._build_filter_where_clauses( + filter_value="created_by:me", + current_user="alice@example.com", + ) + assert len(clauses) == 1 + assert next_filter == "created_by:alice@example.com" + + def test_created_by_me_no_current_user(self): + clauses, next_filter = api_server_sql._build_filter_where_clauses( + filter_value="created_by:me", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:" + + def test_created_by_empty_value(self): + clauses, next_filter = api_server_sql._build_filter_where_clauses( + filter_value="created_by:", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:" + + def test_unsupported_key_raises(self): + with pytest.raises(NotImplementedError, match="Unsupported filter"): + api_server_sql._build_filter_where_clauses( + filter_value="unknown_key:value", + current_user=None, + ) + + def test_text_search_raises(self): + with pytest.raises(NotImplementedError, match="Text search"): + api_server_sql._build_filter_where_clauses( + filter_value="some_text_without_colon", + current_user=None, + )