Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/examples/cohere_example/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from examples.cohere_example.chat import chat_comp
from examples.cohere_example.chatv2 import chat_v2
from examples.cohere_example.chat_streamv2 import chat_stream_v2
from examples.cohere_example.chat_stream import chat_stream
from examples.cohere_example.tools import tool_calling
from examples.cohere_example.embed import embed
from examples.cohere_example.rerank import rerank
from examples.cohere_example.rerankv2 import rerank_v2
from langtrace_python_sdk import with_langtrace_root_span


class CohereRunner:

@with_langtrace_root_span("Cohere")
def run(self):
chat_v2()
chat_stream_v2()
chat_comp()
chat_stream()
tool_calling()
embed()
rerank()
rerank_v2()
17 changes: 17 additions & 0 deletions src/examples/cohere_example/chat_streamv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
from langtrace_python_sdk import langtrace
import cohere

langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))

def chat_stream_v2():
res = co.chat_stream(
model="command-r-plus-08-2024",
messages=[{"role": "user", "content": "Write a title for a blog post about API design. Only output the title text"}],
)

for event in res:
if event:
if event.type == "content-delta":
print(event.delta.message.content.text)
21 changes: 21 additions & 0 deletions src/examples/cohere_example/chatv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from langtrace_python_sdk import langtrace
import cohere

langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))


def chat_v2():
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))

res = co.chat(
model="command-r-plus-08-2024",
messages=[
{
"role": "user",
"content": "Write a title for a blog post about API design. Only output the title text.",
}
],
)

print(res.message.content[0].text)
23 changes: 23 additions & 0 deletions src/examples/cohere_example/rerankv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from langtrace_python_sdk import langtrace
import cohere

langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))

docs = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
]

def rerank_v2():
response = co.rerank(
model="rerank-v3.5",
query="What is the capital of the United States?",
documents=docs,
top_n=3,
)
print(response)
22 changes: 21 additions & 1 deletion src/langtrace_python_sdk/constants/instrumentation/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,39 @@
"METHOD": "cohere.client.chat",
"ENDPOINT": "/v1/chat",
},
"CHAT_CREATE_V2": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client_v2.chat",
"ENDPOINT": "/v2/chat",
},
"EMBED": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client.embed",
"ENDPOINT": "/v1/embed",
},
"EMBED_V2": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client_v2.embed",
"ENDPOINT": "/v2/embed",
},
"CHAT_STREAM": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client.chat_stream",
"ENDPOINT": "/v1/messages",
"ENDPOINT": "/v1/chat",
},
"CHAT_STREAM_V2": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client_v2.chat_stream",
"ENDPOINT": "/v2/chat",
},
"RERANK": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client.rerank",
"ENDPOINT": "/v1/rerank",
},
"RERANK_V2": {
"URL": "https://api.cohere.ai",
"METHOD": "cohere.client_v2.rerank",
"ENDPOINT": "/v2/rerank",
},
}
25 changes: 25 additions & 0 deletions src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from langtrace_python_sdk.instrumentation.cohere.patch import (
chat_create,
chat_create_v2,
chat_stream,
embed,
rerank,
Expand All @@ -48,6 +49,18 @@ def _instrument(self, **kwargs):
chat_create("cohere.client.chat", version, tracer),
)

wrap_function_wrapper(
"cohere.client_v2",
"ClientV2.chat",
chat_create_v2("cohere.client_v2.chat", version, tracer),
)

wrap_function_wrapper(
"cohere.client_v2",
"ClientV2.chat_stream",
chat_create_v2("cohere.client_v2.chat", version, tracer, stream=True),
)

wrap_function_wrapper(
"cohere.client",
"Client.chat_stream",
Expand All @@ -60,12 +73,24 @@ def _instrument(self, **kwargs):
embed("cohere.client.embed", version, tracer),
)

wrap_function_wrapper(
"cohere.client_v2",
"ClientV2.embed",
embed("cohere.client.embed", version, tracer, v2=True),
)

