diff --git a/libs/arangodb/.coverage b/libs/arangodb/.coverage index 52fc3cf..35611f6 100644 Binary files a/libs/arangodb/.coverage and b/libs/arangodb/.coverage differ diff --git a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py index 8a836e7..14e8941 100644 --- a/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py +++ b/libs/arangodb/tests/integration_tests/chat_message_histories/test_arangodb.py @@ -1,8 +1,14 @@ +import os + import pytest +from arango import ArangoClient from arango.database import StandardDatabase +from arango.exceptions import ArangoError from langchain_core.messages import AIMessage, HumanMessage from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory +from langchain_arangodb.graphs.arangodb_graph import ArangoGraph +from tests.integration_tests.utils import ArangoCredentials @pytest.mark.usefixtures("clear_arangodb_database") @@ -44,3 +50,108 @@ def test_add_messages(db: StandardDatabase) -> None: message_store_another.clear() assert len(message_store.messages) == 0 assert len(message_store_another.messages) == 0 + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_add_messages_graph_object(arangodb_credentials: ArangoCredentials) -> None: + """Basic testing: Passing driver through graph object.""" + graph = ArangoGraph.from_db_credentials( + url=arangodb_credentials["url"], + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # rewrite env for testing + old_username = os.environ.get("ARANGO_USERNAME", "root") + os.environ["ARANGO_USERNAME"] = "foo" + + message_store = ArangoChatMessageHistory("23334", db=graph.db) + message_store.clear() + assert len(message_store.messages) == 0 + message_store.add_user_message("Hello! Language Chain!") + message_store.add_ai_message("Hi Guys!") + # Now check if the messages are stored in the database correctly + assert len(message_store.messages) == 2 + + # Restore original environment + os.environ["ARANGO_USERNAME"] = old_username + + +def test_invalid_credentials(arangodb_credentials: ArangoCredentials) -> None: + """Test initializing with invalid credentials raises an authentication error.""" + with pytest.raises(ArangoError) as exc_info: + client = ArangoClient(arangodb_credentials["url"]) + db = client.db(username="invalid_username", password="invalid_password") + # Try to perform a database operation to trigger an authentication error + db.collections() + + # Check for any authentication-related error message + error_msg = str(exc_info.value) + # Just check for "error" which should be in any auth error + assert "not authorized" in error_msg + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangodb_message_history_clear_messages( + db: StandardDatabase, +) -> None: + """Test adding multiple messages at once to ArangoChatMessageHistory.""" + # Specify a custom collection name that includes the session_id + collection_name = "chat_history_123" + message_history = ArangoChatMessageHistory( + session_id="123", db=db, collection_name=collection_name + ) + message_history.add_messages( + [ + HumanMessage(content="You are a helpful assistant."), + AIMessage(content="Hello"), + ] + ) + assert len(message_history.messages) == 2 + assert isinstance(message_history.messages[0], HumanMessage) + assert isinstance(message_history.messages[1], AIMessage) + assert message_history.messages[0].content == "You are a helpful assistant." + assert message_history.messages[1].content == "Hello" + + message_history.clear() + assert len(message_history.messages) == 0 + + # Verify all messages are removed but collection still exists + assert db.has_collection(message_history._collection_name) + assert message_history._collection_name == collection_name + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangodb_message_history_clear_session_collection( + db: StandardDatabase, +) -> None: + """Test clearing messages and removing the collection for a session.""" + # Create a test collection specific to the session + session_id = "456" + collection_name = f"chat_history_{session_id}" + + if not db.has_collection(collection_name): + db.create_collection(collection_name) + + message_history = ArangoChatMessageHistory( + session_id=session_id, db=db, collection_name=collection_name + ) + + message_history.add_messages( + [ + HumanMessage(content="You are a helpful assistant."), + AIMessage(content="Hello"), + ] + ) + assert len(message_history.messages) == 2 + + # Clear messages + message_history.clear() + assert len(message_history.messages) == 0 + + # The collection should still exist after clearing messages + assert db.has_collection(collection_name) + + # Delete the collection (equivalent to delete_session_node in Neo4j) + db.delete_collection(collection_name) + assert not db.has_collection(collection_name) diff --git a/libs/arangodb/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/arangodb/tests/integration_tests/vectorstores/fake_embeddings.py new file mode 100644 index 0000000..9b19c4a --- /dev/null +++ b/libs/arangodb/tests/integration_tests/vectorstores/fake_embeddings.py @@ -0,0 +1,111 @@ +"""Fake Embedding class for testing purposes.""" + +import math +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimension: int = 10): + if dimension < 1: + raise ValueError( + "Dimension must be at least 1 for this FakeEmbeddings style." + ) + self.dimension = dimension + # global_fake_texts maps query texts to the 'i' in [1.0]*(dim-1) + [float(i)] + self.global_fake_texts = ["foo", "bar", "baz", "qux", "quux", "corge", "grault"] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + if self.dimension == 1: + # Special case for dimension 1: just use the index + return [[float(i)] for i in range(len(texts))] + else: + return [ + [1.0] * (self.dimension - 1) + [float(i)] for i in range(len(texts)) + ] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + try: + idx = self.global_fake_texts.index(text) + val = float(idx) + except ValueError: + # Text not in global_fake_texts, use a default 'unknown query' value + val = -1.0 + + if self.dimension == 1: + return [val] # Corrected: List[float] + else: + return [1.0] * (self.dimension - 1) + [val] + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + + @property + def identifer(self) -> str: + return "fake" + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return self.embed_documents([text])[0] + + +class AngularTwoDimensionalEmbeddings(Embeddings): + """ + From angles (as strings in units of pi) to unit embedding vectors on a circle. + """ + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Make a list of texts into a list of embedding vectors. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """ + Convert input text to a 'vector' (list of floats). + If the text is a number, use it as the angle for the + unit vector in units of pi. + Any other input text becomes the singular result [0, 0] ! + """ + try: + angle = float(text) + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + except ValueError: + # Assume: just test string, no attention is paid to values. + return [0.0, 0.0] diff --git a/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py b/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py new file mode 100644 index 0000000..c6a81f4 --- /dev/null +++ b/libs/arangodb/tests/integration_tests/vectorstores/test_arangodb_vector.py @@ -0,0 +1,1064 @@ +"""Integration tests for ArangoVector.""" + +from typing import Any, Dict, List + +import pytest +from arango import ArangoClient +from arango.collection import StandardCollection +from arango.cursor import Cursor +from langchain_core.documents import Document + +from langchain_arangodb.vectorstores.arangodb_vector import ArangoVector +from langchain_arangodb.vectorstores.utils import DistanceStrategy +from tests.integration_tests.utils import ArangoCredentials + +from .fake_embeddings import FakeEmbeddings + +EMBEDDING_DIMENSION = 10 + + +@pytest.fixture(scope="session") +def fake_embedding_function() -> FakeEmbeddings: + """Provides a FakeEmbeddings instance.""" + return FakeEmbeddings() + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_from_texts_and_similarity_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test end-to-end construction from texts and basic similarity search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + # Try to create a collection to force a connection error + if not db.has_collection( + "test_collection_init" + ): # Use a different name to avoid conflict if already exists + _test_init_coll = db.create_collection("test_collection_init") + assert isinstance(_test_init_coll, StandardCollection) + + texts_to_embed = ["hello world", "hello arango", "test document"] + metadatas = [{"source": "doc1"}, {"source": "doc2"}, {"source": "doc3"}] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + metadatas=metadatas, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, # Ensure clean state for the index + ) + + # Manually create the index as from_texts with overwrite=True only deletes it + # in the current version of arangodb_vector.py + vector_store.create_vector_index() + + # Check if the collection was created + assert db.has_collection("test_collection") + _collection_obj = db.collection("test_collection") + assert isinstance(_collection_obj, StandardCollection) + collection: StandardCollection = _collection_obj + assert collection.count() == len(texts_to_embed) + + # Check if the index was created + index_info = None + indexes_raw = collection.indexes() + assert indexes_raw is not None, "collection.indexes() returned None" + assert isinstance( + indexes_raw, list + ), f"collection.indexes() expected list, got {type(indexes_raw)}" + indexes: List[Dict[str, Any]] = indexes_raw + for index in indexes: + if index.get("name") == "test_index" and index.get("type") == "vector": + index_info = index + break + assert index_info is not None + assert index_info["fields"] == ["embedding"] # Default embedding field + + # Test similarity search + query = "hello" + results = vector_store.similarity_search(query, k=1, return_fields={"source"}) + + assert len(results) == 1 + assert results[0].page_content == "hello world" + assert results[0].metadata.get("source") == "doc1" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_euclidean_distance( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test ArangoVector with Euclidean distance.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["docA", "docB", "docC"] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + index_name="test_index", + distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, + overwrite_index=True, + ) + + # Manually create the index as from_texts with overwrite=True only deletes it + vector_store.create_vector_index() + + # Check index metric + _collection_obj_euclidean = db.collection("test_collection") + assert isinstance(_collection_obj_euclidean, StandardCollection) + collection_euclidean: StandardCollection = _collection_obj_euclidean + index_info = None + indexes_raw_euclidean = collection_euclidean.indexes() + assert ( + indexes_raw_euclidean is not None + ), "collection_euclidean.indexes() returned None" + assert isinstance( + indexes_raw_euclidean, list + ), f"collection_euclidean.indexes() expected list, \ + got {type(indexes_raw_euclidean)}" + indexes_euclidean: List[Dict[str, Any]] = indexes_raw_euclidean + for index in indexes_euclidean: + if index.get("name") == "test_index" and index.get("type") == "vector": + index_info = index + break + assert index_info is not None + query = "docA" + results = vector_store.similarity_search(query, k=1) + assert len(results) == 1 + assert results[0].page_content == "docA" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_similarity_search_with_score( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test similarity search with scores.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["alpha", "beta", "gamma"] + metadatas = [{"id": 1}, {"id": 2}, {"id": 3}] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + metadatas=metadatas, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, + ) + + query = "foo" + results_with_scores = vector_store.similarity_search_with_score( + query, k=1, return_fields={"id"} + ) + + assert len(results_with_scores) == 1 + doc, score = results_with_scores[0] + + assert doc.page_content == "alpha" + assert doc.metadata.get("id") == 1 + + # Test with exact cosine similarity + results_with_scores_exact = vector_store.similarity_search_with_score( + query, k=1, use_approx=False, return_fields={"id"} + ) + assert len(results_with_scores_exact) == 1 + doc_exact, score_exact = results_with_scores_exact[0] + assert doc_exact.page_content == "alpha" + assert ( + score_exact == 1.0 + ) # Exact cosine similarity should be 1.0 for identical vectors + + # Test with Euclidean distance + vector_store_l2 = ArangoVector.from_texts( + texts=texts_to_embed, # Re-using same texts for simplicity + embedding=fake_embedding_function, + metadatas=metadatas, + database=db, # db is managed by fixture, collection will be overwritten + collection_name="test_collection" + + "_l2", # Use a different collection or ensure overwrite + index_name="test_index" + "_l2", + distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, + overwrite_index=True, + ) + results_with_scores_l2 = vector_store_l2.similarity_search_with_score( + query, k=1, return_fields={"id"} + ) + assert len(results_with_scores_l2) == 1 + doc_l2, score_l2 = results_with_scores_l2[0] + assert doc_l2.page_content == "alpha" + assert score_l2 == 0.0 # For L2 (Euclidean) distance, perfect match is 0.0 + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_add_embeddings_and_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test construction from pre-computed embeddings and search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["apple", "banana", "cherry"] + metadatas = [ + {"fruit_type": "pome"}, + {"fruit_type": "berry"}, + {"fruit_type": "drupe"}, + ] + + # Manually create embeddings + embeddings = fake_embedding_function.embed_documents(texts_to_embed) + + # Initialize ArangoVector - embedding_dimension must match FakeEmbeddings + vector_store = ArangoVector( + embedding=fake_embedding_function, # Still needed for query embedding + embedding_dimension=EMBEDDING_DIMENSION, # Should be 10 from FakeEmbeddings + database=db, + collection_name="test_collection", # Will be created if not exists + vector_index_name="test_index", + ) + + # Add embeddings first, so the index has data to train on + vector_store.add_embeddings(texts_to_embed, embeddings, metadatas=metadatas) + + # Create the index if it doesn't exist + # For similarity_search to work with approx=True (default), an index is needed. + if not vector_store.retrieve_vector_index(): + vector_store.create_vector_index() + + # Check collection count + _collection_obj_add_embed = db.collection("test_collection") + assert isinstance(_collection_obj_add_embed, StandardCollection) + collection_add_embed: StandardCollection = _collection_obj_add_embed + assert collection_add_embed.count() == len(texts_to_embed) + + # Perform search + query = "apple" + results = vector_store.similarity_search(query, k=1, return_fields={"fruit_type"}) + assert len(results) == 1 + assert results[0].page_content == "apple" + assert results[0].metadata.get("fruit_type") == "pome" + + +# NEW TEST +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_retriever_search_threshold( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test using retriever for searching with a score threshold.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["dog", "cat", "mouse"] + metadatas = [ + {"animal_type": "canine"}, + {"animal_type": "feline"}, + {"animal_type": "rodent"}, + ] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + metadatas=metadatas, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, + ) + + # Default is COSINE, perfect match (score 1.0 with exact, close with approx) + # Test with a threshold that should only include a perfect/near-perfect match + retriever = vector_store.as_retriever( + search_type="similarity_score_threshold", + score_threshold=0.95, + search_kwargs={ + "k": 3, + "use_approx": False, + "score_threshold": 0.95, + "return_fields": {"animal_type"}, + }, + ) + + query = "foo" + results = retriever.invoke(query) + + assert len(results) == 1 + assert results[0].page_content == "dog" + assert results[0].metadata.get("animal_type") == "canine" + + retriever_strict = vector_store.as_retriever( + search_type="similarity_score_threshold", + score_threshold=1.01, + search_kwargs={ + "k": 3, + "use_approx": False, + "score_threshold": 1.01, + "return_fields": {"animal_type"}, + }, + ) + results_strict = retriever_strict.invoke(query) + assert len(results_strict) == 0 + + +# NEW TEST +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_delete_documents( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test deleting documents from ArangoVector.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = [ + "doc_to_keep1", + "doc_to_delete1", + "doc_to_keep2", + "doc_to_delete2", + ] + metadatas = [ + {"id_val": 1, "status": "keep"}, + {"id_val": 2, "status": "delete"}, + {"id_val": 3, "status": "keep"}, + {"id_val": 4, "status": "delete"}, + ] + + # Use specific IDs for easier deletion and verification + doc_ids = ["id_keep1", "id_delete1", "id_keep2", "id_delete2"] + + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + metadatas=metadatas, + ids=doc_ids, # Pass our custom IDs + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, + ) + + # Verify initial count + _collection_obj_delete = db.collection("test_collection") + assert isinstance(_collection_obj_delete, StandardCollection) + collection_delete: StandardCollection = _collection_obj_delete + assert collection_delete.count() == 4 + + # IDs to delete + ids_to_delete = ["id_delete1", "id_delete2"] + delete_result = vector_store.delete(ids=ids_to_delete) + assert delete_result is True + + # Verify count after deletion + assert collection_delete.count() == 2 + + # Verify that specific documents are gone and others remain + # Use direct DB checks for presence/absence of docs by ID + + # Check that deleted documents are indeed gone + deleted_docs_check_raw = collection_delete.get_many(ids_to_delete) + assert ( + deleted_docs_check_raw is not None + ), "collection.get_many() returned None for deleted_docs_check" + assert isinstance( + deleted_docs_check_raw, list + ), f"collection.get_many() expected list for deleted_docs_check,\ + got {type(deleted_docs_check_raw)}" + deleted_docs_check: List[Dict[str, Any]] = deleted_docs_check_raw + assert len(deleted_docs_check) == 0 + + # Check that remaining documents are still present + remaining_ids_expected = ["id_keep1", "id_keep2"] + remaining_docs_check_raw = collection_delete.get_many(remaining_ids_expected) + assert ( + remaining_docs_check_raw is not None + ), "collection.get_many() returned None for remaining_docs_check" + assert isinstance( + remaining_docs_check_raw, list + ), f"collection.get_many() expected list for remaining_docs_check,\ + got {type(remaining_docs_check_raw)}" + remaining_docs_check: List[Dict[str, Any]] = remaining_docs_check_raw + assert len(remaining_docs_check) == 2 + + # Optionally, verify content of remaining documents if needed + retrieved_contents = sorted( + [d[vector_store.text_field] for d in remaining_docs_check] + ) + assert retrieved_contents == sorted( + [texts_to_embed[0], texts_to_embed[2]] + ) # doc_to_keep1, doc_to_keep2 + + +# NEW TEST +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_similarity_search_with_return_fields( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test similarity search with specified return_fields for metadata.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts = ["alpha beta", "gamma delta", "epsilon zeta"] + metadatas = [ + {"source": "doc1", "chapter": "ch1", "page": 10, "author": "A"}, + {"source": "doc2", "chapter": "ch2", "page": 20, "author": "B"}, + {"source": "doc3", "chapter": "ch3", "page": 30, "author": "C"}, + ] + doc_ids = ["id1", "id2", "id3"] + + vector_store = ArangoVector.from_texts( + texts=texts, + embedding=fake_embedding_function, + metadatas=metadatas, + ids=doc_ids, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, + ) + + query_text = "alpha beta" + + # Test 1: No return_fields (should return all metadata except embedding_field) + results_all_meta = vector_store.similarity_search( + query_text, k=1, return_fields={"source", "chapter", "page", "author"} + ) + assert len(results_all_meta) == 1 + assert results_all_meta[0].page_content == query_text + expected_meta_all = {"source": "doc1", "chapter": "ch1", "page": 10, "author": "A"} + assert results_all_meta[0].metadata == expected_meta_all + + # Test 2: Specific return_fields + fields_to_return = {"source", "page"} + results_specific_meta = vector_store.similarity_search( + query_text, k=1, return_fields=fields_to_return + ) + assert len(results_specific_meta) == 1 + assert results_specific_meta[0].page_content == query_text + expected_meta_specific = {"source": "doc1", "page": 10} + assert results_specific_meta[0].metadata == expected_meta_specific + + # Test 3: Empty return_fields set + results_empty_set_meta = vector_store.similarity_search( + query_text, k=1, return_fields={"source", "chapter", "page", "author"} + ) + assert len(results_empty_set_meta) == 1 + assert results_empty_set_meta[0].page_content == query_text + assert results_empty_set_meta[0].metadata == expected_meta_all + + # Test 4: return_fields requesting a non-existent field + # and one existing field + fields_with_non_existent = {"source", "non_existent_field"} + results_non_existent_meta = vector_store.similarity_search( + query_text, k=1, return_fields=fields_with_non_existent + ) + assert len(results_non_existent_meta) == 1 + assert results_non_existent_meta[0].page_content == query_text + expected_meta_non_existent = {"source": "doc1"} + assert results_non_existent_meta[0].metadata == expected_meta_non_existent + + +# NEW TEST +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_max_marginal_relevance_search( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, # Using existing FakeEmbeddings +) -> None: + """Test max marginal relevance search.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Texts designed so some are close to each other via FakeEmbeddings + # FakeEmbeddings: embedding[last_dim] = index i + # apple (0), apricot (1) -> similar + # banana (2), blueberry (3) -> similar + # cherry (4) -> distinct + texts = ["apple", "apricot", "banana", "blueberry", "grape"] + metadatas = [ + {"fruit": "apple", "idx": 0}, + {"fruit": "apricot", "idx": 1}, + {"fruit": "banana", "idx": 2}, + {"fruit": "blueberry", "idx": 3}, + {"fruit": "grape", "idx": 4}, + ] + doc_ids = ["id_apple", "id_apricot", "id_banana", "id_blueberry", "id_grape"] + + vector_store = ArangoVector.from_texts( + texts=texts, + embedding=fake_embedding_function, + metadatas=metadatas, + ids=doc_ids, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=True, + ) + + query_text = "foo" + + # Test with lambda_mult = 0.5 (balance between similarity and diversity) + mmr_results = vector_store.max_marginal_relevance_search( + query_text, k=2, fetch_k=4, lambda_mult=0.5, use_approx=False + ) + assert len(mmr_results) == 2 + assert mmr_results[0].page_content == "apple" + # With new FakeEmbeddings, lambda=0.5 should pick "apricot" as second. + assert mmr_results[1].page_content == "apricot" + + result_contents = {doc.page_content for doc in mmr_results} + assert "apple" in result_contents + assert len(result_contents) == 2 # Ensure two distinct docs + + # Test with lambda_mult favoring similarity (e.g., 0.1) + mmr_results_sim = vector_store.max_marginal_relevance_search( + query_text, k=2, fetch_k=4, lambda_mult=0.1, use_approx=False + ) + assert len(mmr_results_sim) == 2 + assert mmr_results_sim[0].page_content == "apple" + assert mmr_results_sim[1].page_content == "blueberry" + + # Test with lambda_mult favoring diversity (e.g., 0.9) + mmr_results_div = vector_store.max_marginal_relevance_search( + query_text, k=2, fetch_k=4, lambda_mult=0.9, use_approx=False + ) + assert len(mmr_results_div) == 2 + assert mmr_results_div[0].page_content == "apple" + assert mmr_results_div[1].page_content == "apricot" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_delete_vector_index( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test creating and deleting a vector index.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + texts_to_embed = ["alpha", "beta", "gamma"] + + # Create the vector store + vector_store = ArangoVector.from_texts( + texts=texts_to_embed, + embedding=fake_embedding_function, + database=db, + collection_name="test_collection", + index_name="test_index", + overwrite_index=False, + ) + + # Create the index explicitly + vector_store.create_vector_index() + + # Verify the index exists + _collection_obj_del_idx = db.collection("test_collection") + assert isinstance(_collection_obj_del_idx, StandardCollection) + collection_del_idx: StandardCollection = _collection_obj_del_idx + index_info = None + indexes_raw_del_idx = collection_del_idx.indexes() + assert indexes_raw_del_idx is not None + assert isinstance(indexes_raw_del_idx, list) + indexes_del_idx: List[Dict[str, Any]] = indexes_raw_del_idx + for index in indexes_del_idx: + if index.get("name") == "test_index" and index.get("type") == "vector": + index_info = index + break + + assert index_info is not None, "Vector index was not created" + + # Now delete the index + vector_store.delete_vector_index() + + # Verify the index no longer exists + indexes_after_delete_raw = collection_del_idx.indexes() + assert indexes_after_delete_raw is not None + assert isinstance(indexes_after_delete_raw, list) + indexes_after_delete: List[Dict[str, Any]] = indexes_after_delete_raw + index_after_delete = None + for index in indexes_after_delete: + if index.get("name") == "test_index" and index.get("type") == "vector": + index_after_delete = index + break + + assert index_after_delete is None, "Vector index was not deleted" + + # Ensure delete_vector_index is idempotent (calling it again doesn't cause errors) + vector_store.delete_vector_index() + + # Recreate the index and verify + vector_store.create_vector_index() + + indexes_after_recreate_raw = collection_del_idx.indexes() + assert indexes_after_recreate_raw is not None + assert isinstance(indexes_after_recreate_raw, list) + indexes_after_recreate: List[Dict[str, Any]] = indexes_after_recreate_raw + index_after_recreate = None + for index in indexes_after_recreate: + if index.get("name") == "test_index" and index.get("type") == "vector": + index_after_recreate = index + break + + assert index_after_recreate is not None, "Vector index was not recreated" + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_get_by_ids( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test retrieving documents by their IDs.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Create test data with specific IDs + texts = ["apple", "banana", "cherry", "date"] + custom_ids = ["fruit_1", "fruit_2", "fruit_3", "fruit_4"] + metadatas = [ + {"type": "pome", "color": "red", "calories": 95}, + {"type": "berry", "color": "yellow", "calories": 105}, + {"type": "drupe", "color": "red", "calories": 50}, + {"type": "drupe", "color": "brown", "calories": 20}, + ] + + # Create the vector store with custom IDs + vector_store = ArangoVector.from_texts( + texts=texts, + embedding=fake_embedding_function, + metadatas=metadatas, + ids=custom_ids, + database=db, + collection_name="test_collection", + ) + + # Create the index explicitly + vector_store.create_vector_index() + + # Test retrieving a single document by ID + single_doc = vector_store.get_by_ids(["fruit_1"]) + assert len(single_doc) == 1 + assert single_doc[0].page_content == "apple" + assert single_doc[0].id == "fruit_1" + assert single_doc[0].metadata["type"] == "pome" + assert single_doc[0].metadata["color"] == "red" + assert single_doc[0].metadata["calories"] == 95 + + # Test retrieving multiple documents by ID + docs = vector_store.get_by_ids(["fruit_2", "fruit_4"]) + assert len(docs) == 2 + + # Verify each document has the correct content and metadata + banana_doc = next((doc for doc in docs if doc.id == "fruit_2"), None) + date_doc = next((doc for doc in docs if doc.id == "fruit_4"), None) + + assert banana_doc is not None + assert banana_doc.page_content == "banana" + assert banana_doc.metadata["type"] == "berry" + assert banana_doc.metadata["color"] == "yellow" + + assert date_doc is not None + assert date_doc.page_content == "date" + assert date_doc.metadata["type"] == "drupe" + assert date_doc.metadata["color"] == "brown" + + # Test with non-existent ID (should return empty list for that ID) + non_existent_docs = vector_store.get_by_ids(["fruit_999"]) + assert len(non_existent_docs) == 0 + + # Test with mix of existing and non-existing IDs + mixed_docs = vector_store.get_by_ids(["fruit_1", "fruit_999", "fruit_3"]) + assert len(mixed_docs) == 2 # Only fruit_1 and fruit_3 should be found + + # Verify the documents match the expected content + found_ids = [doc.id for doc in mixed_docs] + assert "fruit_1" in found_ids + assert "fruit_3" in found_ids + assert "fruit_999" not in found_ids + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_core_functionality( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test the core functionality of ArangoVector with an integrated workflow.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # 1. Setup - Create a vector store with documents + corpus = [ + "The quick brown fox jumps over the lazy dog", + "Pack my box with five dozen liquor jugs", + "How vexingly quick daft zebras jump", + "Amazingly few discotheques provide jukeboxes", + "Sphinx of black quartz, judge my vow", + ] + + metadatas = [ + {"source": "english", "pangram": True, "length": len(corpus[0])}, + {"source": "english", "pangram": True, "length": len(corpus[1])}, + {"source": "english", "pangram": True, "length": len(corpus[2])}, + {"source": "english", "pangram": True, "length": len(corpus[3])}, + {"source": "english", "pangram": True, "length": len(corpus[4])}, + ] + + custom_ids = ["pangram_1", "pangram_2", "pangram_3", "pangram_4", "pangram_5"] + + vector_store = ArangoVector.from_texts( + texts=corpus, + embedding=fake_embedding_function, + metadatas=metadatas, + ids=custom_ids, + database=db, + collection_name="test_pangrams", + ) + + # Create the vector index + vector_store.create_vector_index() + + # 2. Test similarity_search - the most basic search function + query = "jumps" + results = vector_store.similarity_search(query, k=2) + + # Should return documents with "jumps" in them + assert len(results) == 2 + text_contents = [doc.page_content for doc in results] + # The most relevant results should include docs with "jumps" + has_jump_docs = [doc for doc in text_contents if "jump" in doc.lower()] + assert len(has_jump_docs) > 0 + + # 3. Test similarity_search_with_score - core search with relevance scores + results_with_scores = vector_store.similarity_search_with_score( + query, k=3, return_fields={"source", "pangram"} + ) + + assert len(results_with_scores) == 3 + # Check result format + for doc, score in results_with_scores: + assert isinstance(doc, Document) + assert isinstance(score, float) + # Verify metadata got properly transferred + assert doc.metadata["source"] == "english" + assert doc.metadata["pangram"] is True + + # 4. Test similarity_search_by_vector_with_score + query_embedding = fake_embedding_function.embed_query(query) + vector_results = vector_store.similarity_search_by_vector_with_score( + embedding=query_embedding, + k=2, + return_fields={"source", "length"}, + ) + + assert len(vector_results) == 2 + # Check result format + for doc, score in vector_results: + assert isinstance(doc, Document) + assert isinstance(score, float) + # Verify specific metadata fields were returned + assert "source" in doc.metadata + assert "length" in doc.metadata + # Verify length is a number (as defined in metadatas) + assert isinstance(doc.metadata["length"], int) + + # 5. Test with exact search (non-approximate) + exact_results = vector_store.similarity_search_with_score( + query, k=2, use_approx=False + ) + assert len(exact_results) == 2 + + # 6. Test max_marginal_relevance_search - for getting diverse results + mmr_results = vector_store.max_marginal_relevance_search( + query, k=3, fetch_k=5, lambda_mult=0.5 + ) + assert len(mmr_results) == 3 + # MMR results should be diverse, so they might differ from regular search + + # 7. Test adding new documents to the existing vector store + new_texts = ["The five boxing wizards jump quickly"] + new_metadatas = [ + {"source": "english", "pangram": True, "length": len(new_texts[0])} + ] + new_ids = vector_store.add_texts(texts=new_texts, metadatas=new_metadatas) + + # Verify the document was added by directly checking the collection + _collection_obj_core = db.collection("test_pangrams") + assert isinstance(_collection_obj_core, StandardCollection) + collection_core: StandardCollection = _collection_obj_core + assert collection_core.count() == 6 # Original 5 + 1 new document + + # Verify retrieving by ID works + added_doc = vector_store.get_by_ids([new_ids[0]]) + assert len(added_doc) == 1 + assert added_doc[0].page_content == new_texts[0] + assert "wizard" in added_doc[0].page_content.lower() + + # 8. Testing search by ID + all_docs_cursor = collection_core.all() + assert all_docs_cursor is not None, "collection.all() returned None" + assert isinstance( + all_docs_cursor, Cursor + ), f"collection.all() expected Cursor, got {type(all_docs_cursor)}" + all_ids = [doc["_key"] for doc in all_docs_cursor] + assert new_ids[0] in all_ids + + # 9. Test deleting documents + vector_store.delete(ids=[new_ids[0]]) + + # Verify the document was deleted + deleted_check = vector_store.get_by_ids([new_ids[0]]) + assert len(deleted_check) == 0 + + # Also verify via direct collection count + assert collection_core.count() == 5 # Back to the original 5 documents + + +@pytest.mark.usefixtures("clear_arangodb_database") +def test_arangovector_from_existing_collection( + arangodb_credentials: ArangoCredentials, + fake_embedding_function: FakeEmbeddings, +) -> None: + """Test creating a vector store from an existing collection.""" + client = ArangoClient(hosts=arangodb_credentials["url"]) + db = client.db( + username=arangodb_credentials["username"], + password=arangodb_credentials["password"], + ) + + # Create a test collection with documents that have multiple text fields + collection_name = "test_source_collection" + + if db.has_collection(collection_name): + db.delete_collection(collection_name) + + _collection_obj_exist = db.create_collection(collection_name) + assert isinstance(_collection_obj_exist, StandardCollection) + collection_exist: StandardCollection = _collection_obj_exist + # Create documents with multiple text fields to test different scenarios + documents = [ + { + "_key": "doc1", + "title": "The Solar System", + "abstract": ( + "The Solar System is the gravitationally bound system of the " + "Sun and the objects that orbit it." + ), + "content": ( + "The Solar System formed 4.6 billion years ago from the " + "gravitational collapse of a giant interstellar molecular cloud." + ), + "tags": ["astronomy", "science", "space"], + "author": "John Doe", + }, + { + "_key": "doc2", + "title": "Machine Learning", + "abstract": ( + "Machine learning is a field of inquiry devoted to understanding and " + "building methods that 'learn'." + ), + "content": ( + "Machine learning approaches are traditionally divided into three broad" + " categories: supervised, unsupervised, and reinforcement learning." + ), + "tags": ["ai", "computer science", "data science"], + "author": "Jane Smith", + }, + { + "_key": "doc3", + "title": "The Theory of Relativity", + "abstract": ( + "The theory of relativity usually encompasses two interrelated" + " theories by Albert Einstein." + ), + "content": ( + "Special relativity applies to all physical phenomena in the absence of" + " gravity. General relativity explains the law of gravitation and its" + " relation to other forces of nature." + ), + "tags": ["physics", "science", "Einstein"], + "author": "Albert Einstein", + }, + { + "_key": "doc4", + "title": "Quantum Mechanics", + "abstract": ( + "Quantum mechanics is a fundamental theory in physics that provides a" + " description of the physical properties of nature " + " at the scale of atoms and subatomic particles." + ), + "content": ( + "Quantum mechanics allows the calculation of properties and behaviour " + "of physical systems." + ), + "tags": ["physics", "science", "quantum"], + "author": "Max Planck", + }, + ] + + # Import documents to the collection + collection_exist.import_bulk(documents) + assert collection_exist.count() == 4 + + # 1. Basic usage - embedding title and abstract + text_properties = ["title", "abstract"] + + vector_store = ArangoVector.from_existing_collection( + collection_name=collection_name, + text_properties_to_embed=text_properties, + embedding=fake_embedding_function, + database=db, + embedding_field="embedding", + text_field="combined_text", + insert_text=True, + ) + + # Create the vector index + vector_store.create_vector_index() + + # Verify the vector store was created correctly + # First, check that the original collection still has 4 documents + assert collection_exist.count() == 4 + + # Check that embeddings were added to the original documents + doc_data1 = collection_exist.get("doc1") + assert doc_data1 is not None, "Document 'doc1' not found in collection_exist" + assert isinstance( + doc_data1, dict + ), f"Expected 'doc1' to be a dict, got {type(doc_data1)}" + doc1: Dict[str, Any] = doc_data1 + assert "embedding" in doc1 + assert isinstance(doc1["embedding"], list) + assert "combined_text" in doc1 # Now this field should exist + + # Perform a search to verify functionality + results = vector_store.similarity_search("astronomy") + assert len(results) > 0 + + # 2. Test with custom AQL query to modify the text extraction + custom_aql_query = "RETURN CONCAT(doc[p], ' by ', doc.author)" + + vector_store_custom = ArangoVector.from_existing_collection( + collection_name=collection_name, + text_properties_to_embed=["title"], # Only embed titles + embedding=fake_embedding_function, + database=db, + embedding_field="custom_embedding", + text_field="custom_text", + index_name="custom_vector_index", + aql_return_text_query=custom_aql_query, + insert_text=True, + ) + + # Create the vector index + vector_store_custom.create_vector_index() + + # Check that custom embeddings were added + doc_data2 = collection_exist.get("doc1") + assert doc_data2 is not None, "Document 'doc1' not found after custom processing" + assert isinstance( + doc_data2, dict + ), f"Expected 'doc1' after custom processing to be a dict, got {type(doc_data2)}" + doc2: Dict[str, Any] = doc_data2 + assert "custom_embedding" in doc2 + assert "custom_text" in doc2 + assert "by John Doe" in doc2["custom_text"] # Check the custom extraction format + + # 3. Test with skip_existing_embeddings=True + vector_store.delete_vector_index() + + collection_exist.update({"_key": "doc3", "embedding": None}) + + vector_store_skip = ArangoVector.from_existing_collection( + collection_name=collection_name, + text_properties_to_embed=["title", "abstract"], + embedding=fake_embedding_function, + database=db, + embedding_field="embedding", + text_field="combined_text", + index_name="skip_vector_index", # Use a different index name + skip_existing_embeddings=True, + insert_text=True, # Important for search to work + ) + + # Create the vector index + vector_store_skip.create_vector_index() + + # 4. Test with insert_text=True + vector_store_insert = ArangoVector.from_existing_collection( + collection_name=collection_name, + text_properties_to_embed=["title", "content"], + embedding=fake_embedding_function, + database=db, + embedding_field="content_embedding", + text_field="combined_title_content", + index_name="content_vector_index", # Use a different index name + insert_text=True, # Already set to True, but kept for clarity + ) + + # Create the vector index + vector_store_insert.create_vector_index() + + # Check that the combined text was inserted + doc_data3 = collection_exist.get("doc1") + assert ( + doc_data3 is not None + ), "Document 'doc1' not found after insert_text processing" + assert isinstance( + doc_data3, dict + ), f"Expected 'doc1' after insert_text to be a dict, got {type(doc_data3)}" + doc3: Dict[str, Any] = doc_data3 + assert "combined_title_content" in doc3 + assert "The Solar System" in doc3["combined_title_content"] + assert "formed 4.6 billion years ago" in doc3["combined_title_content"] + + # 5. Test searching in the custom store + results_custom = vector_store_custom.similarity_search("Einstein", k=1) + assert len(results_custom) == 1 + + # 6. Test max_marginal_relevance search + mmr_results = vector_store.max_marginal_relevance_search( + "science", k=2, fetch_k=4, lambda_mult=0.5 + ) + assert len(mmr_results) == 2 + + # 7. Test the get_by_ids method + docs = vector_store.get_by_ids(["doc1", "doc3"]) + assert len(docs) == 2 + assert any(doc.id == "doc1" for doc in docs) + assert any(doc.id == "doc3" for doc in docs) diff --git a/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py new file mode 100644 index 0000000..28592b4 --- /dev/null +++ b/libs/arangodb/tests/unit_tests/chat_message_histories/test_arangodb_chat_message_history.py @@ -0,0 +1,200 @@ +from unittest.mock import MagicMock + +import pytest +from arango.database import StandardDatabase + +from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory + + +def test_init_without_session_id() -> None: + """Test initializing without session_id raises ValueError.""" + mock_db = MagicMock(spec=StandardDatabase) + with pytest.raises(ValueError) as exc_info: + ArangoChatMessageHistory(None, db=mock_db) # type: ignore[arg-type] + assert "Please ensure that the session_id parameter is provided" in str( + exc_info.value + ) + + +def test_messages_setter() -> None: + """Test that assigning to messages raises NotImplementedError.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.has_collection.return_value = True + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + with pytest.raises(NotImplementedError) as exc_info: + message_store.messages = [] + assert "Direct assignment to 'messages' is not allowed." in str(exc_info.value) + + +def test_collection_creation() -> None: + """Test that collection is created if it doesn't exist.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + + # First test when collection doesn't exist + mock_db.has_collection.return_value = False + + ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + collection_name="TestCollection", + ) + + # Verify collection creation was called + mock_db.create_collection.assert_called_once_with("TestCollection") + mock_db.collection.assert_called_once_with("TestCollection") + + # Now test when collection exists + mock_db.reset_mock() + mock_db.has_collection.return_value = True + + ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + collection_name="TestCollection", + ) + + # Verify collection creation was not called + mock_db.create_collection.assert_not_called() + mock_db.collection.assert_called_once_with("TestCollection") + + +def test_index_creation() -> None: + """Test that index on session_id is created if it doesn't exist.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.has_collection.return_value = True + + # First test when index doesn't exist + mock_collection.indexes.return_value = [] + + ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Verify index creation was called + mock_collection.add_persistent_index.assert_called_once_with( + ["session_id"], unique=False + ) + + # Now test when index exists + mock_db.reset_mock() + mock_collection.reset_mock() + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + + ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Verify index creation was not called + mock_collection.add_persistent_index.assert_not_called() + + +def test_add_message() -> None: + """Test adding a message to the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Create a mock message + mock_message = MagicMock() + mock_message.type = "human" + mock_message.content = "Hello, world!" + + # Add the message + message_store.add_message(mock_message) + + # Verify the message was added to the collection + mock_db.collection.assert_called_with("ChatHistory") + mock_collection.insert.assert_called_once_with( + { + "role": "human", + "content": "Hello, world!", + "session_id": "test_session", + } + ) + + +def test_clear() -> None: + """Test clearing messages from the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_aql = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.aql = mock_aql + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Clear the messages + message_store.clear() + + # Verify the AQL query was executed + mock_aql.execute.assert_called_once() + # Check that the bind variables are correct + call_args = mock_aql.execute.call_args[1] + assert call_args["bind_vars"]["@col"] == "ChatHistory" + assert call_args["bind_vars"]["session_id"] == "test_session" + + +def test_messages_property() -> None: + """Test retrieving messages from the collection.""" + mock_db = MagicMock(spec=StandardDatabase) + mock_collection = MagicMock() + mock_aql = MagicMock() + mock_cursor = MagicMock() + mock_db.collection.return_value = mock_collection + mock_db.aql = mock_aql + mock_db.has_collection.return_value = True + mock_collection.indexes.return_value = [{"fields": ["session_id"]}] + mock_aql.execute.return_value = mock_cursor + + # Mock cursor to return two messages + mock_cursor.__iter__.return_value = [ + {"role": "human", "content": "Hello"}, + {"role": "ai", "content": "Hi there"}, + ] + + message_store = ArangoChatMessageHistory( + session_id="test_session", + db=mock_db, + ) + + # Get the messages + messages = message_store.messages + + # Verify the AQL query was executed + mock_aql.execute.assert_called_once() + # Check that the bind variables are correct + call_args = mock_aql.execute.call_args[1] + assert call_args["bind_vars"]["@col"] == "ChatHistory" + assert call_args["bind_vars"]["session_id"] == "test_session" + + # Check that we got the right number of messages + assert len(messages) == 2 + assert messages[0].type == "human" + assert messages[0].content == "Hello" + assert messages[1].type == "ai" + assert messages[1].content == "Hi there" diff --git a/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py b/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py new file mode 100644 index 0000000..197181a --- /dev/null +++ b/libs/arangodb/tests/unit_tests/vectorstores/test_arangodb.py @@ -0,0 +1,650 @@ +from typing import Any, Optional +from unittest.mock import MagicMock, patch + +import pytest + +from langchain_arangodb.vectorstores.arangodb_vector import ( + ArangoVector, + DistanceStrategy, + StandardDatabase, +) + + +@pytest.fixture +def mock_vector_store() -> ArangoVector: + """Create a mock ArangoVector instance for testing.""" + mock_db = MagicMock() + mock_collection = MagicMock() + mock_async_db = MagicMock() + + mock_db.has_collection.return_value = True + mock_db.collection.return_value = mock_collection + mock_db.begin_async_execution.return_value = mock_async_db + + with patch( + "langchain_arangodb.vectorstores.arangodb_vector.StandardDatabase", + return_value=mock_db, + ): + vector_store = ArangoVector( + embedding=MagicMock(), + embedding_dimension=64, + database=mock_db, + ) + + return vector_store + + +@pytest.fixture +def arango_vector_factory() -> Any: + """Factory fixture to create ArangoVector instances + with different configurations.""" + + def _create_vector_store( + method: Optional[str] = None, + texts: Optional[list[str]] = None, + text_embeddings: Optional[list[tuple[str, list[float]]]] = None, + collection_exists: bool = True, + vector_index_exists: bool = True, + **kwargs: Any, + ) -> Any: + mock_db = MagicMock() + mock_collection = MagicMock() + mock_async_db = MagicMock() + + # Configure has_collection + mock_db.has_collection.return_value = collection_exists + mock_db.collection.return_value = mock_collection + mock_db.begin_async_execution.return_value = mock_async_db + + # Configure vector index + if vector_index_exists: + mock_collection.indexes.return_value = [ + { + "name": kwargs.get("index_name", "vector_index"), + "type": "vector", + "fields": [kwargs.get("embedding_field", "embedding")], + "id": "12345", + } + ] + else: + mock_collection.indexes.return_value = [] + + # Create embedding instance + embedding = kwargs.pop("embedding", MagicMock()) + if embedding is not None: + embedding.embed_documents.return_value = [ + [0.1] * kwargs.get("embedding_dimension", 64) + ] * (len(texts) if texts else 1) + embedding.embed_query.return_value = [0.1] * kwargs.get( + "embedding_dimension", 64 + ) + + # Create vector store based on method + common_kwargs = { + "embedding": embedding, + "database": mock_db, + **kwargs, + } + + if method == "from_texts" and texts: + common_kwargs["embedding_dimension"] = kwargs.get("embedding_dimension", 64) + vector_store = ArangoVector.from_texts( + texts=texts, + **common_kwargs, + ) + elif method == "from_embeddings" and text_embeddings: + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + with patch.object( + ArangoVector, "add_embeddings", return_value=["id1", "id2"] + ): + vector_store = ArangoVector( + **common_kwargs, + embedding_dimension=len(embeddings[0]) if embeddings else 64, + ) + else: + vector_store = ArangoVector( + **common_kwargs, + embedding_dimension=kwargs.get("embedding_dimension", 64), + ) + + return vector_store + + return _create_vector_store + + +def test_init_with_invalid_search_type() -> None: + """Test that initializing with an invalid search type raises ValueError.""" + mock_db = MagicMock() + + with pytest.raises(ValueError) as exc_info: + ArangoVector( + embedding=MagicMock(), + embedding_dimension=64, + database=mock_db, + search_type="invalid_search_type", # type: ignore + ) + + assert "search_type must be 'vector'" in str(exc_info.value) + + +def test_init_with_invalid_distance_strategy() -> None: + """Test that initializing with an invalid distance strategy raises ValueError.""" + mock_db = MagicMock() + + with pytest.raises(ValueError) as exc_info: + ArangoVector( + embedding=MagicMock(), + embedding_dimension=64, + database=mock_db, + distance_strategy="INVALID_STRATEGY", # type: ignore + ) + + assert "distance_strategy must be 'COSINE' or 'EUCLIDEAN_DISTANCE'" in str( + exc_info.value + ) + + +def test_collection_creation_if_not_exists(arango_vector_factory: Any) -> None: + """Test that collection is created if it doesn't exist.""" + # Configure collection doesn't exist + vector_store = arango_vector_factory(collection_exists=False) + + # Verify collection was created + vector_store.db.create_collection.assert_called_once_with( + vector_store.collection_name + ) + + +def test_collection_not_created_if_exists(arango_vector_factory: Any) -> None: + """Test that collection is not created if it already exists.""" + # Configure collection exists + vector_store = arango_vector_factory(collection_exists=True) + + # Verify collection was not created + vector_store.db.create_collection.assert_not_called() + + +def test_retrieve_vector_index_exists(arango_vector_factory: Any) -> None: + """Test retrieving vector index when it exists.""" + vector_store = arango_vector_factory(vector_index_exists=True) + + index = vector_store.retrieve_vector_index() + + assert index is not None + assert index["name"] == "vector_index" + assert index["type"] == "vector" + + +def test_retrieve_vector_index_not_exists(arango_vector_factory: Any) -> None: + """Test retrieving vector index when it doesn't exist.""" + vector_store = arango_vector_factory(vector_index_exists=False) + + index = vector_store.retrieve_vector_index() + + assert index is None + + +def test_create_vector_index(arango_vector_factory: Any) -> None: + """Test creating vector index.""" + vector_store = arango_vector_factory() + + vector_store.create_vector_index() + + # Verify index creation was called with correct parameters + vector_store.collection.add_index.assert_called_once() + + call_args = vector_store.collection.add_index.call_args[0][0] + assert call_args["name"] == "vector_index" + assert call_args["type"] == "vector" + assert call_args["fields"] == ["embedding"] + assert call_args["params"]["metric"] == "cosine" + assert call_args["params"]["dimension"] == 64 + + +def test_delete_vector_index_exists(arango_vector_factory: Any) -> None: + """Test deleting vector index when it exists.""" + vector_store = arango_vector_factory(vector_index_exists=True) + + with patch.object( + vector_store, + "retrieve_vector_index", + return_value={"id": "12345", "name": "vector_index"}, + ): + vector_store.delete_vector_index() + + # Verify delete_index was called with correct ID + vector_store.collection.delete_index.assert_called_once_with("12345") + + +def test_delete_vector_index_not_exists(arango_vector_factory: Any) -> None: + """Test deleting vector index when it doesn't exist.""" + vector_store = arango_vector_factory(vector_index_exists=False) + + with patch.object(vector_store, "retrieve_vector_index", return_value=None): + vector_store.delete_vector_index() + + # Verify delete_index was not called + vector_store.collection.delete_index.assert_not_called() + + +def test_delete_vector_index_with_real_index_data(arango_vector_factory: Any) -> None: + """Test deleting vector index with real index data structure.""" + vector_store = arango_vector_factory(vector_index_exists=True) + + # Create a realistic index object with all expected fields + mock_index = { + "id": "vector_index_12345", + "name": "vector_index", + "type": "vector", + "fields": ["embedding"], + "selectivity": 1, + "sparse": False, + "unique": False, + "deduplicate": False, + } + + # Mock retrieve_vector_index to return our realistic index + with patch.object(vector_store, "retrieve_vector_index", return_value=mock_index): + # Call the method under test + vector_store.delete_vector_index() + + # Verify delete_index was called with the exact ID from our mock index + vector_store.collection.delete_index.assert_called_once_with("vector_index_12345") + + # Test the case where the index doesn't have an id field + bad_index = {"name": "vector_index", "type": "vector"} + with patch.object(vector_store, "retrieve_vector_index", return_value=bad_index): + with pytest.raises(KeyError): + vector_store.delete_vector_index() + + +def test_add_embeddings_with_mismatched_lengths(arango_vector_factory: Any) -> None: + """Test adding embeddings with mismatched lengths raises ValueError.""" + vector_store = arango_vector_factory() + + ids = ["id1"] + texts = ["text1", "text2"] + embeddings = [[0.1] * 64, [0.2] * 64, [0.3] * 64] + metadatas = [ + {"key": "value1"}, + {"key": "value2"}, + {"key": "value3"}, + {"key": "value4"}, + ] + + with pytest.raises(ValueError) as exc_info: + vector_store.add_embeddings( + texts=texts, + embeddings=embeddings, + metadatas=metadatas, + ids=ids, + ) + + assert "Length of ids, texts, embeddings and metadatas must be the same" in str( + exc_info.value + ) + + +def test_add_embeddings(arango_vector_factory: Any) -> None: + """Test adding embeddings to the vector store.""" + vector_store = arango_vector_factory() + + texts = ["text1", "text2"] + embeddings = [[0.1] * 64, [0.2] * 64] + metadatas = [{"key": "value1"}, {"key": "value2"}] + + with patch( + "langchain_arangodb.vectorstores.arangodb_vector.farmhash.Fingerprint64" + ) as mock_hash: + mock_hash.side_effect = ["id1", "id2"] + + ids = vector_store.add_embeddings( + texts=texts, + embeddings=embeddings, + metadatas=metadatas, + ) + + # Verify import_bulk was called + vector_store.collection.import_bulk.assert_called() + + # Check the data structure + call_args = vector_store.collection.import_bulk.call_args_list[0][0][0] + assert len(call_args) == 2 + assert call_args[0]["_key"] == "id1" + assert call_args[0]["text"] == "text1" + assert call_args[0]["embedding"] == embeddings[0] + assert call_args[0]["key"] == "value1" + + assert call_args[1]["_key"] == "id2" + assert call_args[1]["text"] == "text2" + assert call_args[1]["embedding"] == embeddings[1] + assert call_args[1]["key"] == "value2" + + # Verify the correct IDs were returned + assert ids == ["id1", "id2"] + + +def test_add_texts(arango_vector_factory: Any) -> None: + """Test adding texts to the vector store.""" + vector_store = arango_vector_factory() + + texts = ["text1", "text2"] + metadatas = [{"key": "value1"}, {"key": "value2"}] + + # Mock the embedding.embed_documents method + mock_embeddings = [[0.1] * 64, [0.2] * 64] + vector_store.embedding.embed_documents.return_value = mock_embeddings + + # Mock the add_embeddings method + with patch.object( + vector_store, "add_embeddings", return_value=["id1", "id2"] + ) as mock_add_embeddings: + ids = vector_store.add_texts( + texts=texts, + metadatas=metadatas, + ) + + # Verify embed_documents was called with texts + vector_store.embedding.embed_documents.assert_called_once_with(texts) + + # Verify add_embeddings was called with correct parameters + mock_add_embeddings.assert_called_once_with( + texts=texts, + embeddings=mock_embeddings, + metadatas=metadatas, + ids=None, + ) + + # Verify the correct IDs were returned + assert ids == ["id1", "id2"] + + +def test_similarity_search(arango_vector_factory: Any) -> None: + """Test similarity search.""" + vector_store = arango_vector_factory() + + # Mock the embedding.embed_query method + mock_embedding = [0.1] * 64 + vector_store.embedding.embed_query.return_value = mock_embedding + + # Mock the similarity_search_by_vector method + expected_docs = [MagicMock(), MagicMock()] + with patch.object( + vector_store, "similarity_search_by_vector", return_value=expected_docs + ) as mock_search_by_vector: + docs = vector_store.similarity_search( + query="test query", + k=2, + return_fields={"field1", "field2"}, + use_approx=True, + ) + + # Verify embed_query was called with query + vector_store.embedding.embed_query.assert_called_once_with("test query") + + # Verify similarity_search_by_vector was called with correct parameters + mock_search_by_vector.assert_called_once_with( + embedding=mock_embedding, + k=2, + return_fields={"field1", "field2"}, + use_approx=True, + filter_clause="", + ) + + # Verify the correct documents were returned + assert docs == expected_docs + + +def test_similarity_search_with_score(arango_vector_factory: Any) -> None: + """Test similarity search with score.""" + vector_store = arango_vector_factory() + + # Mock the embedding.embed_query method + mock_embedding = [0.1] * 64 + vector_store.embedding.embed_query.return_value = mock_embedding + + # Mock the similarity_search_by_vector_with_score method + expected_results = [(MagicMock(), 0.8), (MagicMock(), 0.6)] + with patch.object( + vector_store, + "similarity_search_by_vector_with_score", + return_value=expected_results, + ) as mock_search_by_vector_with_score: + results = vector_store.similarity_search_with_score( + query="test query", + k=2, + return_fields={"field1", "field2"}, + use_approx=True, + ) + + # Verify embed_query was called with query + vector_store.embedding.embed_query.assert_called_once_with("test query") + + # Verify similarity_search_by_vector_with_score was called with correct parameters + mock_search_by_vector_with_score.assert_called_once_with( + embedding=mock_embedding, + k=2, + return_fields={"field1", "field2"}, + use_approx=True, + filter_clause="", + ) + + # Verify the correct results were returned + assert results == expected_results + + +def test_max_marginal_relevance_search(arango_vector_factory: Any) -> None: + """Test max marginal relevance search.""" + vector_store = arango_vector_factory() + + # Mock the embedding.embed_query method + mock_embedding = [0.1] * 64 + vector_store.embedding.embed_query.return_value = mock_embedding + + # Create mock documents and similarity scores + mock_docs = [MagicMock(), MagicMock(), MagicMock()] + mock_similarities = [0.9, 0.8, 0.7] + + with ( + patch.object( + vector_store, + "similarity_search_by_vector_with_score", + return_value=list(zip(mock_docs, mock_similarities)), + ), + patch( + "langchain_arangodb.vectorstores.arangodb_vector.maximal_marginal_relevance", + return_value=[0, 2], # Indices of selected documents + ) as mock_mmr, + ): + results = vector_store.max_marginal_relevance_search( + query="test query", + k=2, + fetch_k=3, + lambda_mult=0.5, + ) + + # Verify embed_query was called with query + vector_store.embedding.embed_query.assert_called_once_with("test query") + + mmr_call_kwargs = mock_mmr.call_args[1] + assert mmr_call_kwargs["k"] == 2 + assert mmr_call_kwargs["lambda_mult"] == 0.5 + + # Verify the selected documents were returned + assert results == [mock_docs[0], mock_docs[2]] + + +def test_from_texts(arango_vector_factory: Any) -> None: + """Test creating vector store from texts.""" + texts = ["text1", "text2"] + mock_embedding = MagicMock() + mock_embedding.embed_documents.return_value = [[0.1] * 64, [0.2] * 64] + + # Configure mock_db for this specific test to simulate no pre-existing index + mock_db_instance = MagicMock(spec=StandardDatabase) + mock_collection_instance = MagicMock() + mock_db_instance.collection.return_value = mock_collection_instance + mock_db_instance.has_collection.return_value = ( + True # Assume collection exists or is created by __init__ + ) + mock_collection_instance.indexes.return_value = [] + + with patch.object(ArangoVector, "add_embeddings", return_value=["id1", "id2"]): + vector_store = ArangoVector.from_texts( + texts=texts, + embedding=mock_embedding, + database=mock_db_instance, # Use the specifically configured mock_db + collection_name="custom_collection", + ) + + # Verify the vector store was initialized correctly + assert vector_store.collection_name == "custom_collection" + assert vector_store.embedding == mock_embedding + assert vector_store.embedding_dimension == 64 + + # Note: create_vector_index is not automatically called in from_texts + # so we don't verify it was called here + + +def test_delete(arango_vector_factory: Any) -> None: + """Test deleting documents from the vector store.""" + vector_store = arango_vector_factory() + + # Test deleting specific IDs + ids = ["id1", "id2"] + vector_store.delete(ids=ids) + + # Verify collection.delete_many was called with correct IDs + vector_store.collection.delete_many.assert_called_once() + # ids are passed as the first positional argument to collection.delete_many + positional_args = vector_store.collection.delete_many.call_args[0] + assert set(positional_args[0]) == set(ids) + + +def test_get_by_ids(arango_vector_factory: Any) -> None: + """Test getting documents by IDs.""" + vector_store = arango_vector_factory() + + # Test case 1: Multiple documents returned + # Mock documents to be returned + mock_docs = [ + {"_key": "id1", "text": "content1", "color": "red", "size": 10}, + {"_key": "id2", "text": "content2", "color": "blue", "size": 20}, + ] + + # Mock collection.get_many to return the mock documents + vector_store.collection.get_many.return_value = mock_docs + + ids = ["id1", "id2"] + docs = vector_store.get_by_ids(ids) + + # Verify collection.get_many was called with correct IDs + vector_store.collection.get_many.assert_called_with(ids) + + # Verify the correct documents were returned + assert len(docs) == 2 + assert docs[0].page_content == "content1" + assert docs[0].id == "id1" + assert docs[0].metadata["color"] == "red" + assert docs[0].metadata["size"] == 10 + assert docs[1].page_content == "content2" + assert docs[1].id == "id2" + assert docs[1].metadata["color"] == "blue" + assert docs[1].metadata["size"] == 20 + + # Test case 2: No documents returned (empty result) + vector_store.collection.get_many.reset_mock() + vector_store.collection.get_many.return_value = [] + + empty_docs = vector_store.get_by_ids(["non_existent_id"]) + + # Verify collection.get_many was called with the non-existent ID + vector_store.collection.get_many.assert_called_with(["non_existent_id"]) + + # Verify an empty list was returned + assert empty_docs == [] + + # Test case 3: Custom text field + vector_store = arango_vector_factory(text_field="custom_text") + + custom_field_docs = [ + {"_key": "id3", "custom_text": "custom content", "tag": "important"}, + ] + + vector_store.collection.get_many.return_value = custom_field_docs + + result_docs = vector_store.get_by_ids(["id3"]) + + # Verify collection.get_many was called + vector_store.collection.get_many.assert_called_with(["id3"]) + + # Verify the document was correctly processed with the custom text field + assert len(result_docs) == 1 + assert result_docs[0].page_content == "custom content" + assert result_docs[0].id == "id3" + assert result_docs[0].metadata["tag"] == "important" + + # Test case 4: Document is missing the text field + vector_store = arango_vector_factory() + + # Document without the text field + incomplete_docs = [ + {"_key": "id4", "other_field": "some value"}, + ] + + vector_store.collection.get_many.return_value = incomplete_docs + + # This should raise a KeyError when trying to access the missing text field + with pytest.raises(KeyError): + vector_store.get_by_ids(["id4"]) + + +def test_select_relevance_score_fn_override(arango_vector_factory: Any) -> None: + """Test that the override relevance score function is used if provided.""" + + def custom_score_fn(score: float) -> float: + return score * 10.0 + + vector_store = arango_vector_factory(relevance_score_fn=custom_score_fn) + selected_fn = vector_store._select_relevance_score_fn() + assert selected_fn(0.5) == 5.0 + assert selected_fn == custom_score_fn + + +def test_select_relevance_score_fn_default_strategies( + arango_vector_factory: Any, +) -> None: + """Test the default relevance score function for supported strategies.""" + # Test for COSINE + vector_store_cosine = arango_vector_factory( + distance_strategy=DistanceStrategy.COSINE + ) + fn_cosine = vector_store_cosine._select_relevance_score_fn() + assert fn_cosine(0.75) == 0.75 + + # Test for EUCLIDEAN_DISTANCE + vector_store_euclidean = arango_vector_factory( + distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE + ) + fn_euclidean = vector_store_euclidean._select_relevance_score_fn() + assert fn_euclidean(1.25) == 1.25 + + +def test_select_relevance_score_fn_invalid_strategy_raises_error( + arango_vector_factory: Any, +) -> None: + """Test that an invalid distance strategy raises a ValueError + if _distance_strategy is mutated post-init.""" + vector_store = arango_vector_factory() + vector_store._distance_strategy = "INVALID_STRATEGY" + + with pytest.raises(ValueError) as exc_info: + vector_store._select_relevance_score_fn() + + expected_message = ( + "No supported normalization function for distance_strategy of INVALID_STRATEGY." + "Consider providing relevance_score_fn to ArangoVector constructor." + ) + assert str(exc_info.value) == expected_message