Skip to content

Commit

Permalink
feat: OpenAI instrumentation to capture context attributes (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjcasti1 committed May 2, 2024
1 parent a7c2425 commit 8e0cab9
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from openinference.instrumentation import using_attributes
from openinference.instrumentation.openai import OpenAIInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
Expand All @@ -16,9 +17,23 @@

if __name__ == "__main__":
client = openai.OpenAI()
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Write a haiku."}],
max_tokens=20,
)
print(response.choices[0].message.content)
with using_attributes(
session_id="my-test-session",
user_id="my-test-user",
metadata={
"test-int": 1,
"test-str": "string",
"test-list": [1, 2, 3],
"test-dict": {
"key-1": "val-1",
"key-2": "val-2",
},
},
tags=["tag-1", "tag-2"],
):
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Write a haiku."}],
max_tokens=20,
)
print(response.choices[0].message.content)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

import openai
from openinference.instrumentation import using_attributes
from openinference.instrumentation.openai import OpenAIInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
Expand All @@ -18,10 +19,24 @@

async def chat_completions(**kwargs):
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(**kwargs)
async for chunk in response:
if content := chunk.choices[0].delta.content:
print(content, end="")
with using_attributes(
session_id="my-test-session",
user_id="my-test-user",
metadata={
"test-int": 1,
"test-str": "string",
"test-list": [1, 2, 3],
"test-dict": {
"key-1": "val-1",
"key-2": "val-2",
},
},
tags=["tag-1", "tag-2"],
):
response = await client.chat.completions.create(**kwargs)
async for chunk in response:
if content := chunk.choices[0].delta.content:
print(content, end="")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"opentelemetry-api",
"opentelemetry-instrumentation",
"opentelemetry-semantic-conventions",
"openinference-instrumentation>=0.1.2",
"openinference-semantic-conventions",
"typing-extensions",
"wrapt",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Tuple,
)

from openinference.instrumentation import get_attributes_from_context
from openinference.instrumentation.openai._request_attributes_extractor import (
_RequestAttributesExtractor,
)
Expand Down Expand Up @@ -60,6 +61,7 @@ def _start_as_current_span(
self,
span_name: str,
attributes: Iterable[Tuple[str, AttributeValue]],
context_attributes: Iterable[Tuple[str, AttributeValue]],
extra_attributes: Iterable[Tuple[str, AttributeValue]],
) -> Iterator[_WithSpan]:
# Because OTEL has a default limit of 128 attributes, we split our attributes into
Expand All @@ -76,7 +78,11 @@ def _start_as_current_span(
record_exception=False,
set_status_on_exception=False,
) as span:
yield _WithSpan(span=span, extra_attributes=dict(extra_attributes))
yield _WithSpan(
span=span,
context_attributes=dict(context_attributes),
extra_attributes=dict(extra_attributes),
)


