Skip to content

Commit

Permalink
feat: make graphql api for span evaluations read from database (#2860)
Browse files Browse the repository at this point in the history
Makes the GraphQL resolver for span evaluations read from the database, using a data loader.

Coauthored-by: Mikyo King mikeldking@gmail.com
  • Loading branch information
axiomofjoy committed Apr 11, 2024
1 parent d539e77 commit 5adf750
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncContextManager, Callable, Optional, Tuple, Union
from typing import AsyncContextManager, Callable, List, Optional, Tuple, Union

from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
Expand All @@ -11,11 +11,13 @@
from phoenix.core.model_schema import Model
from phoenix.core.traces import Traces
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.Evaluation import SpanEvaluation


@dataclass
class DataLoaders:
latency_ms_quantile: DataLoader[Tuple[str, Optional[TimeRange], float], Optional[float]]
span_evaluations: DataLoader[int, List[SpanEvaluation]]


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .latency_ms_quantile import LatencyMsQuantileDataLoader
from .span_evaluations import SpanEvaluationsDataLoader

__all__ = [
"LatencyMsQuantileDataLoader",
"SpanEvaluationsDataLoader",
]
5 changes: 3 additions & 2 deletions src/phoenix/server/api/dataloaders/latency_ms_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
)

from ddsketch.ddsketch import DDSketch
from phoenix.db import models
from phoenix.server.api.input_types.TimeRange import TimeRange
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.api.input_types.TimeRange import TimeRange

ProjectName: TypeAlias = str
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
Segment: TypeAlias = Tuple[ProjectName, TimeInterval]
Expand Down
39 changes: 39 additions & 0 deletions src/phoenix/server/api/dataloaders/span_evaluations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections import defaultdict
from typing import (
AsyncContextManager,
Callable,
DefaultDict,
List,
)

from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.api.types.Evaluation import SpanEvaluation

Key: TypeAlias = int


class SpanEvaluationsDataLoader(DataLoader[Key, List[SpanEvaluation]]):
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[List[SpanEvaluation]]:
span_evaluations_by_id: DefaultDict[Key, List[SpanEvaluation]] = defaultdict(list)
async with self._db() as session:
for span_evaluation in await session.scalars(
select(models.SpanAnnotation).where(
and_(
models.SpanAnnotation.span_rowid.in_(keys),
models.SpanAnnotation.annotator_kind == "LLM",
)
)
):
span_evaluations_by_id[span_evaluation.span_rowid].append(
SpanEvaluation.from_sql_span_annotation(span_evaluation)
)
return [span_evaluations_by_id[key] for key in keys]
14 changes: 10 additions & 4 deletions src/phoenix/server/api/types/Evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import strawberry

import phoenix.trace.v1 as pb
from phoenix.db.models import SpanAnnotation
from phoenix.trace.schemas import SpanID, TraceID


Expand Down Expand Up @@ -46,21 +47,26 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "TraceEvaluation":

@strawberry.type
class SpanEvaluation(Evaluation):
span_id: strawberry.Private[SpanID]

@staticmethod
def from_pb_evaluation(evaluation: pb.Evaluation) -> "SpanEvaluation":
result = evaluation.result
score = result.score.value if result.HasField("score") else None
label = result.label.value if result.HasField("label") else None
explanation = result.explanation.value if result.HasField("explanation") else None
span_id = SpanID(evaluation.subject_id.span_id)
return SpanEvaluation(
name=evaluation.name,
score=score,
label=label,
explanation=explanation,
span_id=span_id,
)

@staticmethod
def from_sql_span_annotation(annotation: "SpanAnnotation") -> "SpanEvaluation":
return SpanEvaluation(
name=annotation.name,
score=annotation.score,
label=annotation.label,
explanation=annotation.explanation,
)


Expand Down
10 changes: 4 additions & 6 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def from_dict(
@strawberry.type
class Span:
project: strawberry.Private[Project]
span_rowid: strawberry.Private[int]
name: str
status_code: SpanStatusCode
status_message: str
Expand Down Expand Up @@ -146,12 +147,8 @@ class Span:
"an LLM, an evaluation may assess the helpfulness of its response with "
"respect to its input."
) # type: ignore
def span_evaluations(self) -> List[SpanEvaluation]:
span_id = SpanID(str(self.context.span_id))
return [
SpanEvaluation.from_pb_evaluation(evaluation)
for evaluation in self.project.get_evaluations_by_span_id(span_id)
]
async def span_evaluations(self, info: Info[Context, None]) -> List[SpanEvaluation]:
return await info.context.data_loaders.span_evaluations.load(self.span_rowid)

@strawberry.field(
description="Evaluations of the documents associated with the span, e.g. "
Expand Down Expand Up @@ -240,6 +237,7 @@ def to_gql_span(span: models.Span, project: Project) -> Span:
num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
return Span(
project=project,
span_rowid=span.id,
name=span.name,
status_code=SpanStatusCode(span.status),
status_message=span.status_message,
Expand Down
4 changes: 3 additions & 1 deletion src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
from phoenix.db.engines import create_engine
from phoenix.pointcloud.umap_parameters import UMAPParameters
from phoenix.server.api.context import Context, DataLoaders
from phoenix.server.api.dataloaders.latency_ms_quantile import (
from phoenix.server.api.dataloaders import (
LatencyMsQuantileDataLoader,
SpanEvaluationsDataLoader,
)
from phoenix.server.api.routers.evaluation_handler import EvaluationHandler
from phoenix.server.api.routers.span_handler import SpanHandler
Expand Down Expand Up @@ -149,6 +150,7 @@ async def get_context(
export_path=self.export_path,
data_loaders=DataLoaders(
latency_ms_quantile=LatencyMsQuantileDataLoader(self.db),
span_evaluations=SpanEvaluationsDataLoader(self.db),
),
)

Expand Down

0 comments on commit 5adf750

Please sign in to comment.