diff --git a/src/examples/milvus_example/main.py b/src/examples/milvus_example/main.py new file mode 100644 index 00000000..48ec7c06 --- /dev/null +++ b/src/examples/milvus_example/main.py @@ -0,0 +1,106 @@ +from pymilvus import MilvusClient, model +from typing import List +from langtrace_python_sdk import langtrace, with_langtrace_root_span +from dotenv import load_dotenv + +load_dotenv() +langtrace.init(write_spans_to_console=False) + +client = MilvusClient("milvus_demo.db") + +COLLECTION_NAME = "demo_collection" +embedding_fn = model.DefaultEmbeddingFunction() + + +def create_collection(collection_name: str = COLLECTION_NAME): + if client.has_collection(collection_name=collection_name): + client.drop_collection(collection_name=collection_name) + + client.create_collection( + collection_name=collection_name, + dimension=768, # The vectors we will use in this demo has 768 dimensions + ) + + +def create_embedding(docs: List[str] = [], subject: str = "history"): + """ + Create embeddings for the given documents. + """ + + vectors = embedding_fn.encode_documents(docs) + # Each entity has id, vector representation, raw text, and a subject label that we use + # to demo metadata filtering later. + data = [ + {"id": i, "vector": vectors[i], "text": docs[i], "subject": subject} + for i in range(len(vectors)) + ] + # print("Data has", len(data), "entities, each with fields: ", data[0].keys()) + # print("Vector dim:", len(data[0]["vector"])) + return data + + +def insert_data(collection_name: str = COLLECTION_NAME, data: List[dict] = []): + client.insert( + collection_name=collection_name, + data=data, + ) + + +def vector_search(collection_name: str = COLLECTION_NAME, queries: List[str] = []): + query_vectors = embedding_fn.encode_queries(queries) + # If you don't have the embedding function you can use a fake vector to finish the demo: + # query_vectors = [ [ random.uniform(-1, 1) for _ in range(768) ] ] + + res = client.search( + collection_name="demo_collection", # target collection + data=query_vectors, # query vectors + limit=2, # number of returned entities + output_fields=["text", "subject"], # specifies fields to be returned + timeout=10, + partition_names=["history"], + anns_field="vector", + search_params={"nprobe": 10}, + ) + + +def query(collection_name: str = COLLECTION_NAME, query: str = ""): + res = client.query( + collection_name=collection_name, + filter=query, + # output_fields=["text", "subject"], + ) + + # print(res) + + +@with_langtrace_root_span("milvus_example") +def main(): + create_collection() + # insert Alan Turing's history + turing_data = create_embedding( + docs=[ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", + ] + ) + insert_data(data=turing_data) + + # insert AI Drug Discovery + drug_data = create_embedding( + docs=[ + "Machine learning has been used for drug design.", + "Computational synthesis with AI algorithms predicts molecular properties.", + "DDR1 is involved in cancers and fibrosis.", + ], + subject="biology", + ) + insert_data(data=drug_data) + + vector_search(queries=["Who is Alan Turing?"]) + query(query="subject == 'history'") + query(query="subject == 'biology'") + + +if __name__ == "__main__": + main() diff --git a/src/langtrace_python_sdk/constants/instrumentation/common.py b/src/langtrace_python_sdk/constants/instrumentation/common.py index 937c596f..ae5d7003 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/common.py +++ b/src/langtrace_python_sdk/constants/instrumentation/common.py @@ -37,6 +37,7 @@ "MONGODB": "MongoDB", "AWS_BEDROCK": "AWS Bedrock", "CEREBRAS": "Cerebras", + "MILVUS": "Milvus", } LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes" diff --git a/src/langtrace_python_sdk/constants/instrumentation/milvus.py b/src/langtrace_python_sdk/constants/instrumentation/milvus.py new file mode 100644 index 00000000..cf1a1c11 --- /dev/null +++ b/src/langtrace_python_sdk/constants/instrumentation/milvus.py @@ -0,0 +1,38 @@ +APIS = { + "INSERT": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.insert", + "OPERATION": "insert", + "SPAN_NAME": "Milvus Insert", + }, + "QUERY": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.query", + "OPERATION": "query", + "SPAN_NAME": "Milvus Query", + }, + "SEARCH": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.search", + "OPERATION": "search", + "SPAN_NAME": "Milvus Search", + }, + "DELETE": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.delete", + "OPERATION": "delete", + "SPAN_NAME": "Milvus Delete", + }, + "CREATE_COLLECTION": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.create_collection", + "OPERATION": "create_collection", + "SPAN_NAME": "Milvus Create Collection", + }, + "UPSERT": { + "MODULE": "pymilvus", + "METHOD": "MilvusClient.upsert", + "OPERATION": "upsert", + "SPAN_NAME": "Milvus Upsert", + }, +} diff --git a/src/langtrace_python_sdk/instrumentation/__init__.py b/src/langtrace_python_sdk/instrumentation/__init__.py index 2f13fddd..952a4e10 100644 --- a/src/langtrace_python_sdk/instrumentation/__init__.py +++ b/src/langtrace_python_sdk/instrumentation/__init__.py @@ -23,6 +23,7 @@ from .litellm import LiteLLMInstrumentation from .pymongo import PyMongoInstrumentation from .cerebras import CerebrasInstrumentation +from .milvus import MilvusInstrumentation __all__ = [ "AnthropicInstrumentation", @@ -50,4 +51,5 @@ "PyMongoInstrumentation", "AWSBedrockInstrumentation", "CerebrasInstrumentation", + "MilvusInstrumentation", ] diff --git a/src/langtrace_python_sdk/instrumentation/milvus/__init__.py b/src/langtrace_python_sdk/instrumentation/milvus/__init__.py new file mode 100644 index 00000000..1e760673 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/milvus/__init__.py @@ -0,0 +1,3 @@ +from .instrumentation import MilvusInstrumentation + +__all__ = ["MilvusInstrumentation"] diff --git a/src/langtrace_python_sdk/instrumentation/milvus/instrumentation.py b/src/langtrace_python_sdk/instrumentation/milvus/instrumentation.py new file mode 100644 index 00000000..80bd1ea5 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/milvus/instrumentation.py @@ -0,0 +1,29 @@ +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.trace import get_tracer + +from typing import Collection +from importlib_metadata import version as v +from wrapt import wrap_function_wrapper as _W + +from langtrace_python_sdk.constants.instrumentation.milvus import APIS +from .patch import generic_patch + + +class MilvusInstrumentation(BaseInstrumentor): + + def instrumentation_dependencies(self) -> Collection[str]: + return ["pymilvus >= 2.4.1"] + + def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider") + tracer = get_tracer(__name__, "", tracer_provider) + version = v("pymilvus") + for api in APIS.values(): + _W( + module=api["MODULE"], + name=api["METHOD"], + wrapper=generic_patch(api, version, tracer), + ) + + def _uninstrument(self, **kwargs): + pass diff --git a/src/langtrace_python_sdk/instrumentation/milvus/patch.py b/src/langtrace_python_sdk/instrumentation/milvus/patch.py new file mode 100644 index 00000000..2c19ac79 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/milvus/patch.py @@ -0,0 +1,125 @@ +from langtrace_python_sdk.utils.silently_fail import silently_fail +from opentelemetry.trace import Tracer +from opentelemetry.trace import SpanKind +from langtrace_python_sdk.utils import handle_span_error, set_span_attribute +from langtrace_python_sdk.utils.llm import ( + get_extra_attributes, + set_span_attributes, +) +import json + + +def generic_patch(api, version: str, tracer: Tracer): + def traced_method(wrapped, instance, args, kwargs): + span_name = api["SPAN_NAME"] + operation = api["OPERATION"] + with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span: + try: + span_attributes = { + "db.system": "milvus", + "db.operation": operation, + "db.name": kwargs.get("collection_name", None), + **get_extra_attributes(), + } + + if operation == "create_collection": + set_create_collection_attributes(span_attributes, kwargs) + + elif operation == "insert" or operation == "upsert": + set_insert_or_upsert_attributes(span_attributes, kwargs) + + elif operation == "search": + set_search_attributes(span_attributes, kwargs) + + elif operation == "query": + set_query_attributes(span_attributes, kwargs) + + set_span_attributes(span, span_attributes) + result = wrapped(*args, **kwargs) + + if operation == "query": + set_query_response_attributes(span, result) + + if operation == "search": + set_search_response_attributes(span, result) + return result + except Exception as err: + handle_span_error(span, err) + raise + + return traced_method + + +@silently_fail +def set_create_collection_attributes(span_attributes, kwargs): + span_attributes["db.dimension"] = kwargs.get("dimension", None) + + +@silently_fail +def set_insert_or_upsert_attributes(span_attributes, kwargs): + data = kwargs.get("data") + timeout = kwargs.get("timeout") + partition_name = kwargs.get("partition_name") + + span_attributes["db.num_entities"] = len(data) if data else None + span_attributes["db.timeout"] = timeout + span_attributes["db.partition_name"] = partition_name + + +@silently_fail +def set_search_attributes(span_attributes, kwargs): + data = kwargs.get("data") + filter = kwargs.get("filter") + limit = kwargs.get("limit") + output_fields = kwargs.get("output_fields") + search_params = kwargs.get("search_params") + timeout = kwargs.get("timeout") + partition_names = kwargs.get("partition_names") + anns_field = kwargs.get("anns_field") + span_attributes["db.num_queries"] = len(data) if data else None + span_attributes["db.filter"] = filter + span_attributes["db.limit"] = limit + span_attributes["db.output_fields"] = json.dumps(output_fields) + span_attributes["db.search_params"] = json.dumps(search_params) + span_attributes["db.partition_names"] = json.dumps(partition_names) + span_attributes["db.anns_field"] = anns_field + span_attributes["db.timeout"] = timeout + + +@silently_fail +def set_query_attributes(span_attributes, kwargs): + filter = kwargs.get("filter") + output_fields = kwargs.get("output_fields") + timeout = kwargs.get("timeout") + partition_names = kwargs.get("partition_names") + ids = kwargs.get("ids") + + span_attributes["db.filter"] = filter + span_attributes["db.output_fields"] = output_fields + span_attributes["db.timeout"] = timeout + span_attributes["db.partition_names"] = partition_names + span_attributes["db.ids"] = ids + + +@silently_fail +def set_query_response_attributes(span, result): + set_span_attribute(span, name="db.num_matches", value=len(result)) + for match in result: + span.add_event( + "db.query.match", + attributes=match, + ) + + +@silently_fail +def set_search_response_attributes(span, result): + for res in result: + for match in res: + span.add_event( + "db.search.match", + attributes={ + "id": match["id"], + "distance": str(match["distance"]), + "entity": json.dumps(match["entity"]), + }, + ) diff --git a/src/langtrace_python_sdk/instrumentation/pymongo/patch.py b/src/langtrace_python_sdk/instrumentation/pymongo/patch.py index dd0a59dc..773c7482 100644 --- a/src/langtrace_python_sdk/instrumentation/pymongo/patch.py +++ b/src/langtrace_python_sdk/instrumentation/pymongo/patch.py @@ -38,7 +38,6 @@ def traced_method(wrapped, instance, args, kwargs): try: result = wrapped(*args, **kwargs) - print(result) for doc in result: if span.is_recording(): span.add_event( diff --git a/src/langtrace_python_sdk/langtrace.py b/src/langtrace_python_sdk/langtrace.py index a86c613d..e742e754 100644 --- a/src/langtrace_python_sdk/langtrace.py +++ b/src/langtrace_python_sdk/langtrace.py @@ -66,6 +66,7 @@ WeaviateInstrumentation, PyMongoInstrumentation, CerebrasInstrumentation, + MilvusInstrumentation, ) from opentelemetry.util.re import parse_env_headers @@ -284,6 +285,7 @@ def init( "autogen": AutogenInstrumentation(), "pymongo": PyMongoInstrumentation(), "cerebras-cloud-sdk": CerebrasInstrumentation(), + "pymilvus": MilvusInstrumentation(), } init_instrumentations(config.disable_instrumentations, all_instrumentations) diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 8a6dd7cc..c90ab1ba 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.3.4" +__version__ = "3.3.5"