Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_edit_annotations(self, authed_api):
assert response["count"] == 1
assert response["trace"]["trace_id"] == trace_id
assert response["trace"]["data"]["outputs"] == new_annotation_data_outputs
assert response["trace"]["links"] == {}
assert response["trace"]["links"] == annotation_links
assert response["trace"]["tags"] == {"tag3": "value3"}
assert response["trace"]["meta"] == {"meta3": "value3"}
# ----------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import uuid4


def _create_simple_query(authed_api) -> dict:
def _create_simple_query(authed_api, *, trace_type: str = "invocation") -> dict:
slug = uuid4().hex
response = authed_api(
"POST",
Expand All @@ -14,12 +14,17 @@ def _create_simple_query(authed_api) -> dict:
"filtering": {
"operator": "and",
"conditions": [
{
"field": "trace_type",
"value": trace_type,
"operator": "is",
},
{
"field": "attributes",
"key": f"test-key-{slug[:8]}",
"value": "test-value",
"operator": "is",
}
},
],
}
},
Expand Down Expand Up @@ -91,3 +96,31 @@ def test_create_live_simple_evaluation_accepts_query_and_evaluator_revision_ids(
assert set(evaluation["data"]["evaluator_steps"].keys()) == {
evaluator["revision_id"]
}

def test_create_live_simple_evaluation_rejects_non_invocation_query_revision(
self, authed_api
):
query = _create_simple_query(authed_api, trace_type="annotation")
evaluator = _create_simple_evaluator(authed_api)

response = authed_api(
"POST",
"/simple/evaluations/",
json={
"evaluation": {
"name": "live-non-invocation-query-rejected",
"flags": {
"is_live": True,
},
"data": {
"query_steps": [query["revision_id"]],
"evaluator_steps": [evaluator["revision_id"]],
},
}
},
)

assert response.status_code == 200
body = response.json()
assert body["count"] == 0
assert body.get("evaluation") is None
114 changes: 113 additions & 1 deletion api/oss/tests/pytest/unit/test_simple_evaluations_parser.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from uuid import uuid4
import sys
import types
from unittest.mock import AsyncMock

import pytest

from oss.src.core.shared.dtos import Reference
from oss.src.core.queries.dtos import QueryRevisionData
from oss.src.core.evaluations.types import (
EvaluationRun,
EvaluationRunData,
EvaluationRunDataStep,
EvaluationRunFlags,
EvaluationStatus,
)
from oss.src.core.tracing.dtos import Condition, Filtering

sys.modules.setdefault("genson", types.SimpleNamespace(SchemaBuilder=object))

from oss.src.core.evaluations.service import SimpleEvaluationsService # noqa: E402
from oss.src.core.evaluations.service import ( # noqa: E402
EvaluationsService,
SimpleEvaluationsService,
_is_invocation_query,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -69,3 +76,108 @@ async def test_parse_evaluation_run_prefers_workflow_revision_refs():
assert set(evaluation.data.query_steps.keys()) == {query_id}
assert set(evaluation.data.application_steps.keys()) == {application_id}
assert set(evaluation.data.evaluator_steps.keys()) == {evaluator_revision_id}


def _query_revision_data(
*,
field: str = "trace_type",
operator: str = "is",
value: str = "invocation",
) -> QueryRevisionData:
return QueryRevisionData(
filtering=Filtering(
conditions=[
Condition(
field=field,
operator=operator,
value=value,
)
]
)
)


def test_is_invocation_query_requires_top_level_invocation_trace_type():
assert _is_invocation_query(_query_revision_data()) is True
assert (
_is_invocation_query(
_query_revision_data(field="attributes", value="invocation")
)
is False
)
assert _is_invocation_query(_query_revision_data(value="annotation")) is False
assert _is_invocation_query(_query_revision_data(operator="in")) is False
assert _is_invocation_query(QueryRevisionData()) is False


@pytest.mark.asyncio
async def test_live_run_validation_accepts_invocation_query_revision():
query_revision_id = uuid4()
queries_service = types.SimpleNamespace(
fetch_query_revision=AsyncMock(
return_value=types.SimpleNamespace(data=_query_revision_data())
)
)
service = EvaluationsService(
evaluations_dao=None, # type: ignore[arg-type]
tracing_service=None, # type: ignore[arg-type]
queries_service=queries_service, # type: ignore[arg-type]
testsets_service=None, # type: ignore[arg-type]
evaluators_service=None, # type: ignore[arg-type]
)

run = EvaluationRun(
id=uuid4(),
flags=EvaluationRunFlags(is_live=True, is_active=True),
status=EvaluationStatus.PENDING,
data=EvaluationRunData(
steps=[
EvaluationRunDataStep(
key="query-step",
type="input",
origin="custom",
references={"query_revision": Reference(id=query_revision_id)},
)
]
),
)

assert await service._is_live_run_valid(project_id=uuid4(), run=run) is True
queries_service.fetch_query_revision.assert_awaited_once()


@pytest.mark.asyncio
async def test_live_run_validation_rejects_non_invocation_query_revision():
query_revision_id = uuid4()
queries_service = types.SimpleNamespace(
fetch_query_revision=AsyncMock(
return_value=types.SimpleNamespace(
data=_query_revision_data(value="annotation")
)
)
)
service = EvaluationsService(
evaluations_dao=None, # type: ignore[arg-type]
tracing_service=None, # type: ignore[arg-type]
queries_service=queries_service, # type: ignore[arg-type]
testsets_service=None, # type: ignore[arg-type]
evaluators_service=None, # type: ignore[arg-type]
)

run = EvaluationRun(
id=uuid4(),
flags=EvaluationRunFlags(is_live=True, is_active=True),
status=EvaluationStatus.PENDING,
data=EvaluationRunData(
steps=[
EvaluationRunDataStep(
key="query-step",
type="input",
origin="custom",
references={"query_revision": Reference(id=query_revision_id)},
)
]
),
)

assert await service._is_live_run_valid(project_id=uuid4(), run=run) is False
Loading