Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Michael Skarlinski <mskarlinski@futurehouse.org> mskarlin <12701035+mskarlin@use
Odhran O'Donoghue <odhran.r.odonoghue@gmail.com> odhran-o-d <odhran.r.odonoghue@gmail.com>
Odhran O'Donoghue <odhran.r.odonoghue@gmail.com> <39832722+odhran-o-d@users.noreply.github.com>
Samantha Cox <samc@futurehouse.org> <swrig30@ur.rochester.edu>
Anush008 <anushshetty90@gmail.com> Anush <anushshetty90@gmail.com>
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ repos:
- pandas-stubs
- pydantic~=2.0,>=2.10.1 # Match pyproject.toml
- pydantic-settings
- qdrant-client
- rich
- tantivy
- tenacity
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ for doc in ("myfile.pdf", "myotherfile.pdf"):
Note that PaperQA2 uses Numpy as a dense vector store.
Its design of using a keyword search initially reduces the number of chunks needed for each answer to a relatively small number < 1k.
Therefore, `NumpyVectorStore` is a good place to start, it's a simple in-memory store, without an index.
However, if a larger-than-memory vector store is needed, we are currently lacking here.
However, if a larger-than-memory vector store is needed, you can an external vector database like [Qdrant](https://qdrant.tech/) via the `QdrantVectorStore` class.

The hybrid embeddings can be customized:

Expand Down
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LLMModel,
LLMResult,
NumpyVectorStore,
QdrantVectorStore,
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
Expand All @@ -40,6 +41,7 @@
"LiteLLMModel",
"NumpyVectorStore",
"PQASession",
"QdrantVectorStore",
"QueryRequest",
"SentenceTransformerEmbeddingModel",
"Settings",
Expand Down
135 changes: 134 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import itertools
import logging
import uuid
from abc import ABC, abstractmethod
from collections.abc import (
AsyncGenerator,
Expand Down Expand Up @@ -35,9 +36,16 @@

from paperqa.prompts import default_system_prompt
from paperqa.rate_limiter import GLOBAL_LIMITER
from paperqa.types import Embeddable, LLMResult
from paperqa.types import Embeddable, LLMResult, Text
from paperqa.utils import is_coroutine_callable

try:
from qdrant_client import QdrantClient, models

qdrant_installed = True
except ImportError:
qdrant_installed = False

PromptRunner = Callable[
[dict, list[Callable[[str], None]] | None, str | None],
Awaitable[LLMResult],
Expand Down Expand Up @@ -994,6 +1002,131 @@ async def similarity_search(
)


class QdrantVectorStore(VectorStore):
client: Any = Field(
default=None,
description="Instance of `qdrant_client.QdrantClient`. Defaults to an in-memory instance.",
)
collection_name: str = Field(default_factory=lambda: f"paper-qa-{uuid.uuid4().hex}")
vector_name: str | None = Field(default=None)
_point_ids: set[str] | None = None

def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
return NotImplemented

return (
self.texts_hashes == other.texts_hashes
and self.mmr_lambda == other.mmr_lambda
and self.collection_name == other.collection_name
and self.vector_name == other.vector_name
and self.client.init_options == other.client.init_options
and self._point_ids == other._point_ids
)

@model_validator(mode="after")
def validate_client(self):
if not qdrant_installed:
msg = (
"`QdrantVectorStore` requires the `qdrant-client` package. "
"Install it with `pip install paper-qa[qdrant]`"
)
raise ImportError(msg)

if self.client and not isinstance(self.client, QdrantClient):
raise TypeError(
f"'client' should be an instance of `qdrant_client.QdrantClient`. Got `{type(self.client)}`"
)

if not self.client:
# Defaults to the Python based in-memory implementation.
self.client = QdrantClient(location=":memory:")

return self

def clear(self) -> None:
super().clear()

if not self.client.collection_exists(self.collection_name):
return

self.client.delete(
collection_name=self.collection_name,
points_selector=models.Filter(must=[]),
wait=True,
)
self._point_ids = None

def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None:
super().add_texts_and_embeddings(texts)

texts_list = list(texts)

if texts_list and not self.client.collection_exists(self.collection_name):
params = models.VectorParams(
size=len(texts_list[0].embedding), distance=models.Distance.COSINE # type: ignore[arg-type]
)
self.client.create_collection(
self.collection_name,
vectors_config=(
{self.vector_name: params} if self.vector_name else params
),
)

ids, payloads, vectors = [], [], []
for text in texts_list:
# Entries with same IDs are overwritten.
# We generate deterministic UUIDs based on the embedding vectors.
ids.append(uuid.uuid5(uuid.NAMESPACE_URL, str(text.embedding)).hex)
payloads.append(text.model_dump(exclude={"embedding"}))
vectors.append(
{self.vector_name: text.embedding}
if self.vector_name
else text.embedding
)

self.client.upload_collection(
collection_name=self.collection_name,
vectors=vectors,
payload=payloads,
wait=True,
ids=ids,
)
self._point_ids = set(ids)

async def similarity_search(
self, query: str, k: int, embedding_model: EmbeddingModel
) -> tuple[Sequence[Embeddable], list[float]]:
if not self.client.collection_exists(self.collection_name):
return ([], [])

embedding_model.set_mode(EmbeddingModes.QUERY)
np_query = np.array((await embedding_model.embed_documents([query]))[0])
embedding_model.set_mode(EmbeddingModes.DOCUMENT)

points = self.client.query_points(
collection_name=self.collection_name,
query=np_query,
using=self.vector_name,
limit=k,
with_vectors=True,
with_payload=True,
).points

return (
[
Text(
**p.payload,
embedding=(
p.vector[self.vector_name] if self.vector_name else p.vector
),
)
for p in points
],
[p.score for p in points],
)


def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:
"""
Factory function to create an appropriate EmbeddingModel based on the embedding string.
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ datasets = [
dev = [
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"paper-qa[datasets,ldp,typing,zotero,local]",
"paper-qa[datasets,ldp,typing,zotero,local,qdrant]",
"pre-commit>=3.4", # Pin to keep recent
"pydantic~=2.0",
"pylint-pydantic",
Expand All @@ -83,6 +83,9 @@ ldp = [
local = [
"sentence-transformers",
]
qdrant = [
"qdrant-client",
]
typing = [
"pandas-stubs",
"types-PyYAML",
Expand Down
25 changes: 17 additions & 8 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Docs,
NumpyVectorStore,
PQASession,
QdrantVectorStore,
Settings,
Text,
print_callback,
Expand All @@ -35,6 +36,7 @@
LiteLLMEmbeddingModel,
LLMModel,
SparseEmbeddingModel,
VectorStore,
)
from paperqa.prompts import CANNOT_ANSWER_PHRASE
from paperqa.prompts import qa_prompt as default_qa_prompt
Expand Down Expand Up @@ -624,21 +626,24 @@ def test_duplicate(stub_data_dir: Path) -> None:
), "Unique documents should be hashed as unique"


def test_docs_with_custom_embedding(subtests: SubTests, stub_data_dir: Path) -> None:
@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore])
def test_docs_with_custom_embedding(
subtests: SubTests, stub_data_dir: Path, vector_store: type[VectorStore]
) -> None:
class MyEmbeds(EmbeddingModel):
name: str = "my_embed"

async def embed_documents(self, texts):
return [[1, 2, 3] for _ in texts]
return [[0.0, 0.28, 0.95] for _ in texts]

docs = Docs(texts_index=NumpyVectorStore())
docs = Docs(texts_index=vector_store())
docs.add(
stub_data_dir / "bates.txt",
citation="WikiMedia Foundation, 2023, Accessed now",
embedding_model=MyEmbeds(),
)
with subtests.test(msg="confirm-embedding"):
assert docs.texts[0].embedding == [1, 2, 3]
assert docs.texts[0].embedding == [0.0, 0.28, 0.95]

with subtests.test(msg="copying-before-get-evidence"):
# Before getting evidence, shallow and deep copies are the same
Expand All @@ -647,6 +652,7 @@ async def embed_documents(self, texts):
**docs.model_dump(exclude={"texts_index"}),
)
docs_deep_copy = deepcopy(docs)

assert (
docs.texts_index
== docs_shallow_copy.texts_index
Expand All @@ -664,12 +670,14 @@ async def embed_documents(self, texts):
**docs.model_dump(exclude={"texts_index"}),
)
docs_deep_copy = deepcopy(docs)

assert docs.texts_index != docs_shallow_copy.texts_index
assert docs.texts_index == docs_deep_copy.texts_index


def test_sparse_embedding(stub_data_dir: Path) -> None:
docs = Docs(texts_index=NumpyVectorStore())
@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore])
def test_sparse_embedding(stub_data_dir: Path, vector_store: type[VectorStore]) -> None:
docs = Docs(texts_index=vector_store())
docs.add(
stub_data_dir / "bates.txt",
citation="WikiMedia Foundation, 2023, Accessed now",
Expand All @@ -686,11 +694,12 @@ def test_sparse_embedding(stub_data_dir: Path) -> None:
assert np.shape(docs.texts[0].embedding) == np.shape(docs.texts[1].embedding)


def test_hybrid_embedding(stub_data_dir: Path) -> None:
@pytest.mark.parametrize("vector_store", [NumpyVectorStore, QdrantVectorStore])
def test_hybrid_embedding(stub_data_dir: Path, vector_store: type[VectorStore]) -> None:
emb_model = HybridEmbeddingModel(
models=[LiteLLMEmbeddingModel(), SparseEmbeddingModel()]
)
docs = Docs(texts_index=NumpyVectorStore())
docs = Docs(texts_index=vector_store())
docs.add(
stub_data_dir / "bates.txt",
citation="WikiMedia Foundation, 2023, Accessed now",
Expand Down
Loading
Loading