wrap_function_wrapper(
"cohere.client",
"Client.rerank",
rerank("cohere.client.rerank", version, tracer),
)

wrap_function_wrapper(
"cohere.client_v2",
"ClientV2.rerank",
rerank("cohere.client.rerank", version, tracer, v2=True),
)

def _instrument_module(self, module_name):
pass

Expand Down
114 changes: 106 additions & 8 deletions src/langtrace_python_sdk/instrumentation/cohere/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_span_name,
set_event_completion,
set_usage_attributes,
StreamWrapper
)
from langtrace.trace_attributes import Event, LLMSpanAttributes
from langtrace_python_sdk.utils import set_span_attribute
Expand All @@ -38,7 +39,7 @@
from langtrace.trace_attributes import SpanAttributes


def rerank(original_method, version, tracer):
def rerank(original_method, version, tracer, v2=False):
"""Wrap the `rerank` method."""

def traced_method(wrapped, instance, args, kwargs):
Expand All @@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
**get_llm_request_attributes(kwargs, operation_name="rerank"),
**get_llm_url(instance),
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
SpanAttributes.LLM_URL: APIS["RERANK" if not v2 else "RERANK_V2"]["URL"],
SpanAttributes.LLM_PATH: APIS["RERANK" if not v2 else "RERANK_V2"]["ENDPOINT"],
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
kwargs.get("documents"), cls=datetime_encoder
),
Expand All @@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs):
attributes = LLMSpanAttributes(**span_attributes)

span = tracer.start_span(
name=get_span_name(APIS["RERANK"]["METHOD"]), kind=SpanKind.CLIENT
name=get_span_name(APIS["RERANK" if not v2 else "RERANK_V2"]["METHOD"]), kind=SpanKind.CLIENT
)
for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)
Expand Down Expand Up @@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs):
return traced_method


def embed(original_method, version, tracer):
def embed(original_method, version, tracer, v2=False):
"""Wrap the `embed` method."""

def traced_method(wrapped, instance, args, kwargs):
Expand All @@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs):
**get_langtrace_attributes(version, service_provider),
**get_llm_request_attributes(kwargs, operation_name="embed"),
**get_llm_url(instance),
SpanAttributes.LLM_URL: APIS["EMBED"]["URL"],
SpanAttributes.LLM_PATH: APIS["EMBED"]["ENDPOINT"],
SpanAttributes.LLM_URL: APIS["EMBED" if not v2 else "EMBED_V2"]["URL"],
SpanAttributes.LLM_PATH: APIS["EMBED" if not v2 else "EMBED_V2"]["ENDPOINT"],
SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS: json.dumps(
kwargs.get("texts")
),
Expand All @@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs):
attributes = LLMSpanAttributes(**span_attributes)

span = tracer.start_span(
name=get_span_name(APIS["EMBED"]["METHOD"]),
name=get_span_name(APIS["EMBED" if not v2 else "EMBED_V2"]["METHOD"]),
kind=SpanKind.CLIENT,
)
for field, value in attributes.model_dump(by_alias=True).items():
Expand Down Expand Up @@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs):
return traced_method


def chat_create_v2(original_method, version, tracer, stream=False):
"""Wrap the `chat_create` method for Cohere API v2."""

def traced_method(wrapped, instance, args, kwargs):
service_provider = SERVICE_PROVIDERS["COHERE"]

messages = kwargs.get("messages", [])
if kwargs.get("preamble"):
messages = [{"role": "system", "content": kwargs["preamble"]}] + messages

span_attributes = {
**get_langtrace_attributes(version, service_provider),
**get_llm_request_attributes(kwargs, prompts=messages),
**get_llm_url(instance),
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
SpanAttributes.LLM_URL: APIS["CHAT_CREATE_V2"]["URL"],
SpanAttributes.LLM_PATH: APIS["CHAT_CREATE_V2"]["ENDPOINT"],
**get_extra_attributes(),
}

attributes = LLMSpanAttributes(**span_attributes)

