Skip to content

Commit

Permalink
fix: eliminate interference on global tracer provider (#2998)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Apr 26, 2024
1 parent 310c9b0 commit 5d7b843
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 33 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"starlette",
"uvicorn",
"psutil",
"strawberry-graphql==0.208.2",
"strawberry-graphql==0.227.2", # need to pin version because we're monkey-patching
"pyarrow",
"typing-extensions>=4.5; python_version<'3.12'",
# A minimum version of typing-extensions==4.6.0 is needed to avoid this issue on Python 3.12: https://github.com/Azure/azure-sdk-for-python/issues/33442#issuecomment-1847886784
Expand Down Expand Up @@ -73,7 +73,7 @@ dev = [
"pytest-postgresql",
"asyncpg",
"psycopg[binary]",
"strawberry-graphql[debug-server]==0.208.2",
"strawberry-graphql[debug-server,opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
"pre-commit",
"arize[AutoEmbeddings, LLM_Evaluation]",
"llama-index>=0.10.3",
Expand Down Expand Up @@ -104,7 +104,7 @@ container = [
"opentelemetry-exporter-otlp",
"opentelemetry-instrumentation-starlette",
"opentelemetry-instrumentation-sqlalchemy",
"strawberry-graphql[opentelemetry]",
"strawberry-graphql[opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
]

[project.urls]
Expand Down Expand Up @@ -175,7 +175,7 @@ dependencies = [
"opentelemetry-exporter-otlp",
"opentelemetry-instrumentation-starlette",
"opentelemetry-instrumentation-sqlalchemy",
"strawberry-graphql[opentelemetry]",
"strawberry-graphql[opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
]

[tool.hatch.envs.style]
Expand Down Expand Up @@ -274,7 +274,7 @@ check = [

[tool.hatch.envs.gql]
dependencies = [
"strawberry-graphql[cli]==0.208.2",
"strawberry-graphql[cli]==0.227.2", # need to pin version because we're monkey-patching
"requests",
]

Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)


def is_server_instrumentation_enabled() -> bool:
def server_instrumentation_is_enabled() -> bool:
return bool(
os.getenv(ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT)
) or bool(os.getenv(ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_GRPC_ENDPOINT))
Expand Down
13 changes: 5 additions & 8 deletions src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from strawberry.types import Info
from typing_extensions import Annotated

from phoenix.config import DEFAULT_PROJECT_NAME, is_server_instrumentation_enabled
from phoenix.config import DEFAULT_PROJECT_NAME
from phoenix.db import models
from phoenix.pointcloud.clustering import Hdbscan
from phoenix.server.api.context import Context
Expand Down Expand Up @@ -272,14 +272,11 @@ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
return Query()


_extensions = []
if is_server_instrumentation_enabled():
from strawberry.extensions.tracing import OpenTelemetryExtension

_extensions.append(OpenTelemetryExtension)

# This is the schema for generating `schema.graphql`.
# See https://strawberry.rocks/docs/guides/schema-export
# It should be kept in sync with the server's runtime-initialized
# instance. To do so, search for the usage of `strawberry.Schema(...)`.
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
extensions=_extensions,
)
49 changes: 40 additions & 9 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
Expand All @@ -13,8 +14,10 @@
Optional,
Tuple,
Union,
cast,
)

import strawberry
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
Expand Down Expand Up @@ -42,7 +45,7 @@
from phoenix.config import (
DEFAULT_PROJECT_NAME,
SERVER_DIR,
is_server_instrumentation_enabled,
server_instrumentation_is_enabled,
)
from phoenix.core.model_schema import Model
from phoenix.db.bulk_inserter import BulkInserter
Expand All @@ -60,6 +63,7 @@
from phoenix.server.api.dataloaders.span_descendants import SpanDescendantsDataLoader
from phoenix.server.api.routers.v1 import V1_ROUTES
from phoenix.server.api.schema import schema
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
from phoenix.trace.schemas import Span

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -259,21 +263,45 @@ def create_app(
""
)
raise PhoenixMigrationError(msg) from e

if is_server_instrumentation_enabled():
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor

SQLAlchemyInstrumentor().instrument(engine=engine.sync_engine)

db = _db(engine)
bulk_inserter = BulkInserter(
db,
initial_batch_of_spans=initial_batch_of_spans,
initial_batch_of_evaluations=initial_batch_of_evaluations,
)
tracer_provider = None
strawberry_extensions = schema.get_extensions()
if server_instrumentation_is_enabled():
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.trace import TracerProvider
from strawberry.extensions.tracing import OpenTelemetryExtension

tracer_provider = initialize_opentelemetry_tracer_provider()
SQLAlchemyInstrumentor().instrument(
engine=engine.sync_engine,
tracer_provider=tracer_provider,
)
if TYPE_CHECKING:
# Type-check the class before monkey-patching its private attribute.
assert OpenTelemetryExtension._tracer

class _OpenTelemetryExtension(OpenTelemetryExtension):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Monkey-patch its private tracer to eliminate usage of the global
# TracerProvider, which in a notebook setting could be the one
# used by OpenInference.
self._tracer = cast(TracerProvider, tracer_provider).get_tracer("strawberry")

strawberry_extensions.append(_OpenTelemetryExtension)
graphql = GraphQLWithContext(
db=db,
schema=schema,
schema=strawberry.Schema(
query=schema.query,
mutation=schema.mutation,
subscription=schema.subscription,
extensions=strawberry_extensions,
),
model=model,
corpus=corpus,
export_path=export_path,
Expand All @@ -286,7 +314,6 @@ def create_app(
prometheus_middlewares = [Middleware(PrometheusMiddleware)]
else:
prometheus_middlewares = []

app = Starlette(
lifespan=_lifespan(bulk_inserter),
middleware=[
Expand Down Expand Up @@ -329,4 +356,8 @@ def create_app(
)
app.state.read_only = read_only
app.state.db = db
if tracer_provider:
from opentelemetry.instrumentation.starlette import StarletteInstrumentor

StarletteInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
return app
7 changes: 0 additions & 7 deletions src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
get_env_port,
get_pids_path,
get_working_dir,
is_server_instrumentation_enabled,
)
from phoenix.core.model_schema_adapter import create_model_from_datasets
from phoenix.db import get_printable_db_url
Expand All @@ -31,7 +30,6 @@
UMAPParameters,
)
from phoenix.server.app import create_app
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
from phoenix.settings import Settings
from phoenix.trace.fixtures import (
TRACES_FIXTURES,
Expand Down Expand Up @@ -230,11 +228,6 @@ def _get_pid_file() -> Path:
initial_spans=fixture_spans,
initial_evaluations=fixture_evals,
)
if is_server_instrumentation_enabled():
from opentelemetry.instrumentation.starlette import StarletteInstrumentor

initialize_opentelemetry_tracer_provider()
StarletteInstrumentor.instrument_app(app)
server = Server(config=Config(app, host=host, port=port))
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()

Expand Down
9 changes: 6 additions & 3 deletions src/phoenix/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from opentelemetry.trace import TracerProvider

from phoenix.config import (
ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_GRPC_ENDPOINT,
ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT,
)


def initialize_opentelemetry_tracer_provider() -> None:
from opentelemetry import trace as trace_api
def initialize_opentelemetry_tracer_provider() -> "TracerProvider":
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import BatchSpanProcessor

Expand All @@ -28,4 +31,4 @@ def initialize_opentelemetry_tracer_provider() -> None:
)

tracer_provider.add_span_processor(BatchSpanProcessor(GrpcExporter(grpc_endpoint)))
trace_api.set_tracer_provider(tracer_provider)
return tracer_provider

0 comments on commit 5d7b843

Please sign in to comment.