Skip to content

Commit

Permalink
feat: fetch annotation names (#2964)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Apr 24, 2024
1 parent d665e49 commit 6c5d25d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
53 changes: 41 additions & 12 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from phoenix.server.api.types.Trace import Trace
from phoenix.server.api.types.ValidationResult import ValidationResult
from phoenix.trace.dsl import SpanFilter
from phoenix.trace.schemas import SpanID


@strawberry.type
Expand Down Expand Up @@ -208,26 +207,56 @@ async def spans(
description="Names of all available evaluations for traces. "
"(The list contains no duplicates.)"
) # type: ignore
def trace_evaluation_names(self) -> List[str]:
return self.project.get_trace_evaluation_names()
async def trace_evaluation_names(
self,
info: Info[Context, None],
) -> List[str]:
stmt = (
select(distinct(models.TraceAnnotation.name))
.join(models.Trace)
.where(models.Trace.project_rowid == self.id_attr)
.where(models.TraceAnnotation.annotator_kind == "LLM")
)
async with info.context.db() as session:
return list(await session.scalars(stmt))

@strawberry.field(
description="Names of all available evaluations for spans. "
"(The list contains no duplicates.)"
) # type: ignore
def span_evaluation_names(self) -> List[str]:
return self.project.get_span_evaluation_names()
async def span_evaluation_names(
self,
info: Info[Context, None],
) -> List[str]:
stmt = (
select(distinct(models.SpanAnnotation.name))
.join(models.Span)
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
.where(models.Trace.project_rowid == self.id_attr)
.where(models.SpanAnnotation.annotator_kind == "LLM")
)
async with info.context.db() as session:
return list(await session.scalars(stmt))

@strawberry.field(
description="Names of available document evaluations.",
) # type: ignore
def document_evaluation_names(
async def document_evaluation_names(
self,
info: Info[Context, None],
span_id: Optional[ID] = UNSET,
) -> List[str]:
return self.project.get_document_evaluation_names(
None if span_id is UNSET else SpanID(span_id),
stmt = (
select(distinct(models.DocumentAnnotation.name))
.join(models.Span)
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
.where(models.Trace.project_rowid == self.id_attr)
.where(models.DocumentAnnotation.annotator_kind == "LLM")
)
if span_id:
stmt = stmt.where(models.Span.span_id == str(span_id))
async with info.context.db() as session:
return list(await session.scalars(stmt))

@strawberry.field
async def trace_evaluation_summary(
Expand Down Expand Up @@ -355,13 +384,13 @@ def streaming_last_updated_at(
return info.context.streaming_last_updated_at()

@strawberry.field
def validate_span_filter_condition(self, condition: str) -> ValidationResult:
valid_eval_names = self.project.get_span_evaluation_names()
async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
# TODO(persistence): this query is too expensive to run on every validation
# valid_eval_names = await self.span_evaluation_names()
try:
SpanFilter(
condition=condition,
evals=self.project,
valid_eval_names=valid_eval_names,
# valid_eval_names=valid_eval_names,
)
return ValidationResult(is_valid=True, error_message=None)
except SyntaxError as e:
Expand Down
9 changes: 4 additions & 5 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def get_span_evaluation(self, span_id: SpanID, name: str) -> typing.Optional[pb.
@dataclass(frozen=True)
class SpanFilter:
condition: str = ""
# TODO(persistence): remove `evals` and `valid_eval_names` from this class
evals: typing.Optional[SupportsGetSpanEvaluation] = None
valid_eval_names: typing.Optional[typing.Sequence[str]] = None
translated: ast.Expression = field(init=False, repr=False)
compiled: typing.Any = field(init=False, repr=False)
Expand Down Expand Up @@ -124,11 +122,12 @@ def to_dict(self) -> typing.Dict[str, typing.Any]:
def from_dict(
cls,
obj: typing.Mapping[str, typing.Any],
# TODO(persistence): remove `evals` and `valid_eval_names` from this class
evals: typing.Optional[SupportsGetSpanEvaluation] = None,
valid_eval_names: typing.Optional[typing.Sequence[str]] = None,
) -> "SpanFilter":
return cls(condition=obj.get("condition") or "")
return cls(
condition=obj.get("condition") or "",
valid_eval_names=valid_eval_names,
)


@dataclass(frozen=True)
Expand Down
4 changes: 1 addition & 3 deletions src/phoenix/trace/dsl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
unflatten,
)
from phoenix.trace.dsl import SpanFilter
from phoenix.trace.dsl.filter import Projector, SupportsGetSpanEvaluation
from phoenix.trace.dsl.filter import Projector
from phoenix.trace.schemas import ATTRIBUTE_PREFIX

# supported SQL dialects
Expand Down Expand Up @@ -611,7 +611,6 @@ def to_dict(self) -> Dict[str, Any]:
def from_dict(
cls,
obj: Mapping[str, Any],
evals: Optional[SupportsGetSpanEvaluation] = None,
valid_eval_names: Optional[Sequence[str]] = None,
) -> "SpanQuery":
return cls(
Expand All @@ -631,7 +630,6 @@ def from_dict(
{
"_filter": SpanFilter.from_dict(
cast(Mapping[str, Any], filter),
evals=evals,
valid_eval_names=valid_eval_names,
)
} # type: ignore
Expand Down

0 comments on commit 6c5d25d

Please sign in to comment.