Skip to content

Commit

Permalink
feat(persistence): get or delete projects using sql (#2839)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Apr 12, 2024
1 parent e4b667d commit 527b9a9
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 108 deletions.
1 change: 0 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ type Mutation {
"""
exportClusters(clusters: [ClusterInput!]!, fileName: String): ExportedFile!
deleteProject(id: GlobalID!): Query!
archiveProject(id: GlobalID!): Query!
}

"""A node in the graph with a globally unique ID"""
Expand Down
19 changes: 1 addition & 18 deletions src/phoenix/core/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from queue import SimpleQueue
from threading import RLock, Thread
from types import MethodType
from typing import DefaultDict, Iterator, Optional, Tuple, Union
from typing import DefaultDict, Optional, Tuple, Union

from typing_extensions import assert_never

Expand Down Expand Up @@ -38,23 +38,6 @@ def get_project(self, project_name: str) -> Optional["Project"]:
with self._lock:
return self._projects.get(project_name)

def get_projects(self) -> Iterator[Tuple[int, str, "Project"]]:
with self._lock:
for project_id, (project_name, project) in enumerate(self._projects.items()):
if project.is_archived:
continue
yield project_id, project_name, project

def archive_project(self, id: int) -> Optional["Project"]:
if id == 0:
raise ValueError("Cannot archive the default project")
with self._lock:
for project_id, _, project in self.get_projects():
if id == project_id:
project.archive()
return project
return None

def put(
self,
item: Union[Span, pb.Evaluation],
Expand Down
40 changes: 35 additions & 5 deletions src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def upgrade() -> None:
op.create_table(
"traces",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False),
sa.Column(
"project_rowid",
sa.Integer,
sa.ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
# TODO(mikeldking): might not be the right place for this
sa.Column("session_id", sa.String, nullable=True),
sa.Column("trace_id", sa.String, nullable=False, unique=True),
Expand All @@ -61,7 +67,13 @@ def upgrade() -> None:
op.create_table(
"spans",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False),
sa.Column(
"trace_rowid",
sa.Integer,
sa.ForeignKey("traces.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column("span_id", sa.String, nullable=False, unique=True),
sa.Column("parent_span_id", sa.String, nullable=True, index=True),
sa.Column("name", sa.String, nullable=False),
Expand Down Expand Up @@ -89,7 +101,13 @@ def upgrade() -> None:
op.create_table(
"span_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False),
sa.Column(
"span_rowid",
sa.Integer,
sa.ForeignKey("spans.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
sa.Column("score", sa.Float, nullable=True),
Expand Down Expand Up @@ -128,7 +146,13 @@ def upgrade() -> None:
op.create_table(
"trace_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False),
sa.Column(
"trace_rowid",
sa.Integer,
sa.ForeignKey("traces.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
sa.Column("score", sa.Float, nullable=True),
Expand Down Expand Up @@ -167,7 +191,13 @@ def upgrade() -> None:
op.create_table(
"document_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False),
sa.Column(
"span_rowid",
sa.Integer,
sa.ForeignKey("spans.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column("document_index", sa.Integer, nullable=False),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
Expand Down
30 changes: 24 additions & 6 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ class Project(Base):
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)

traces: WriteOnlyMapped["Trace"] = relationship(
traces: WriteOnlyMapped[List["Trace"]] = relationship(
"Trace",
back_populates="project",
cascade="all, delete-orphan",
passive_deletes=True,
uselist=True,
)
__table_args__ = (
UniqueConstraint(
Expand All @@ -110,7 +112,10 @@ class Project(Base):
class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id"))
project_rowid: Mapped[int] = mapped_column(
ForeignKey("projects.id", ondelete="CASCADE"),
index=True,
)
session_id: Mapped[Optional[str]]
trace_id: Mapped[str]
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
Expand All @@ -125,6 +130,7 @@ class Trace(Base):
"Span",
back_populates="trace",
cascade="all, delete-orphan",
uselist=True,
)
__table_args__ = (
UniqueConstraint(
Expand All @@ -138,7 +144,10 @@ class Trace(Base):
class Span(Base):
__tablename__ = "spans"
id: Mapped[int] = mapped_column(primary_key=True)
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id"))
trace_rowid: Mapped[int] = mapped_column(
ForeignKey("traces.id", ondelete="CASCADE"),
index=True,
)
span_id: Mapped[str]
parent_span_id: Mapped[Optional[str]] = mapped_column(index=True)
name: Mapped[str]
Expand Down Expand Up @@ -183,7 +192,10 @@ async def init_models(engine: AsyncEngine) -> None:
class SpanAnnotation(Base):
__tablename__ = "span_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id"))
span_rowid: Mapped[int] = mapped_column(
ForeignKey("spans.id", ondelete="CASCADE"),
index=True,
)
name: Mapped[str]
label: Mapped[Optional[str]]
score: Mapped[Optional[float]]
Expand All @@ -209,7 +221,10 @@ class SpanAnnotation(Base):
class TraceAnnotation(Base):
__tablename__ = "trace_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id"))
trace_rowid: Mapped[int] = mapped_column(
ForeignKey("traces.id", ondelete="CASCADE"),
index=True,
)
name: Mapped[str]
label: Mapped[Optional[str]]
score: Mapped[Optional[float]]
Expand All @@ -235,7 +250,10 @@ class TraceAnnotation(Base):
class DocumentAnnotation(Base):
__tablename__ = "document_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id"))
span_rowid: Mapped[int] = mapped_column(
ForeignKey("spans.id", ondelete="CASCADE"),
index=True,
)
document_index: Mapped[int]
name: Mapped[str]
label: Mapped[Optional[str]]
Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@dataclass
class DataLoaders:
latency_ms_quantile: DataLoader[Tuple[str, Optional[TimeRange], float], Optional[float]]
latency_ms_quantile: DataLoader[Tuple[int, Optional[TimeRange], float], Optional[float]]
span_evaluations: DataLoader[int, List[SpanEvaluation]]
document_evaluations: DataLoader[int, List[DocumentEvaluation]]

Expand Down
16 changes: 8 additions & 8 deletions src/phoenix/server/api/dataloaders/latency_ms_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from phoenix.db import models
from phoenix.server.api.input_types.TimeRange import TimeRange

ProjectName: TypeAlias = str
ProjectId: TypeAlias = int
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
Segment: TypeAlias = Tuple[ProjectName, TimeInterval]
Segment: TypeAlias = Tuple[ProjectId, TimeInterval]
Probability: TypeAlias = float
Key: TypeAlias = Tuple[ProjectName, Optional[TimeRange], Probability]
Key: TypeAlias = Tuple[ProjectId, Optional[TimeRange], Probability]
ResultPosition: TypeAlias = int
QuantileValue: TypeAlias = float
OrmExpression: TypeAlias = Any
Expand Down Expand Up @@ -69,21 +69,21 @@ async def _load_fn(self, keys: List[Key]) -> List[Optional[QuantileValue]]:


def _get_filter_condition(segment: Segment) -> OrmExpression:
name, (start_time, stop_time) = segment
id_, (start_time, stop_time) = segment
if start_time and stop_time:
return and_(
models.Project.name == name,
models.Project.id == id_,
start_time <= models.Trace.start_time,
models.Trace.start_time < stop_time,
)
if start_time:
return and_(
models.Project.name == name,
models.Project.id == id_,
start_time <= models.Trace.start_time,
)
if stop_time:
return and_(
models.Project.name == name,
models.Project.id == id_,
models.Trace.start_time < stop_time,
)
return models.Project.name == name
return models.Project.id == id_

0 comments on commit 527b9a9

Please sign in to comment.