-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
chatreadretrieveread.py
370 lines (335 loc) · 16.8 KB
/
chatreadretrieveread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import json
import logging
import re
from typing import Any, AsyncGenerator, Optional, Union
import aiohttp
import openai
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import QueryType
from approaches.approach import Approach
from core.messagebuilder import MessageBuilder
from core.modelhelper import get_token_limit
from text import nonewlines
class ChatReadRetrieveReadApproach(Approach):
# Chat roles
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
NO_RESPONSE = "0"
"""
A multi-step approach that first uses OpenAI to turn the user's question into a search query,
then uses Azure AI Search to retrieve relevant documents, and then sends the conversation history,
original user question, and search results to OpenAI to generate a response.
"""
system_message_chat_conversation = """Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, for example [info1.txt]. Don't combine sources, list each source separately, for example [info1.txt][info2.pdf].
{follow_up_questions_prompt}
{injected_prompt}
"""
follow_up_questions_prompt_content = """Generate 3 very brief follow-up questions that the user would likely ask next.
Enclose the follow-up questions in double angle brackets. Example:
<<Are there exclusions for prescriptions?>>
<<Which pharmacies can be ordered from?>>
<<What is the limit for over-the-counter medication?>>
Do no repeat questions that have already been asked.
Make sure the last question ends with ">>"."""
query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook.
You have access to an Azure AI Search index with 100's of documents.
Generate a search query based on the conversation and the new question.
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
Do not include any text inside [] or <<>> in the search query terms.
Do not include any special characters like '+'.
If the question is not in English, translate the question to English before generating the search query.
If you cannot generate a search query, return just the number 0.
"""
query_prompt_few_shots = [
{"role": USER, "content": "What are my health plans?"},
{"role": ASSISTANT, "content": "Show available health plans"},
{"role": USER, "content": "does my plan cover cardio?"},
{"role": ASSISTANT, "content": "Health plan cardio coverage"},
]
def __init__(
self,
search_client: SearchClient,
openai_host: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
chatgpt_model: str,
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
):
self.search_client = search_client
self.openai_host = openai_host
self.chatgpt_deployment = chatgpt_deployment
self.chatgpt_model = chatgpt_model
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
async def run_until_final_call(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
) -> tuple:
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
filter = self.build_filter(overrides, auth_claims)
original_user_query = history[-1]["content"]
user_query_request = "Generate search query for: " + original_user_query
functions = [
{
"name": "search_sources",
"description": "Retrieve sources from the Azure AI Search index",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
}
},
"required": ["search_query"],
},
}
]
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
messages = self.get_messages_from_history(
system_prompt=self.query_prompt_template,
model_id=self.chatgpt_model,
history=history,
user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - len(user_query_request),
few_shots=self.query_prompt_few_shots,
)
chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {}
chat_completion = await openai.ChatCompletion.acreate(
**chatgpt_args,
model=self.chatgpt_model,
messages=messages,
temperature=0.0,
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
n=1,
functions=functions,
function_call="auto",
)
query_text = self.get_search_query(chat_completion, original_user_query)
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
# If retrieval mode includes vectors, compute an embedding for the query
if has_vector:
embedding_args = {"deployment_id": self.embedding_deployment} if self.openai_host == "azure" else {}
embedding = await openai.Embedding.acreate(**embedding_args, model=self.embedding_model, input=query_text)
query_vector = embedding["data"][0]["embedding"]
else:
query_vector = None
# Only keep the text query if the retrieval mode uses text, otherwise drop it
if not has_text:
query_text = None
# Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
if overrides.get("semantic_ranker") and has_text:
r = await self.search_client.search(
query_text,
filter=filter,
query_type=QueryType.SEMANTIC,
query_language=self.query_language,
query_speller=self.query_speller,
semantic_configuration_name="default",
top=top,
query_caption="extractive|highlight-false" if use_semantic_captions else None,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None,
)
else:
r = await self.search_client.search(
query_text,
filter=filter,
top=top,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="embedding" if query_vector else None,
)
if use_semantic_captions:
results = [
doc[self.sourcepage_field] + ": " + nonewlines(" . ".join([c.text for c in doc["@search.captions"]]))
async for doc in r
]
else:
results = [doc[self.sourcepage_field] + ": " + nonewlines(doc[self.content_field]) async for doc in r]
content = "\n".join(results)
follow_up_questions_prompt = (
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else ""
)
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
# Allow client to replace the entire prompt, or to inject into the exiting prompt using >>>
prompt_override = overrides.get("prompt_template")
if prompt_override is None:
system_message = self.system_message_chat_conversation.format(
injected_prompt="", follow_up_questions_prompt=follow_up_questions_prompt
)
elif prompt_override.startswith(">>>"):
system_message = self.system_message_chat_conversation.format(
injected_prompt=prompt_override[3:] + "\n", follow_up_questions_prompt=follow_up_questions_prompt
)
else:
system_message = prompt_override.format(follow_up_questions_prompt=follow_up_questions_prompt)
response_token_limit = 1024
messages_token_limit = self.chatgpt_token_limit - response_token_limit
messages = self.get_messages_from_history(
system_prompt=system_message,
model_id=self.chatgpt_model,
history=history,
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
user_content=original_user_query + "\n\nSources:\n" + content,
max_tokens=messages_token_limit,
)
msg_to_display = "\n\n".join([str(message) for message in messages])
extra_info = {
"data_points": results,
"thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>"
+ msg_to_display.replace("\n", "<br>"),
}
chat_coroutine = openai.ChatCompletion.acreate(
**chatgpt_args,
model=self.chatgpt_model,
messages=messages,
temperature=overrides.get("temperature") or 0.7,
max_tokens=response_token_limit,
n=1,
stream=should_stream,
)
return (extra_info, chat_coroutine)
async def run_without_streaming(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> dict[str, Any]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=False
)
chat_resp = dict(await chat_coroutine)
chat_resp["choices"][0]["context"] = extra_info
if overrides.get("suggest_followup_questions"):
content, followup_questions = self.extract_followup_questions(chat_resp["choices"][0]["message"]["content"])
chat_resp["choices"][0]["message"]["content"] = content
chat_resp["choices"][0]["context"]["followup_questions"] = followup_questions
chat_resp["choices"][0]["session_state"] = session_state
return chat_resp
async def run_with_streaming(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> AsyncGenerator[dict, None]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=True
)
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"context": extra_info,
"session_state": session_state,
"finish_reason": None,
"index": 0,
}
],
"object": "chat.completion.chunk",
}
followup_questions_started = False
followup_content = ""
async for event in await chat_coroutine:
# "2023-07-01-preview" API version has a bug where first response has empty choices
if event["choices"]:
# if event contains << and not >>, it is start of follow-up question, truncate
content = event["choices"][0]["delta"].get("content", "")
if overrides.get("suggest_followup_questions") and "<<" in content:
followup_questions_started = True
earlier_content = content[: content.index("<<")]
if earlier_content:
event["choices"][0]["delta"]["content"] = earlier_content
yield event
followup_content += content[content.index("<<") :]
elif followup_questions_started:
followup_content += content
else:
yield event
if followup_content:
_, followup_questions = self.extract_followup_questions(followup_content)
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"context": {"followup_questions": followup_questions},
"finish_reason": None,
"index": 0,
}
],
"object": "chat.completion.chunk",
}
async def run(
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
overrides = context.get("overrides", {})
auth_claims = context.get("auth_claims", {})
if stream is False:
# Workaround for: https://github.com/openai/openai-python/issues/371
async with aiohttp.ClientSession() as s:
openai.aiosession.set(s)
response = await self.run_without_streaming(messages, overrides, auth_claims, session_state)
return response
else:
return self.run_with_streaming(messages, overrides, auth_claims, session_state)
def get_messages_from_history(
self,
system_prompt: str,
model_id: str,
history: list[dict[str, str]],
user_content: str,
max_tokens: int,
few_shots=[],
) -> list:
message_builder = MessageBuilder(system_prompt, model_id)
# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
for shot in reversed(few_shots):
message_builder.insert_message(shot.get("role"), shot.get("content"))
append_index = len(few_shots) + 1
message_builder.insert_message(self.USER, user_content, index=append_index)
total_token_count = message_builder.count_tokens_for_message(message_builder.messages[-1])
newest_to_oldest = list(reversed(history[:-1]))
for message in newest_to_oldest:
potential_message_count = message_builder.count_tokens_for_message(message)
if (total_token_count + potential_message_count) > max_tokens:
logging.debug("Reached max tokens of %d, history will be truncated", max_tokens)
break
message_builder.insert_message(message["role"], message["content"], index=append_index)
total_token_count += potential_message_count
return message_builder.messages
def get_search_query(self, chat_completion: dict[str, Any], user_query: str):
response_message = chat_completion["choices"][0]["message"]
if function_call := response_message.get("function_call"):
if function_call["name"] == "search_sources":
arg = json.loads(function_call["arguments"])
search_query = arg.get("search_query", self.NO_RESPONSE)
if search_query != self.NO_RESPONSE:
return search_query
elif query_text := response_message.get("content"):
if query_text.strip() != self.NO_RESPONSE:
return query_text
return user_query
def extract_followup_questions(self, content: str):
return content.split("<<")[0], re.findall(r"<<([^>>]+)>>", content)