-
Notifications
You must be signed in to change notification settings - Fork 304
/
OutputParserTool.py
96 lines (82 loc) · 3.53 KB
/
OutputParserTool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List
import logging
import re
import json
from .ParserBase import ParserBase
from ..common.SourceDocument import SourceDocument
logger = logging.getLogger(__name__)
class OutputParserTool(ParserBase):
def __init__(self) -> None:
self.name = "OutputParser"
def _clean_up_answer(self, answer):
return answer.replace(" ", " ")
def _get_source_docs_from_answer(self, answer):
# extract all [docN] from answer and extract N, and just return the N's as a list of ints
results = re.findall(r"\[doc(\d+)\]", answer)
return [int(i) for i in results]
def _replace_last(self, text, old, new):
"""Replaces the last occurence of a substring in a string
This is done by reversing the string using [::-1], replacing the first occurence of the reversed substring, and
reversing the string again.
"""
return (text[::-1].replace(old[::-1], new[::-1], 1))[::-1]
def _make_doc_references_sequential(self, answer, doc_ids):
for i, idx in enumerate(doc_ids):
answer = self._replace_last(answer, f"[doc{idx}]", f"[doc{i+1}]")
return answer
def parse(
self,
question: str,
answer: str,
source_documents: List[SourceDocument] = [],
**kwargs: dict,
) -> List[dict]:
answer = self._clean_up_answer(answer)
doc_ids = self._get_source_docs_from_answer(answer)
answer = self._make_doc_references_sequential(answer, doc_ids)
# create return message object
messages = [
{
"role": "tool",
"content": {"citations": [], "intent": question},
"end_turn": False,
}
]
for i in doc_ids:
idx = i - 1
if idx >= len(source_documents):
logger.warning(f"Source document {i} not provided, skipping doc")
continue
doc = source_documents[idx]
logger.debug(f"doc{idx}: {doc}")
# Then update the citation object in the response, it needs to have filepath and chunk_id to render in the UI as a file
messages[0]["content"]["citations"].append(
{
"content": doc.get_markdown_url() + "\n\n\n" + doc.content,
"id": doc.id,
"chunk_id": (
re.findall(r"\d+", doc.chunk_id)[-1]
if doc.chunk_id is not None
else doc.chunk
),
"title": doc.title,
"filepath": doc.get_filename(include_path=True),
"url": doc.get_markdown_url(),
"metadata": {
"offset": doc.offset,
"source": doc.source,
"markdown_url": doc.get_markdown_url(),
"title": doc.title,
"original_url": doc.source, # TODO: do we need this?
"chunk": doc.chunk,
"key": doc.id,
"filename": doc.get_filename(),
},
}
)
if messages[0]["content"]["citations"] == []:
answer = re.sub(r"\[doc\d+\]", "", answer)
messages.append({"role": "assistant", "content": answer, "end_turn": True})
# everything in content needs to be stringified to work with Azure BYOD frontend
messages[0]["content"] = json.dumps(messages[0]["content"])
return messages