Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: 192 persist chat history for each phase #201

Merged
merged 17 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions backend/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@ import session from "express-session";
import { getInitialDefences } from "./defence";
import { setOpenAiApiKey } from "./openai";
import { router } from "./router";
import { ChatCompletionRequestMessage } from "openai";
import { ChatHistoryMessage } from "./models/chat";
import { EmailInfo } from "./models/email";
import { DefenceInfo } from "./models/defence";
import { CHAT_MODELS } from "./models/chat";
import { PHASE_NAMES } from "./models/phase";
import { retrievalQAPrePrompt } from "./promptTemplates";

dotenv.config();

declare module "express-session" {
interface Session {
chatHistory: ChatCompletionRequestMessage[];
defences: DefenceInfo[];
openAiApiKey: string | null;
gptModel: CHAT_MODELS;
phaseState: PhaseState[];
numPhasesCompleted: number;
openAiApiKey: string | null;
}
interface PhaseState {
phase: PHASE_NAMES;
chatHistory: ChatHistoryMessage[];
defences: DefenceInfo[];
sentEmails: EmailInfo[];
}
}
Expand Down Expand Up @@ -65,12 +70,6 @@ app.use(

app.use(async (req, _res, next) => {
// initialise session variables
if (!req.session.chatHistory) {
req.session.chatHistory = [];
}
if (!req.session.defences) {
req.session.defences = getInitialDefences();
}
if (!req.session.gptModel) {
req.session.gptModel = defaultModel;
}
Expand All @@ -80,16 +79,28 @@ app.use(async (req, _res, next) => {
if (!req.session.openAiApiKey) {
req.session.openAiApiKey = process.env.OPENAI_API_KEY || null;
}
if (!req.session.sentEmails) {
req.session.sentEmails = [];
if (!req.session.phaseState) {
req.session.phaseState = [];
// add empty states for phases 0-3
Object.values(PHASE_NAMES).forEach((value) => {
if (isNaN(Number(value))) {
req.session.phaseState.push({
phase: value as PHASE_NAMES,
chatHistory: [],
defences: getInitialDefences(),
sentEmails: [],
});
}
});
console.log("Initialised phase state: ", req.session.phaseState);
}
next();
});

app.use("/", router);
app.listen(port, () => {
console.log("Server is running on port: " + port);

// for dev purposes only - set the API key from the environment variable
const envOpenAiKey = process.env.OPENAI_API_KEY;
const prePrompt = retrievalQAPrePrompt;
Expand Down
13 changes: 13 additions & 0 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ enum CHAT_MODELS {

enum CHAT_MESSAGE_TYPE {
BOT,
BOT_BLOCKED,
INFO,
USER,
USER_TRANSFORMED,
PHASE_INFO,
DEFENCE_ALERTED,
DEFENCE_TRIGGERED,
SYSTEM,
FUNCTION_CALL,
}

interface ChatDefenceReport {
Expand Down Expand Up @@ -50,11 +56,18 @@ interface ChatHttpResponse {
wonPhase: boolean;
}

interface ChatHistoryMessage {
completion: ChatCompletionRequestMessage | null;
chatMessageType: CHAT_MESSAGE_TYPE;
infoMessage?: string | null;
}

export type {
ChatAnswer,
ChatDefenceReport,
ChatMalicious,
ChatResponse,
ChatHttpResponse,
ChatHistoryMessage,
};
export { CHAT_MODELS, CHAT_MESSAGE_TYPE };
2 changes: 1 addition & 1 deletion backend/src/models/phase.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
enum PHASE_NAMES {
PHASE_0 = 0,
PHASE_0,
PHASE_1,
PHASE_2,
SANDBOX,
Expand Down
93 changes: 79 additions & 14 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ import {
Configuration,
OpenAIApi,
} from "openai";
import { CHAT_MODELS, ChatDefenceReport } from "./models/chat";
import {
CHAT_MESSAGE_TYPE,
CHAT_MODELS,
ChatDefenceReport,
ChatHistoryMessage,
} from "./models/chat";
import { DEFENCE_TYPES, DefenceInfo } from "./models/defence";
import { PHASE_NAMES } from "./models/phase";

Expand Down Expand Up @@ -223,7 +228,7 @@ async function chatGptCallFunction(
}

async function chatGptChatCompletion(
chatHistory: ChatCompletionRequestMessage[],
chatHistory: ChatHistoryMessage[],
defences: DefenceInfo[],
gptModel: CHAT_MODELS,
openai: OpenAIApi,
Expand All @@ -237,35 +242,72 @@ async function chatGptChatCompletion(
isDefenceActive(DEFENCE_TYPES.SYSTEM_ROLE, defences)
) {
// check to see if there's already a system role
if (!chatHistory.find((message) => message.role === "system")) {
if (!chatHistory.find((message) => message.completion?.role === "system")) {
// add the system role to the start of the chat history
chatHistory.unshift({
role: "system",
content: getSystemRole(defences, currentPhase),
completion: {
role: "system",
content: getSystemRole(defences, currentPhase),
},
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
});
}
} else {
// remove the system role from the chat history
while (chatHistory.length > 0 && chatHistory[0].role === "system") {
while (
chatHistory.length > 0 &&
chatHistory[0].completion?.role === "system"
) {
chatHistory.shift();
}
}

const chat_completion = await openai.createChatCompletion({
model: gptModel,
messages: chatHistory,
messages: getChatCompletionsFromHistory(chatHistory),
functions: chatGptFunctions,
});

// get the reply
return chat_completion.data.choices[0].message || null;
}

// take only the completions to send to GPT
const getChatCompletionsFromHistory = (
chatHistory: ChatHistoryMessage[]
): ChatCompletionRequestMessage[] => {
const completions: ChatCompletionRequestMessage[] =
chatHistory.length > 0
? chatHistory
.filter((message) => message.completion !== null)
.map((message) => message.completion as ChatCompletionRequestMessage)
: [];
return completions;
};

const pushCompletionToHistory = (
chatHistory: ChatHistoryMessage[],
completion: ChatCompletionRequestMessage,
messageType: CHAT_MESSAGE_TYPE
) => {
if (messageType !== CHAT_MESSAGE_TYPE.BOT_BLOCKED) {
chatHistory.push({
completion: completion,
chatMessageType: messageType,
});
} else {
// do not add the bots reply which was subsequently blocked
console.log("Skipping adding blocked message to chat history", completion);
}
return chatHistory;
};

async function chatGptSendMessage(
chatHistory: ChatCompletionRequestMessage[],
chatHistory: ChatHistoryMessage[],
defences: DefenceInfo[],
gptModel: CHAT_MODELS,
message: string,
messageIsTransformed: boolean,
openAiApiKey: string,
sentEmails: EmailInfo[],
// default to sandbox
Expand All @@ -283,7 +325,16 @@ async function chatGptSendMessage(
let wonPhase: boolean | undefined | null = false;

// add user message to chat
chatHistory.push({ role: "user", content: message });
chatHistory = pushCompletionToHistory(
chatHistory,
{
role: "user",
content: message,
},
messageIsTransformed
? CHAT_MESSAGE_TYPE.USER_TRANSFORMED
: CHAT_MESSAGE_TYPE.USER
);

const openai = getOpenAiFromKey(openAiApiKey);
let reply = await chatGptChatCompletion(
Expand All @@ -295,7 +346,11 @@ async function chatGptSendMessage(
);
// check if GPT wanted to call a function
while (reply && reply.function_call) {
chatHistory.push(reply);
chatHistory = pushCompletionToHistory(
chatHistory,
reply,
CHAT_MESSAGE_TYPE.FUNCTION_CALL
);

// call the function and get a new reply and defence info from
const functionCallReply = await chatGptCallFunction(
Expand All @@ -309,12 +364,15 @@ async function chatGptSendMessage(
wonPhase = functionCallReply.wonPhase;
// add the function call to the chat history
if (functionCallReply.completion !== undefined) {
chatHistory.push(functionCallReply.completion);
chatHistory = pushCompletionToHistory(
chatHistory,
functionCallReply.completion,
CHAT_MESSAGE_TYPE.FUNCTION_CALL
);
}
// update the defence info
defenceInfo = functionCallReply.defenceInfo;
}

// get a new reply from ChatGPT now that the function has been called
reply = await chatGptChatCompletion(
chatHistory,
Expand All @@ -326,7 +384,6 @@ async function chatGptSendMessage(
}

if (reply && reply.content) {
console.debug("GPT reply: " + reply.content);
// if output filter defence is active, check for blocked words/phrases
if (
currentPhase === PHASE_NAMES.PHASE_2 ||
Expand All @@ -353,9 +410,17 @@ async function chatGptSendMessage(
}
}
// add the ai reply to the chat history
chatHistory.push(reply);
chatHistory = pushCompletionToHistory(
chatHistory,
reply,
defenceInfo.isBlocked
? CHAT_MESSAGE_TYPE.BOT_BLOCKED
: CHAT_MESSAGE_TYPE.BOT
);

// log the entire chat history so far
console.log(chatHistory);

return {
completion: reply,
defenceInfo: defenceInfo,
Expand Down
Loading