Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

225 multi user langchain #231

Merged
merged 7 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions backend/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import { EmailInfo } from "./models/email";
import { DefenceInfo } from "./models/defence";
import { CHAT_MODELS } from "./models/chat";
import { PHASE_NAMES } from "./models/phase";
import { retrievalQAPrePrompt } from "./promptTemplates";
import path from "path";
import { getInitialDefences } from "./defence";
import { initDocumentVectors } from "./langchain";

dotenv.config();

Expand Down Expand Up @@ -97,13 +97,21 @@ app.use("/", router);
app.listen(port, () => {
console.log(`Server is running on port: ${port}`);

// initialise the documents on app startup
initDocumentVectors()
.then(() => {
console.debug("Document vectors initialised");
})
.catch((err) => {
console.error("Error initialising document vectors", err);
});

// for dev purposes only - set the API key from the environment variable
const envOpenAiKey = process.env.OPENAI_API_KEY;
const prePrompt = retrievalQAPrePrompt;
if (envOpenAiKey) {
console.debug("Initializing models with API key from environment variable");
// asynchronously set the API key
void setOpenAiApiKey(envOpenAiKey, defaultModel, prePrompt).then(() => {
void setOpenAiApiKey(envOpenAiKey, defaultModel).then(() => {
console.debug("OpenAI models initialized");
});
}
Expand Down
5 changes: 3 additions & 2 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ function transformMessage(message: string, defences: DefenceInfo[]) {
// detects triggered defences in original message and blocks the message if necessary
async function detectTriggeredDefences(
message: string,
defences: DefenceInfo[]
defences: DefenceInfo[],
openAiApiKey: string
) {
// keep track of any triggered defences
const defenceReport: ChatDefenceReport = {
Expand Down Expand Up @@ -352,7 +353,7 @@ async function detectTriggeredDefences(
}

// evaluate the message for prompt injection
const evalPrompt = await queryPromptEvaluationModel(message);
const evalPrompt = await queryPromptEvaluationModel(message, openAiApiKey);
if (evalPrompt.isMalicious) {
if (isDefenceActive(DEFENCE_TYPES.LLM_EVALUATION, defences)) {
defenceReport.triggeredDefences.push(DEFENCE_TYPES.LLM_EVALUATION);
Expand Down
110 changes: 63 additions & 47 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import { OpenAI } from "langchain/llms/openai";
import { PromptTemplate } from "langchain/prompts";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { MemoryVectorStore } from "langchain/vectorstores/memory";

import { CHAT_MODELS, ChatAnswer } from "./models/chat";
import { DocumentsVector } from "./models/document";

import {
maliciousPromptTemplate,
promptInjectionEvalTemplate,
Expand All @@ -21,23 +22,15 @@ import {
import { PHASE_NAMES } from "./models/phase";
import { PromptEvaluationChainReply, QaChainReply } from "./models/langchain";

// chain we use in question/answer request
let qaChain: RetrievalQAChain | null = null;

// chain we use in prompt evaluation request
let promptEvaluationChain: SequentialChain | null = null;
// store vectorised documents for each phase as array
let vectorisedDocuments: DocumentsVector[] = [];

function setQAChain(chain: RetrievalQAChain | null) {
console.debug("Setting QA chain.");
qaChain = chain;
// set the global varibale
function setVectorisedDocuments(docs: DocumentsVector[]) {
vectorisedDocuments = docs;
}

function setPromptEvaluationChain(chain: SequentialChain | null) {
console.debug("Setting evaluation chain.");
promptEvaluationChain = chain;
}

function getFilepath(currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX) {
function getFilepath(currentPhase: PHASE_NAMES) {
let filePath = "resources/documents/";
switch (currentPhase) {
case PHASE_NAMES.PHASE_0:
Expand All @@ -46,11 +39,13 @@ function getFilepath(currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX) {
return (filePath += "phase_1/");
case PHASE_NAMES.PHASE_2:
return (filePath += "phase_2/");
default:
case PHASE_NAMES.SANDBOX:
return (filePath += "common/");
default:
console.error(`No document filepath found for phase: ${filePath}`);
return "";
}
}

// load the documents from filesystem
async function getDocuments(filePath: string) {
console.debug(`Loading documents from: ${filePath}`);
Expand All @@ -74,51 +69,66 @@ async function getDocuments(filePath: string) {
// join the configurable preprompt to the context template
function getQAPromptTemplate(prePrompt: string) {
if (!prePrompt) {
console.debug("Using default retrieval QA pre-prompt");
// use the default prePrompt
prePrompt = retrievalQAPrePrompt;
}
const fullPrompt = prePrompt + qAcontextTemplate;
console.debug(`QA prompt: ${fullPrompt}`);
const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt);
return template;
}
// create and store the document vectors for each phase
async function initDocumentVectors() {
const docVectors: DocumentsVector[] = [];

for (const value of Object.values(PHASE_NAMES)) {
if (!isNaN(Number(value))) {
const phase = value as PHASE_NAMES;
// get the documents
const filePath: string = getFilepath(phase);
const documents: Document[] = await getDocuments(filePath);

// embed and store the splits - will use env variable for API key
const embeddings = new OpenAIEmbeddings();

const vectorStore: MemoryVectorStore =
await MemoryVectorStore.fromDocuments(documents, embeddings);
// store the document vectors for the phase
docVectors.push({
phase: phase,
docVector: vectorStore,
});
}
}
setVectorisedDocuments(docVectors);
console.debug(
"Intitialised document vectors for each phase. count=",
vectorisedDocuments.length
);
}

// QA Chain - ask the chat model a question about the documents
async function initQAModel(
openAiApiKey: string,
function initQAModel(
phase: PHASE_NAMES,
prePrompt: string,
// default to sandbox
currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX
openAiApiKey: string
) {
if (!openAiApiKey) {
console.debug("No OpenAI API key set to initialise QA model");
return;
}
// get the documents
const docs: Document[] = await getDocuments(getFilepath(currentPhase));

// embed and store the splits
const embeddings = new OpenAIEmbeddings({
openAIApiKey: openAiApiKey,
});
const vectorStore = await MemoryVectorStore.fromDocuments(docs, embeddings);
const documentVectors = vectorisedDocuments[phase].docVector;

// initialise model
const model = new ChatOpenAI({
modelName: CHAT_MODELS.GPT_4,
streaming: true,
openAIApiKey: openAiApiKey,
});
const promptTemplate = getQAPromptTemplate(prePrompt);

// prompt template for question and answering
const qaPrompt = getQAPromptTemplate(prePrompt);

// set chain to retrieval QA chain
setQAChain(
RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
})
);
console.debug("QA chain initialised.");
return RetrievalQAChain.fromLLM(model, documentVectors.asRetriever(), {
prompt: promptTemplate,
});
}

// initialise the prompt evaluation model
Expand Down Expand Up @@ -163,13 +173,18 @@ function initPromptEvaluationModel(openAiApiKey: string) {
inputVariables: ["prompt"],
outputVariables: ["promptInjectionEval", "maliciousInputEval"],
});
setPromptEvaluationChain(sequentialChain);

console.debug("Prompt evaluation chain initialised.");
return sequentialChain;
}

// ask the question and return models answer
async function queryDocuments(question: string) {
async function queryDocuments(
question: string,
prePrompt: string,
currentPhase: PHASE_NAMES,
openAiApiKey: string
) {
const qaChain = initQAModel(currentPhase, prePrompt, openAiApiKey);
if (!qaChain) {
console.debug("QA chain not initialised.");
return { reply: "", questionAnswered: false };
Expand All @@ -187,7 +202,8 @@ async function queryDocuments(question: string) {
}

// ask LLM whether the prompt is malicious
async function queryPromptEvaluationModel(input: string) {
async function queryPromptEvaluationModel(input: string, openAIApiKey: string) {
const promptEvaluationChain = initPromptEvaluationModel(openAIApiKey);
if (!promptEvaluationChain) {
console.debug("Prompt evaluation chain not initialised.");
return { isMalicious: false, reason: "" };
Expand Down Expand Up @@ -252,6 +268,6 @@ export {
queryDocuments,
queryPromptEvaluationModel,
formatEvaluationOutput,
setQAChain,
setPromptEvaluationChain,
setVectorisedDocuments,
initDocumentVectors,
};
10 changes: 9 additions & 1 deletion backend/src/models/document.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { PHASE_NAMES } from "./phase";

interface Document {
filename: string;
filetype: string;
}

export type { Document };
interface DocumentsVector {
phase: PHASE_NAMES;
docVector: MemoryVectorStore;
}

export type { Document, DocumentsVector };
38 changes: 20 additions & 18 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ import {
getSystemRole,
detectFilterList,
getFilterList,
getQALLMprePrompt,
} from "./defence";
import { sendEmail, getEmailWhitelist, isEmailInWhitelist } from "./email";
import {
initQAModel,
initPromptEvaluationModel,
queryDocuments,
} from "./langchain";
import { queryDocuments } from "./langchain";
import { EmailInfo, EmailResponse } from "./models/email";
import {
ChatCompletionRequestMessage,
Expand Down Expand Up @@ -101,19 +98,11 @@ async function validateApiKey(openAiApiKey: string, gptModel: string) {
}
}

async function setOpenAiApiKey(
openAiApiKey: string,
gptModel: string,
prePrompt: string,
// default to sandbox mode
currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX
) {
// initialise all models with the new key
async function setOpenAiApiKey(openAiApiKey: string, gptModel: string) {
// initialise models with the new key
if (await validateApiKey(openAiApiKey, gptModel)) {
console.debug("Setting API key and initialising models");
initOpenAi(openAiApiKey);
await initQAModel(openAiApiKey, prePrompt, currentPhase);
initPromptEvaluationModel(openAiApiKey);
return true;
} else {
// set to empty in case it was previously set
Expand Down Expand Up @@ -158,7 +147,8 @@ async function chatGptCallFunction(
functionCall: ChatCompletionRequestMessageFunctionCall,
sentEmails: EmailInfo[],
// default to sandbox
currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX
currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX,
openAiApiKey: string
) {
let reply: ChatCompletionRequestMessage | null = null;
let wonPhase = false;
Expand Down Expand Up @@ -217,7 +207,18 @@ async function chatGptCallFunction(
) as FunctionAskQuestionParams;
console.debug(`Asking question: ${params.question}`);
// if asking a question, call the queryDocuments
response = (await queryDocuments(params.question)).reply;
let qaPrompt = "";
if (isDefenceActive(DEFENCE_TYPES.QA_LLM_INSTRUCTIONS, defences)) {
qaPrompt = getQALLMprePrompt(defences);
}
response = (
await queryDocuments(
params.question,
qaPrompt,
currentPhase,
openAiApiKey
)
).reply;
} else {
console.error("No arguments provided to askQuestion function");
}
Expand Down Expand Up @@ -377,7 +378,8 @@ async function chatGptSendMessage(
defences,
reply.function_call,
sentEmails,
currentPhase
currentPhase,
openAiApiKey
);
if (functionCallReply) {
wonPhase = functionCallReply.wonPhase;
Expand Down
Loading