diff --git a/src/langtrace_python_sdk/constants/instrumentation/common.py b/src/langtrace_python_sdk/constants/instrumentation/common.py index 4c4ec63f..53ff5d65 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/common.py +++ b/src/langtrace_python_sdk/constants/instrumentation/common.py @@ -33,6 +33,7 @@ "MISTRAL": "Mistral", "EMBEDCHAIN": "Embedchain", "AUTOGEN": "Autogen", + "XAI": "XAI", } LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes" diff --git a/src/langtrace_python_sdk/instrumentation/groq/patch.py b/src/langtrace_python_sdk/instrumentation/groq/patch.py index 220387ac..9b9b545e 100644 --- a/src/langtrace_python_sdk/instrumentation/groq/patch.py +++ b/src/langtrace_python_sdk/instrumentation/groq/patch.py @@ -55,6 +55,8 @@ def traced_method(wrapped, instance, args, kwargs): service_provider = SERVICE_PROVIDERS["PPLX"] elif "azure" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["AZURE"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] # handle tool calls in the kwargs llm_prompts = [] @@ -274,6 +276,8 @@ async def traced_method(wrapped, instance, args, kwargs): service_provider = SERVICE_PROVIDERS["PPLX"] elif "azure" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["AZURE"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] # handle tool calls in the kwargs llm_prompts = [] diff --git a/src/langtrace_python_sdk/instrumentation/litellm/patch.py b/src/langtrace_python_sdk/instrumentation/litellm/patch.py index 09c77477..a6cbb183 100644 --- a/src/langtrace_python_sdk/instrumentation/litellm/patch.py +++ b/src/langtrace_python_sdk/instrumentation/litellm/patch.py @@ -248,6 +248,8 @@ def traced_method( service_provider = SERVICE_PROVIDERS["AZURE"] elif "groq" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["GROQ"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] llm_prompts = [] for item in kwargs.get("messages", []): tools = get_tool_calls(item) @@ -336,6 +338,8 @@ async def traced_method( service_provider = SERVICE_PROVIDERS["PPLX"] elif "azure" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["AZURE"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] llm_prompts = [] for item in kwargs.get("messages", []): tools = get_tool_calls(item) diff --git a/src/langtrace_python_sdk/instrumentation/openai/patch.py b/src/langtrace_python_sdk/instrumentation/openai/patch.py index d2902aa5..3b0da8b3 100644 --- a/src/langtrace_python_sdk/instrumentation/openai/patch.py +++ b/src/langtrace_python_sdk/instrumentation/openai/patch.py @@ -249,6 +249,8 @@ def traced_method( service_provider = SERVICE_PROVIDERS["AZURE"] elif "groq" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["GROQ"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] llm_prompts = [] for item in kwargs.get("messages", []): tools = get_tool_calls(item) @@ -337,6 +339,8 @@ async def traced_method( service_provider = SERVICE_PROVIDERS["PPLX"] elif "azure" in get_base_url(instance): service_provider = SERVICE_PROVIDERS["AZURE"] + elif "x.ai" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["XAI"] llm_prompts = [] for item in kwargs.get("messages", []): tools = get_tool_calls(item) diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 911557b8..f7493720 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.1.2" +__version__ = "3.1.3"