Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/examples/milvus_example/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pymilvus import MilvusClient, model
from typing import List
from langtrace_python_sdk import langtrace, with_langtrace_root_span
from dotenv import load_dotenv

load_dotenv()
langtrace.init(write_spans_to_console=False)

client = MilvusClient("milvus_demo.db")

COLLECTION_NAME = "demo_collection"
embedding_fn = model.DefaultEmbeddingFunction()


def create_collection(collection_name: str = COLLECTION_NAME):
if client.has_collection(collection_name=collection_name):
client.drop_collection(collection_name=collection_name)

client.create_collection(
collection_name=collection_name,
dimension=768, # The vectors we will use in this demo has 768 dimensions
)


def create_embedding(docs: List[str] = [], subject: str = "history"):
"""
Create embeddings for the given documents.
"""

vectors = embedding_fn.encode_documents(docs)
# Each entity has id, vector representation, raw text, and a subject label that we use
# to demo metadata filtering later.
data = [
{"id": i, "vector": vectors[i], "text": docs[i], "subject": subject}
for i in range(len(vectors))
]
# print("Data has", len(data), "entities, each with fields: ", data[0].keys())
# print("Vector dim:", len(data[0]["vector"]))
return data


def insert_data(collection_name: str = COLLECTION_NAME, data: List[dict] = []):
client.insert(
collection_name=collection_name,
data=data,
)


def vector_search(collection_name: str = COLLECTION_NAME, queries: List[str] = []):
query_vectors = embedding_fn.encode_queries(queries)
# If you don't have the embedding function you can use a fake vector to finish the demo:
# query_vectors = [ [ random.uniform(-1, 1) for _ in range(768) ] ]

res = client.search(
collection_name="demo_collection", # target collection
data=query_vectors, # query vectors
limit=2, # number of returned entities
output_fields=["text", "subject"], # specifies fields to be returned
timeout=10,
partition_names=["history"],
anns_field="vector",
search_params={"nprobe": 10},
)


def query(collection_name: str = COLLECTION_NAME, query: str = ""):
res = client.query(
collection_name=collection_name,
filter=query,
# output_fields=["text", "subject"],
)

# print(res)


@with_langtrace_root_span("milvus_example")
def main():
create_collection()
# insert Alan Turing's history
turing_data = create_embedding(
docs=[
"Artificial intelligence was founded as an academic discipline in 1956.",
"Alan Turing was the first person to conduct substantial research in AI.",
"Born in Maida Vale, London, Turing was raised in southern England.",
]
)
insert_data(data=turing_data)

# insert AI Drug Discovery
drug_data = create_embedding(
docs=[
"Machine learning has been used for drug design.",
"Computational synthesis with AI algorithms predicts molecular properties.",
"DDR1 is involved in cancers and fibrosis.",
],
subject="biology",
)
insert_data(data=drug_data)

vector_search(queries=["Who is Alan Turing?"])
query(query="subject == 'history'")
query(query="subject == 'biology'")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"MONGODB": "MongoDB",
"AWS_BEDROCK": "AWS Bedrock",
"CEREBRAS": "Cerebras",
"MILVUS": "Milvus",
}

LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
38 changes: 38 additions & 0 deletions src/langtrace_python_sdk/constants/instrumentation/milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
APIS = {
"INSERT": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.insert",
"OPERATION": "insert",
"SPAN_NAME": "Milvus Insert",
},
"QUERY": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.query",
"OPERATION": "query",
"SPAN_NAME": "Milvus Query",
},
"SEARCH": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.search",
"OPERATION": "search",
"SPAN_NAME": "Milvus Search",
},
"DELETE": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.delete",
"OPERATION": "delete",
"SPAN_NAME": "Milvus Delete",
},
"CREATE_COLLECTION": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.create_collection",
"OPERATION": "create_collection",
"SPAN_NAME": "Milvus Create Collection",
},
"UPSERT": {
"MODULE": "pymilvus",
"METHOD": "MilvusClient.upsert",
"OPERATION": "upsert",
"SPAN_NAME": "Milvus Upsert",
},
}
2 changes: 2 additions & 0 deletions src/langtrace_python_sdk/instrumentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .litellm import LiteLLMInstrumentation
from .pymongo import PyMongoInstrumentation
from .cerebras import CerebrasInstrumentation
from .milvus import MilvusInstrumentation

__all__ = [
"AnthropicInstrumentation",
Expand Down Expand Up @@ -50,4 +51,5 @@
"PyMongoInstrumentation",
"AWSBedrockInstrumentation",
"CerebrasInstrumentation",
"MilvusInstrumentation",
]
3 changes: 3 additions & 0 deletions src/langtrace_python_sdk/instrumentation/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .instrumentation import MilvusInstrumentation

__all__ = ["MilvusInstrumentation"]
29 changes: 29 additions & 0 deletions src/langtrace_python_sdk/instrumentation/milvus/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace import get_tracer

from typing import Collection
from importlib_metadata import version as v
from wrapt import wrap_function_wrapper as _W

from langtrace_python_sdk.constants.instrumentation.milvus import APIS
from .patch import generic_patch


