Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/examples/cohere_example/rerank.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cohere
from dotenv import find_dotenv, load_dotenv
from datetime import datetime

from langtrace_python_sdk import langtrace

Expand All @@ -16,10 +17,22 @@
# @with_langtrace_root_span("embed_create")
def rerank():
docs = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
{
"text": "Carson City is the capital city of the American state of Nevada.",
"date": datetime.now(),
},
{
"text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"date": datetime(2020, 5, 17),
},
{
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"date": datetime(1776, 7, 4),
},
{
"text": "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
"date": datetime(2023, 9, 14),
},
]

response = co.rerank(
Expand Down
4 changes: 3 additions & 1 deletion src/examples/langchain_example/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .basic import basic_app, rag, load_and_split
from langtrace_python_sdk import with_langtrace_root_span

from .groq_example import groq_basic, groq_streaming
from .groq_example import groq_basic, groq_tool_choice, groq_streaming
from .langgraph_example_tools import basic_graph_tools


Expand All @@ -20,3 +20,5 @@ class GroqRunner:
@with_langtrace_root_span("Groq")
def run(self):
groq_streaming()
groq_basic()
groq_tool_choice()
80 changes: 78 additions & 2 deletions src/examples/langchain_example/groq_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from dotenv import find_dotenv, load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from groq import Groq

_ = load_dotenv(find_dotenv())
Expand Down Expand Up @@ -30,6 +30,82 @@ def groq_basic():
return chat_completion


def groq_tool_choice():

user_prompt = "What is 25 * 4 + 10?"
MODEL = "llama3-groq-70b-8192-tool-use-preview"

def calculate(expression):
"""Evaluate a mathematical expression"""
try:
result = eval(expression)
return json.dumps({"result": result})
except:
return json.dumps({"error": "Invalid expression"})

messages = [
{
"role": "system",
"content": "You are a calculator assistant. Use the calculate function to perform mathematical operations and provide the results.",
},
{
"role": "user",
"content": user_prompt,
},
]
tools = [
{
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate",
}
},
"required": ["expression"],
},
},
}
]
response = client.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
tool_choice={"type": "function", "function": {"name": "calculate"}},
max_tokens=4096,
)

response_message = response.choices[0].message
tool_calls = response_message.tool_calls
if tool_calls:
available_functions = {
"calculate": calculate,
}
messages.append(response_message)
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
expression=function_args.get("expression")
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
)
second_response = client.chat.completions.create(model=MODEL, messages=messages)
return second_response.choices[0].message.content


def groq_streaming():
chat_completion = client.chat.completions.create(
messages=[
Expand Down
5 changes: 4 additions & 1 deletion src/langtrace_python_sdk/instrumentation/cohere/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from langtrace.trace_attributes import Event, LLMSpanAttributes
from langtrace_python_sdk.utils import set_span_attribute
from langtrace_python_sdk.utils.misc import datetime_encoder
from opentelemetry.trace import SpanKind
from opentelemetry.trace.status import Status, StatusCode

Expand All @@ -50,7 +51,9 @@ def traced_method(wrapped, instance, args, kwargs):
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(kwargs.get("documents")),
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
kwargs.get("documents"), cls=datetime_encoder
),
SpanAttributes.LLM_COHERE_RERANK_QUERY: kwargs.get("query"),
**get_extra_attributes(),
}
Expand Down
3 changes: 2 additions & 1 deletion src/langtrace_python_sdk/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name=

top_p = kwargs.get("p", None) or kwargs.get("top_p", None)
tools = kwargs.get("tools", None)
tool_choice = kwargs.get("tool_choice", None)
return {
SpanAttributes.LLM_OPERATION_NAME: operation_name,
SpanAttributes.LLM_REQUEST_MODEL: model
Expand All @@ -141,7 +142,7 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name=
SpanAttributes.LLM_FREQUENCY_PENALTY: kwargs.get("frequency_penalty"),
SpanAttributes.LLM_REQUEST_SEED: kwargs.get("seed"),
SpanAttributes.LLM_TOOLS: json.dumps(tools) if tools else None,
SpanAttributes.LLM_TOOL_CHOICE: kwargs.get("tool_choice"),
SpanAttributes.LLM_TOOL_CHOICE: json.dumps(tool_choice) if tool_choice else None,
SpanAttributes.LLM_REQUEST_LOGPROPS: kwargs.get("logprobs"),
SpanAttributes.LLM_REQUEST_LOGITBIAS: kwargs.get("logit_bias"),
SpanAttributes.LLM_REQUEST_TOP_LOGPROPS: kwargs.get("top_logprobs"),
Expand Down
8 changes: 8 additions & 0 deletions src/langtrace_python_sdk/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ def is_serializable(value):

# Convert to string representation
return json.dumps(serializable_args)


class datetime_encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()

return json.JSONEncoder.default(self, o)
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.3.15"
__version__ = "2.3.16"