-
Notifications
You must be signed in to change notification settings - Fork 21
/
_chroma.py
139 lines (120 loc) · 4.64 KB
/
_chroma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import uuid
import ragna
from ragna.core import (
Document,
Source,
)
from ._vector_database import VectorDatabaseSourceStorage
class Chroma(VectorDatabaseSourceStorage):
"""[Chroma vector database](https://www.trychroma.com/)
!!! info "Required packages"
- `chromadb>=0.4.13`
"""
# Note that this class has no extra requirements, since the chromadb package is
# already required for the base class.
def __init__(self) -> None:
super().__init__()
import chromadb
self._client = chromadb.Client(
chromadb.config.Settings(
is_persistent=True,
persist_directory=str(ragna.local_root() / "chroma"),
anonymized_telemetry=False,
)
)
def store(
self,
documents: list[Document],
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
collection = self._client.create_collection(
str(chat_id), embedding_function=self._embedding_function
)
ids = []
texts = []
metadatas = []
for document in documents:
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
ids.append(str(uuid.uuid4()))
texts.append(chunk.text)
metadatas.append(
{
"document_id": str(document.id),
"page_numbers": self._page_numbers_to_str(chunk.page_numbers),
"num_tokens": chunk.num_tokens,
}
)
collection.add(
ids=ids,
documents=texts,
metadatas=metadatas, # type: ignore[arg-type]
)
def retrieve(
self,
documents: list[Document],
prompt: str,
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
collection = self._client.get_collection(
str(chat_id), embedding_function=self._embedding_function
)
result = collection.query(
query_texts=prompt,
n_results=min(
# We cannot retrieve source by a maximum number of tokens. Thus, we
# estimate how many sources we have to query. We overestimate by a
# factor of two to avoid retrieving to few sources and needed to query
# again.
# ---
# FIXME: querying only a low number of documents can lead to not finding
# the most relevant one.
# See https://github.com/chroma-core/chroma/issues/1205 for details.
# Instead of just querying more documents here, we should use the
# appropriate index parameters when creating the collection. However,
# they are undocumented for now.
max(int(num_tokens * 2 / chunk_size), 100),
collection.count(),
),
include=["distances", "metadatas", "documents"],
)
num_results = len(result["ids"][0])
result = {
key: [None] * num_results if value is None else value[0] # type: ignore[index]
for key, value in result.items()
}
# dict of lists -> list of dicts
results = [
{key[:-1]: value[idx] for key, value in result.items()}
for idx in range(num_results)
]
# That should be the default, but let's make extra sure here
results = sorted(results, key=lambda r: r["distance"])
# TODO: we should have some functionality here to remove results with a high
# distance to keep only "valid" sources. However, there are two issues:
# 1. A "high distance" is fairly subjective
# 2. Whatever threshold we use is very much dependent on the encoding method
# Thus, we likely need to have a callable parameter for this class
document_map = {str(document.id): document for document in documents}
return self._take_sources_up_to_max_tokens(
(
Source(
id=result["id"],
document=document_map[result["metadata"]["document_id"]],
location=result["metadata"]["page_numbers"],
content=result["document"],
num_tokens=result["metadata"]["num_tokens"],
)
for result in results
),
max_tokens=num_tokens,
)