forked from microsoft/TaskWeaver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdocument_retriever.py
116 lines (101 loc) · 4.51 KB
/
document_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os
import pickle
from taskweaver.plugin import Plugin, register_plugin
@register_plugin
class DocumentRetriever(Plugin):
vectorstore = None
def _init(self):
try:
import tiktoken
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
except ImportError:
raise ImportError("Please install langchain-community first.")
self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
self.vectorstore = FAISS.load_local(
folder_path=self.config.get("index_folder"),
embeddings=self.embeddings,
)
with open(
os.path.join(
self.config.get("index_folder"),
"chunk_id_to_index.pkl",
),
"rb",
) as f:
self.chunk_id_to_index = pickle.load(f)
self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
def __call__(self, query: str, size: int = 5, target_length: int = 256):
if self.vectorstore is None:
self._init()
result = self.vectorstore.similarity_search(
query=query,
k=size,
)
expanded_chunks = self.do_expand(result, target_length)
return f"DocumentRetriever has done searching for `{query}`.\n" + self.ctx.wrap_text_with_delimiter_temporal(
"\n```json\n" + json.dumps(expanded_chunks, indent=4) + "```\n",
)
def do_expand(self, result, target_length):
expanded_chunks = []
# do expansion
for r in result:
source = r.metadata["source"]
chunk_id = r.metadata["chunk_id"]
content = r.page_content
expanded_result = content
left_chunk_id, right_chunk_id = chunk_id - 1, chunk_id + 1
left_valid, right_valid = True, True
chunk_ids = [chunk_id]
while True:
current_length = len(self.enc.encode(expanded_result))
if f"{source}_{left_chunk_id}" in self.chunk_id_to_index:
chunk_ids.append(left_chunk_id)
left_chunk_index = self.vectorstore.index_to_docstore_id[
self.chunk_id_to_index[f"{source}_{left_chunk_id}"]
]
left_chunk = self.vectorstore.docstore.search(left_chunk_index)
encoded_left_chunk = self.enc.encode(left_chunk.page_content)
if len(encoded_left_chunk) + current_length < target_length:
expanded_result = left_chunk.page_content + expanded_result
left_chunk_id -= 1
current_length += len(encoded_left_chunk)
else:
expanded_result += self.enc.decode(
encoded_left_chunk[-(target_length - current_length) :],
)
current_length = target_length
break
else:
left_valid = False
if f"{source}_{right_chunk_id}" in self.chunk_id_to_index:
chunk_ids.append(right_chunk_id)
right_chunk_index = self.vectorstore.index_to_docstore_id[
self.chunk_id_to_index[f"{source}_{right_chunk_id}"]
]
right_chunk = self.vectorstore.docstore.search(right_chunk_index)
encoded_right_chunk = self.enc.encode(right_chunk.page_content)
if len(encoded_right_chunk) + current_length < target_length:
expanded_result += right_chunk.page_content
right_chunk_id += 1
current_length += len(encoded_right_chunk)
else:
expanded_result += self.enc.decode(
encoded_right_chunk[: target_length - current_length],
)
current_length = target_length
break
else:
right_valid = False
if not left_valid and not right_valid:
break
expanded_chunks.append(
{
"chunk": expanded_result,
"metadata": r.metadata,
# "length": current_length,
# "chunk_ids": chunk_ids
},
)
return expanded_chunks