-
Notifications
You must be signed in to change notification settings - Fork 353
/
post_prompt_tool.py
66 lines (58 loc) · 2.15 KB
/
post_prompt_tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.callbacks import get_openai_callback
from ..common.answer import Answer
from ..helpers.llm_helper import LLMHelper
from ..helpers.config.config_helper import ConfigHelper
class PostPromptTool:
def __init__(self) -> None:
pass
def validate_answer(self, answer: Answer) -> Answer:
config = ConfigHelper.get_active_config_or_default()
llm_helper = LLMHelper()
was_message_filtered = False
post_answering_prompt = PromptTemplate(
template=config.prompts.post_answering_prompt,
input_variables=["question", "answer", "sources"],
)
post_answering_chain = LLMChain(
llm=llm_helper.get_llm(),
prompt=post_answering_prompt,
output_key="correct",
verbose=True,
)
sources = "\n".join(
[
f"[doc{i+1}]: {source.content}"
for i, source in enumerate(answer.source_documents)
]
)
with get_openai_callback() as cb:
post_result = post_answering_chain(
{
"question": answer.question,
"answer": answer.answer,
"sources": sources,
}
)
was_message_filtered = not (
post_result["correct"].lower() == "true"
or post_result["correct"].lower() == "yes"
)
# Return filtered answer or just the original one
if was_message_filtered:
return Answer(
question=answer.question,
answer=config.messages.post_answering_filter,
source_documents=[],
prompt_tokens=cb.prompt_tokens,
completion_tokens=cb.completion_tokens,
)
else:
return Answer(
question=answer.question,
answer=answer.answer,
source_documents=answer.source_documents,
prompt_tokens=cb.prompt_tokens,
completion_tokens=cb.completion_tokens,
)