Skip to content

Commit

Permalink
feat(persistence): dataloader for span descendants (#2980)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Apr 25, 2024
1 parent 481242c commit d8e10d4
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
2 changes: 2 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from strawberry.dataloader import DataLoader

from phoenix.core.model_schema import Model
from phoenix.db import models
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation
Expand All @@ -24,6 +25,7 @@ class DataLoaders:
document_retrieval_metrics: DataLoader[
Tuple[int, Optional[str], int], List[DocumentRetrievalMetrics]
]
span_descendants: DataLoader[str, List[models.Span]]


@dataclass
Expand Down
64 changes: 64 additions & 0 deletions src/phoenix/server/api/dataloaders/span_descendants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from itertools import groupby
from random import randint
from typing import (
AsyncContextManager,
Callable,
Dict,
List,
)

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models

SpanId: TypeAlias = str
Key: TypeAlias = SpanId


class SpanDescendantsDataLoader(DataLoader[Key, List[models.Span]]):
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[List[models.Span]]:
root_ids = set(keys)
root_id_label = f"root_id_{randint(0, 10**6):06}"
descendant_ids = (
select(
models.Span.id,
models.Span.span_id,
models.Span.parent_id.label(root_id_label),
)
.where(models.Span.parent_id.in_(root_ids))
.cte(recursive=True)
)
parent_ids = descendant_ids.alias()
descendant_ids = descendant_ids.union_all(
select(
models.Span.id,
models.Span.span_id,
parent_ids.c[root_id_label],
).join(
parent_ids,
models.Span.parent_id == parent_ids.c.span_id,
)
)
stmt = (
select(descendant_ids.c[root_id_label], models.Span)
.join(descendant_ids, models.Span.id == descendant_ids.c.id)
.join(models.Trace)
.options(contains_eager(models.Span.trace))
.order_by(descendant_ids.c[root_id_label])
)
async with self._db() as session:
data = await session.execute(stmt)
if not data:
return [[] for _ in keys]
results: Dict[SpanId, List[models.Span]] = {key: [] for key in keys}
for root_id, group in groupby(data, key=lambda d: d[0]):
results[root_id].extend(span for _, span in group)
return [results[key].copy() for key in keys]
24 changes: 2 additions & 22 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
import strawberry
from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from strawberry import ID, UNSET
from strawberry.types import Info

Expand Down Expand Up @@ -180,26 +178,8 @@ async def descendants(
self,
info: Info[Context, None],
) -> List["Span"]:
# TODO(persistence): add dataloader (to avoid N+1 queries) or change how this is fetched
async with info.context.db() as session:
descendant_ids = (
select(models.Span.id, models.Span.span_id)
.filter(models.Span.parent_id == str(self.context.span_id))
.cte(recursive=True)
)
parent_ids = descendant_ids.alias()
descendant_ids = descendant_ids.union_all(
select(models.Span.id, models.Span.span_id).join(
parent_ids,
models.Span.parent_id == parent_ids.c.span_id,
)
)
spans = await session.scalars(
select(models.Span)
.join(descendant_ids, models.Span.id == descendant_ids.c.id)
.join(models.Trace)
.options(contains_eager(models.Span.trace))
)
span_id = str(self.context.span_id)
spans = await info.context.data_loaders.span_descendants.load(span_id)
return [to_gql_span(span) for span in spans]


Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SpanEvaluationsDataLoader,
TraceEvaluationsDataLoader,
)
from phoenix.server.api.dataloaders.span_descendants import SpanDescendantsDataLoader
from phoenix.server.api.routers.v1 import V1_ROUTES
from phoenix.server.api.schema import schema
from phoenix.trace.schemas import Span
Expand Down Expand Up @@ -159,6 +160,7 @@ async def get_context(
document_evaluations=DocumentEvaluationsDataLoader(self.db),
trace_evaluations=TraceEvaluationsDataLoader(self.db),
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db),
span_descendants=SpanDescendantsDataLoader(self.db),
),
)

Expand Down

0 comments on commit d8e10d4

Please sign in to comment.