Skip to content

Commit

Permalink
Merge pull request #204 from Scale3-Labs/ali/ollama
Browse files Browse the repository at this point in the history
Ollama Instrumentation
  • Loading branch information
alizenhom committed Jun 13, 2024
2 parents 046087f + 9e02339 commit cbde43b
Show file tree
Hide file tree
Showing 14 changed files with 365 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dev = [
"cohere",
"qdrant_client",
"weaviate-client",
"ollama"
]

test = [
Expand Down
14 changes: 14 additions & 0 deletions src/examples/ollama_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .basic import chat, async_chat, async_generate, generate, embed, async_embed
from langtrace_python_sdk import with_langtrace_root_span
import asyncio


class OllamaRunner:
@with_langtrace_root_span("OllamaRunner")
def run(self):
chat()
generate()
embed()
asyncio.run(async_chat())
asyncio.run(async_generate())
asyncio.run(async_embed())
50 changes: 50 additions & 0 deletions src/examples/ollama_example/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from langtrace_python_sdk import langtrace, with_langtrace_root_span
import ollama
from ollama import AsyncClient
from dotenv import load_dotenv

load_dotenv()

langtrace.init(write_spans_to_console=False)


def chat():
response = ollama.chat(
model="llama3",
messages=[
{
"role": "user",
"content": "hi",
},
],
stream=True,
)

return response


async def async_chat():
message = {"role": "user", "content": "Why is the sky blue?"}
return await AsyncClient().chat(model="llama3", messages=[message])


def generate():
return ollama.generate(model="llama3", prompt="Why is the sky blue?")


def async_generate():
return AsyncClient().generate(model="llama3", prompt="Why is the sky blue?")


def embed():
return ollama.embeddings(
model="llama3",
prompt="cat",
)


async def async_embed():
return await AsyncClient().embeddings(
model="llama3",
prompt="cat",
)
1 change: 1 addition & 0 deletions src/examples/pinecone_example/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
with_additional_attributes,
)
from langtrace_python_sdk.utils.with_root_span import SendUserFeedback
from opentelemetry.sdk.trace.export import ConsoleSpanExporter

_ = load_dotenv(find_dotenv())
langtrace.init()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"PPLX": "Perplexity",
"QDRANT": "Qdrant",
"WEAVIATE": "Weaviate",
"OLLAMA": "Ollama",
}

LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
8 changes: 8 additions & 0 deletions src/langtrace_python_sdk/constants/instrumentation/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
APIS = {
"GENERATE": {
"METHOD": "generate",

},
"CHAT": {"METHOD": "chat"},
"EMBEDDINGS": {"METHOD": "embeddings"},
}
2 changes: 2 additions & 0 deletions src/langtrace_python_sdk/instrumentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .pinecone import PineconeInstrumentation
from .qdrant import QdrantInstrumentation
from .weaviate import WeaviateInstrumentation
from .ollama import OllamaInstrumentor

__all__ = [
"AnthropicInstrumentation",
Expand All @@ -26,4 +27,5 @@
"PineconeInstrumentation",
"QdrantInstrumentation",
"WeaviateInstrumentation",
"OllamaInstrumentor",
]
3 changes: 3 additions & 0 deletions src/langtrace_python_sdk/instrumentation/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .instrumentation import OllamaInstrumentor

__all__ = ["OllamaInstrumentor"]
58 changes: 58 additions & 0 deletions src/langtrace_python_sdk/instrumentation/ollama/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright (c) 2024 Scale3 Labs
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer
from wrapt import wrap_function_wrapper as _W
from typing import Collection
from importlib_metadata import version as v
from langtrace_python_sdk.constants.instrumentation.ollama import APIS
from .patch import generic_patch, ageneric_patch


class OllamaInstrumentor(BaseInstrumentor):
"""
The OllamaInstrumentor class represents the Ollama instrumentation"""

def instrumentation_dependencies(self) -> Collection[str]:
return ["ollama >= 0.2.0, < 1"]

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, "", tracer_provider)
version = v("ollama")
for operation_name, details in APIS.items():
operation = details["METHOD"]
# Dynamically creating the patching call
_W(
"ollama._client",
f"Client.{operation}",
generic_patch(operation_name, version, tracer),
)

_W(
"ollama._client",
f"AsyncClient.{operation}",
ageneric_patch(operation_name, version, tracer),
)
_W(
"ollama",
f"{operation}",
generic_patch(operation_name, version, tracer),
)

def _uninstrument(self, **kwargs):
pass
215 changes: 215 additions & 0 deletions src/langtrace_python_sdk/instrumentation/ollama/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from langtrace_python_sdk.constants.instrumentation.ollama import APIS
from importlib_metadata import version as v
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
from langtrace_python_sdk.utils import set_span_attribute
from langtrace_python_sdk.utils.silently_fail import silently_fail
from langtrace_python_sdk.constants.instrumentation.common import (
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
SERVICE_PROVIDERS,
)
from opentelemetry import baggage
from langtrace.trace_attributes import LLMSpanAttributes, Event
from opentelemetry.trace import SpanKind
import json
from opentelemetry.trace.status import Status, StatusCode


