diff --git a/src/examples/cohere_example/__init__.py b/src/examples/cohere_example/__init__.py index 5610cf4b..f36a7e7d 100644 --- a/src/examples/cohere_example/__init__.py +++ b/src/examples/cohere_example/__init__.py @@ -1,8 +1,11 @@ from examples.cohere_example.chat import chat_comp +from examples.cohere_example.chatv2 import chat_v2 +from examples.cohere_example.chat_streamv2 import chat_stream_v2 from examples.cohere_example.chat_stream import chat_stream from examples.cohere_example.tools import tool_calling from examples.cohere_example.embed import embed from examples.cohere_example.rerank import rerank +from examples.cohere_example.rerankv2 import rerank_v2 from langtrace_python_sdk import with_langtrace_root_span @@ -10,8 +13,11 @@ class CohereRunner: @with_langtrace_root_span("Cohere") def run(self): + chat_v2() + chat_stream_v2() chat_comp() chat_stream() tool_calling() embed() rerank() + rerank_v2() diff --git a/src/examples/cohere_example/chat_streamv2.py b/src/examples/cohere_example/chat_streamv2.py new file mode 100644 index 00000000..2bce996b --- /dev/null +++ b/src/examples/cohere_example/chat_streamv2.py @@ -0,0 +1,17 @@ +import os +from langtrace_python_sdk import langtrace +import cohere + +langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY")) +co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) + +def chat_stream_v2(): + res = co.chat_stream( + model="command-r-plus-08-2024", + messages=[{"role": "user", "content": "Write a title for a blog post about API design. Only output the title text"}], + ) + + for event in res: + if event: + if event.type == "content-delta": + print(event.delta.message.content.text) \ No newline at end of file diff --git a/src/examples/cohere_example/chatv2.py b/src/examples/cohere_example/chatv2.py new file mode 100644 index 00000000..26f59745 --- /dev/null +++ b/src/examples/cohere_example/chatv2.py @@ -0,0 +1,21 @@ +import os +from langtrace_python_sdk import langtrace +import cohere + +langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY")) + + +def chat_v2(): + co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) + + res = co.chat( + model="command-r-plus-08-2024", + messages=[ + { + "role": "user", + "content": "Write a title for a blog post about API design. Only output the title text.", + } + ], + ) + + print(res.message.content[0].text) diff --git a/src/examples/cohere_example/rerankv2.py b/src/examples/cohere_example/rerankv2.py new file mode 100644 index 00000000..55e262cc --- /dev/null +++ b/src/examples/cohere_example/rerankv2.py @@ -0,0 +1,23 @@ +import os +from langtrace_python_sdk import langtrace +import cohere + +langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY")) +co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) + +docs = [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", +] + +def rerank_v2(): + response = co.rerank( + model="rerank-v3.5", + query="What is the capital of the United States?", + documents=docs, + top_n=3, + ) + print(response) diff --git a/src/langtrace_python_sdk/constants/instrumentation/cohere.py b/src/langtrace_python_sdk/constants/instrumentation/cohere.py index cc38254b..650fa75c 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/cohere.py +++ b/src/langtrace_python_sdk/constants/instrumentation/cohere.py @@ -4,19 +4,39 @@ "METHOD": "cohere.client.chat", "ENDPOINT": "/v1/chat", }, + "CHAT_CREATE_V2": { + "URL": "https://api.cohere.ai", + "METHOD": "cohere.client_v2.chat", + "ENDPOINT": "/v2/chat", + }, "EMBED": { "URL": "https://api.cohere.ai", "METHOD": "cohere.client.embed", "ENDPOINT": "/v1/embed", }, + "EMBED_V2": { + "URL": "https://api.cohere.ai", + "METHOD": "cohere.client_v2.embed", + "ENDPOINT": "/v2/embed", + }, "CHAT_STREAM": { "URL": "https://api.cohere.ai", "METHOD": "cohere.client.chat_stream", - "ENDPOINT": "/v1/messages", + "ENDPOINT": "/v1/chat", + }, + "CHAT_STREAM_V2": { + "URL": "https://api.cohere.ai", + "METHOD": "cohere.client_v2.chat_stream", + "ENDPOINT": "/v2/chat", }, "RERANK": { "URL": "https://api.cohere.ai", "METHOD": "cohere.client.rerank", "ENDPOINT": "/v1/rerank", }, + "RERANK_V2": { + "URL": "https://api.cohere.ai", + "METHOD": "cohere.client_v2.rerank", + "ENDPOINT": "/v2/rerank", + }, } diff --git a/src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py b/src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py index df433c13..414672ce 100644 --- a/src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py +++ b/src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py @@ -23,6 +23,7 @@ from langtrace_python_sdk.instrumentation.cohere.patch import ( chat_create, + chat_create_v2, chat_stream, embed, rerank, @@ -48,6 +49,18 @@ def _instrument(self, **kwargs): chat_create("cohere.client.chat", version, tracer), ) + wrap_function_wrapper( + "cohere.client_v2", + "ClientV2.chat", + chat_create_v2("cohere.client_v2.chat", version, tracer), + ) + + wrap_function_wrapper( + "cohere.client_v2", + "ClientV2.chat_stream", + chat_create_v2("cohere.client_v2.chat", version, tracer, stream=True), + ) + wrap_function_wrapper( "cohere.client", "Client.chat_stream", @@ -60,12 +73,24 @@ def _instrument(self, **kwargs): embed("cohere.client.embed", version, tracer), ) + wrap_function_wrapper( + "cohere.client_v2", + "ClientV2.embed", + embed("cohere.client.embed", version, tracer, v2=True), + ) + wrap_function_wrapper( "cohere.client", "Client.rerank", rerank("cohere.client.rerank", version, tracer), ) + wrap_function_wrapper( + "cohere.client_v2", + "ClientV2.rerank", + rerank("cohere.client.rerank", version, tracer, v2=True), + ) + def _instrument_module(self, module_name): pass diff --git a/src/langtrace_python_sdk/instrumentation/cohere/patch.py b/src/langtrace_python_sdk/instrumentation/cohere/patch.py index 38908c3d..8b9a8e53 100644 --- a/src/langtrace_python_sdk/instrumentation/cohere/patch.py +++ b/src/langtrace_python_sdk/instrumentation/cohere/patch.py @@ -24,6 +24,7 @@ get_span_name, set_event_completion, set_usage_attributes, + StreamWrapper ) from langtrace.trace_attributes import Event, LLMSpanAttributes from langtrace_python_sdk.utils import set_span_attribute @@ -38,7 +39,7 @@ from langtrace.trace_attributes import SpanAttributes -def rerank(original_method, version, tracer): +def rerank(original_method, version, tracer, v2=False): """Wrap the `rerank` method.""" def traced_method(wrapped, instance, args, kwargs): @@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs): **get_llm_request_attributes(kwargs, operation_name="rerank"), **get_llm_url(instance), SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus", - SpanAttributes.LLM_URL: APIS["RERANK"]["URL"], - SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"], + SpanAttributes.LLM_URL: APIS["RERANK" if not v2 else "RERANK_V2"]["URL"], + SpanAttributes.LLM_PATH: APIS["RERANK" if not v2 else "RERANK_V2"]["ENDPOINT"], SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps( kwargs.get("documents"), cls=datetime_encoder ), @@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs): attributes = LLMSpanAttributes(**span_attributes) span = tracer.start_span( - name=get_span_name(APIS["RERANK"]["METHOD"]), kind=SpanKind.CLIENT + name=get_span_name(APIS["RERANK" if not v2 else "RERANK_V2"]["METHOD"]), kind=SpanKind.CLIENT ) for field, value in attributes.model_dump(by_alias=True).items(): set_span_attribute(span, field, value) @@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs): return traced_method -def embed(original_method, version, tracer): +def embed(original_method, version, tracer, v2=False): """Wrap the `embed` method.""" def traced_method(wrapped, instance, args, kwargs): @@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs): **get_langtrace_attributes(version, service_provider), **get_llm_request_attributes(kwargs, operation_name="embed"), **get_llm_url(instance), - SpanAttributes.LLM_URL: APIS["EMBED"]["URL"], - SpanAttributes.LLM_PATH: APIS["EMBED"]["ENDPOINT"], + SpanAttributes.LLM_URL: APIS["EMBED" if not v2 else "EMBED_V2"]["URL"], + SpanAttributes.LLM_PATH: APIS["EMBED" if not v2 else "EMBED_V2"]["ENDPOINT"], SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS: json.dumps( kwargs.get("texts") ), @@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs): attributes = LLMSpanAttributes(**span_attributes) span = tracer.start_span( - name=get_span_name(APIS["EMBED"]["METHOD"]), + name=get_span_name(APIS["EMBED" if not v2 else "EMBED_V2"]["METHOD"]), kind=SpanKind.CLIENT, ) for field, value in attributes.model_dump(by_alias=True).items(): @@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs): return traced_method +def chat_create_v2(original_method, version, tracer, stream=False): + """Wrap the `chat_create` method for Cohere API v2.""" + + def traced_method(wrapped, instance, args, kwargs): + service_provider = SERVICE_PROVIDERS["COHERE"] + + messages = kwargs.get("messages", []) + if kwargs.get("preamble"): + messages = [{"role": "system", "content": kwargs["preamble"]}] + messages + + span_attributes = { + **get_langtrace_attributes(version, service_provider), + **get_llm_request_attributes(kwargs, prompts=messages), + **get_llm_url(instance), + SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus", + SpanAttributes.LLM_URL: APIS["CHAT_CREATE_V2"]["URL"], + SpanAttributes.LLM_PATH: APIS["CHAT_CREATE_V2"]["ENDPOINT"], + **get_extra_attributes(), + } + + attributes = LLMSpanAttributes(**span_attributes) + + for attr_name in ["max_input_tokens", "conversation_id", "connectors", "tools", "tool_results"]: + value = kwargs.get(attr_name) + if value is not None: + if attr_name == "max_input_tokens": + attributes.llm_max_input_tokens = str(value) + elif attr_name == "conversation_id": + attributes.conversation_id = value + else: + setattr(attributes, f"llm_{attr_name}", json.dumps(value)) + + span = tracer.start_span( + name=get_span_name(APIS["CHAT_CREATE_V2"]["METHOD"]), + kind=SpanKind.CLIENT + ) + + for field, value in attributes.model_dump(by_alias=True).items(): + set_span_attribute(span, field, value) + + try: + result = wrapped(*args, **kwargs) + + if stream: + return StreamWrapper( + result, + span, + tool_calls=kwargs.get("tools") is not None, + ) + else: + if hasattr(result, "id") and result.id is not None: + span.set_attribute(SpanAttributes.LLM_GENERATION_ID, result.id) + span.set_attribute(SpanAttributes.LLM_RESPONSE_ID, result.id) + + if (hasattr(result, "message") and + hasattr(result.message, "content") and + len(result.message.content) > 0 and + hasattr(result.message.content[0], "text") and + result.message.content[0].text is not None and + result.message.content[0].text != ""): + responses = [{ + "role": result.message.role, + "content": result.message.content[0].text + }] + set_event_completion(span, responses) + if hasattr(result, "tool_calls") and result.tool_calls is not None: + tool_calls = [tool_call.json() for tool_call in result.tool_calls] + span.set_attribute( + SpanAttributes.LLM_TOOL_RESULTS, + json.dumps(tool_calls) + ) + if hasattr(result, "usage") and result.usage is not None: + if (hasattr(result.usage, "billed_units") and + result.usage.billed_units is not None): + usage = result.usage.billed_units + for metric, value in { + "input": usage.input_tokens or 0, + "output": usage.output_tokens or 0, + "total": (usage.input_tokens or 0) + (usage.output_tokens or 0), + }.items(): + span.set_attribute( + f"gen_ai.usage.{metric}_tokens", + int(value) + ) + span.set_status(StatusCode.OK) + span.end() + return result + + except Exception as error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + span.end() + raise + + return traced_method + + def chat_stream(original_method, version, tracer): """Wrap the `messages_stream` method.""" diff --git a/src/langtrace_python_sdk/utils/llm.py b/src/langtrace_python_sdk/utils/llm.py index b9355124..bdb6bb53 100644 --- a/src/langtrace_python_sdk/utils/llm.py +++ b/src/langtrace_python_sdk/utils/llm.py @@ -393,8 +393,19 @@ def build_streaming_response(self, chunk): if hasattr(chunk, "text") and chunk.text is not None: content = [chunk.text] + # CohereV2 + if (hasattr(chunk, "delta") and + chunk.delta is not None and + hasattr(chunk.delta, "message") and + chunk.delta.message is not None and + hasattr(chunk.delta.message, "content") and + chunk.delta.message.content is not None and + hasattr(chunk.delta.message.content, "text") and + chunk.delta.message.content.text is not None): + content = [chunk.delta.message.content.text] + # Anthropic - if hasattr(chunk, "delta") and chunk.delta is not None: + if hasattr(chunk, "delta") and chunk.delta is not None and not hasattr(chunk.delta, "message"): content = [chunk.delta.text] if hasattr(chunk.delta, "text") else [] if isinstance(chunk, dict): @@ -408,7 +419,17 @@ def set_usage_attributes(self, chunk): # Anthropic & OpenAI if hasattr(chunk, "type") and chunk.type == "message_start": - self.prompt_tokens = chunk.message.usage.input_tokens + if hasattr(chunk.message, "usage") and chunk.message.usage is not None: + self.prompt_tokens = chunk.message.usage.input_tokens + + # CohereV2 + if hasattr(chunk, "type") and chunk.type == "message-end": + if (hasattr(chunk, "delta") and chunk.delta is not None and + hasattr(chunk.delta, "usage") and chunk.delta.usage is not None and + hasattr(chunk.delta.usage, "billed_units") and chunk.delta.usage.billed_units is not None): + usage = chunk.delta.usage.billed_units + self.completion_tokens = int(usage.output_tokens) + self.prompt_tokens = int(usage.input_tokens) if hasattr(chunk, "usage") and chunk.usage is not None: if hasattr(chunk.usage, "output_tokens"): diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index ef7fb174..263d012c 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.3.14" +__version__ = "3.3.15" diff --git a/src/run_example.py b/src/run_example.py index 9f4cecf2..ab925665 100644 --- a/src/run_example.py +++ b/src/run_example.py @@ -4,7 +4,7 @@ "anthropic": False, "azureopenai": False, "chroma": False, - "cohere": False, + "cohere": True, "fastapi": False, "langchain": False, "llamaindex": False, @@ -21,7 +21,7 @@ "gemini": False, "mistral": False, "awsbedrock": False, - "cerebras": True, + "cerebras": False, } if ENABLED_EXAMPLES["anthropic"]: