Skip to content

Commit

Permalink
馃Чfix: Handle Abort Message Edge Cases (danny-avila#1462)
Browse files Browse the repository at this point in the history
* chore: bump langchain to v0.0.213 from v0.0.186

* fix: handle abort edge cases:
- abort message server-side if response experienced error mid-generation
- attempt to recover message if aborting resulted in error
- if abortKey is not provided, use conversationId if it exists
- if headers were already sent, send an Event stream message
- issue warning for possible Google censor/filter

refactor(streamResponse): for `sendError`, allow passing overrides so that error can include partial generation, improve typing for `sendMessage`

* chore(MessageContent): remove eslint warning for unused `i`, rephrase unfinished message text

* fix(useSSE): avoid invoking cancelHandler if the abort response was 404

* chore(TMessage): remove unnecessary, unused legacy message property `submitting`

* chore(TMessage): remove unnecessary legacy message property `cancelled`

* chore(abortMiddleware): remove unused `errorText` property to avoid confusion
  • Loading branch information
danny-avila committed Dec 30, 2023
1 parent a9bef52 commit 7a2ad12
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 32 deletions.
2 changes: 1 addition & 1 deletion app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ class BaseClient {
}

async saveMessageToDatabase(message, endpointOptions, user = null) {
await saveMessage({ ...message, user, unfinished: false, cancelled: false });
await saveMessage({ ...message, user, unfinished: false });
await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
Expand Down
1 change: 0 additions & 1 deletion app/clients/prompts/formatMessages.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ describe('formatMessage', () => {
_id: '6512cdfb92cbf69fea615331',
messageId: 'b620bf73-c5c3-4a38-b724-76886aac24c4',
__v: 0,
cancelled: false,
conversationId: '5c23d24f-941f-4aab-85df-127b596c8aa5',
createdAt: Date.now(),
error: false,
Expand Down
2 changes: 0 additions & 2 deletions models/Message.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ module.exports = {
isCreatedByUser = false,
error,
unfinished,
cancelled,
files,
isEdited = false,
finish_reason = null,
Expand All @@ -45,7 +44,6 @@ module.exports = {
finish_reason,
error,
unfinished,
cancelled,
tokenCount,
plugin,
plugins,
Expand Down
4 changes: 0 additions & 4 deletions models/schema/messageSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ const messageSchema = mongoose.Schema(
type: Boolean,
default: false,
},
cancelled: {
type: Boolean,
default: false,
},
error: {
type: Boolean,
default: false,
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"keyv": "^4.5.4",
"keyv-file": "^0.2.0",
"klona": "^2.0.6",
"langchain": "^0.0.186",
"langchain": "^0.0.213",
"librechat-data-provider": "*",
"lodash": "^4.17.21",
"meilisearch": "^0.33.0",
Expand Down
1 change: 0 additions & 1 deletion server/controllers/AskController.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
text: partialText,
model: client.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
user,
});
Expand Down
1 change: 0 additions & 1 deletion server/controllers/EditController.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ const EditController = async (req, res, next, initializeClient) => {
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
user,
Expand Down
37 changes: 30 additions & 7 deletions server/middleware/abortMiddleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@ const spendTokens = require('~/models/spendTokens');
const { logger } = require('~/config');

async function abortMessage(req, res) {
const { abortKey } = req.body;
let { abortKey, conversationId } = req.body;

if (!abortKey && conversationId) {
abortKey = conversationId;
}

if (!abortControllers.has(abortKey) && !res.headersSent) {
return res.status(404).send({ message: 'Request not found' });
}

const { abortController } = abortControllers.get(abortKey);
const ret = await abortController.abortCompletion();
const finalEvent = await abortController.abortCompletion();
logger.debug('[abortMessage] Aborted request', { abortKey });
abortControllers.delete(abortKey);
res.send(JSON.stringify(ret));

if (res.headersSent && finalEvent) {
return sendMessage(res, finalEvent);
}

res.send(JSON.stringify(finalEvent));
}

const handleAbort = () => {
Expand Down Expand Up @@ -58,7 +67,6 @@ const createAbortController = (req, res, getAbortData) => {
finish_reason: 'incomplete',
model: endpointOption.modelOptions.model,
unfinished: false,
cancelled: true,
error: false,
isCreatedByUser: false,
tokenCount: completionTokens,
Expand All @@ -84,10 +92,16 @@ const createAbortController = (req, res, getAbortData) => {
};

const handleAbortError = async (res, req, error, data) => {
logger.error('[handleAbortError] response error and aborting request', error);
logger.error('[handleAbortError] AI response error; aborting request:', error);
const { sender, conversationId, messageId, parentMessageId, partialText } = data;

const respondWithError = async () => {
if (error.stack && error.stack.includes('google')) {
logger.warn(
`AI Response error for conversation ${conversationId} likely caused by Google censor/filter`,
);
}

const respondWithError = async (partialText) => {
const options = {
sender,
messageId,
Expand All @@ -97,6 +111,15 @@ const handleAbortError = async (res, req, error, data) => {
shouldSaveMessage: true,
user: req.user.id,
};

if (partialText) {
options.overrideProps = {
error: false,
unfinished: true,
text: partialText,
};
}

const callback = async () => {
if (abortControllers.has(conversationId)) {
const { abortController } = abortControllers.get(conversationId);
Expand All @@ -113,7 +136,7 @@ const handleAbortError = async (res, req, error, data) => {
return await abortMessage(req, res);
} catch (err) {
logger.error('[handleAbortError] error while trying to abort message', err);
return respondWithError();
return respondWithError(partialText);
}
} else {
return respondWithError();
Expand Down
3 changes: 0 additions & 3 deletions server/routes/ask/askChatGPTBrowser.js
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ const ask = async ({
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false,
isCreatedByUser: false,
user,
Expand Down Expand Up @@ -155,7 +154,6 @@ const ask = async ({
text: await handleText(response),
sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false,
cancelled: false,
error: false,
isCreatedByUser: false,
};
Expand Down Expand Up @@ -226,7 +224,6 @@ const ask = async ({
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true,
isCreatedByUser: false,
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
Expand Down
4 changes: 0 additions & 4 deletions server/routes/ask/bingAI.js
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ const ask = async ({
model,
text: text,
unfinished: true,
cancelled: false,
error: false,
isCreatedByUser: false,
user,
Expand Down Expand Up @@ -193,7 +192,6 @@ const ask = async ({
response.details.suggestedResponses &&
response.details.suggestedResponses.map((s) => s.text),
unfinished,
cancelled: false,
error: false,
isCreatedByUser: false,
};
Expand Down Expand Up @@ -263,7 +261,6 @@ const ask = async ({
text: partialText,
model,
unfinished: true,
cancelled: false,
error: false,
isCreatedByUser: false,
};
Expand All @@ -285,7 +282,6 @@ const ask = async ({
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message,
model,
Expand Down
1 change: 0 additions & 1 deletion server/routes/ask/gptPlugins.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
error: false,
plugins,
user,
Expand Down
1 change: 0 additions & 1 deletion server/routes/edit/gptPlugins.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
cancelled: false,
isEdited: true,
error: false,
user,
Expand Down
30 changes: 25 additions & 5 deletions server/utils/streamResponse.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const crypto = require('crypto');
const { saveMessage } = require('~/models/Message');
const { saveMessage, getMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');

/**
* Sends error data in Server Sent Events format and ends the response.
Expand All @@ -15,7 +16,7 @@ const handleError = (res, message) => {
* Sends message data in Server Sent Events format.
* @param {object} res - - The server response.
* @param {string} message - The message to be sent.
* @param {string} event - [Optional] The type of event. Default is 'message'.
* @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'.
*/
const sendMessage = (res, message, event = 'message') => {
if (message.length === 0) {
Expand All @@ -32,19 +33,27 @@ const sendMessage = (res, message, event = 'message') => {
* @param {function} callback - [Optional] The callback function to be executed.
*/
const sendError = async (res, options, callback) => {
const { user, sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } =
options;
const {
user,
sender,
conversationId,
messageId,
parentMessageId,
text,
shouldSaveMessage,
overrideProps = {},
} = options;
const errorMessage = {
sender,
messageId: messageId ?? crypto.randomUUID(),
conversationId,
parentMessageId,
unfinished: false,
cancelled: false,
error: true,
final: true,
text,
isCreatedByUser: false,
...overrideProps,
};
if (callback && typeof callback === 'function') {
await callback();
Expand All @@ -54,6 +63,17 @@ const sendError = async (res, options, callback) => {
await saveMessage({ ...errorMessage, user });
}

if (!errorMessage.error) {
const requestMessage = { messageId: parentMessageId, conversationId };
const query = await getMessages(requestMessage);
return sendMessage(res, {
final: true,
requestMessage: query?.[0] ? query[0] : requestMessage,
responseMessage: errorMessage,
conversation: await getConvo(user, conversationId),
});
}

handleError(res, errorMessage);
};

Expand Down

0 comments on commit 7a2ad12

Please sign in to comment.