def generic_patch(operation_name, version, tracer):
def traced_method(wrapped, instance, args, kwargs):
base_url = (
str(instance._client._base_url)
if hasattr(instance, "_client") and hasattr(instance._client, "_base_url")
else ""
)
api = APIS[operation_name]
service_provider = SERVICE_PROVIDERS["OLLAMA"]
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
span_attributes = {
"langtrace.sdk.name": "langtrace-python-sdk",
"langtrace.service.name": service_provider,
"langtrace.service.type": "llm",
"langtrace.service.version": version,
"langtrace.version": v(LANGTRACE_SDK_NAME),
"llm.model": kwargs.get("model"),
"llm.stream": kwargs.get("stream"),
"url.full": base_url,
"llm.api": api["METHOD"],
**(extra_attributes if extra_attributes is not None else {}),
}

attributes = LLMSpanAttributes(**span_attributes)
with tracer.start_as_current_span(
f'ollama.{api["METHOD"]}', kind=SpanKind.CLIENT
) as span:
_set_input_attributes(span, kwargs, attributes)

try:
result = wrapped(*args, **kwargs)
if result:
if span.is_recording():

if kwargs.get("stream"):
return _handle_streaming_response(
span, result, api["METHOD"]
)

_set_response_attributes(span, result)
span.set_status(Status(StatusCode.OK))

span.end()
return result

except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise

return traced_method


def ageneric_patch(operation_name, version, tracer):
async def traced_method(wrapped, instance, args, kwargs):
api = APIS[operation_name]
service_provider = SERVICE_PROVIDERS["OLLAMA"]
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
span_attributes = {
"langtrace.sdk.name": "langtrace-python-sdk",
"langtrace.service.name": service_provider,
"url.full": "",
"llm.api": "",
"langtrace.service.type": "llm",
"langtrace.service.version": version,
"langtrace.version": v(LANGTRACE_SDK_NAME),
"llm.model": kwargs.get("model"),
"llm.stream": kwargs.get("stream"),
**(extra_attributes if extra_attributes is not None else {}),
}

attributes = LLMSpanAttributes(**span_attributes)
with tracer.start_as_current_span(api["METHOD"], kind=SpanKind.CLIENT) as span:
_set_input_attributes(span, kwargs, attributes)
try:
result = await wrapped(*args, **kwargs)
if result:
if span.is_recording():
if kwargs.get("stream"):
return _ahandle_streaming_response(
span, result, api["METHOD"]
)

_set_response_attributes(span, result)
span.set_status(Status(StatusCode.OK))
span.end()
return result
except Exception as err:
# Record the exception in the span
span.record_exception(err)

# Set the span status to indicate an error
span.set_status(Status(StatusCode.ERROR, str(err)))

# Reraise the exception to ensure it's not swallowed
raise

return traced_method


@silently_fail
def _set_response_attributes(span, response):

input_tokens = response.get("prompt_eval_count") or 0
output_tokens = response.get("eval_count") or 0
total_tokens = input_tokens + output_tokens
usage_dict = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}

if total_tokens > 0:
set_span_attribute(span, "llm.token.counts", json.dumps(usage_dict))
set_span_attribute(span, "llm.finish_reason", response.get("done_reason"))

if "message" in response:
set_span_attribute(span, "llm.responses", json.dumps([response.get("message")]))

if "response" in response:
set_span_attribute(
span, "llm.responses", json.dumps([response.get("response")])
)


@silently_fail
def _set_input_attributes(span, kwargs, attributes):
for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)

if "messages" in kwargs:
set_span_attribute(
span,
"llm.prompts",
json.dumps([kwargs.get("messages", [])]),
)
if "prompt" in kwargs:
set_span_attribute(
span,
"llm.prompts",
json.dumps([{"role": "user", "content": kwargs.get("prompt", [])}]),
)


def _handle_streaming_response(span, response, api):
accumulated_tokens = None
if api == "chat":
accumulated_tokens = {"message": {"content": "", "role": ""}}
if api == "completion":
accumulated_tokens = {"response": ""}

span.add_event(Event.STREAM_START.value)
try:
for chunk in response:
if api == "chat":
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
if api == "generate":
accumulated_tokens["response"] += chunk["response"]

_set_response_attributes(span, chunk | accumulated_tokens)
finally:
# Finalize span after processing all chunks
span.add_event(Event.STREAM_END.value)
span.set_status(StatusCode.OK)
span.end()

return response


async def _ahandle_streaming_response(span, response, api):
accumulated_tokens = None
if api == "chat":
accumulated_tokens = {"message": {"content": "", "role": ""}}
if api == "completion":
accumulated_tokens = {"response": ""}

span.add_event(Event.STREAM_START.value)
try:
async for chunk in response:
if api == "chat":
accumulated_tokens["message"]["content"] += chunk["message"]["content"]
accumulated_tokens["message"]["role"] = chunk["message"]["role"]
if api == "generate":
accumulated_tokens["response"] += chunk["response"]

_set_response_attributes(span, chunk | accumulated_tokens)
finally:
# Finalize span after processing all chunks
span.add_event(Event.STREAM_END.value)
span.set_status(StatusCode.OK)
span.end()

return response
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
from typing import Collection

from langtrace.trace_attributes import PineconeMethods
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer
from wrapt import wrap_function_wrapper
Expand Down
Loading

0 comments on commit cbde43b

Please sign in to comment.