Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
56464b8
tests | initial commit
aMahanna Apr 28, 2025
822fcda
fix: dir
aMahanna Apr 28, 2025
a835d8b
fix: lint
aMahanna Apr 28, 2025
e3d3c42
bring back `test_compile.py`
aMahanna Apr 28, 2025
f4cc9e8
update: compose.yml
aMahanna Apr 30, 2025
f920f7b
fix: ArangoGraphQAChain
aMahanna Apr 30, 2025
82896fb
new: `test_aql_generating_run`
aMahanna Apr 30, 2025
7a61902
fix: lint
aMahanna Apr 30, 2025
9e0eff7
type: ignore
aMahanna Apr 30, 2025
b8d16b8
update: tests
aMahanna May 5, 2025
f7fa9d9
integration_tests_chat_history all passing
ajaykallepalli May 7, 2025
463ea08
Adding chat history unit tests
ajaykallepalli May 7, 2025
237a235
new: raise_on_write_operation
aMahanna May 12, 2025
3c3f6e6
rename: force_read_only_query
aMahanna May 12, 2025
3246b0c
fix: `AQL_WRITE_OPERATIONS`
aMahanna May 12, 2025
b859232
fix: lint
aMahanna May 12, 2025
c76b3c6
new: `from_existing_collection`
aMahanna May 12, 2025
16303f2
cleanup
aMahanna May 12, 2025
4560c9c
All 18 tests pass
ajaykallepalli May 14, 2025
11a08fe
minimal changes to arangodb_vector.py
ajaykallepalli May 14, 2025
b95bb04
No changes to arangodb_vector
ajaykallepalli May 14, 2025
fa2f1f2
fix: docstring
aMahanna May 14, 2025
b361cd2
new: coverage flags
aMahanna May 14, 2025
5679003
Merge branch 'tests' into chat_vector_tests
ajaykallepalli May 14, 2025
895a97a
All integration test and unit test passing, coverage 73% and 66%
ajaykallepalli May 19, 2025
4025fb7
Adding unit tests and integration tests for get by id
ajaykallepalli May 19, 2025
581808f
Testing from existing collection, all major coverage complete
ajaykallepalli May 19, 2025
ccad356
Fixing linting and formatting errors
ajaykallepalli May 19, 2025
9be4539
temp
aMahanna May 19, 2025
4494ee0
attempt: remove chmod
aMahanna May 19, 2025
9344bf6
Update README.md
aMahanna May 19, 2025
9c35b8f
No lint errors
ajaykallepalli May 21, 2025
5034e4a
No lint errors, all tests pass
ajaykallepalli May 21, 2025
cde5615
Merge branch 'tests' into chat_vector_tests
ajaykallepalli May 21, 2025
bbbcecc
Updating assert statements to match latest ruff requirements
ajaykallepalli May 21, 2025
8ceac2d
Updating assert statements to match latest ruff requirements python 12
ajaykallepalli May 21, 2025
9e0031a
make format py312
ajaykallepalli May 21, 2025
5906fbf
Updating assert statements
ajaykallepalli May 28, 2025
65aace7
Updating assert statements
ajaykallepalli May 28, 2025
24a28ac
ruff format locally failing CI/CD
ajaykallepalli May 28, 2025
99c3a48
new: hybrid search
aMahanna May 30, 2025
dd70425
fix: docstrings
aMahanna May 30, 2025
61950e2
fix: lint
aMahanna May 30, 2025
b7b53f2
Merge branch 'tests' into chat_vector_tests
aMahanna May 30, 2025
70e6cfd
fix: ci
aMahanna May 30, 2025
d00bd64
Merge branch 'main' into chat_vector_tests
aMahanna May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified libs/arangodb/.coverage
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os

import pytest
from arango import ArangoClient
from arango.database import StandardDatabase
from arango.exceptions import ArangoError
from langchain_core.messages import AIMessage, HumanMessage

from langchain_arangodb.chat_message_histories.arangodb import ArangoChatMessageHistory
from langchain_arangodb.graphs.arangodb_graph import ArangoGraph
from tests.integration_tests.utils import ArangoCredentials


@pytest.mark.usefixtures("clear_arangodb_database")
Expand Down Expand Up @@ -44,3 +50,108 @@ def test_add_messages(db: StandardDatabase) -> None:
message_store_another.clear()
assert len(message_store.messages) == 0
assert len(message_store_another.messages) == 0


