Skip to content

Commit

Permalink
feat(persistence): bulk inserter for spans (#2808)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Apr 8, 2024
1 parent 399a9f6 commit 9ce841e
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 166 deletions.
6 changes: 3 additions & 3 deletions src/phoenix/db/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ keys = console
keys = generic

[logger_root]
level = DEBUG
level = WARN
handlers = console
qualname =

[logger_sqlalchemy]
level = DEBUG
level = WARN
handlers =
qualname = sqlalchemy.engine

[logger_alembic]
level = DEBUG
level = WARN
handlers =
qualname = alembic

Expand Down
187 changes: 187 additions & 0 deletions src/phoenix/db/bulk_inserter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import asyncio
import logging
from itertools import islice
from time import time
from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast

from openinference.semconv.trace import SpanAttributes
from sqlalchemy import func, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from phoenix.db import models
from phoenix.trace.schemas import Span, SpanStatusCode

logger = logging.getLogger(__name__)


class BulkInserter:
def __init__(
self,
db: Callable[[], AsyncContextManager[AsyncSession]],
initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None,
run_interval_in_seconds: float = 0.5,
max_num_per_transaction: int = 100,
) -> None:
"""
:param db: A function to initiate a new database session.
:param initial_batch_of_spans: Initial batch of spans to insert.
:param run_interval_in_seconds: The time interval between the starts of each
bulk insert. If there's nothing to insert, the inserter goes back to sleep.
:param max_num_per_transaction: The maximum number of items to insert in a single
transaction. Multiple transactions will be used if there are more items in the batch.
"""
self._db = db
self._running = False
self._run_interval_seconds = run_interval_in_seconds
self._max_num_per_transaction = max_num_per_transaction
self._spans: List[Tuple[Span, str]] = (
[] if initial_batch_of_spans is None else list(initial_batch_of_spans)
)
self._task: Optional[asyncio.Task[None]] = None

async def __aenter__(self) -> Callable[[Span, str], None]:
self._running = True
self._task = asyncio.create_task(self._bulk_insert())
return self._queue_span

async def __aexit__(self, *args: Any) -> None:
self._running = False

def _queue_span(self, span: Span, project_name: str) -> None:
self._spans.append((span, project_name))

async def _bulk_insert(self) -> None:
next_run_at = time() + self._run_interval_seconds
while self._spans or self._running:
await asyncio.sleep(next_run_at - time())
next_run_at = time() + self._run_interval_seconds
if self._spans:
await self._insert_spans()

async def _insert_spans(self) -> None:
spans = self._spans
self._spans = []
for i in range(0, len(spans), self._max_num_per_transaction):
try:
async with self._db() as session:
for span, project_name in islice(spans, i, i + self._max_num_per_transaction):
try:
async with session.begin_nested():
await _insert_span(session, span, project_name)
except Exception:
logger.exception(
f"Failed to insert span with span_id={span.context.span_id}"
)
except Exception:
logger.exception("Failed to insert spans")


async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None:
if await session.scalar(select(1).where(models.Span.span_id == span.context.span_id)):
# Span already exists
return
if not (
project_rowid := await session.scalar(
select(models.Project.id).where(models.Project.name == project_name)
)
):
project_rowid = await session.scalar(
insert(models.Project).values(name=project_name).returning(models.Project.id)
)
if trace := await session.scalar(
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
):
trace_rowid = trace.id
# TODO(persistence): Figure out how to reliably retrieve timezone-aware
# datetime from the (sqlite) database, because all datetime in our
# programs should be timezone-aware.
if span.start_time < trace.start_time or trace.end_time < span.end_time:
trace.start_time = min(trace.start_time, span.start_time)
trace.end_time = max(trace.end_time, span.end_time)
await session.execute(
update(models.Trace)
.where(models.Trace.id == trace_rowid)
.values(
start_time=min(trace.start_time, span.start_time),
end_time=max(trace.end_time, span.end_time),
)
)
else:
trace_rowid = cast(
int,
await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.Trace.id)
),
)
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0)
)
cumulative_llm_token_count_completion = cast(
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0)
)
if accumulation := (
await session.execute(
select(
func.sum(models.Span.cumulative_error_count),
func.sum(models.Span.cumulative_llm_token_count_prompt),
func.sum(models.Span.cumulative_llm_token_count_completion),
).where(models.Span.parent_span_id == span.context.span_id)
)
).first():
cumulative_error_count += cast(int, accumulation[0] or 0)
cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0)
cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0)
latency_ms = (span.end_time - span.start_time).total_seconds() * 1000
session.add(
models.Span(
span_id=span.context.span_id,
trace_rowid=trace_rowid,
parent_span_id=span.parent_id,
kind=span.span_kind.value,
name=span.name,
start_time=span.start_time,
end_time=span.end_time,
attributes=span.attributes,
events=span.events,
status=span.status_code.value,
status_message=span.status_message,
latency_ms=latency_ms,
cumulative_error_count=cumulative_error_count,
cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt,
cumulative_llm_token_count_completion=cumulative_llm_token_count_completion,
)
)
# Propagate cumulative values to ancestors. This is usually a no-op, since
# the parent usually arrives after the child. But in the event that a
# child arrives after its parent, we need to make sure the all the
# ancestors' cumulative values are updated.
ancestors = (
select(models.Span.id, models.Span.parent_span_id)
.where(models.Span.span_id == span.parent_id)
.cte(recursive=True)
)
child = ancestors.alias()
ancestors = ancestors.union_all(
select(models.Span.id, models.Span.parent_span_id).join(
child, models.Span.span_id == child.c.parent_span_id
)
)
await session.execute(
update(models.Span)
.where(models.Span.id.in_(select(ancestors.c.id)))
.values(
cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count,
cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt
+ cumulative_llm_token_count_prompt,
cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion
+ cumulative_llm_token_count_completion,
)
)
7 changes: 6 additions & 1 deletion src/phoenix/db/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def aio_sqlite_engine(
engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps)
event.listen(engine.sync_engine, "connect", set_sqlite_pragma)
if str(database) == ":memory:":
asyncio.run(init_models(engine))
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.run(init_models(engine))
else:
asyncio.create_task(init_models(engine))
else:
migrate(engine.url)
return engine
Expand Down
17 changes: 11 additions & 6 deletions src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ def upgrade() -> None:
# TODO does the uniqueness constraint need to be named
sa.Column("name", sa.String, nullable=False, unique=True),
sa.Column("description", sa.String, nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(),
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
Expand All @@ -41,8 +46,8 @@ def upgrade() -> None:
# TODO(mikeldking): might not be the right place for this
sa.Column("session_id", sa.String, nullable=True),
sa.Column("trace_id", sa.String, nullable=False, unique=True),
sa.Column("start_time", sa.DateTime(), nullable=False, index=True),
sa.Column("end_time", sa.DateTime(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False, index=True),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
)

op.create_table(
Expand All @@ -53,8 +58,8 @@ def upgrade() -> None:
sa.Column("parent_span_id", sa.String, nullable=True, index=True),
sa.Column("name", sa.String, nullable=False),
sa.Column("kind", sa.String, nullable=False),
sa.Column("start_time", sa.DateTime(), nullable=False),
sa.Column("end_time", sa.DateTime(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("attributes", sa.JSON, nullable=False),
sa.Column("events", sa.JSON, nullable=False),
sa.Column(
Expand Down
52 changes: 45 additions & 7 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

from sqlalchemy import (
JSON,
CheckConstraint,
DateTime,
Dialect,
ForeignKey,
MetaData,
TypeDecorator,
UniqueConstraint,
func,
insert,
Expand All @@ -21,6 +23,42 @@
)


class UtcTimeStamp(TypeDecorator[datetime]):
"""TODO(persistence): Figure out how to reliably store and retrieve
timezone-aware datetime objects from the (sqlite) database. Below is a
workaround to guarantee that the timestamps we fetch from the database is
always timezone-aware, in order to prevent comparisons of timezone-naive
datetime with timezone-aware datetime, because objects in the rest of our
programs are always timezone-aware.
"""

cache_ok = True
impl = DateTime
_LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo

def process_bind_param(
self,
value: Optional[datetime],
dialect: Dialect,
) -> Optional[datetime]:
if not value:
return None
if value.tzinfo is None:
value = value.astimezone(self._LOCAL_TIMEZONE)
return value.astimezone(timezone.utc)

def process_result_value(
self,
value: Optional[Any],
dialect: Dialect,
) -> Optional[datetime]:
if not isinstance(value, datetime):
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)


class Base(DeclarativeBase):
# Enforce best practices for naming constraints
# https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate
Expand All @@ -44,9 +82,9 @@ class Project(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
description: Mapped[Optional[str]]
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)

traces: WriteOnlyMapped["Trace"] = relationship(
Expand All @@ -69,8 +107,8 @@ class Trace(Base):
project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id"))
session_id: Mapped[Optional[str]]
trace_id: Mapped[str]
start_time: Mapped[datetime] = mapped_column(DateTime(), index=True)
end_time: Mapped[datetime] = mapped_column(DateTime())
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)

project: Mapped["Project"] = relationship(
"Project",
Expand Down Expand Up @@ -98,8 +136,8 @@ class Span(Base):
parent_span_id: Mapped[Optional[str]] = mapped_column(index=True)
name: Mapped[str]
kind: Mapped[str]
start_time: Mapped[datetime] = mapped_column(DateTime())
end_time: Mapped[datetime] = mapped_column(DateTime())
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
attributes: Mapped[Dict[str, Any]]
events: Mapped[List[Dict[str, Any]]]
status: Mapped[str] = mapped_column(
Expand Down

0 comments on commit 9ce841e

Please sign in to comment.