Skip to content

Commit

Permalink
fix: join on trace_id in get_qa_with_reference (#3248)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed May 20, 2024
1 parent b8f83af commit a88d4ff
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions src/phoenix/trace/dsl/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from datetime import datetime
from typing import List, Optional, Protocol, Union, cast
from typing import List, Optional, Protocol, Tuple, Union, cast

import pandas as pd
from openinference.semconv.trace import DocumentAttributes, SpanAttributes
Expand Down Expand Up @@ -75,7 +75,7 @@ def get_qa_with_reference(
project_name: Optional[str] = None,
# Deprecated
stop_time: Optional[datetime] = None,
) -> pd.DataFrame:
) -> Optional[pd.DataFrame]:
project_name = project_name or get_env_project_name()
if stop_time:
# Deprecated. Raise a warning
Expand All @@ -84,23 +84,35 @@ def get_qa_with_reference(
DeprecationWarning,
)
end_time = end_time or stop_time
return pd.concat(
cast(
List[pd.DataFrame],
obj.query_spans(
SpanQuery().select(**IO).where(IS_ROOT),
SpanQuery()
.where(IS_RETRIEVER)
.select(span_id="parent_id")
.concat(
RETRIEVAL_DOCUMENTS,
reference=DOCUMENT_CONTENT,
),
start_time=start_time,
end_time=end_time,
project_name=project_name,
),
separator = "\n\n"
qa_query = SpanQuery().select("span_id", **IO).where(IS_ROOT).with_index("trace_id")
docs_query = (
SpanQuery()
.where(IS_RETRIEVER)
.concat(RETRIEVAL_DOCUMENTS, reference=DOCUMENT_CONTENT)
.with_concat_separator(separator=separator)
.with_index("trace_id")
)
df_qa, df_docs = cast(
Tuple[pd.DataFrame, pd.DataFrame],
obj.query_spans(
qa_query,
docs_query,
start_time=start_time,
end_time=end_time,
project_name=project_name,
),
axis=1,
join="inner",
)
if df_qa is None or df_qa.empty:
print("No spans found.")
return None
if df_docs is None or df_docs.empty:
print("No retrieval documents found.")
return None
# Consolidate duplicate rows via concatenation. This can happen if there are multiple
# retriever spans in the same trace. We simply concatenate all of them (in no particular
# order) into a single row.
ref = df_docs.groupby("context.trace_id")["reference"].apply(lambda x: separator.join(x))
df_ref = pd.DataFrame({"reference": ref})
df_qa_ref = pd.concat([df_qa, df_ref], axis=1, join="inner").set_index("context.span_id")
return df_qa_ref

0 comments on commit a88d4ff

Please sign in to comment.