class MilvusInstrumentation(BaseInstrumentor):

def instrumentation_dependencies(self) -> Collection[str]:
return ["pymilvus >= 2.4.1"]

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, "", tracer_provider)
version = v("pymilvus")
for api in APIS.values():
_W(
module=api["MODULE"],
name=api["METHOD"],
wrapper=generic_patch(api, version, tracer),
)

def _uninstrument(self, **kwargs):
pass
125 changes: 125 additions & 0 deletions src/langtrace_python_sdk/instrumentation/milvus/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from langtrace_python_sdk.utils.silently_fail import silently_fail
from opentelemetry.trace import Tracer
from opentelemetry.trace import SpanKind
from langtrace_python_sdk.utils import handle_span_error, set_span_attribute
from langtrace_python_sdk.utils.llm import (
get_extra_attributes,
set_span_attributes,
)
import json


def generic_patch(api, version: str, tracer: Tracer):
def traced_method(wrapped, instance, args, kwargs):
span_name = api["SPAN_NAME"]
operation = api["OPERATION"]
with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
try:
span_attributes = {
"db.system": "milvus",
"db.operation": operation,
"db.name": kwargs.get("collection_name", None),
**get_extra_attributes(),
}

if operation == "create_collection":
set_create_collection_attributes(span_attributes, kwargs)

elif operation == "insert" or operation == "upsert":
set_insert_or_upsert_attributes(span_attributes, kwargs)

elif operation == "search":
set_search_attributes(span_attributes, kwargs)

elif operation == "query":
set_query_attributes(span_attributes, kwargs)

set_span_attributes(span, span_attributes)
result = wrapped(*args, **kwargs)

if operation == "query":
set_query_response_attributes(span, result)

if operation == "search":
set_search_response_attributes(span, result)
return result
except Exception as err:
handle_span_error(span, err)
raise

return traced_method


@silently_fail
def set_create_collection_attributes(span_attributes, kwargs):
span_attributes["db.dimension"] = kwargs.get("dimension", None)


@silently_fail
def set_insert_or_upsert_attributes(span_attributes, kwargs):
data = kwargs.get("data")
timeout = kwargs.get("timeout")
partition_name = kwargs.get("partition_name")

span_attributes["db.num_entities"] = len(data) if data else None
span_attributes["db.timeout"] = timeout
span_attributes["db.partition_name"] = partition_name


@silently_fail
def set_search_attributes(span_attributes, kwargs):
data = kwargs.get("data")
filter = kwargs.get("filter")
limit = kwargs.get("limit")
output_fields = kwargs.get("output_fields")
search_params = kwargs.get("search_params")
timeout = kwargs.get("timeout")
partition_names = kwargs.get("partition_names")
anns_field = kwargs.get("anns_field")
span_attributes["db.num_queries"] = len(data) if data else None
span_attributes["db.filter"] = filter
span_attributes["db.limit"] = limit
span_attributes["db.output_fields"] = json.dumps(output_fields)
span_attributes["db.search_params"] = json.dumps(search_params)
span_attributes["db.partition_names"] = json.dumps(partition_names)
span_attributes["db.anns_field"] = anns_field
span_attributes["db.timeout"] = timeout


@silently_fail
def set_query_attributes(span_attributes, kwargs):
filter = kwargs.get("filter")
output_fields = kwargs.get("output_fields")
timeout = kwargs.get("timeout")
partition_names = kwargs.get("partition_names")
ids = kwargs.get("ids")

span_attributes["db.filter"] = filter
span_attributes["db.output_fields"] = output_fields
span_attributes["db.timeout"] = timeout
span_attributes["db.partition_names"] = partition_names
span_attributes["db.ids"] = ids


@silently_fail
def set_query_response_attributes(span, result):
set_span_attribute(span, name="db.num_matches", value=len(result))
for match in result:
span.add_event(
"db.query.match",
attributes=match,
)


@silently_fail
def set_search_response_attributes(span, result):
for res in result:
for match in res:
span.add_event(
"db.search.match",
attributes={
"id": match["id"],
"distance": str(match["distance"]),
"entity": json.dumps(match["entity"]),
},
)
1 change: 0 additions & 1 deletion src/langtrace_python_sdk/instrumentation/pymongo/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def traced_method(wrapped, instance, args, kwargs):

try:
result = wrapped(*args, **kwargs)
print(result)
for doc in result:
if span.is_recording():
span.add_event(
Expand Down
2 changes: 2 additions & 0 deletions src/langtrace_python_sdk/langtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
WeaviateInstrumentation,
PyMongoInstrumentation,
CerebrasInstrumentation,
MilvusInstrumentation,
)
from opentelemetry.util.re import parse_env_headers

Expand Down Expand Up @@ -284,6 +285,7 @@ def init(
"autogen": AutogenInstrumentation(),
"pymongo": PyMongoInstrumentation(),
"cerebras-cloud-sdk": CerebrasInstrumentation(),
"pymilvus": MilvusInstrumentation(),
}

init_instrumentations(config.disable_instrumentations, all_instrumentations)
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.4"
__version__ = "3.3.5"
Loading