/
ReadRetrieveReadChatService.cs
247 lines (215 loc) · 10.6 KB
/
ReadRetrieveReadChatService.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
// Copyright (c) Microsoft. All rights reserved.
using Azure.Core;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Embeddings;
namespace MinimalApi.Services;
#pragma warning disable SKEXP0011 // Mark members as static
#pragma warning disable SKEXP0001 // Mark members as static
public class ReadRetrieveReadChatService
{
private readonly ISearchService _searchClient;
private readonly Kernel _kernel;
private readonly IConfiguration _configuration;
private readonly IComputerVisionService? _visionService;
private readonly TokenCredential? _tokenCredential;
public ReadRetrieveReadChatService(
ISearchService searchClient,
OpenAIClient client,
IConfiguration configuration,
IComputerVisionService? visionService = null,
TokenCredential? tokenCredential = null)
{
_searchClient = searchClient;
var kernelBuilder = Kernel.CreateBuilder();
if (configuration["UseAOAI"] == "false")
{
var deployment = configuration["OpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployment);
kernelBuilder = kernelBuilder.AddOpenAIChatCompletion(deployment, client);
var embeddingModelName = configuration["OpenAiEmbeddingDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(embeddingModelName);
kernelBuilder = kernelBuilder.AddOpenAITextEmbeddingGeneration(embeddingModelName, client);
}
else
{
var deployedModelName = configuration["AzureOpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployedModelName);
var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"];
if (!string.IsNullOrEmpty(embeddingModelName))
{
var endpoint = configuration["AzureOpenAiServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(endpoint);
kernelBuilder = kernelBuilder.AddAzureOpenAITextEmbeddingGeneration(embeddingModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
kernelBuilder = kernelBuilder.AddAzureOpenAIChatCompletion(deployedModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
}
}
_kernel = kernelBuilder.Build();
_configuration = configuration;
_visionService = visionService;
_tokenCredential = tokenCredential;
}
public async Task<ChatAppResponse> ReplyAsync(
ChatMessage[] history,
RequestOverrides? overrides,
CancellationToken cancellationToken = default)
{
var top = overrides?.Top ?? 3;
var useSemanticCaptions = overrides?.SemanticCaptions ?? false;
var useSemanticRanker = overrides?.SemanticRanker ?? false;
var excludeCategory = overrides?.ExcludeCategory ?? null;
var filter = excludeCategory is null ? null : $"category ne '{excludeCategory}'";
var chat = _kernel.GetRequiredService<IChatCompletionService>();
var embedding = _kernel.GetRequiredService<ITextEmbeddingGenerationService>();
float[]? embeddings = null;
var question = history.LastOrDefault(m => m.IsUser)?.Content is { } userQuestion
? userQuestion
: throw new InvalidOperationException("Use question is null");
string[]? followUpQuestionList = null;
if (overrides?.RetrievalMode != RetrievalMode.Text && embedding is not null)
{
embeddings = (await embedding.GenerateEmbeddingAsync(question, cancellationToken: cancellationToken)).ToArray();
}
// step 1
// use llm to get query if retrieval mode is not vector
string? query = null;
if (overrides?.RetrievalMode != RetrievalMode.Vector)
{
var getQueryChat = new ChatHistory(@"You are a helpful AI assistant, generate search query for followup question.
Make your respond simple and precise. Return the query only, do not return any other text.
e.g.
Northwind Health Plus AND standard plan.
standard plan AND dental AND employee benefit.
");
getQueryChat.AddUserMessage(question);
var result = await chat.GetChatMessageContentAsync(
getQueryChat,
cancellationToken: cancellationToken);
query = result.Content ?? throw new InvalidOperationException("Failed to get search query");
}
// step 2
// use query to search related docs
var documentContentList = await _searchClient.QueryDocumentsAsync(query, embeddings, overrides, cancellationToken);
string documentContents = string.Empty;
if (documentContentList.Length == 0)
{
documentContents = "no source available.";
}
else
{
documentContents = string.Join("\r", documentContentList.Select(x =>$"{x.Title}:{x.Content}"));
}
// step 2.5
// retrieve images if _visionService is available
SupportingImageRecord[]? images = default;
if (_visionService is not null)
{
var queryEmbeddings = await _visionService.VectorizeTextAsync(query ?? question, cancellationToken);
images = await _searchClient.QueryImagesAsync(query, queryEmbeddings.vector, overrides, cancellationToken);
}
// step 3
// put together related docs and conversation history to generate answer
var answerChat = new ChatHistory(
"You are a system assistant who helps the company employees with their questions. Be brief in your answers");
// add chat history
foreach (var message in history)
{
if (message.IsUser)
{
answerChat.AddUserMessage(message.Content);
}
else
{
answerChat.AddAssistantMessage(message.Content);
}
}
if (images != null)
{
var prompt = @$"## Source ##
{documentContents}
## End ##
Answer question based on available source and images.
Your answer needs to be a json object with answer and thoughts field.
Don't put your answer between ```json and ```, return the json string directly. e.g {{""answer"": ""I don't know"", ""thoughts"": ""I don't know""}}";
var tokenRequestContext = new TokenRequestContext(new[] { "https://storage.azure.com/.default" });
var sasToken = await (_tokenCredential?.GetTokenAsync(tokenRequestContext, cancellationToken) ?? throw new InvalidOperationException("Failed to get token"));
var sasTokenString = sasToken.Token;
var imageUrls = images.Select(x => $"{x.Url}?{sasTokenString}").ToArray();
var collection = new ChatMessageContentItemCollection();
collection.Add(new TextContent(prompt));
foreach (var imageUrl in imageUrls)
{
collection.Add(new ImageContent(new Uri(imageUrl)));
}
answerChat.AddUserMessage(collection);
}
else
{
var prompt = @$" ## Source ##
{documentContents}
## End ##
You answer needs to be a json object with the following format.
{{
""answer"": // the answer to the question, add a source reference to the end of each sentence. e.g. Apple is a fruit [reference1.pdf][reference2.pdf]. If no source available, put the answer as I don't know.
""thoughts"": // brief thoughts on how you came up with the answer, e.g. what sources you used, what you thought about, etc.
}}";
answerChat.AddUserMessage(prompt);
}
var promptExecutingSetting = new OpenAIPromptExecutionSettings
{
MaxTokens = 1024,
Temperature = overrides?.Temperature ?? 0.7,
StopSequences = [],
};
// get answer
var answer = await chat.GetChatMessageContentAsync(
answerChat,
promptExecutingSetting,
cancellationToken: cancellationToken);
var answerJson = answer.Content ?? throw new InvalidOperationException("Failed to get search query");
var answerObject = JsonSerializer.Deserialize<JsonElement>(answerJson);
var ans = answerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer");
var thoughts = answerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts");
// step 4
// add follow up questions if requested
if (overrides?.SuggestFollowupQuestions is true)
{
var followUpQuestionChat = new ChatHistory(@"You are a helpful AI assistant");
followUpQuestionChat.AddUserMessage($@"Generate three follow-up question based on the answer you just generated.
# Answer
{ans}
# Format of the response
Return the follow-up question as a json string list. Don't put your answer between ```json and ```, return the json string directly.
e.g.
[
""What is the deductible?"",
""What is the co-pay?"",
""What is the out-of-pocket maximum?""
]");
var followUpQuestions = await chat.GetChatMessageContentAsync(
followUpQuestionChat,
promptExecutingSetting,
cancellationToken: cancellationToken);
var followUpQuestionsJson = followUpQuestions.Content ?? throw new InvalidOperationException("Failed to get search query");
var followUpQuestionsObject = JsonSerializer.Deserialize<JsonElement>(followUpQuestionsJson);
var followUpQuestionsList = followUpQuestionsObject.EnumerateArray().Select(x => x.GetString()!).ToList();
foreach (var followUpQuestion in followUpQuestionsList)
{
ans += $" <<{followUpQuestion}>> ";
}
followUpQuestionList = followUpQuestionsList.ToArray();
}
var responseMessage = new ResponseMessage("assistant", ans);
var responseContext = new ResponseContext(
DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(),
DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(),
FollowupQuestions: followUpQuestionList ?? Array.Empty<string>(),
Thoughts: new[] { new Thoughts("Thoughts", thoughts) });
var choice = new ResponseChoice(
Index: 0,
Message: responseMessage,
Context: responseContext,
CitationBaseUrl: _configuration.ToCitationBaseUrl());
return new ChatAppResponse(new[] { choice });
}
}