@pytest.mark.usefixtures("clear_arangodb_database")
def test_add_messages_graph_object(arangodb_credentials: ArangoCredentials) -> None:
"""Basic testing: Passing driver through graph object."""
graph = ArangoGraph.from_db_credentials(
url=arangodb_credentials["url"],
username=arangodb_credentials["username"],
password=arangodb_credentials["password"],
)

# rewrite env for testing
old_username = os.environ.get("ARANGO_USERNAME", "root")
os.environ["ARANGO_USERNAME"] = "foo"

message_store = ArangoChatMessageHistory("23334", db=graph.db)
message_store.clear()
assert len(message_store.messages) == 0
message_store.add_user_message("Hello! Language Chain!")
message_store.add_ai_message("Hi Guys!")
# Now check if the messages are stored in the database correctly
assert len(message_store.messages) == 2

# Restore original environment
os.environ["ARANGO_USERNAME"] = old_username


def test_invalid_credentials(arangodb_credentials: ArangoCredentials) -> None:
"""Test initializing with invalid credentials raises an authentication error."""
with pytest.raises(ArangoError) as exc_info:
client = ArangoClient(arangodb_credentials["url"])
db = client.db(username="invalid_username", password="invalid_password")
# Try to perform a database operation to trigger an authentication error
db.collections()

# Check for any authentication-related error message
error_msg = str(exc_info.value)
# Just check for "error" which should be in any auth error
assert "not authorized" in error_msg


@pytest.mark.usefixtures("clear_arangodb_database")
def test_arangodb_message_history_clear_messages(
db: StandardDatabase,
) -> None:
"""Test adding multiple messages at once to ArangoChatMessageHistory."""
# Specify a custom collection name that includes the session_id
collection_name = "chat_history_123"
message_history = ArangoChatMessageHistory(
session_id="123", db=db, collection_name=collection_name
)
message_history.add_messages(
[
HumanMessage(content="You are a helpful assistant."),
AIMessage(content="Hello"),
]
)
assert len(message_history.messages) == 2
assert isinstance(message_history.messages[0], HumanMessage)
assert isinstance(message_history.messages[1], AIMessage)
assert message_history.messages[0].content == "You are a helpful assistant."
assert message_history.messages[1].content == "Hello"

message_history.clear()
assert len(message_history.messages) == 0

# Verify all messages are removed but collection still exists
assert db.has_collection(message_history._collection_name)
assert message_history._collection_name == collection_name


@pytest.mark.usefixtures("clear_arangodb_database")
def test_arangodb_message_history_clear_session_collection(
db: StandardDatabase,
) -> None:
"""Test clearing messages and removing the collection for a session."""
# Create a test collection specific to the session
session_id = "456"
collection_name = f"chat_history_{session_id}"

if not db.has_collection(collection_name):
db.create_collection(collection_name)

message_history = ArangoChatMessageHistory(
session_id=session_id, db=db, collection_name=collection_name
)

message_history.add_messages(
[
HumanMessage(content="You are a helpful assistant."),
AIMessage(content="Hello"),
]
)
assert len(message_history.messages) == 2

# Clear messages
message_history.clear()
assert len(message_history.messages) == 0

# The collection should still exist after clearing messages
assert db.has_collection(collection_name)

# Delete the collection (equivalent to delete_session_node in Neo4j)
db.delete_collection(collection_name)
assert not db.has_collection(collection_name)
111 changes: 111 additions & 0 deletions libs/arangodb/tests/integration_tests/vectorstores/fake_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Fake Embedding class for testing purposes."""

import math
from typing import List

from langchain_core.embeddings import Embeddings

fake_texts = ["foo", "bar", "baz"]


class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""

def __init__(self, dimension: int = 10):
if dimension < 1:
raise ValueError(
"Dimension must be at least 1 for this FakeEmbeddings style."
)
self.dimension = dimension
# global_fake_texts maps query texts to the 'i' in [1.0]*(dim-1) + [float(i)]
self.global_fake_texts = ["foo", "bar", "baz", "qux", "quux", "corge", "grault"]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
if self.dimension == 1:
# Special case for dimension 1: just use the index
return [[float(i)] for i in range(len(texts))]
else:
return [
[1.0] * (self.dimension - 1) + [float(i)] for i in range(len(texts))
]

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
try:
idx = self.global_fake_texts.index(text)
val = float(idx)
except ValueError:
# Text not in global_fake_texts, use a default 'unknown query' value
val = -1.0

if self.dimension == 1:
return [val] # Corrected: List[float]
else:
return [1.0] * (self.dimension - 1) + [val]

async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)

@property
def identifer(self) -> str:
return "fake"


class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""

def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors

def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]


class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]

def embed_query(self, text: str) -> List[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]
Loading