diff --git a/src/examples/cohere_example/rerank.py b/src/examples/cohere_example/rerank.py index de1a21f5..3995feb5 100644 --- a/src/examples/cohere_example/rerank.py +++ b/src/examples/cohere_example/rerank.py @@ -1,5 +1,6 @@ import cohere from dotenv import find_dotenv, load_dotenv +from datetime import datetime from langtrace_python_sdk import langtrace @@ -16,10 +17,22 @@ # @with_langtrace_root_span("embed_create") def rerank(): docs = [ - "Carson City is the capital city of the American state of Nevada.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", - "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", - "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + { + "text": "Carson City is the capital city of the American state of Nevada.", + "date": datetime.now(), + }, + { + "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "date": datetime(2020, 5, 17), + }, + { + "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "date": datetime(1776, 7, 4), + }, + { + "text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + "date": datetime(2023, 9, 14), + }, ] response = co.rerank( diff --git a/src/examples/langchain_example/__init__.py b/src/examples/langchain_example/__init__.py index ceaee24c..f9621a2c 100644 --- a/src/examples/langchain_example/__init__.py +++ b/src/examples/langchain_example/__init__.py @@ -2,7 +2,7 @@ from .basic import basic_app, rag, load_and_split from langtrace_python_sdk import with_langtrace_root_span -from .groq_example import groq_basic, groq_streaming +from .groq_example import groq_basic, groq_tool_choice, groq_streaming from .langgraph_example_tools import basic_graph_tools @@ -20,3 +20,5 @@ class GroqRunner: @with_langtrace_root_span("Groq") def run(self): groq_streaming() + groq_basic() + groq_tool_choice() diff --git a/src/examples/langchain_example/groq_example.py b/src/examples/langchain_example/groq_example.py index c16e4851..99a61761 100644 --- a/src/examples/langchain_example/groq_example.py +++ b/src/examples/langchain_example/groq_example.py @@ -1,6 +1,6 @@ +import json + from dotenv import find_dotenv, load_dotenv -from langchain_core.prompts import ChatPromptTemplate -from langchain_groq import ChatGroq from groq import Groq _ = load_dotenv(find_dotenv()) @@ -30,6 +30,82 @@ def groq_basic(): return chat_completion +def groq_tool_choice(): + + user_prompt = "What is 25 * 4 + 10?" + MODEL = "llama3-groq-70b-8192-tool-use-preview" + + def calculate(expression): + """Evaluate a mathematical expression""" + try: + result = eval(expression) + return json.dumps({"result": result}) + except: + return json.dumps({"error": "Invalid expression"}) + + messages = [ + { + "role": "system", + "content": "You are a calculator assistant. Use the calculate function to perform mathematical operations and provide the results.", + }, + { + "role": "user", + "content": user_prompt, + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "calculate", + "description": "Evaluate a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to evaluate", + } + }, + "required": ["expression"], + }, + }, + } + ] + response = client.chat.completions.create( + model=MODEL, + messages=messages, + tools=tools, + tool_choice={"type": "function", "function": {"name": "calculate"}}, + max_tokens=4096, + ) + + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + if tool_calls: + available_functions = { + "calculate": calculate, + } + messages.append(response_message) + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + expression=function_args.get("expression") + ) + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) + second_response = client.chat.completions.create(model=MODEL, messages=messages) + return second_response.choices[0].message.content + + def groq_streaming(): chat_completion = client.chat.completions.create( messages=[ diff --git a/src/langtrace_python_sdk/instrumentation/cohere/patch.py b/src/langtrace_python_sdk/instrumentation/cohere/patch.py index c165a10c..38908c3d 100644 --- a/src/langtrace_python_sdk/instrumentation/cohere/patch.py +++ b/src/langtrace_python_sdk/instrumentation/cohere/patch.py @@ -27,6 +27,7 @@ ) from langtrace.trace_attributes import Event, LLMSpanAttributes from langtrace_python_sdk.utils import set_span_attribute +from langtrace_python_sdk.utils.misc import datetime_encoder from opentelemetry.trace import SpanKind from opentelemetry.trace.status import Status, StatusCode @@ -50,7 +51,9 @@ def traced_method(wrapped, instance, args, kwargs): SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus", SpanAttributes.LLM_URL: APIS["RERANK"]["URL"], SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"], - SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(kwargs.get("documents")), + SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps( + kwargs.get("documents"), cls=datetime_encoder + ), SpanAttributes.LLM_COHERE_RERANK_QUERY: kwargs.get("query"), **get_extra_attributes(), } diff --git a/src/langtrace_python_sdk/utils/llm.py b/src/langtrace_python_sdk/utils/llm.py index 6d9647e7..e1d6fa9d 100644 --- a/src/langtrace_python_sdk/utils/llm.py +++ b/src/langtrace_python_sdk/utils/llm.py @@ -124,6 +124,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name= top_p = kwargs.get("p", None) or kwargs.get("top_p", None) tools = kwargs.get("tools", None) + tool_choice = kwargs.get("tool_choice", None) return { SpanAttributes.LLM_OPERATION_NAME: operation_name, SpanAttributes.LLM_REQUEST_MODEL: model @@ -141,7 +142,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name= SpanAttributes.LLM_FREQUENCY_PENALTY: kwargs.get("frequency_penalty"), SpanAttributes.LLM_REQUEST_SEED: kwargs.get("seed"), SpanAttributes.LLM_TOOLS: json.dumps(tools) if tools else None, - SpanAttributes.LLM_TOOL_CHOICE: kwargs.get("tool_choice"), + SpanAttributes.LLM_TOOL_CHOICE: json.dumps(tool_choice) if tool_choice else None, SpanAttributes.LLM_REQUEST_LOGPROPS: kwargs.get("logprobs"), SpanAttributes.LLM_REQUEST_LOGITBIAS: kwargs.get("logit_bias"), SpanAttributes.LLM_REQUEST_TOP_LOGPROPS: kwargs.get("top_logprobs"), diff --git a/src/langtrace_python_sdk/utils/misc.py b/src/langtrace_python_sdk/utils/misc.py index a0d20452..56924cfe 100644 --- a/src/langtrace_python_sdk/utils/misc.py +++ b/src/langtrace_python_sdk/utils/misc.py @@ -60,3 +60,11 @@ def is_serializable(value): # Convert to string representation return json.dumps(serializable_args) + + +class datetime_encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + + return json.JSONEncoder.default(self, o) diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index ab20ff9c..e4f37b40 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "2.3.15" +__version__ = "2.3.16"