_RequestParameters: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -265,6 +271,7 @@ def __call__(
cast_to=cast_to,
request_parameters=request_parameters,
),
context_attributes=get_attributes_from_context(),
extra_attributes=self._get_extra_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
Expand Down Expand Up @@ -318,6 +325,7 @@ async def __call__(
cast_to=cast_to,
request_parameters=request_parameters,
),
context_attributes=get_attributes_from_context(),
extra_attributes=self._get_extra_attributes_from_request(
cast_to=cast_to,
request_parameters=request_parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
class _WithSpan:
__slots__ = (
"_span",
"_context_attributes",
"_extra_attributes",
"_is_finished",
)

def __init__(
self,
span: trace_api.Span,
context_attributes: Attributes = None,
extra_attributes: Attributes = None,
) -> None:
self._span = span
self._context_attributes = context_attributes
self._extra_attributes = extra_attributes
try:
self._is_finished = not self._span.is_recording()
Expand Down Expand Up @@ -58,6 +61,7 @@ def finish_tracing(
return
for mapping in (
attributes,
self._context_attributes,
self._extra_attributes,
extra_attributes,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import random
from contextlib import suppress
from contextlib import ExitStack, suppress
from importlib import import_module
from importlib.metadata import version
from itertools import count
Expand All @@ -23,6 +23,7 @@

import pytest
from httpx import AsyncByteStream, Response
from openinference.instrumentation import using_attributes
from openinference.instrumentation.openai import OpenAIInstrumentor
from openinference.semconv.trace import (
EmbeddingAttributes,
Expand Down Expand Up @@ -63,6 +64,10 @@ def test_chat_completions(
completion_usage: Dict[str, Any],
model_name: str,
chat_completion_mock_stream: Tuple[List[bytes], List[Dict[str, Any]]],
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
) -> None:
input_messages: List[Dict[str, Any]] = get_messages()
output_messages: List[Dict[str, Any]] = (
Expand Down Expand Up @@ -100,7 +105,17 @@ def test_chat_completions(
else openai.OpenAI(api_key="sk-").chat.completions
)
create = completions.with_raw_response.create if is_raw else completions.create
with suppress(openai.BadRequestError):

with ExitStack() as stack:
stack.enter_context(suppress(openai.BadRequestError))
stack.enter_context(
using_attributes(
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
)
)
if is_async:

async def task() -> None:
Expand Down Expand Up @@ -181,6 +196,18 @@ async def task() -> None:
)
# We left out model_name from our mock stream.
assert attributes.pop(LLM_MODEL_NAME, None) == model_name
assert attributes.pop(SESSION_ID, None) == session_id
assert attributes.pop(USER_ID, None) == user_id
attr_tags = attributes.pop(TAG_TAGS, None)
assert attr_tags is not None
assert isinstance(attr_tags, tuple)
assert len(attr_tags) == len(tags)
assert list(attr_tags) == tags
attr_metadata = attributes.pop(METADATA, None)
assert attr_metadata is not None
assert isinstance(attr_metadata, str) # must be json string
metadata_dict = json.loads(attr_metadata)
assert metadata_dict == metadata
assert attributes == {} # test should account for all span attributes


Expand All @@ -198,6 +225,10 @@ def test_completions(
completion_usage: Dict[str, Any],
model_name: str,
completion_mock_stream: Tuple[List[bytes], List[str]],
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
) -> None:
prompt: List[str] = get_texts()
output_texts: List[str] = completion_mock_stream[1] if is_stream else get_texts()
Expand Down Expand Up @@ -233,7 +264,16 @@ def test_completions(
else openai.OpenAI(api_key="sk-").completions
)
create = completions.with_raw_response.create if is_raw else completions.create
with suppress(openai.BadRequestError):
with ExitStack() as stack:
stack.enter_context(suppress(openai.BadRequestError))
stack.enter_context(
using_attributes(
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
)
)
if is_async:

async def task() -> None:
Expand Down Expand Up @@ -286,6 +326,18 @@ async def task() -> None:
)
# We left out model_name from our mock stream.
assert attributes.pop(LLM_MODEL_NAME, None) == model_name
assert attributes.pop(SESSION_ID, None) == session_id
assert attributes.pop(USER_ID, None) == user_id
attr_tags = attributes.pop(TAG_TAGS, None)
assert attr_tags is not None
assert isinstance(attr_tags, tuple)
assert len(attr_tags) == len(tags)
assert list(attr_tags) == tags
attr_metadata = attributes.pop(METADATA, None)
assert attr_metadata is not None
assert isinstance(attr_metadata, str) # must be json string
metadata_dict = json.loads(attr_metadata)
assert metadata_dict == metadata
assert attributes == {} # test should account for all span attributes


Expand Down Expand Up @@ -388,6 +440,34 @@ async def task() -> None:
assert attributes == {} # test should account for all span attributes


@pytest.fixture()
def session_id() -> str:
return "my-test-session-id"


@pytest.fixture()
def user_id() -> str:
return "my-test-user-id"


@pytest.fixture()
def metadata() -> Dict[str, Any]:
return {
"test-int": 1,
"test-str": "string",
"test-list": [1, 2, 3],
"test-dict": {
"key-1": "val-1",
"key-2": "val-2",
},
}


@pytest.fixture()
def tags() -> List[str]:
return ["tag-1", "tag-2"]


@pytest.fixture(scope="module")
def in_memory_span_exporter() -> InMemorySpanExporter:
return InMemorySpanExporter()
Expand Down Expand Up @@ -671,3 +751,7 @@ def tool_call_function_arguments(prefix: str, i: int, j: int) -> str:
EMBEDDING_MODEL_NAME = SpanAttributes.EMBEDDING_MODEL_NAME
EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT
SESSION_ID = SpanAttributes.SESSION_ID
USER_ID = SpanAttributes.USER_ID
METADATA = SpanAttributes.METADATA
TAG_TAGS = SpanAttributes.TAG_TAGS

0 comments on commit 8e0cab9

Please sign in to comment.