Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 43 additions & 110 deletions src/langtrace_python_sdk/instrumentation/ollama/patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langtrace_python_sdk.constants.instrumentation.ollama import APIS
from langtrace_python_sdk.utils import set_span_attribute
from langtrace_python_sdk.utils.llm import (
StreamWrapper,
get_extra_attributes,
get_langtrace_attributes,
get_llm_request_attributes,
Expand All @@ -16,9 +17,10 @@
import json
from opentelemetry.trace.status import Status, StatusCode
from langtrace.trace_attributes import SpanAttributes
from opentelemetry.trace import Tracer


def generic_patch(operation_name, version, tracer):
def generic_patch(operation_name, version, tracer: Tracer):
def traced_method(wrapped, instance, args, kwargs):
api = APIS[operation_name]
service_provider = SERVICE_PROVIDERS["OLLAMA"]
Expand All @@ -35,36 +37,29 @@ def traced_method(wrapped, instance, args, kwargs):
}

attributes = LLMSpanAttributes(**span_attributes)
with tracer.start_as_current_span(
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
) as span:
_set_input_attributes(span, kwargs, attributes)

try:
result = wrapped(*args, **kwargs)
if result:
if span.is_recording():

if kwargs.get("stream"):
return _handle_streaming_response(
span, result, api["METHOD"]
)

_set_response_attributes(span, result)
span.set_status(Status(StatusCode.OK))
span = tracer.start_span(
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
)
_set_input_attributes(span, kwargs, attributes)

span.end()
return result
try:
result = wrapped(*args, **kwargs)
if kwargs.get("stream"):
return StreamWrapper(result, span)
else:
_set_response_attributes(span, result)
return result

except Exception as err:
# Record the exception in the span
span.record_exception(err)
except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))
# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise
# Reraise the exception to ensure it's not swallowed
raise

return traced_method

Expand All @@ -82,30 +77,28 @@ async def traced_method(wrapped, instance, args, kwargs):
**get_extra_attributes(),
}
attributes = LLMSpanAttributes(**span_attributes)
with tracer.start_as_current_span(api["METHOD"], kind=SpanKind.CLIENT) as span:
_set_input_attributes(span, kwargs, attributes)
try:
result = await wrapped(*args, **kwargs)
if result:
if span.is_recording():
if kwargs.get("stream"):
return _ahandle_streaming_response(
span, result, api["METHOD"]
)

_set_response_attributes(span, result)
span.set_status(Status(StatusCode.OK))
span.end()
return result
except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise
span = tracer.start_span(
name=get_span_name(f'ollama.{api["METHOD"]}'), kind=SpanKind.CLIENT
)

_set_input_attributes(span, kwargs, attributes)
try:
result = await wrapped(*args, **kwargs)
if kwargs.get("stream"):
return StreamWrapper(span, result)
else:
_set_response_attributes(span, result)
span.end()
return result
except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise

return traced_method

Expand Down Expand Up @@ -162,63 +155,3 @@ def _set_input_attributes(span, kwargs, attributes):
SpanAttributes.LLM_PRESENCE_PENALTY,
options.get("presence_penalty"),
)


def _handle_streaming_response(span, response, api):
accumulated_tokens = None
if api == "chat":
accumulated_tokens = {"message": {"content": "", "role": ""}}
if api == "completion" or api == "generate":
accumulated_tokens = {"response": ""}
span.add_event(Event.STREAM_START.value)
try:
for chunk in response:
content = None
if api == "chat":
content = chunk["message"]["content"]
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
if api == "generate":
content = chunk["response"]
accumulated_tokens["response"] += chunk["response"]

set_event_completion_chunk(span, content)

_set_response_attributes(span, chunk | accumulated_tokens)
finally:
# Finalize span after processing all chunks
span.add_event(Event.STREAM_END.value)
span.set_status(StatusCode.OK)
span.end()

return response


async def _ahandle_streaming_response(span, response, api):
accumulated_tokens = None
if api == "chat":
accumulated_tokens = {"message": {"content": "", "role": ""}}
if api == "completion" or api == "generate":
accumulated_tokens = {"response": ""}

span.add_event(Event.STREAM_START.value)
try:
async for chunk in response:
content = None
if api == "chat":
content = chunk["message"]["content"]
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
if api == "generate":
content = chunk["response"]
accumulated_tokens["response"] += chunk["response"]

set_event_completion_chunk(span, content)
_set_response_attributes(span, chunk | accumulated_tokens)
finally:
# Finalize span after processing all chunks
span.add_event(Event.STREAM_END.value)
span.set_status(StatusCode.OK)
span.end()

return response
13 changes: 11 additions & 2 deletions src/langtrace_python_sdk/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ class StreamWrapper:
def __init__(
self, stream, span, prompt_tokens=0, function_call=False, tool_calls=False
):

self.stream = stream
self.span = span
self.prompt_tokens = prompt_tokens
Expand Down Expand Up @@ -284,7 +283,6 @@ def cleanup(self):
}
],
)

self.span.set_status(StatusCode.OK)
self.span.end()
self._span_started = False
Expand Down Expand Up @@ -377,6 +375,10 @@ def build_streaming_response(self, chunk):
if hasattr(chunk, "delta") and chunk.delta is not None:
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []

if isinstance(chunk, dict):
if "message" in chunk:
if "content" in chunk["message"]:
content = [chunk["message"]["content"]]
if content:
self.result_content.append(content[0])

Expand All @@ -401,6 +403,13 @@ def set_usage_attributes(self, chunk):
self.completion_tokens = chunk.usage_metadata.candidates_token_count
self.prompt_tokens = chunk.usage_metadata.prompt_token_count

# Ollama
if isinstance(chunk, dict):
if "prompt_eval_count" in chunk:
self.prompt_tokens = chunk["prompt_eval_count"]
if "eval_count" in chunk:
self.completion_tokens = chunk["eval_count"]

def process_chunk(self, chunk):
self.set_response_model(chunk=chunk)
self.build_streaming_response(chunk=chunk)
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.21"
__version__ = "2.2.22"