diff --git a/src/examples/mongo_vector_search_example/main.py b/src/examples/mongo_vector_search_example/main.py new file mode 100644 index 00000000..61725b8c --- /dev/null +++ b/src/examples/mongo_vector_search_example/main.py @@ -0,0 +1,61 @@ +from langtrace_python_sdk import langtrace, with_langtrace_root_span +import pymongo +import os +from dotenv import load_dotenv +from openai import OpenAI + +load_dotenv() +langtrace.init(write_spans_to_console=False, batch=False) +MODEL = "text-embedding-ada-002" +openai_client = OpenAI() +client = pymongo.MongoClient(os.environ["MONGO_URI"]) + + +# Define a function to generate embeddings +def get_embedding(text): + """Generates vector embeddings for the given text.""" + embedding = ( + openai_client.embeddings.create(input=[text], model=MODEL).data[0].embedding + ) + return embedding + + +@with_langtrace_root_span("mongo-vector-search") +def vector_query(): + db = client["sample_mflix"] + + embedded_movies_collection = db["embedded_movies"] + # define pipeline + pipeline = [ + { + "$vectorSearch": { + "index": "vector_index", + "path": "plot_embedding", + "queryVector": get_embedding("time travel"), + "numCandidates": 150, + "limit": 10, + } + }, + { + "$project": { + "_id": 0, + "plot": 1, + "title": 1, + "score": {"$meta": "vectorSearchScore"}, + } + }, + ] + + result = embedded_movies_collection.aggregate(pipeline) + for doc in result: + # print(doc) + pass + + +if __name__ == "__main__": + try: + vector_query() + except Exception as e: + print("error", e) + finally: + client.close() diff --git a/src/langtrace_python_sdk/constants/instrumentation/common.py b/src/langtrace_python_sdk/constants/instrumentation/common.py index 48779ec5..937c596f 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/common.py +++ b/src/langtrace_python_sdk/constants/instrumentation/common.py @@ -34,6 +34,7 @@ "EMBEDCHAIN": "Embedchain", "AUTOGEN": "Autogen", "XAI": "XAI", + "MONGODB": "MongoDB", "AWS_BEDROCK": "AWS Bedrock", "CEREBRAS": "Cerebras", } diff --git a/src/langtrace_python_sdk/constants/instrumentation/pymongo.py b/src/langtrace_python_sdk/constants/instrumentation/pymongo.py new file mode 100644 index 00000000..d65d8052 --- /dev/null +++ b/src/langtrace_python_sdk/constants/instrumentation/pymongo.py @@ -0,0 +1,8 @@ +APIS = { + "AGGREGATE": { + "MODULE": "pymongo.collection", + "METHOD": "Collection.aggregate", + "OPERATION": "aggregate", + "SPAN_NAME": "MongoDB Aggregate", + }, +} diff --git a/src/langtrace_python_sdk/instrumentation/__init__.py b/src/langtrace_python_sdk/instrumentation/__init__.py index a742d683..2f13fddd 100644 --- a/src/langtrace_python_sdk/instrumentation/__init__.py +++ b/src/langtrace_python_sdk/instrumentation/__init__.py @@ -21,6 +21,7 @@ from .aws_bedrock import AWSBedrockInstrumentation from .embedchain import EmbedchainInstrumentation from .litellm import LiteLLMInstrumentation +from .pymongo import PyMongoInstrumentation from .cerebras import CerebrasInstrumentation __all__ = [ @@ -46,6 +47,7 @@ "VertexAIInstrumentation", "GeminiInstrumentation", "MistralInstrumentation", + "PyMongoInstrumentation", "AWSBedrockInstrumentation", "CerebrasInstrumentation", ] diff --git a/src/langtrace_python_sdk/instrumentation/openai/patch.py b/src/langtrace_python_sdk/instrumentation/openai/patch.py index 3b0da8b3..85af2ab3 100644 --- a/src/langtrace_python_sdk/instrumentation/openai/patch.py +++ b/src/langtrace_python_sdk/instrumentation/openai/patch.py @@ -27,6 +27,7 @@ set_event_completion, StreamWrapper, set_span_attributes, + set_usage_attributes, ) from langtrace_python_sdk.types import NOT_GIVEN @@ -450,6 +451,14 @@ def traced_method( span_attributes[SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS] = json.dumps( [kwargs.get("input", "")] ) + span_attributes[SpanAttributes.LLM_PROMPTS] = json.dumps( + [ + { + "role": "user", + "content": kwargs.get("input"), + } + ] + ) attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) @@ -463,6 +472,11 @@ def traced_method( try: # Attempt to call the original method result = wrapped(*args, **kwargs) + usage = getattr(result, "usage", None) + if usage: + set_usage_attributes( + span, {"prompt_tokens": getattr(usage, "prompt_tokens", 0)} + ) span.set_status(StatusCode.OK) return result except Exception as err: diff --git a/src/langtrace_python_sdk/instrumentation/pymongo/__init__.py b/src/langtrace_python_sdk/instrumentation/pymongo/__init__.py new file mode 100644 index 00000000..c197384c --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/pymongo/__init__.py @@ -0,0 +1,5 @@ +from .instrumentation import PyMongoInstrumentation + +__all__ = [ + "PyMongoInstrumentation", +] diff --git a/src/langtrace_python_sdk/instrumentation/pymongo/instrumentation.py b/src/langtrace_python_sdk/instrumentation/pymongo/instrumentation.py new file mode 100644 index 00000000..762394e6 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/pymongo/instrumentation.py @@ -0,0 +1,47 @@ +""" +Copyright (c) 2024 Scale3 Labs + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.trace import get_tracer + +from typing import Collection +from importlib_metadata import version as v +from wrapt import wrap_function_wrapper as _W +from .patch import generic_patch +from langtrace_python_sdk.constants.instrumentation.pymongo import APIS + + +class PyMongoInstrumentation(BaseInstrumentor): + """ + The PyMongoInstrumentation class represents the PyMongo instrumentation + """ + + def instrumentation_dependencies(self) -> Collection[str]: + return ["pymongo >= 4.0.0"] + + def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider") + tracer = get_tracer(__name__, "", tracer_provider) + version = v("pymongo") + for api in APIS.values(): + _W( + module=api["MODULE"], + name=api["METHOD"], + wrapper=generic_patch(api["SPAN_NAME"], version, tracer), + ) + + def _uninstrument(self, **kwargs): + pass diff --git a/src/langtrace_python_sdk/instrumentation/pymongo/patch.py b/src/langtrace_python_sdk/instrumentation/pymongo/patch.py new file mode 100644 index 00000000..dd0a59dc --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/pymongo/patch.py @@ -0,0 +1,66 @@ +from langtrace_python_sdk.utils.llm import ( + get_langtrace_attributes, + get_span_name, + set_span_attributes, + set_span_attribute, +) +from langtrace_python_sdk.utils import deduce_args_and_kwargs, handle_span_error +from opentelemetry.trace import SpanKind +from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS +from langtrace.trace_attributes import DatabaseSpanAttributes + +import json + + +def generic_patch(name, version, tracer): + def traced_method(wrapped, instance, args, kwargs): + database = instance.database.__dict__ + span_attributes = { + **get_langtrace_attributes( + version=version, + service_provider=SERVICE_PROVIDERS["MONGODB"], + vendor_type="vectordb", + ), + "db.system": "mongodb", + "db.query": "aggregate", + } + + attributes = DatabaseSpanAttributes(**span_attributes) + + with tracer.start_as_current_span( + get_span_name(name), kind=SpanKind.CLIENT + ) as span: + if span.is_recording(): + set_input_attributes( + span, deduce_args_and_kwargs(wrapped, *args, **kwargs) + ) + set_span_attributes(span, attributes) + + try: + result = wrapped(*args, **kwargs) + print(result) + for doc in result: + if span.is_recording(): + span.add_event( + name="db.query.match", + attributes={**doc}, + ) + return result + except Exception as err: + handle_span_error(span, err) + raise + + return traced_method + + +def set_input_attributes(span, args): + pipeline = args.get("pipeline", None) + for stage in pipeline: + for k, v in stage.items(): + if k == "$vectorSearch": + set_span_attribute(span, "db.index", v.get("index", None)) + set_span_attribute(span, "db.path", v.get("path", None)) + set_span_attribute(span, "db.top_k", v.get("numCandidates")) + set_span_attribute(span, "db.limit", v.get("limit")) + else: + set_span_attribute(span, k, json.dumps(v)) diff --git a/src/langtrace_python_sdk/langtrace.py b/src/langtrace_python_sdk/langtrace.py index 5ad2bdda..a86c613d 100644 --- a/src/langtrace_python_sdk/langtrace.py +++ b/src/langtrace_python_sdk/langtrace.py @@ -64,6 +64,7 @@ AutogenInstrumentation, VertexAIInstrumentation, WeaviateInstrumentation, + PyMongoInstrumentation, CerebrasInstrumentation, ) from opentelemetry.util.re import parse_env_headers @@ -281,6 +282,7 @@ def init( "mistralai": MistralInstrumentation(), "boto3": AWSBedrockInstrumentation(), "autogen": AutogenInstrumentation(), + "pymongo": PyMongoInstrumentation(), "cerebras-cloud-sdk": CerebrasInstrumentation(), } diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 3e2d550b..80014d0e 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.3.2" +__version__ = "3.3.3"