From 1084f703b98b07af495b2ba2b4587b98b1f74c6b Mon Sep 17 00:00:00 2001 From: Anush008 Date: Thu, 28 Nov 2024 14:56:17 +0530 Subject: [PATCH 01/10] feat: Qdrant support Signed-off-by: Anush008 --- README.md | 2 +- paperqa/__init__.py | 2 + paperqa/llms.py | 89 +++++++++++++++++++++++ pyproject.toml | 3 + tests/test_paperqa.py | 19 +++-- uv.lock | 161 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 268 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 016ac04ae..6b74a013d 100644 --- a/README.md +++ b/README.md @@ -519,7 +519,7 @@ for doc in ("myfile.pdf", "myotherfile.pdf"): Note that PaperQA2 uses Numpy as a dense vector store. Its design of using a keyword search initially reduces the number of chunks needed for each answer to a relatively small number < 1k. Therefore, `NumpyVectorStore` is a good place to start, it's a simple in-memory store, without an index. -However, if a larger-than-memory vector store is needed, we are currently lacking here. +However, if a larger-than-memory vector store is needed, you can an external vector database like [Qdrant](https://qdrant.tech/) via the `QdrantVectorStore` class. The hybrid embeddings can be customized: diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 008b18255..6267de5fc 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -18,6 +18,7 @@ LLMModel, LLMResult, NumpyVectorStore, + QdrantVectorStore, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -40,6 +41,7 @@ "LiteLLMModel", "NumpyVectorStore", "PQASession", + "QdrantVectorStore", "QueryRequest", "SentenceTransformerEmbeddingModel", "Settings", diff --git a/paperqa/llms.py b/paperqa/llms.py index b8b62805e..361a3cbe4 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -3,6 +3,7 @@ import asyncio import contextlib import functools +import uuid from abc import ABC, abstractmethod from collections.abc import ( AsyncGenerator, @@ -909,6 +910,94 @@ async def similarity_search( ) +class QdrantVectorStore(VectorStore): + client: Any = Field( + default=None, description="Instance of qdrant_client.QdrantClient" + ) + collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") + vector_name: str | None = Field(default=None) + + @model_validator(mode="after") + def validate_client(self): + from qdrant_client import QdrantClient + + if self.client and not isinstance(self.client, QdrantClient): + raise TypeError( + f"'client' should be an instance of qdrant_client.QdrantClient. Got {type(self.client)}" + ) + + if not self.client: + self.client = QdrantClient() + + return self + + def clear(self) -> None: + super().clear() + from qdrant_client import models + + if not self.client.collection_exists(self.collection_name): + return + + self.client.delete( + collection_name=self.collection_name, + points_selector=models.Filter(must=[]), + wait=True, + ) + + def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: + super().add_texts_and_embeddings(texts) + from qdrant_client import models + + texts_list = list(texts) + + if len(texts_list) > 0 and not self.client.collection_exists( + self.collection_name + ): + params = models.VectorParams( + size=len(texts_list[0].embedding), distance=models.Distance.COSINE # type: ignore[arg-type] + ) + self.client.create_collection( + self.collection_name, + vectors_config=( + {self.vector_name: params} if self.vector_name else params + ), + ) + + vectors = [ + {self.vector_name: t.embedding} if self.vector_name else t.embedding + for t in texts_list + ] + self.client.upload_collection( + collection_name=self.collection_name, vectors=vectors, wait=True + ) + + async def similarity_search( + self, query: str, k: int, embedding_model: EmbeddingModel + ) -> tuple[Sequence[Embeddable], list[float]]: + if not self.client.collection_exists(self.collection_name): + return ([], []) + + embedding_model.set_mode(EmbeddingModes.QUERY) + np_query = np.array((await embedding_model.embed_documents([query]))[0]) + embedding_model.set_mode(EmbeddingModes.DOCUMENT) + + points = self.client.query_points( + collection_name=self.collection_name, + query=np_query, + using=self.vector_name, + limit=k, + with_vectors=True, + ).points + + return ( + [ + p.vector[self.vector_name] if self.vector_name else p.vector + for p in points + ], + [p.score for p in points], + ) + + def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel: """ Factory function to create an appropriate EmbeddingModel based on the embedding string. diff --git a/pyproject.toml b/pyproject.toml index c047a2199..f6a4d8846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,9 @@ ldp = [ local = [ "sentence-transformers", ] +qdrant = [ + "qdrant-client", +] typing = [ "pandas-stubs", "types-PyYAML", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 072ba466d..0c6fc4a86 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -20,6 +20,7 @@ Docs, NumpyVectorStore, PQASession, + QdrantVectorStore, Settings, Text, print_callback, @@ -33,6 +34,7 @@ LiteLLMEmbeddingModel, LLMModel, SparseEmbeddingModel, + VectorStore, ) from paperqa.prompts import CANNOT_ANSWER_PHRASE from paperqa.prompts import qa_prompt as default_qa_prompt @@ -618,14 +620,17 @@ def test_duplicate(stub_data_dir: Path) -> None: ), "Unique documents should be hashed as unique" -def test_docs_with_custom_embedding(subtests: SubTests, stub_data_dir: Path) -> None: +@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore]) +def test_docs_with_custom_embedding( + subtests: SubTests, stub_data_dir: Path, vector_store: type[VectorStore] +) -> None: class MyEmbeds(EmbeddingModel): name: str = "my_embed" async def embed_documents(self, texts): return [[1, 2, 3] for _ in texts] - docs = Docs(texts_index=NumpyVectorStore()) + docs = Docs(texts_index=vector_store()) docs.add( stub_data_dir / "bates.txt", citation="WikiMedia Foundation, 2023, Accessed now", @@ -662,8 +667,9 @@ async def embed_documents(self, texts): assert docs.texts_index == docs_deep_copy.texts_index -def test_sparse_embedding(stub_data_dir: Path) -> None: - docs = Docs(texts_index=NumpyVectorStore()) +@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore]) +def test_sparse_embedding(stub_data_dir: Path, vector_store: type[VectorStore]) -> None: + docs = Docs(texts_index=vector_store()) docs.add( stub_data_dir / "bates.txt", citation="WikiMedia Foundation, 2023, Accessed now", @@ -680,11 +686,12 @@ def test_sparse_embedding(stub_data_dir: Path) -> None: assert np.shape(docs.texts[0].embedding) == np.shape(docs.texts[1].embedding) -def test_hybrid_embedding(stub_data_dir: Path) -> None: +@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore]) +def test_hybrid_embedding(stub_data_dir: Path, vector_store: type[VectorStore]) -> None: emb_model = HybridEmbeddingModel( models=[LiteLLMEmbeddingModel(), SparseEmbeddingModel()] ) - docs = Docs(texts_index=NumpyVectorStore()) + docs = Docs(texts_index=vector_store()) docs.add( stub_data_dir / "bates.txt", citation="WikiMedia Foundation, 2023, Accessed now", diff --git a/uv.lock b/uv.lock index 9dbd0158d..d47657df7 100644 --- a/uv.lock +++ b/uv.lock @@ -615,6 +615,81 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "grpcio" +version = "1.68.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d5/da/132615afbfc722df4bba963844843a205aa298fd5f9a03fa2995e8dddf11/grpcio-1.68.0.tar.gz", hash = "sha256:7e7483d39b4a4fddb9906671e9ea21aaad4f031cdfc349fec76bdfa1e404543a", size = 12682655 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/5f/019594ff8130ce84f9317cfc1e3d2c2beef2b74fd8822c5f1dfe237cb0d5/grpcio-1.68.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:3b2b559beb2d433129441783e5f42e3be40a9e1a89ec906efabf26591c5cd415", size = 5180685 }, + { url = "https://files.pythonhosted.org/packages/7b/59/34dae935bbb42f3e8929c90e9dfff49090cef412cf767cf4f14cd01ded18/grpcio-1.68.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e46541de8425a4d6829ac6c5d9b16c03c292105fe9ebf78cb1c31e8d242f9155", size = 11150577 }, + { url = "https://files.pythonhosted.org/packages/a6/5e/3df718124aadfc5d565c70ebe6a32c9ee747a9ccf211041596dd471fd763/grpcio-1.68.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c1245651f3c9ea92a2db4f95d37b7597db6b246d5892bca6ee8c0e90d76fb73c", size = 5685490 }, + { url = "https://files.pythonhosted.org/packages/4c/57/4e39ac1030875e0497debc9d5a4b3a1478ee1bd957ba4b87c27fcd7a3545/grpcio-1.68.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f1931c7aa85be0fa6cea6af388e576f3bf6baee9e5d481c586980c774debcb4", size = 6316329 }, + { url = "https://files.pythonhosted.org/packages/26/fe/9208707b0c07d28bb9f466340e4f052142fe40d54ea5c2d57870ba0d6860/grpcio-1.68.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b0ff09c81e3aded7a183bc6473639b46b6caa9c1901d6f5e2cba24b95e59e30", size = 5939890 }, + { url = "https://files.pythonhosted.org/packages/05/b9/e344bf744e095e2795fe942ce432add2d03761c3c440a5747705ff5b8efb/grpcio-1.68.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8c73f9fbbaee1a132487e31585aa83987ddf626426d703ebcb9a528cf231c9b1", size = 6644776 }, + { url = "https://files.pythonhosted.org/packages/ef/bf/0856c5fa93c3e1bd9f42da62a7aa6988c7a8f95f30dc4f9a3d631f75bb8e/grpcio-1.68.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6b2f98165ea2790ea159393a2246b56f580d24d7da0d0342c18a085299c40a75", size = 6211889 }, + { url = "https://files.pythonhosted.org/packages/63/40/eac5203baf7f45c56b16645c81a4c8ed515510fe81322371e8625758239b/grpcio-1.68.0-cp311-cp311-win32.whl", hash = "sha256:e1e7ed311afb351ff0d0e583a66fcb39675be112d61e7cfd6c8269884a98afbc", size = 3650597 }, + { url = "https://files.pythonhosted.org/packages/e4/31/120ec7132e6b82a0df91952f71aa0aa5e9f23d70152b58d96fac9b3e7cfe/grpcio-1.68.0-cp311-cp311-win_amd64.whl", hash = "sha256:e0d2f68eaa0a755edd9a47d40e50dba6df2bceda66960dee1218da81a2834d27", size = 4400445 }, + { url = "https://files.pythonhosted.org/packages/30/66/79508e13feee4182e6f2ea260ad4eea96b8b396bbf81334660142a6eecab/grpcio-1.68.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:8af6137cc4ae8e421690d276e7627cfc726d4293f6607acf9ea7260bd8fc3d7d", size = 5147575 }, + { url = "https://files.pythonhosted.org/packages/41/8d/19ffe12a736f57e9860bad506c0e711dd3c9c7c9f06030cfd87fa3eb6b45/grpcio-1.68.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4028b8e9a3bff6f377698587d642e24bd221810c06579a18420a17688e421af7", size = 11126767 }, + { url = "https://files.pythonhosted.org/packages/9c/c6/9aa8178d0fa3c893531a3ef38fa65a0e9997047ded9a8a20e3aa5706f923/grpcio-1.68.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f60fa2adf281fd73ae3a50677572521edca34ba373a45b457b5ebe87c2d01e1d", size = 5644649 }, + { url = "https://files.pythonhosted.org/packages/36/91/e2c451a103b8b595d3e3725fc78c76242d38a96cfe22dd9a47c31faba99d/grpcio-1.68.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e18589e747c1e70b60fab6767ff99b2d0c359ea1db8a2cb524477f93cdbedf5b", size = 6292623 }, + { url = "https://files.pythonhosted.org/packages/0b/5f/cbb2c0dfb3f7b893b30d6daca0a7829067f302c55f20b9c470111f48e6e3/grpcio-1.68.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0d30f3fee9372796f54d3100b31ee70972eaadcc87314be369360248a3dcffe", size = 5905873 }, + { url = "https://files.pythonhosted.org/packages/9d/37/ddc32a46baccac6a0a3cdcabd6908d23dfa526f061a1b81211fe029489c7/grpcio-1.68.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7e0a3e72c0e9a1acab77bef14a73a416630b7fd2cbd893c0a873edc47c42c8cd", size = 6630863 }, + { url = "https://files.pythonhosted.org/packages/45/69/4f74f67ae33be4422bd20050e09ad8b5318f8827a7eb153507de8fb78aef/grpcio-1.68.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a831dcc343440969aaa812004685ed322cdb526cd197112d0db303b0da1e8659", size = 6200368 }, + { url = "https://files.pythonhosted.org/packages/91/e9/25e51915cd972e8c66daf29644e653135f967d7411eccd2651fa347a6337/grpcio-1.68.0-cp312-cp312-win32.whl", hash = "sha256:5a180328e92b9a0050958ced34dddcb86fec5a8b332f5a229e353dafc16cd332", size = 3637786 }, + { url = "https://files.pythonhosted.org/packages/e2/1d/b1250907a727f08de6508d752f367e4b46d113d4eac9eb919ebd9da6a5d6/grpcio-1.68.0-cp312-cp312-win_amd64.whl", hash = "sha256:2bddd04a790b69f7a7385f6a112f46ea0b34c4746f361ebafe9ca0be567c78e9", size = 4390622 }, + { url = "https://files.pythonhosted.org/packages/fb/2d/d9cbdb75dc99141705f08474e97b181034c2e53a345d94b58e3c55f4dd92/grpcio-1.68.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:fc05759ffbd7875e0ff2bd877be1438dfe97c9312bbc558c8284a9afa1d0f40e", size = 5149697 }, + { url = "https://files.pythonhosted.org/packages/6f/37/a848871a5adba8cd571fa89e8aabc40ca0c475bd78b2e645e1649b20e095/grpcio-1.68.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:15fa1fe25d365a13bc6d52fcac0e3ee1f9baebdde2c9b3b2425f8a4979fccea1", size = 11084394 }, + { url = "https://files.pythonhosted.org/packages/1f/52/b09374aab9c9c2f66627ce7de39eef41d73670aa0f75286d91dcc22a2dd8/grpcio-1.68.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:32a9cb4686eb2e89d97022ecb9e1606d132f85c444354c17a7dbde4a455e4a3b", size = 5645417 }, + { url = "https://files.pythonhosted.org/packages/01/78/ec5ad7c44d7adaf0b932fd41ce8c59a95177a8c79c947c77204600b652db/grpcio-1.68.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dba037ff8d284c8e7ea9a510c8ae0f5b016004f13c3648f72411c464b67ff2fb", size = 6291062 }, + { url = "https://files.pythonhosted.org/packages/f7/7f/7f5a1a8dc63a42b78ca930d195eb0c97aa7a09e8553bb3a07b7cf37f6bc1/grpcio-1.68.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0efbbd849867e0e569af09e165363ade75cf84f5229b2698d53cf22c7a4f9e21", size = 5906505 }, + { url = "https://files.pythonhosted.org/packages/41/7b/0b048b8ad1a09fab5f4567fba2a569fb9106c4c1bb473c009c25659542cb/grpcio-1.68.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:4e300e6978df0b65cc2d100c54e097c10dfc7018b9bd890bbbf08022d47f766d", size = 6635069 }, + { url = "https://files.pythonhosted.org/packages/5e/c5/9f0ebc9cfba8309a15a9786c953ce99eaf4e1ca2df402b3c5ecf42493bd4/grpcio-1.68.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:6f9c7ad1a23e1047f827385f4713b5b8c6c7d325705be1dd3e31fb00dcb2f665", size = 6200683 }, + { url = "https://files.pythonhosted.org/packages/ce/e1/d3eba05299d5acdae6c11d056308b885f1d1be0b328baa8233d5d139ec1d/grpcio-1.68.0-cp313-cp313-win32.whl", hash = "sha256:3ac7f10850fd0487fcce169c3c55509101c3bde2a3b454869639df2176b60a03", size = 3637301 }, + { url = "https://files.pythonhosted.org/packages/3c/c1/decb2b368a54c00a6ee815c3f610903f36432e3cb591d43369319826b05e/grpcio-1.68.0-cp313-cp313-win_amd64.whl", hash = "sha256:afbf45a62ba85a720491bfe9b2642f8761ff348006f5ef67e4622621f116b04a", size = 4390939 }, +] + +[[package]] +name = "grpcio-tools" +version = "1.68.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "protobuf" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/81/40/47299f96fc21b9cd448cbebcbf174b1bedeaa1f82a1e7d4ed144d084d002/grpcio_tools-1.68.0.tar.gz", hash = "sha256:737804ec2225dd4cc27e633b4ca0e963b0795161bf678285fab6586e917fd867", size = 5275538 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/60/80a141ab5e3a747f400ba585be9b690e00a232167bf6909fccaedde17bab/grpcio_tools-1.68.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f65942fab440e99113ce14436deace7554d5aa554ea18358e3a5f3fc47efe322", size = 2342417 }, + { url = "https://files.pythonhosted.org/packages/c4/a2/78a4c5c3e3ae3bd209519da5a4fc6669a5f3d06423d466028d01e7fbbbce/grpcio_tools-1.68.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8fefc6d000e169a97336feded23ce614df3fb9926fc48c7a9ff8ea459d93b5b0", size = 5587871 }, + { url = "https://files.pythonhosted.org/packages/74/58/9da5fd8840d13389805bf52c347e6405665380244c01b26fb5580b743749/grpcio_tools-1.68.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:6dd69c9f3ff85eee8d1f71adf7023c638ca8d465633244ac1b7f19bc3668612d", size = 2306367 }, + { url = "https://files.pythonhosted.org/packages/2c/85/3fdd9bc501a6c0f251bda233fec114a115c82603b6535373a5e74c77400c/grpcio_tools-1.68.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7dc5195dc02057668cc22da1ff1aea1811f6fa0deb801b3194dec1fe0bab1cf0", size = 2679524 }, + { url = "https://files.pythonhosted.org/packages/da/21/f2ed730aa8a5e8f4ab7500d4863c6b2a1cbb33beaff717a01ddacff995db/grpcio_tools-1.68.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:849b12bec2320e49e988df104c92217d533e01febac172a4495caab36d9f0edc", size = 2425894 }, + { url = "https://files.pythonhosted.org/packages/58/c4/0bd72a59192cdb6c595c7dd72f3d48eccb5017be625459427dd798e3a381/grpcio_tools-1.68.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:766c2cd2e365e0fc0e559af56f2c2d144d95fd7cb8668a34d533e66d6435eb34", size = 3288925 }, + { url = "https://files.pythonhosted.org/packages/81/c5/ee3d0e45d24c716449b4d84485f7ea39f4a8e670717270fc2bee55b0b21b/grpcio_tools-1.68.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2ec3a2e0afa4866ccc5ba33c071aebaa619245dfdd840cbb74f2b0591868d085", size = 2903913 }, + { url = "https://files.pythonhosted.org/packages/13/7f/85e1ac0a4c4d23a89d6d569f516b39f5a0467b6069fe967382ede41341d2/grpcio_tools-1.68.0-cp311-cp311-win32.whl", hash = "sha256:80b733014eb40d920d836d782e5cdea0dcc90d251a2ffb35ab378ef4f8a42c14", size = 946129 }, + { url = "https://files.pythonhosted.org/packages/48/64/591a4fe11fabc4c43780921b3e72233462810b893240f447cea0dec953ce/grpcio_tools-1.68.0-cp311-cp311-win_amd64.whl", hash = "sha256:f95103e3e4e7fee7c6123bc9e4e925e07ad24d8d09d7c1c916fb6c8d1cb9e726", size = 1097286 }, + { url = "https://files.pythonhosted.org/packages/b1/da/986224ace81c96a693f0e972b7cb330af06625dc57849aff9dcc95c98afa/grpcio_tools-1.68.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:dd9a654af8536b3de8525bff72a245fef62d572eabf96ac946fe850e707cb27d", size = 2342316 }, + { url = "https://files.pythonhosted.org/packages/8e/55/25a9a8e47d0b7f0551309bb9af641f04d076e2995e10866b5e08d0d73628/grpcio_tools-1.68.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0f77957e3a0916a0dd18d57ce6b49d95fc9a5cfed92310f226339c0fda5394f6", size = 5585973 }, + { url = "https://files.pythonhosted.org/packages/a5/db/518695c93b86db44eef2445c245b51f8d3c7413cb22941b4ce5fc0377dc7/grpcio_tools-1.68.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:92a09afe64fe26696595de2036e10967876d26b12c894cc9160f00152cacebe7", size = 2306181 }, + { url = "https://files.pythonhosted.org/packages/e8/18/8e395bea3f1ea1da49ca99685e670ec21251e8b6a6d37ced266109b33c32/grpcio_tools-1.68.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28ebdbad2ef16699d07400b65260240851049a75502eff69a59b127d3ab960f1", size = 2679660 }, + { url = "https://files.pythonhosted.org/packages/e5/f8/7b0bc247c3607c5a3a5f09c81d37b887f684cb3863837eaeacc24835a951/grpcio_tools-1.68.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d3150d784d8050b10dcf5eb06e04fb90747a1547fed3a062a608d940fe57066", size = 2425466 }, + { url = "https://files.pythonhosted.org/packages/64/6a/91f8948b34c245b06ed738a49e0f29948168ecca967aee653f70cd8e9009/grpcio_tools-1.68.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:261d98fd635595de42aadee848f9af46da6654d63791c888891e94f66c5d0682", size = 3289408 }, + { url = "https://files.pythonhosted.org/packages/97/d7/5ff90d41e8036cbcac4c2b4f53d303b778d23f74a3dbb40c625fc0f3e475/grpcio_tools-1.68.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:061345c0079b9471f32230186ab01acb908ea0e577bc1699a8cf47acef8be4af", size = 2903935 }, + { url = "https://files.pythonhosted.org/packages/e1/26/c360f9ce0a0a49f375f2c487ba91daeb85e519ea6e1f9eed04781faabb12/grpcio_tools-1.68.0-cp312-cp312-win32.whl", hash = "sha256:533ce6791a5ba21e35d74c6c25caf4776f5692785a170c01ea1153783ad5af31", size = 946040 }, + { url = "https://files.pythonhosted.org/packages/13/96/dbc239492dac0abad04de84578a068b72f3bdff4c5afbc38a9587738b2ef/grpcio_tools-1.68.0-cp312-cp312-win_amd64.whl", hash = "sha256:56842a0ce74b4b92eb62cd5ee00181b2d3acc58ba0c4fd20d15a5db51f891ba6", size = 1096722 }, + { url = "https://files.pythonhosted.org/packages/29/63/ccdcb96d0f3a473b457f9b1cc78adb0f1226d7fed6cfffbbdeeb3ce88fbb/grpcio_tools-1.68.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:1117a81592542f0c36575082daa6413c57ca39188b18a4c50ec7332616f4b97e", size = 2342127 }, + { url = "https://files.pythonhosted.org/packages/4d/28/1bbc4cd976f518bd45c1c1ec0d1d0a3db35adcdaf5245cbaaa95c2fdf548/grpcio_tools-1.68.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:51e5a090849b30c99a2396d42140b8a3e558eff6cdfa12603f9582e2cd07724e", size = 5573889 }, + { url = "https://files.pythonhosted.org/packages/77/80/2ccdf2fd60b5ab822ff800c315afd5cbaf9368a58882b802cb64865740bb/grpcio_tools-1.68.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:4fe611d89a1836df8936f066d39c7eb03d4241806449ec45d4b8e1c843ae8011", size = 2305568 }, + { url = "https://files.pythonhosted.org/packages/1a/9c/3b8e73b8e60aaacda101a4adfdec837e60a03a1dbf54c7b80f85ceff0c9c/grpcio_tools-1.68.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c10f3faa0cc4d89eb546f53b623837af23e86dc495d3b89510bcc0e0a6c0b8b2", size = 2678660 }, + { url = "https://files.pythonhosted.org/packages/ed/95/19a545674b81ad8b8783807a125f8b51210c29ab0cea6e79a2d21c0077c1/grpcio_tools-1.68.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46b537480b8fd2195d988120a28467601a2a3de2e504043b89fb90318e1eb754", size = 2425008 }, + { url = "https://files.pythonhosted.org/packages/00/c7/9da961471f7ec6f3d437e2bf91fec0247315c0f1151e2412e6d08852f3d4/grpcio_tools-1.68.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:17d0c9004ea82b4213955a585401e80c30d4b37a1d4ace32ccdea8db4d3b7d43", size = 3288874 }, + { url = "https://files.pythonhosted.org/packages/f8/22/2147b3c104cba9dda2a28a375f88b27d86fd5f25e249d8e8547ca0ea04ef/grpcio_tools-1.68.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:2919faae04fe47bad57fc9b578aeaab527da260e851f321a253b6b11862254a8", size = 2903220 }, + { url = "https://files.pythonhosted.org/packages/d7/d8/c9e8bd2bf3128608e14bb28266a0d587ebca8bfd8279b956da1f0f939270/grpcio_tools-1.68.0-cp313-cp313-win32.whl", hash = "sha256:ee86157ef899f58ba2fe1055cce0d33bd703e99aa6d5a0895581ac3969f06bfa", size = 945292 }, + { url = "https://files.pythonhosted.org/packages/f1/0d/99bd17898a923d40869a54f80bd79ff1013ef9c014d778c7750aa4493809/grpcio_tools-1.68.0-cp313-cp313-win_amd64.whl", hash = "sha256:d0470ffc6a93c86cdda48edd428d22e2fef17d854788d60d0d5f291038873157", size = 1096010 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -624,6 +699,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, ] +[[package]] +name = "h2" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/32/fec683ddd10629ea4ea46d206752a95a2d8a48c22521edd70b142488efe1/h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb", size = 2145593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e5/db6d438da759efbb488c4f3fbdab7764492ff3c3f953132efa6b9f0e9e53/h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d", size = 57488 }, +] + +[[package]] +name = "hpack" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/9b/fda93fb4d957db19b0f6b370e79d586b3e8528b20252c729c476a2c02954/hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095", size = 49117 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/34/e8b383f35b77c402d28563d2b8f83159319b509bc5f760b15d60b0abf165/hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c", size = 32611 }, +] + [[package]] name = "html2text" version = "2024.2.26" @@ -659,6 +756,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/95/9377bcb415797e44274b51d46e3249eba641711cf3348050f76ee7b15ffc/httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", size = 76395 }, ] +[package.optional-dependencies] +http2 = [ + { name = "h2" }, +] + [[package]] name = "huggingface-hub" version = "0.26.2" @@ -677,6 +779,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/bf/cea0b9720c32fa01b0c4ec4b16b9f4ae34ca106b202ebbae9f03ab98cd8f/huggingface_hub-0.26.2-py3-none-any.whl", hash = "sha256:98c2a5a8e786c7b2cb6fdeb2740893cba4d53e312572ed3d8afafda65b128c46", size = 447536 }, ] +[[package]] +name = "hyperframe" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/2a/4747bff0a17f7281abe73e955d60d80aae537a5d203f417fa1c2e7578ebb/hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914", size = 25008 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/de/85a784bcc4a3779d1753a7ec2dee5de90e18c7bcf402e71b51fcf150b129/hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15", size = 12389 }, +] + [[package]] name = "identify" version = "2.6.3" @@ -1511,7 +1622,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.5.1.dev10+gfae5848.d20241128" +version = "5.5.1.dev10+g5d930a4.d20241128" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1544,6 +1655,9 @@ ldp = [ local = [ { name = "sentence-transformers" }, ] +qdrant = [ + { name = "qdrant-client" }, +] typing = [ { name = "pandas-stubs" }, { name = "types-pyyaml" }, @@ -1599,6 +1713,7 @@ requires-dist = [ { name = "pydantic-settings" }, { name = "pymupdf", specifier = ">=1.24.12" }, { name = "pyzotero", marker = "extra == 'zotero'" }, + { name = "qdrant-client", marker = "extra == 'qdrant'" }, { name = "rich" }, { name = "sentence-transformers", marker = "extra == 'local'" }, { name = "setuptools" }, @@ -1718,6 +1833,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "portalocker" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/fb/a70a4214956182e0d7a9099ab17d50bfcba1056188e9b14f35b9e2b62a0d/portalocker-2.10.1-py3-none-any.whl", hash = "sha256:53a5984ebc86a025552264b459b46a2086e269b21823cb572f8f28ee759e45bf", size = 18423 }, +] + [[package]] name = "pre-commit" version = "4.0.1" @@ -1803,6 +1930,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3d/b6/e6d98278f2d49b22b4d033c9f792eda783b9ab2094b041f013fc69bcde87/propcache-0.2.0-py3-none-any.whl", hash = "sha256:2ccc28197af5313706511fab3a8b66dcd6da067a1331372c82ea1cb74285e036", size = 11603 }, ] +[[package]] +name = "protobuf" +version = "5.29.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/bb/8e59a30b83102a37d24f907f417febb58e5f544d4f124dd1edcd12e078bf/protobuf-5.29.0.tar.gz", hash = "sha256:445a0c02483869ed8513a585d80020d012c6dc60075f96fa0563a724987b1001", size = 424944 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/cc/98140acbcc3e3a58c679d50dd4f04c3687bdd19690f388c65bb1ae4c1e5e/protobuf-5.29.0-cp310-abi3-win32.whl", hash = "sha256:ea7fb379b257911c8c020688d455e8f74efd2f734b72dc1ea4b4d7e9fd1326f2", size = 422709 }, + { url = "https://files.pythonhosted.org/packages/c9/91/38fb97b0cbe96109fa257536ad49dffdac3c8f86b46d9c85dc9e949b5291/protobuf-5.29.0-cp310-abi3-win_amd64.whl", hash = "sha256:34a90cf30c908f47f40ebea7811f743d360e202b6f10d40c02529ebd84afc069", size = 434510 }, + { url = "https://files.pythonhosted.org/packages/da/97/faeca508d61b231372cdc3006084fd97f21f3c8c726a2de5f2ebb8e4ab78/protobuf-5.29.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:c931c61d0cc143a2e756b1e7f8197a508de5365efd40f83c907a9febf36e6b43", size = 417827 }, + { url = "https://files.pythonhosted.org/packages/eb/d6/c6a45a285374ab14499a9ef5a69e4e7b4911e641465681c1d602518d6ab2/protobuf-5.29.0-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:85286a47caf63b34fa92fdc1fd98b649a8895db595cfa746c5286eeae890a0b1", size = 319576 }, + { url = "https://files.pythonhosted.org/packages/ee/2e/cc46181ddce0940647d21a8341bf2eddad247a5d030e8c30c7a342793978/protobuf-5.29.0-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:0d10091d6d03537c3f902279fcf11e95372bdd36a79556311da0487455791b20", size = 319672 }, + { url = "https://files.pythonhosted.org/packages/7c/6c/dd1f0e8372ec2a8006102871d8da1466b116f3328db96972e19bf24f09ca/protobuf-5.29.0-py3-none-any.whl", hash = "sha256:88c4af76a73183e21061881360240c0cdd3c39d263b4e8fb570aaf83348d608f", size = 172553 }, +] + [[package]] name = "ptyprocess" version = "0.7.0" @@ -2244,6 +2385,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/3c/717f90930fbba6ec433a8b6eb9c8854089b7cd54118fb3dd5822d53bdfc7/pyzotero-1.5.25-py3-none-any.whl", hash = "sha256:6529130cb7c7e773963d92db7e7a18b698a97ac8156766320b55ebd4c7e94ed5", size = 22966 }, ] +[[package]] +name = "qdrant-client" +version = "1.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "grpcio-tools" }, + { name = "httpx", extra = ["http2"] }, + { name = "numpy" }, + { name = "portalocker" }, + { name = "pydantic" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/5e/ec560881e086f893947c8798949c72de5cfae9453fd05c2250f8dfeaa571/qdrant_client-1.12.1.tar.gz", hash = "sha256:35e8e646f75b7b883b3d2d0ee4c69c5301000bba41c82aa546e985db0f1aeb72", size = 237441 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/c0/eef4fe9dad6d41333f7dc6567fa8144ffc1837c8a0edfc2317d50715335f/qdrant_client-1.12.1-py3-none-any.whl", hash = "sha256:b2d17ce18e9e767471368380dd3bbc4a0e3a0e2061fedc9af3542084b48451e0", size = 267171 }, +] + [[package]] name = "referencing" version = "0.35.1" From 87cdc860fecd69d31ece1083d6a25f8845858092 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Thu, 28 Nov 2024 15:02:17 +0530 Subject: [PATCH 02/10] ci: Configure tests Signed-off-by: Anush008 --- .github/workflows/tests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c70eac650..6bf47cb6b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,11 @@ jobs: - run: uv run pylint paperqa test: runs-on: ubuntu-latest + services: + qdrant: + image: qdrant/qdrant + ports: + - 6333:6333 strategy: matrix: python-version: [3.11, 3.12] # Our min and max supported Python versions @@ -49,7 +54,7 @@ jobs: with: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - - run: uv sync --python-preference=only-managed + - run: uv sync --python-preference=only-managed --extra qdrant - name: Cache datasets uses: actions/cache@v4 with: From d15f3f38160399bae3af07c3d28ce9743ab50e8c Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 4 Dec 2024 00:02:04 +0530 Subject: [PATCH 03/10] chore: Review updates Signed-off-by: Anush008 --- .pre-commit-config.yaml | 1 + paperqa/llms.py | 20 ++++++++++++++++---- uv.lock | 10 ++++++++-- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59ff8df26..47410bee7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -96,6 +96,7 @@ repos: - pandas-stubs - pydantic~=2.0,>=2.10.1 # Match pyproject.toml - pydantic-settings + - qdrant-client - rich - tantivy - tenacity diff --git a/paperqa/llms.py b/paperqa/llms.py index 361a3cbe4..e0540e622 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -37,6 +37,13 @@ from paperqa.types import Embeddable, LLMResult from paperqa.utils import is_coroutine_callable +try: + from qdrant_client import QdrantClient, models + + qdrant_installed = True +except ImportError: + qdrant_installed = False + PromptRunner = Callable[ [dict, list[Callable[[str], None]] | None, str | None], Awaitable[LLMResult], @@ -912,14 +919,20 @@ async def similarity_search( class QdrantVectorStore(VectorStore): client: Any = Field( - default=None, description="Instance of qdrant_client.QdrantClient" + default=None, + description="Instance of `qdrant_client.QdrantClient`. Tries to connect to http://localhost:6333/ by default.", ) collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") vector_name: str | None = Field(default=None) @model_validator(mode="after") def validate_client(self): - from qdrant_client import QdrantClient + if not qdrant_installed: + msg = ( + "`QdrantVectorStore` requires the `qdrant-client` package. " + "Install it with `pip install paper-qa[qdrant]`" + ) + raise ImportError(msg) if self.client and not isinstance(self.client, QdrantClient): raise TypeError( @@ -927,13 +940,13 @@ def validate_client(self): ) if not self.client: + # The default instance connects to http://localhost:6333/ self.client = QdrantClient() return self def clear(self) -> None: super().clear() - from qdrant_client import models if not self.client.collection_exists(self.collection_name): return @@ -946,7 +959,6 @@ def clear(self) -> None: def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: super().add_texts_and_embeddings(texts) - from qdrant_client import models texts_list = list(texts) diff --git a/uv.lock b/uv.lock index 5c2a90b08..e4c6901a4 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.5.1.dev15+gd03d424" +version = "5.6.1.dev6+ge347a0b.d20241203" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1658,9 +1658,11 @@ zotero = [ [package.dev-dependencies] dev = [ + { name = "datasets" }, { name = "ipython" }, + { name = "ldp" }, { name = "mypy" }, - { name = "paper-qa", extra = ["datasets", "ldp", "local", "typing", "zotero"] }, + { name = "pandas-stubs" }, { name = "pre-commit" }, { name = "pydantic" }, { name = "pylint-pydantic" }, @@ -1673,8 +1675,12 @@ dev = [ { name = "pytest-timer", extra = ["colorama"] }, { name = "pytest-xdist" }, { name = "python-dotenv" }, + { name = "pyzotero" }, { name = "refurb" }, + { name = "sentence-transformers" }, { name = "typeguard" }, + { name = "types-pyyaml" }, + { name = "types-setuptools" }, ] [package.metadata] From fa1bc9415f7444dab77e5c2a76d21949b17dcede Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 4 Dec 2024 01:29:22 +0530 Subject: [PATCH 04/10] chore: Missed stashed commit Signed-off-by: Anush008 --- paperqa/llms.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index e0540e622..d4ebfb214 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -34,7 +34,7 @@ from paperqa.prompts import default_system_prompt from paperqa.rate_limiter import GLOBAL_LIMITER -from paperqa.types import Embeddable, LLMResult +from paperqa.types import Embeddable, LLMResult, Text from paperqa.utils import is_coroutine_callable try: @@ -911,6 +911,7 @@ async def similarity_search( # we could use arg-partition here # but a lot of algorithms expect a sorted list sorted_indices = np.argsort(-similarity_scores) + return ( [self.texts[i] for i in sorted_indices[:k]], [similarity_scores[i] for i in sorted_indices[:k]], @@ -925,6 +926,15 @@ class QdrantVectorStore(VectorStore): collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") vector_name: str | None = Field(default=None) + def __eq__(self, other) -> bool: + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.collection_name == other.collection_name + and self.vector_name == other.vector_name + and self.client.init_options == other.client.init_options + ) + @model_validator(mode="after") def validate_client(self): if not qdrant_installed: @@ -975,12 +985,16 @@ def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: ), ) + payload = [t.model_dump(exclude={"embedding"}) for t in texts_list] vectors = [ {self.vector_name: t.embedding} if self.vector_name else t.embedding for t in texts_list ] self.client.upload_collection( - collection_name=self.collection_name, vectors=vectors, wait=True + collection_name=self.collection_name, + vectors=vectors, + payload=payload, + wait=True, ) async def similarity_search( @@ -999,11 +1013,17 @@ async def similarity_search( using=self.vector_name, limit=k, with_vectors=True, + with_payload=True, ).points return ( [ - p.vector[self.vector_name] if self.vector_name else p.vector + Text( + **p.payload, + embedding=( + p.vector[self.vector_name] if self.vector_name else p.vector + ), + ) for p in points ], [p.score for p in points], From cfa6cfcf5f7f174e20b4b1cad209b2a4466c88d4 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 4 Dec 2024 03:23:37 +0530 Subject: [PATCH 05/10] chore: Fix tests Signed-off-by: Anush008 --- .github/workflows/tests.yml | 7 +------ paperqa/llms.py | 13 ++++++++----- pyproject.toml | 2 +- tests/test_paperqa.py | 6 ++++-- uv.lock | 5 +++-- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6bf47cb6b..c70eac650 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,11 +40,6 @@ jobs: - run: uv run pylint paperqa test: runs-on: ubuntu-latest - services: - qdrant: - image: qdrant/qdrant - ports: - - 6333:6333 strategy: matrix: python-version: [3.11, 3.12] # Our min and max supported Python versions @@ -54,7 +49,7 @@ jobs: with: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - - run: uv sync --python-preference=only-managed --extra qdrant + - run: uv sync --python-preference=only-managed - name: Cache datasets uses: actions/cache@v4 with: diff --git a/paperqa/llms.py b/paperqa/llms.py index d4ebfb214..ab972ff99 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -921,7 +921,7 @@ async def similarity_search( class QdrantVectorStore(VectorStore): client: Any = Field( default=None, - description="Instance of `qdrant_client.QdrantClient`. Tries to connect to http://localhost:6333/ by default.", + description="Instance of `qdrant_client.QdrantClient`. Defaults to an in-memory instance.", ) collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") vector_name: str | None = Field(default=None) @@ -929,8 +929,11 @@ class QdrantVectorStore(VectorStore): def __eq__(self, other) -> bool: if not isinstance(other, type(self)): return NotImplemented + return ( - self.collection_name == other.collection_name + self.texts_hashes == other.texts_hashes + and self.mmr_lambda == other.mmr_lambda + and self.collection_name == other.collection_name and self.vector_name == other.vector_name and self.client.init_options == other.client.init_options ) @@ -946,12 +949,12 @@ def validate_client(self): if self.client and not isinstance(self.client, QdrantClient): raise TypeError( - f"'client' should be an instance of qdrant_client.QdrantClient. Got {type(self.client)}" + f"'client' should be an instance of `qdrant_client.QdrantClient`. Got `{type(self.client)}`" ) if not self.client: - # The default instance connects to http://localhost:6333/ - self.client = QdrantClient() + # Defaults to the Python based in-memory implementation. + self.client = QdrantClient(location=":memory:") return self diff --git a/pyproject.toml b/pyproject.toml index a51267d04..03607a396 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = ["setuptools>=64", "setuptools_scm>=8"] dev = [ "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override - "paper-qa[datasets,ldp,typing,zotero,local]", + "paper-qa[datasets,ldp,local,qdrant,typing,zotero]", "pre-commit>=3.4", # Pin to keep recent "pydantic~=2.0", "pylint-pydantic", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index e45dcdd63..8bff40257 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -631,7 +631,7 @@ class MyEmbeds(EmbeddingModel): name: str = "my_embed" async def embed_documents(self, texts): - return [[1, 2, 3] for _ in texts] + return [[0.0, 0.28, 0.95] for _ in texts] docs = Docs(texts_index=vector_store()) docs.add( @@ -640,7 +640,7 @@ async def embed_documents(self, texts): embedding_model=MyEmbeds(), ) with subtests.test(msg="confirm-embedding"): - assert docs.texts[0].embedding == [1, 2, 3] + assert docs.texts[0].embedding == [0.0, 0.28, 0.95] with subtests.test(msg="copying-before-get-evidence"): # Before getting evidence, shallow and deep copies are the same @@ -649,6 +649,7 @@ async def embed_documents(self, texts): **docs.model_dump(exclude={"texts_index"}), ) docs_deep_copy = deepcopy(docs) + assert ( docs.texts_index == docs_shallow_copy.texts_index @@ -666,6 +667,7 @@ async def embed_documents(self, texts): **docs.model_dump(exclude={"texts_index"}), ) docs_deep_copy = deepcopy(docs) + assert docs.texts_index != docs_shallow_copy.texts_index assert docs.texts_index == docs_deep_copy.texts_index diff --git a/uv.lock b/uv.lock index e4c6901a4..eceb83ee8 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.6.1.dev6+ge347a0b.d20241203" +version = "5.6.1.dev8+gfa1bc94.d20241203" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1676,6 +1676,7 @@ dev = [ { name = "pytest-xdist" }, { name = "python-dotenv" }, { name = "pyzotero" }, + { name = "qdrant-client" }, { name = "refurb" }, { name = "sentence-transformers" }, { name = "typeguard" }, @@ -1717,7 +1718,7 @@ requires-dist = [ dev = [ { name = "ipython", specifier = ">=8" }, { name = "mypy", specifier = ">=1.8" }, - { name = "paper-qa", extras = ["datasets", "ldp", "typing", "zotero", "local"] }, + { name = "paper-qa", extras = ["datasets", "ldp", "local", "qdrant", "typing", "zotero"] }, { name = "pre-commit", specifier = ">=3.4" }, { name = "pydantic", specifier = "~=2.0" }, { name = "pylint-pydantic" }, From 9836e09dec8c5ffa6c44c30e4d4754584503b961 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Thu, 5 Dec 2024 12:12:35 +0530 Subject: [PATCH 06/10] fix: unique _ids identifier Signed-off-by: Anush008 --- paperqa/llms.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index ab972ff99..a07ac8bf6 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -925,6 +925,7 @@ class QdrantVectorStore(VectorStore): ) collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") vector_name: str | None = Field(default=None) + _ids: list[str] = [] def __eq__(self, other) -> bool: if not isinstance(other, type(self)): @@ -936,6 +937,7 @@ def __eq__(self, other) -> bool: and self.collection_name == other.collection_name and self.vector_name == other.vector_name and self.client.init_options == other.client.init_options + and self._ids == other._ids ) @model_validator(mode="after") @@ -969,6 +971,7 @@ def clear(self) -> None: points_selector=models.Filter(must=[]), wait=True, ) + self._ids = [] def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: super().add_texts_and_embeddings(texts) @@ -988,17 +991,26 @@ def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: ), ) - payload = [t.model_dump(exclude={"embedding"}) for t in texts_list] - vectors = [ - {self.vector_name: t.embedding} if self.vector_name else t.embedding - for t in texts_list - ] + ids, payloads, vectors = [], [], [] + for text in texts_list: + # Entries with same IDs are overwritten. + # We generate deterministic UUIDs based on the embedding vectors. + ids.append(uuid.uuid5(uuid.NAMESPACE_URL, str(text.embedding)).hex) + payloads.append(text.model_dump(exclude={"embedding"})) + vectors.append( + {self.vector_name: text.embedding} + if self.vector_name + else text.embedding + ) + self.client.upload_collection( collection_name=self.collection_name, vectors=vectors, - payload=payload, + payload=payloads, wait=True, + ids=ids, ) + self._ids = ids async def similarity_search( self, query: str, k: int, embedding_model: EmbeddingModel From fdd769b2537499fca6a2317b054ca26aab1975a6 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Thu, 5 Dec 2024 13:31:11 +0530 Subject: [PATCH 07/10] chore: lockfile Signed-off-by: Anush008 --- pyproject.toml | 2 +- uv.lock | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23d3af2bd..5d0363e7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ datasets = [ dev = [ "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override - "paper-qa[datasets,ldp,typing,zotero,local]", + "paper-qa[datasets,ldp,typing,zotero,local,qdrant]", "pre-commit>=3.4", # Pin to keep recent "pydantic~=2.0", "pylint-pydantic", diff --git a/uv.lock b/uv.lock index 10a9d5944..57b924540 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.7.1.dev1+g60e3e39.d20241204" +version = "5.7.1.dev10+g654ff1b.d20241205" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1657,6 +1657,7 @@ dev = [ { name = "pytest-xdist" }, { name = "python-dotenv" }, { name = "pyzotero" }, + { name = "qdrant-client" }, { name = "refurb" }, { name = "sentence-transformers" }, { name = "typeguard" }, @@ -1701,6 +1702,7 @@ dev = [ { name = "pytest-xdist" }, { name = "python-dotenv" }, { name = "pyzotero" }, + { name = "qdrant-client" }, { name = "refurb" }, { name = "sentence-transformers" }, { name = "typeguard" }, @@ -1724,7 +1726,7 @@ requires-dist = [ { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, { name = "numpy" }, { name = "pandas-stubs", marker = "extra == 'typing'" }, - { name = "paper-qa", extras = ["datasets", "ldp", "local", "typing", "zotero"], marker = "extra == 'dev'" }, + { name = "paper-qa", extras = ["datasets", "ldp", "local", "qdrant", "typing", "zotero"], marker = "extra == 'dev'" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.4" }, { name = "pybtex" }, { name = "pydantic", specifier = "~=2.0,>=2.10.1" }, @@ -1742,6 +1744,7 @@ requires-dist = [ { name = "pytest-xdist", marker = "extra == 'dev'" }, { name = "python-dotenv", marker = "extra == 'dev'" }, { name = "pyzotero", marker = "extra == 'zotero'" }, + { name = "qdrant-client", marker = "extra == 'qdrant'" }, { name = "refurb", marker = "extra == 'dev'", specifier = ">=2" }, { name = "rich" }, { name = "sentence-transformers", marker = "extra == 'local'" }, From f87fea609727f29c2d91bfbe5ebef91fe60439b7 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 6 Dec 2024 12:05:37 +0530 Subject: [PATCH 08/10] chore: That one refurb Signed-off-by: Anush008 --- paperqa/llms.py | 4 +--- uv.lock | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index a07ac8bf6..55206a057 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -978,9 +978,7 @@ def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: texts_list = list(texts) - if len(texts_list) > 0 and not self.client.collection_exists( - self.collection_name - ): + if texts_list and not self.client.collection_exists(self.collection_name): params = models.VectorParams( size=len(texts_list[0].embedding), distance=models.Distance.COSINE # type: ignore[arg-type] ) diff --git a/uv.lock b/uv.lock index 57b924540..39ac1fbec 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.7.1.dev10+g654ff1b.d20241205" +version = "5.7.1.dev12+g8d575c0" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From 9143b5c76d588fd4eeb5b2c8fa2986d09b100d25 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 6 Dec 2024 23:46:54 +0530 Subject: [PATCH 09/10] fix: try .mailmap Signed-off-by: Anush008 --- .mailmap | 1 + uv.lock | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.mailmap b/.mailmap index de79b152a..493e3378a 100644 --- a/.mailmap +++ b/.mailmap @@ -6,3 +6,4 @@ Michael Skarlinski mskarlin <12701035+mskarlin@use Odhran O'Donoghue odhran-o-d Odhran O'Donoghue <39832722+odhran-o-d@users.noreply.github.com> Samantha Cox +Anush008 Anush diff --git a/uv.lock b/uv.lock index 39ac1fbec..3e5694a64 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.7.1.dev12+g8d575c0" +version = "5.7.1.dev12+gf87fea6" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From 8b59afdf8fba71a67d849b5fa21ee468dc652bdc Mon Sep 17 00:00:00 2001 From: Anush008 Date: Sat, 7 Dec 2024 14:54:06 +0530 Subject: [PATCH 10/10] refactor: _point_ids Signed-off-by: Anush008 --- paperqa/llms.py | 9 ++++----- uv.lock | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 55206a057..ae29ea0e5 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -911,7 +911,6 @@ async def similarity_search( # we could use arg-partition here # but a lot of algorithms expect a sorted list sorted_indices = np.argsort(-similarity_scores) - return ( [self.texts[i] for i in sorted_indices[:k]], [similarity_scores[i] for i in sorted_indices[:k]], @@ -925,7 +924,7 @@ class QdrantVectorStore(VectorStore): ) collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}") vector_name: str | None = Field(default=None) - _ids: list[str] = [] + _point_ids: set[str] | None = None def __eq__(self, other) -> bool: if not isinstance(other, type(self)): @@ -937,7 +936,7 @@ def __eq__(self, other) -> bool: and self.collection_name == other.collection_name and self.vector_name == other.vector_name and self.client.init_options == other.client.init_options - and self._ids == other._ids + and self._point_ids == other._point_ids ) @model_validator(mode="after") @@ -971,7 +970,7 @@ def clear(self) -> None: points_selector=models.Filter(must=[]), wait=True, ) - self._ids = [] + self._point_ids = None def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: super().add_texts_and_embeddings(texts) @@ -1008,7 +1007,7 @@ def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None: wait=True, ids=ids, ) - self._ids = ids + self._point_ids = set(ids) async def similarity_search( self, query: str, k: int, embedding_model: EmbeddingModel diff --git a/uv.lock b/uv.lock index 3e5694a64..9d9ee3eb6 100644 --- a/uv.lock +++ b/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.7.1.dev12+gf87fea6" +version = "5.7.1.dev13+g9143b5c.d20241207" source = { editable = "." } dependencies = [ { name = "aiohttp" },