-
Notifications
You must be signed in to change notification settings - Fork 56
/
ChatService.cs
321 lines (226 loc) · 11.8 KB
/
ChatService.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
using Humanizer.Localisation.TimeToClockNotation;
using Search.Constants;
using SharedLib.Models;
using SharedLib.Services;
using SharpToken;
using Microsoft.ML.Tokenizers;
namespace Search.Services;
public class ChatService
{
/// <summary>
/// All data is cached in the _sessions List object.
/// </summary>
private static List<Session> _sessions = new();
private readonly OpenAiService _openAiService;
private readonly MongoDbService _mongoDbService;
private readonly int _maxConversationTokens;
private readonly int _maxCompletionTokens;
private readonly ILogger _logger;
public ChatService(OpenAiService openAiService, MongoDbService mongoDbService, ILogger logger)
{
_openAiService = openAiService;
_mongoDbService = mongoDbService;
_maxConversationTokens = openAiService.MaxConversationTokens;
_maxCompletionTokens = openAiService.MaxCompletionTokens;
_logger = logger;
}
/// <summary>
/// Returns list of chat session ids and names for left-hand nav to bind to (display Name and ChatSessionId as hidden)
/// </summary>
public async Task<List<Session>> GetAllChatSessionsAsync()
{
return _sessions = await _mongoDbService.GetSessionsAsync();
}
/// <summary>
/// Returns the chat messages to display on the main web page when the user selects a chat from the left-hand nav
/// </summary>
public async Task<List<Message>> GetChatSessionMessagesAsync(string? sessionId)
{
ArgumentNullException.ThrowIfNull(sessionId);
List<Message> chatMessages = new();
if (_sessions.Count == 0)
{
return Enumerable.Empty<Message>().ToList();
}
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
if (_sessions[index].Messages.Count == 0)
{
// Messages are not cached, go read from database
chatMessages = await _mongoDbService.GetSessionMessagesAsync(sessionId);
// Cache results
_sessions[index].Messages = chatMessages;
}
else
{
// Load from cache
chatMessages = _sessions[index].Messages;
}
return chatMessages;
}
/// <summary>
/// User creates a new Chat Session.
/// </summary>
public async Task CreateNewChatSessionAsync()
{
Session session = new();
_sessions.Add(session);
await _mongoDbService.InsertSessionAsync(session);
}
/// <summary>
/// Rename the Chat Ssssion from "New Chat" to the summary provided by OpenAI
/// </summary>
public async Task RenameChatSessionAsync(string? sessionId, string newChatSessionName)
{
ArgumentNullException.ThrowIfNull(sessionId);
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
_sessions[index].Name = newChatSessionName;
await _mongoDbService.UpdateSessionAsync(_sessions[index]);
}
/// <summary>
/// User deletes a chat session
/// </summary>
public async Task DeleteChatSessionAsync(string? sessionId)
{
ArgumentNullException.ThrowIfNull(sessionId);
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
_sessions.RemoveAt(index);
await _mongoDbService.DeleteSessionAndMessagesAsync(sessionId);
}
/// <summary>
/// Receive a prompt from a user, Vectorize it from _openAIService Get a completion from _openAiService
/// </summary>
public async Task<string> GetChatCompletionAsync(string? sessionId, string userPrompt, string collectionName)
{
try
{
ArgumentNullException.ThrowIfNull(sessionId);
//Get embeddings for user prompt and number of tokens it uses.
(float[] promptVectors, int promptTokens) = await _openAiService.GetEmbeddingsAsync(sessionId, userPrompt);
//Create the prompt message object. Created here to give it a timestamp that precedes the completion message.
Message promptMessage = new Message(sessionId, nameof(Participants.User), promptTokens, default, userPrompt);
//Do vector search on the user prompt, return list of documents
string retrievedDocuments = await _mongoDbService.VectorSearchAsync(collectionName, promptVectors);
//Get the most recent conversation history up to _maxConversationTokens
string conversation = GetConversationHistory(sessionId);
//Construct our prompts sent to Azure OpenAI. Calculate token usage and trim the RAG payload and conversation history to prevent exceeding token limits.
(string augmentedContent, string conversationAndUserPrompt) = BuildPrompts(userPrompt, conversation, retrievedDocuments);
//Generate the completion from Azure OpenAI to return to the user
(string completionText, int ragTokens, int completionTokens) = await _openAiService.GetChatCompletionAsync(sessionId, conversationAndUserPrompt, augmentedContent);
//Create the completion message object
Message completionMessage = new Message(sessionId, nameof(Participants.Assistant), completionTokens, ragTokens, completionText);
//Add the user prompt and completion to cache, then persist to Cosmos in a transaction
await AddPromptCompletionMessagesAsync(sessionId, promptMessage, completionMessage);
return completionText;
}
catch (Exception ex)
{
string message = $"ChatService.GetChatCompletionAsync(): {ex.Message}";
_logger.LogError(message);
throw;
}
}
/// <summary>
/// Estimate the token usage for OpenAI completion to prevent exceeding the OpenAI model's maximum token limit. This function estimates the
/// amount of tokens the vector search result data and the user prompt will consume. If the search result data exceeds the configured amount
/// the function reduces the number of vectors, reducing the amount of data sent.
/// </summary>
private (string augmentedContent, string conversationAndUserPrompt) BuildPrompts(string userPrompt, string conversation, string retrievedData)
{
string updatedAugmentedContent = "";
string updatedConversationAndUserPrompt = "";
//SharpToken only estimates token usage and often undercounts. Add a buffer of 200 tokens.
int bufferTokens = 200;
//Create a new instance of SharpToken
var encoding = GptEncoding.GetEncoding("cl100k_base"); //encoding base for GPT 3.5 Turbo and GPT 4
//var encoding = GptEncoding.GetEncodingForModel("gpt-35-turbo");
List<int> ragVectors = encoding.Encode(retrievedData);
int ragTokens = ragVectors.Count;
List<int> convVectors = encoding.Encode(conversation);
int convTokens = convVectors.Count;
int userPromptTokens = encoding.Encode(userPrompt).Count;
//If RAG data plus user prompt, plus conversation, plus tokens for completion is greater than max completion tokens we've defined, reduce the rag data and conversation by relative amount.
int totalTokens = ragTokens + convTokens + userPromptTokens + bufferTokens;
//Too much data, reduce the rag data and conversation data by the same percentage. Do not reduce the user prompt as this is required for the completion.
if (totalTokens > _maxCompletionTokens)
{
//Get the number of tokens to reduce by
int tokensToReduce = totalTokens - _maxCompletionTokens;
//Get the percentage of tokens to reduce by
float ragTokenPct = (float)ragTokens / totalTokens;
float conTokenPct = (float)convTokens / totalTokens;
//Calculate the new number of tokens for each data set
int newRagTokens = (int)Math.Round(ragTokens - (ragTokenPct * tokensToReduce), 0);
int newConvTokens = (int)Math.Round(convTokens - (conTokenPct * tokensToReduce), 0);
//Get the reduced set of RAG vectors
List<int> trimmedRagVectors = ragVectors.GetRange(0, newRagTokens);
//Convert the vectors back to text
updatedAugmentedContent = encoding.Decode(trimmedRagVectors);
int offset = convVectors.Count - newConvTokens;
//Get the reduce set of conversation vectors
List<int> trimmedConvVectors = convVectors.GetRange(offset, newConvTokens);
//Convert vectors back into reduced conversation length
updatedConversationAndUserPrompt = encoding.Decode(trimmedConvVectors);
//add user prompt
updatedConversationAndUserPrompt += Environment.NewLine + userPrompt;
}
//If everything is less than _maxCompletionTokens then good to go.
else
{
//Return all of the content
updatedAugmentedContent = retrievedData;
updatedConversationAndUserPrompt = conversation + Environment.NewLine + userPrompt;
}
return (augmentedContent: updatedAugmentedContent, conversationAndUserPrompt: updatedConversationAndUserPrompt);
}
/// <summary>
/// Get the most recent conversation history to provide additional context for the completion LLM
/// </summary>
private string GetConversationHistory(string sessionId)
{
int? tokensUsed = 0;
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
List<Message> conversationMessages = _sessions[index].Messages.ToList(); //make a full copy
//Iterate through these in reverse order to get the most recent conversation history up to _maxConversationTokens
var trimmedMessages = conversationMessages
.OrderByDescending(m => m.TimeStamp)
.TakeWhile(m => (tokensUsed += m.Tokens) <= _maxConversationTokens)
.Select(m => m.Text)
.ToList();
trimmedMessages.Reverse();
//Return as a string
string conversation = string.Join(Environment.NewLine, trimmedMessages.ToArray());
return conversation;
}
public async Task<string> SummarizeChatSessionNameAsync(string? sessionId, string prompt)
{
ArgumentNullException.ThrowIfNull(sessionId);
string response = await _openAiService.SummarizeAsync(sessionId, prompt);
await RenameChatSessionAsync(sessionId, response);
return response;
}
/// <summary>
/// Add user prompt to the chat session message list object and insert into the data service.
/// </summary>
private async Task AddPromptMessageAsync(string sessionId, string promptText)
{
Message promptMessage = new(sessionId, nameof(Participants.User), default, default, promptText);
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
_sessions[index].AddMessage(promptMessage);
await _mongoDbService.InsertMessageAsync(promptMessage);
}
/// <summary>
/// Add user prompt and AI assistance response to the chat session message list object and insert into the data service as a transaction.
/// </summary>
private async Task AddPromptCompletionMessagesAsync(string sessionId, Message promptMessage, Message completionMessage)
{
int index = _sessions.FindIndex(s => s.SessionId == sessionId);
//Add prompt and completion to the cache
_sessions[index].AddMessage(promptMessage);
_sessions[index].AddMessage(completionMessage);
//Update session cache with tokens used
_sessions[index].TokensUsed += promptMessage.Tokens;
_sessions[index].TokensUsed += completionMessage.PromptTokens;
_sessions[index].TokensUsed += completionMessage.Tokens;
await _mongoDbService.UpsertSessionBatchAsync(session: _sessions[index], promptMessage: promptMessage, completionMessage: completionMessage);
}
}