diff --git a/.env.example b/.env.example index 90bec0aa..4b829165 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,11 @@ LOCAL_DB_PATH=local.sqlite ## QDRANT_PORT=your_qdrant_port ## QDRANT_API_KEY=your_qdrant_api_key +# ## lancedb +## LANCEDB_URI=your_lancedb_uri_local_or_cloud +## LANCEDB_API_KEY=your_lancedb_cloud_api_key +## LANCEDB_REGION=your_lancedb_cloud_region + # LLM Providers ## openai diff --git a/pyproject.toml b/pyproject.toml index deb727e1..febf46c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ boto3 = {version = "^1.34.71", optional = true} exa-py = {version = "^1.0.9", optional = true} llama-cpp-python = {version = "^0.2.57", optional = true} sentence-transformers = {version = "^2.6.1", optional = true} +lancedb = { version = "^0.6.5", optional = true } [tool.poetry.extras] parsing = ["bs4", "pypdf"] @@ -60,7 +61,8 @@ ionic = ["ionic-api-sdk"] reducto = ["boto3"] exa = ["exa-py"] local_llm = ["llama-cpp-python", "sentence-transformers"] -all = ["bs4", "pypdf", "tiktoken", "datasets", "qdrant_client", "psycopg2-binary", "sentry-sdk", "parea-ai", "boto3", "exa-py", "llama-cpp-python"] +lancedb = ["lancedb"] +all = ["bs4", "pypdf", "tiktoken", "datasets", "qdrant_client", "psycopg2-binary", "sentry-sdk", "parea-ai", "boto3", "exa-py", "llama-cpp-python", "lancedb"] [tool.poetry.group.dev.dependencies] black = "^24.3.0" diff --git a/r2r/core/providers/vector_db.py b/r2r/core/providers/vector_db.py index 59f69e16..4fa7a310 100644 --- a/r2r/core/providers/vector_db.py +++ b/r2r/core/providers/vector_db.py @@ -64,7 +64,7 @@ def to_dict(self) -> dict: class VectorDBProvider(ABC): - supported_providers = ["local", "pgvector", "qdrant"] + supported_providers = ["local", "pgvector", "qdrant", "lancedb"] def __init__(self, provider: str): if provider not in VectorDBProvider.supported_providers: diff --git a/r2r/main/factory.py b/r2r/main/factory.py index 5bb11b0f..9e07a50e 100644 --- a/r2r/main/factory.py +++ b/r2r/main/factory.py @@ -41,6 +41,10 @@ def get_vector_db(database_config: dict[str, Any]): from r2r.vector_dbs import LocalVectorDB return LocalVectorDB() + elif database_config["provider"] == "lancedb": + from r2r.vector_dbs import LanceDB + + return LanceDB() @staticmethod def get_embeddings_provider(embedding_config: dict[str, Any]): diff --git a/r2r/vector_dbs/__init__.py b/r2r/vector_dbs/__init__.py index 419a0431..6f7ad4ac 100644 --- a/r2r/vector_dbs/__init__.py +++ b/r2r/vector_dbs/__init__.py @@ -1,5 +1,6 @@ from .local.base import LocalVectorDB from .pg_vector.base import PGVectorDB from .qdrant.base import QdrantDB +from .lancedb.base import LanceDB -__all__ = ["LocalVectorDB", "PGVectorDB", "QdrantDB"] +__all__ = ["LocalVectorDB", "PGVectorDB", "QdrantDB", "LanceDB"] diff --git a/r2r/vector_dbs/lancedb/base.py b/r2r/vector_dbs/lancedb/base.py new file mode 100644 index 00000000..8585d1a5 --- /dev/null +++ b/r2r/vector_dbs/lancedb/base.py @@ -0,0 +1,147 @@ +import logging +import os +from typing import Optional, Union + +from r2r.core import VectorDBProvider, VectorEntry, VectorSearchResult + +logger = logging.getLogger(__name__) + + +class LanceDB(VectorDBProvider): + def __init__( + self, provider: str = "lancedb", db_path: Optional[str] = None + ) -> None: + logger.info("Initializing `LanceDB` to store and retrieve embeddings.") + + super().__init__(provider) + + if provider != "lancedb": + raise ValueError( + "LanceDB must be initialized with provider `lancedb`." + ) + + try: + import lancedb + except ImportError: + raise ValueError( + f"Error, `lancedb` is not installed. Please install it using `pip install lancedb`." + ) + + self.db_path = db_path + try: + self.client = lancedb.connect(uri=self.db_path or os.environ.get("LANCEDB_URI"), api_key=os.environ.get("LANCEDB_API_KEY") or None, region=os.environ.get("LANCEDB_REGION") or None) + except Exception as e: + raise ValueError( + f"Error {e} occurred while attempting to connect to the lancedb provider." + ) + self.collection_name: Optional[str] = None + + def initialize_collection( + self, collection_name: str, dimension: int + ) -> None: + self.collection_name = collection_name + + try: + import pyarrow + except ImportError: + raise ValueError( + f"Error, `pyarrow` is not installed. Please install it using `pip install pyarrow`." + ) + + table_schema = pyarrow.schema( + [ + pyarrow.field("id", pyarrow.string()), + pyarrow.field( + "vector", pyarrow.list_(pyarrow.float32(), dimension) + ), + # TODO Handle storing metadata + ] + ) + + try: + self.client.create_table( + name=f"{collection_name}", + on_bad_vectors="error", + schema=table_schema, + ) + except Exception as e: + # TODO - Handle more appropriately - create collection fails when it already exists + pass + + def copy(self, entry: VectorEntry, commit=True) -> None: + raise NotImplementedError( + "LanceDB does not support the `copy` method." + ) + + def upsert(self, entry: VectorEntry, commit=True) -> None: + if self.collection_name is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert`." + ) + self.client.open_table(self.collection_name).add( + { + "vector": entry.vector, + "id": entry.id, + # TODO ADD metadata storage + }, + mode="overwrite", + ) + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + if self.collection_name is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert_entries`." + ) + + self.client.open_table(self.collection_name).add( + [{"vector": entry.vector, + "id": entry.id + # TODO ADD metadata storage + } + for entry in entries], + mode="overwrite", + ) + + + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + if self.collection_name is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `search`." + ) + + results = self.client.open_table(self.collection_name).search( + query=query_vector, + # TODO implement metadata filter + ).limit(limit).to_list() + + return [ + VectorSearchResult( + str(idx), result.get("_distance"), {} # TODO Handle metadata + ) + for idx, result in enumerate(results) + ] + + def create_index(self, index_type, column_name, index_options): + pass + + def close(self): + pass + + def filtered_deletion( + self, key: str, value: Union[bool, int, str] + ) -> None: + pass + + def get_all_unique_values( + self, collection_name: str, metadata_field: str, filters: dict = {} + ) -> list: + pass