-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python…
…-sdk into ali/sampling
- Loading branch information
Showing
14 changed files
with
364 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,7 @@ dev = [ | |
"cohere", | ||
"qdrant_client", | ||
"weaviate-client", | ||
"ollama" | ||
] | ||
|
||
test = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .basic import chat, async_chat, async_generate, generate, embed, async_embed | ||
from langtrace_python_sdk import with_langtrace_root_span | ||
import asyncio | ||
|
||
|
||
class OllamaRunner: | ||
@with_langtrace_root_span("OllamaRunner") | ||
def run(self): | ||
chat() | ||
generate() | ||
embed() | ||
asyncio.run(async_chat()) | ||
asyncio.run(async_generate()) | ||
asyncio.run(async_embed()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from langtrace_python_sdk import langtrace, with_langtrace_root_span | ||
import ollama | ||
from ollama import AsyncClient | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
langtrace.init(write_spans_to_console=False) | ||
|
||
|
||
def chat(): | ||
response = ollama.chat( | ||
model="llama3", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "hi", | ||
}, | ||
], | ||
stream=True, | ||
) | ||
|
||
return response | ||
|
||
|
||
async def async_chat(): | ||
message = {"role": "user", "content": "Why is the sky blue?"} | ||
return await AsyncClient().chat(model="llama3", messages=[message]) | ||
|
||
|
||
def generate(): | ||
return ollama.generate(model="llama3", prompt="Why is the sky blue?") | ||
|
||
|
||
def async_generate(): | ||
return AsyncClient().generate(model="llama3", prompt="Why is the sky blue?") | ||
|
||
|
||
def embed(): | ||
return ollama.embeddings( | ||
model="llama3", | ||
prompt="cat", | ||
) | ||
|
||
|
||
async def async_embed(): | ||
return await AsyncClient().embeddings( | ||
model="llama3", | ||
prompt="cat", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
APIS = { | ||
"GENERATE": { | ||
"METHOD": "generate", | ||
|
||
}, | ||
"CHAT": {"METHOD": "chat"}, | ||
"EMBEDDINGS": {"METHOD": "embeddings"}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .instrumentation import OllamaInstrumentor | ||
|
||
__all__ = ["OllamaInstrumentor"] |
58 changes: 58 additions & 0 deletions
58
src/langtrace_python_sdk/instrumentation/ollama/instrumentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
Copyright (c) 2024 Scale3 Labs | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor | ||
from opentelemetry.trace import get_tracer | ||
from wrapt import wrap_function_wrapper as _W | ||
from typing import Collection | ||
from importlib_metadata import version as v | ||
from langtrace_python_sdk.constants.instrumentation.ollama import APIS | ||
from .patch import generic_patch, ageneric_patch | ||
|
||
|
||
class OllamaInstrumentor(BaseInstrumentor): | ||
""" | ||
The OllamaInstrumentor class represents the Ollama instrumentation""" | ||
|
||
def instrumentation_dependencies(self) -> Collection[str]: | ||
return ["ollama >= 0.2.0, < 1"] | ||
|
||
def _instrument(self, **kwargs): | ||
tracer_provider = kwargs.get("tracer_provider") | ||
tracer = get_tracer(__name__, "", tracer_provider) | ||
version = v("ollama") | ||
for operation_name, details in APIS.items(): | ||
operation = details["METHOD"] | ||
# Dynamically creating the patching call | ||
_W( | ||
"ollama._client", | ||
f"Client.{operation}", | ||
generic_patch(operation_name, version, tracer), | ||
) | ||
|
||
_W( | ||
"ollama._client", | ||
f"AsyncClient.{operation}", | ||
ageneric_patch(operation_name, version, tracer), | ||
) | ||
_W( | ||
"ollama", | ||
f"{operation}", | ||
generic_patch(operation_name, version, tracer), | ||
) | ||
|
||
def _uninstrument(self, **kwargs): | ||
pass |
215 changes: 215 additions & 0 deletions
215
src/langtrace_python_sdk/instrumentation/ollama/patch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
from langtrace_python_sdk.constants.instrumentation.ollama import APIS | ||
from importlib_metadata import version as v | ||
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME | ||
from langtrace_python_sdk.utils import set_span_attribute | ||
from langtrace_python_sdk.utils.silently_fail import silently_fail | ||
from langtrace_python_sdk.constants.instrumentation.common import ( | ||
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY, | ||
SERVICE_PROVIDERS, | ||
) | ||
from opentelemetry import baggage | ||
from langtrace.trace_attributes import LLMSpanAttributes, Event | ||
from opentelemetry.trace import SpanKind | ||
import json | ||
from opentelemetry.trace.status import Status, StatusCode | ||
|
||
|
||
def generic_patch(operation_name, version, tracer): | ||
def traced_method(wrapped, instance, args, kwargs): | ||
base_url = ( | ||
str(instance._client._base_url) | ||
if hasattr(instance, "_client") and hasattr(instance._client, "_base_url") | ||
else "" | ||
) | ||
api = APIS[operation_name] | ||
service_provider = SERVICE_PROVIDERS["OLLAMA"] | ||
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY) | ||
span_attributes = { | ||
"langtrace.sdk.name": "langtrace-python-sdk", | ||
"langtrace.service.name": service_provider, | ||
"langtrace.service.type": "llm", | ||
"langtrace.service.version": version, | ||
"langtrace.version": v(LANGTRACE_SDK_NAME), | ||
"llm.model": kwargs.get("model"), | ||
"llm.stream": kwargs.get("stream"), | ||
"url.full": base_url, | ||
"llm.api": api["METHOD"], | ||
**(extra_attributes if extra_attributes is not None else {}), | ||
} | ||
|
||
attributes = LLMSpanAttributes(**span_attributes) | ||
with tracer.start_as_current_span( | ||
f'ollama.{api["METHOD"]}', kind=SpanKind.CLIENT | ||
) as span: | ||
_set_input_attributes(span, kwargs, attributes) | ||
|
||
try: | ||
result = wrapped(*args, **kwargs) | ||
if result: | ||
if span.is_recording(): | ||
|
||
if kwargs.get("stream"): | ||
return _handle_streaming_response( | ||
span, result, api["METHOD"] | ||
) | ||
|
||
_set_response_attributes(span, result) | ||
span.set_status(Status(StatusCode.OK)) | ||
|
||
span.end() | ||
return result | ||
|
||
except Exception as err: | ||
# Record the exception in the span | ||
span.record_exception(err) | ||
|
||
# Set the span status to indicate an error | ||
span.set_status(Status(StatusCode.ERROR, str(err))) | ||
|
||
# Reraise the exception to ensure it's not swallowed | ||
raise | ||
|
||
return traced_method | ||
|
||
|
||
def ageneric_patch(operation_name, version, tracer): | ||
async def traced_method(wrapped, instance, args, kwargs): | ||
api = APIS[operation_name] | ||
service_provider = SERVICE_PROVIDERS["OLLAMA"] | ||
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY) | ||
span_attributes = { | ||
"langtrace.sdk.name": "langtrace-python-sdk", | ||
"langtrace.service.name": service_provider, | ||
"url.full": "", | ||
"llm.api": "", | ||
"langtrace.service.type": "llm", | ||
"langtrace.service.version": version, | ||
"langtrace.version": v(LANGTRACE_SDK_NAME), | ||
"llm.model": kwargs.get("model"), | ||
"llm.stream": kwargs.get("stream"), | ||
**(extra_attributes if extra_attributes is not None else {}), | ||
} | ||
|
||
attributes = LLMSpanAttributes(**span_attributes) | ||
with tracer.start_as_current_span(api["METHOD"], kind=SpanKind.CLIENT) as span: | ||
_set_input_attributes(span, kwargs, attributes) | ||
try: | ||
result = await wrapped(*args, **kwargs) | ||
if result: | ||
if span.is_recording(): | ||
if kwargs.get("stream"): | ||
return _ahandle_streaming_response( | ||
span, result, api["METHOD"] | ||
) | ||
|
||
_set_response_attributes(span, result) | ||
span.set_status(Status(StatusCode.OK)) | ||
span.end() | ||
return result | ||
except Exception as err: | ||
# Record the exception in the span | ||
span.record_exception(err) | ||
|
||
# Set the span status to indicate an error | ||
span.set_status(Status(StatusCode.ERROR, str(err))) | ||
|
||
# Reraise the exception to ensure it's not swallowed | ||
raise | ||
|
||
return traced_method | ||
|
||
|
||
@silently_fail | ||
def _set_response_attributes(span, response): | ||
|
||
input_tokens = response.get("prompt_eval_count") or 0 | ||
output_tokens = response.get("eval_count") or 0 | ||
total_tokens = input_tokens + output_tokens | ||
usage_dict = { | ||
"input_tokens": input_tokens, | ||
"output_tokens": output_tokens, | ||
"total_tokens": total_tokens, | ||
} | ||
|
||
if total_tokens > 0: | ||
set_span_attribute(span, "llm.token.counts", json.dumps(usage_dict)) | ||
set_span_attribute(span, "llm.finish_reason", response.get("done_reason")) | ||
|
||
if "message" in response: | ||
set_span_attribute(span, "llm.responses", json.dumps([response.get("message")])) | ||
|
||
if "response" in response: | ||
set_span_attribute( | ||
span, "llm.responses", json.dumps([response.get("response")]) | ||
) | ||
|
||
|
||
@silently_fail | ||
def _set_input_attributes(span, kwargs, attributes): | ||
for field, value in attributes.model_dump(by_alias=True).items(): | ||
set_span_attribute(span, field, value) | ||
|
||
if "messages" in kwargs: | ||
set_span_attribute( | ||
span, | ||
"llm.prompts", | ||
json.dumps([kwargs.get("messages", [])]), | ||
) | ||
if "prompt" in kwargs: | ||
set_span_attribute( | ||
span, | ||
"llm.prompts", | ||
json.dumps([{"role": "user", "content": kwargs.get("prompt", [])}]), | ||
) | ||
|
||
|
||
def _handle_streaming_response(span, response, api): | ||
accumulated_tokens = None | ||
if api == "chat": | ||
accumulated_tokens = {"message": {"content": "", "role": ""}} | ||
if api == "completion": | ||
accumulated_tokens = {"response": ""} | ||
|
||
span.add_event(Event.STREAM_START.value) | ||
try: | ||
for chunk in response: | ||
if api == "chat": | ||
accumulated_tokens["message"]["content"] += chunk["message"]["content"] | ||
accumulated_tokens["message"]["role"] = chunk["message"]["role"] | ||
if api == "generate": | ||
accumulated_tokens["response"] += chunk["response"] | ||
|
||
_set_response_attributes(span, chunk | accumulated_tokens) | ||
finally: | ||
# Finalize span after processing all chunks | ||
span.add_event(Event.STREAM_END.value) | ||
span.set_status(StatusCode.OK) | ||
span.end() | ||
|
||
return response | ||
|
||
|
||
async def _ahandle_streaming_response(span, response, api): | ||
accumulated_tokens = None | ||
if api == "chat": | ||
accumulated_tokens = {"message": {"content": "", "role": ""}} | ||
if api == "completion": | ||
accumulated_tokens = {"response": ""} | ||
|
||
span.add_event(Event.STREAM_START.value) | ||
try: | ||
async for chunk in response: | ||
if api == "chat": | ||
accumulated_tokens["message"]["content"] += chunk["message"]["content"] | ||
accumulated_tokens["message"]["role"] = chunk["message"]["role"] | ||
if api == "generate": | ||
accumulated_tokens["response"] += chunk["response"] | ||
|
||
_set_response_attributes(span, chunk | accumulated_tokens) | ||
finally: | ||
# Finalize span after processing all chunks | ||
span.add_event(Event.STREAM_END.value) | ||
span.set_status(StatusCode.OK) | ||
span.end() | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.