for attr_name in ["max_input_tokens", "conversation_id", "connectors", "tools", "tool_results"]:
value = kwargs.get(attr_name)
if value is not None:
if attr_name == "max_input_tokens":
attributes.llm_max_input_tokens = str(value)
elif attr_name == "conversation_id":
attributes.conversation_id = value
else:
setattr(attributes, f"llm_{attr_name}", json.dumps(value))

span = tracer.start_span(
name=get_span_name(APIS["CHAT_CREATE_V2"]["METHOD"]),
kind=SpanKind.CLIENT
)

for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)

try:
result = wrapped(*args, **kwargs)

if stream:
return StreamWrapper(
result,
span,
tool_calls=kwargs.get("tools") is not None,
)
else:
if hasattr(result, "id") and result.id is not None:
span.set_attribute(SpanAttributes.LLM_GENERATION_ID, result.id)
span.set_attribute(SpanAttributes.LLM_RESPONSE_ID, result.id)

if (hasattr(result, "message") and
hasattr(result.message, "content") and
len(result.message.content) > 0 and
hasattr(result.message.content[0], "text") and
result.message.content[0].text is not None and
result.message.content[0].text != ""):
responses = [{
"role": result.message.role,
"content": result.message.content[0].text
}]
set_event_completion(span, responses)
if hasattr(result, "tool_calls") and result.tool_calls is not None:
tool_calls = [tool_call.json() for tool_call in result.tool_calls]
span.set_attribute(
SpanAttributes.LLM_TOOL_RESULTS,
json.dumps(tool_calls)
)
if hasattr(result, "usage") and result.usage is not None:
if (hasattr(result.usage, "billed_units") and
result.usage.billed_units is not None):
usage = result.usage.billed_units
for metric, value in {
"input": usage.input_tokens or 0,
"output": usage.output_tokens or 0,
"total": (usage.input_tokens or 0) + (usage.output_tokens or 0),
}.items():
span.set_attribute(
f"gen_ai.usage.{metric}_tokens",
int(value)
)
span.set_status(StatusCode.OK)
span.end()
return result

except Exception as error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
span.end()
raise

return traced_method


def chat_stream(original_method, version, tracer):
"""Wrap the `messages_stream` method."""

Expand Down
25 changes: 23 additions & 2 deletions src/langtrace_python_sdk/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,19 @@ def build_streaming_response(self, chunk):
if hasattr(chunk, "text") and chunk.text is not None:
content = [chunk.text]

# CohereV2
if (hasattr(chunk, "delta") and
chunk.delta is not None and
hasattr(chunk.delta, "message") and
chunk.delta.message is not None and
hasattr(chunk.delta.message, "content") and
chunk.delta.message.content is not None and
hasattr(chunk.delta.message.content, "text") and
chunk.delta.message.content.text is not None):
content = [chunk.delta.message.content.text]

# Anthropic
if hasattr(chunk, "delta") and chunk.delta is not None:
if hasattr(chunk, "delta") and chunk.delta is not None and not hasattr(chunk.delta, "message"):
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []

if isinstance(chunk, dict):
Expand All @@ -408,7 +419,17 @@ def set_usage_attributes(self, chunk):

# Anthropic & OpenAI
if hasattr(chunk, "type") and chunk.type == "message_start":
self.prompt_tokens = chunk.message.usage.input_tokens
if hasattr(chunk.message, "usage") and chunk.message.usage is not None:
self.prompt_tokens = chunk.message.usage.input_tokens

# CohereV2
if hasattr(chunk, "type") and chunk.type == "message-end":
if (hasattr(chunk, "delta") and chunk.delta is not None and
hasattr(chunk.delta, "usage") and chunk.delta.usage is not None and
hasattr(chunk.delta.usage, "billed_units") and chunk.delta.usage.billed_units is not None):
usage = chunk.delta.usage.billed_units
self.completion_tokens = int(usage.output_tokens)
self.prompt_tokens = int(usage.input_tokens)

if hasattr(chunk, "usage") and chunk.usage is not None:
if hasattr(chunk.usage, "output_tokens"):
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.14"
__version__ = "3.3.15"
Loading
Loading