Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

begin integration tests for langchain #182

Merged
merged 9 commits into from
Aug 24, 2023
113 changes: 91 additions & 22 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function getMaxMessageLength(defences: DefenceInfo[]) {
}

function getRandomSequenceEnclosurePrePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", "");
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", retrievalQAPrePromptSecure);
}

function getRandomSequenceEnclosureLength(defences: DefenceInfo[]) {
Expand Down Expand Up @@ -179,9 +179,8 @@ function escapeXml(unsafe: string) {
case "'":
return "'";
case '"':
return """;
default:
return c;
return """;
}
});
}
Expand Down Expand Up @@ -281,11 +280,11 @@ export {
activateDefence,
configureDefence,
deactivateDefence,
detectTriggeredDefences,
getEmailWhitelistVar,
getInitialDefences,
getQALLMprePrompt,
getSystemRole,
isDefenceActive,
transformMessage,
detectTriggeredDefences,
getEmailWhitelistVar,
};
46 changes: 34 additions & 12 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ let qaChain: RetrievalQAChain | null = null;
// chain we use in prompt evaluation request
let promptEvaluationChain: SequentialChain | null = null;

function setQAChain(chain: RetrievalQAChain | null) {
console.debug("Setting QA chain.");
qaChain = chain;
}

function setPromptEvaluationChain(chain: SequentialChain | null) {
console.debug("Setting evaluation chain.");
promptEvaluationChain = chain;
}

function getFilepath(currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX) {
let filePath = "resources/documents/";
switch (currentPhase) {
Expand Down Expand Up @@ -66,7 +76,9 @@ function getQAPromptTemplate(prePrompt: string) {
console.debug("Using default retrieval QA pre-prompt");
prePrompt = retrievalQAPrePrompt;
}
return PromptTemplate.fromTemplate(prePrompt + qAcontextTemplate);
const fullPrompt = prePrompt + qAcontextTemplate;
const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt);
return template;
}

// QA Chain - ask the chat model a question about the documents
Expand Down Expand Up @@ -100,18 +112,22 @@ async function initQAModel(
const qaPrompt = getQAPromptTemplate(prePrompt);

// set chain to retrieval QA chain
qaChain = RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
});
setQAChain(
RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
})
);
console.debug("QA chain initialised.");
}

// initialise the prompt evaluation model
function initPromptEvaluationModel(openAiApiKey: string) {
if (!openAiApiKey) {
console.debug("No OpenAI API key set to initialise prompt evaluation model");
console.debug(
"No OpenAI API key set to initialise prompt evaluation model"
);
return;
}

// create chain to detect prompt injection
const promptInjectionPrompt = PromptTemplate.fromTemplate(
promptInjectionEvalTemplate
Expand Down Expand Up @@ -141,12 +157,14 @@ function initPromptEvaluationModel(openAiApiKey: string) {
outputKey: "maliciousInputEval",
});

promptEvaluationChain = new SequentialChain({
const sequentialChain = new SequentialChain({
chains: [promptInjectionChain, maliciousInputChain],
inputVariables: ["prompt"],
outputVariables: ["promptInjectionEval", "maliciousInputEval"],
});
console.debug("Prompt evaluation chain initialised");
setPromptEvaluationChain(sequentialChain);

console.debug("Prompt evaluation chain initialised.");
}

// ask the question and return models answer
Expand All @@ -172,19 +190,17 @@ async function queryPromptEvaluationModel(input: string) {
console.debug("Prompt evaluation chain not initialised.");
return { isMalicious: false, reason: "" };
}

console.log(`Checking '${input}' for malicious prompts`);

const response = await promptEvaluationChain.call({
prompt: input,
});

const promptInjectionEval = formatEvaluationOutput(
response.promptInjectionEval
);
const maliciousInputEval = formatEvaluationOutput(
response.maliciousInputEval
);

console.debug(
"Prompt injection eval: " + JSON.stringify(promptInjectionEval)
);
Expand All @@ -210,7 +226,7 @@ function formatEvaluationOutput(response: string) {
// split response on first full stop or comma
const splitResponse = response.split(/\.|,/);
const answer = splitResponse[0]?.replace(/\W/g, "").toLowerCase();
const reason = splitResponse[1];
const reason = splitResponse[1]?.trim();
return {
isMalicious: answer === "yes",
reason: reason,
Expand All @@ -228,7 +244,13 @@ function formatEvaluationOutput(response: string) {

export {
initQAModel,
getFilepath,
getQAPromptTemplate,
getDocuments,
initPromptEvaluationModel,
queryDocuments,
queryPromptEvaluationModel,
formatEvaluationOutput,
setQAChain,
setPromptEvaluationChain,
};
Loading