Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

184 use defence types enum #186

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ const getInitialDefences = (): DefenceInfo[] => {
];
};

function activateDefence(id: string, defences: DefenceInfo[]) {
function activateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: true } : defence
);
}

function deactivateDefence(id: string, defences: DefenceInfo[]) {
function deactivateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: false } : defence
);
}

function configureDefence(
id: string,
id: DEFENCE_TYPES,
defences: DefenceInfo[],
config: DefenceConfig[]
) {
Expand All @@ -72,7 +72,7 @@ function configureDefence(

function getConfigValue(
defences: DefenceInfo[],
defenceId: string,
defenceId: DEFENCE_TYPES,
configId: string,
defaultValue: string
) {
Expand All @@ -85,20 +85,20 @@ function getConfigValue(
function getMaxMessageLength(defences: DefenceInfo[]) {
return getConfigValue(
defences,
"CHARACTER_LIMIT",
DEFENCE_TYPES.CHARACTER_LIMIT,
"maxMessageLength",
String(280)
);
}

function getRandomSequenceEnclosurePrePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", retrievalQAPrePromptSecure);
return getConfigValue(defences, DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, "prePrompt", retrievalQAPrePromptSecure);
}

function getRandomSequenceEnclosureLength(defences: DefenceInfo[]) {
return getConfigValue(
defences,
"RANDOM_SEQUENCE_ENCLOSURE",
DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE,
"length",
String(10)
);
Expand All @@ -117,19 +117,19 @@ function getSystemRole(
case PHASE_NAMES.PHASE_2:
return process.env.SYSTEM_ROLE_PHASE_2 || "";
default:
return getConfigValue(defences, "SYSTEM_ROLE", "systemRole", "");
return getConfigValue(defences, DEFENCE_TYPES.SYSTEM_ROLE, "systemRole", "");
}
}

function getEmailWhitelistVar(defences: DefenceInfo[]) {
return getConfigValue(defences, "EMAIL_WHITELIST", "whitelist", "");
return getConfigValue(defences, DEFENCE_TYPES.EMAIL_WHITELIST, "whitelist", "");
}

function getQALLMprePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "QA_LLM_INSTRUCTIONS", "prePrompt", "");
return getConfigValue(defences, DEFENCE_TYPES.QA_LLM_INSTRUCTIONS, "prePrompt", "");
}

function isDefenceActive(id: string, defences: DefenceInfo[]) {
function isDefenceActive(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
return defences.find((defence) => defence.id === id && defence.isActive)
? true
: false;
Expand Down Expand Up @@ -207,13 +207,13 @@ function transformXmlTagging(message: string) {
//apply defence string transformations to original message
function transformMessage(message: string, defences: DefenceInfo[]) {
let transformedMessage: string = message;
if (isDefenceActive("RANDOM_SEQUENCE_ENCLOSURE", defences)) {
if (isDefenceActive(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences)) {
transformedMessage = transformRandomSequenceEnclosure(
transformedMessage,
defences
);
}
if (isDefenceActive("XML_TAGGING", defences)) {
if (isDefenceActive(DEFENCE_TYPES.XML_TAGGING, defences)) {
transformedMessage = transformXmlTagging(transformedMessage);
}
if (message == transformedMessage) {
Expand Down Expand Up @@ -242,9 +242,9 @@ async function detectTriggeredDefences(
if (message.length > maxMessageLength) {
console.debug("CHARACTER_LIMIT defence triggered.");
// add the defence to the list of triggered defences
defenceReport.triggeredDefences.push("CHARACTER_LIMIT");
defenceReport.triggeredDefences.push(DEFENCE_TYPES.CHARACTER_LIMIT);
// check if the defence is active
if (isDefenceActive("CHARACTER_LIMIT", defences)) {
if (isDefenceActive(DEFENCE_TYPES.CHARACTER_LIMIT, defences)) {
// block the message
defenceReport.isBlocked = true;
defenceReport.blockedReason = "Message is too long";
Expand All @@ -257,14 +257,14 @@ async function detectTriggeredDefences(
if (detectXMLTags(message)) {
console.debug("XML_TAGGING defence triggered.");
// add the defence to the list of triggered defences
defenceReport.triggeredDefences.push("XML_TAGGING");
defenceReport.triggeredDefences.push(DEFENCE_TYPES.XML_TAGGING);
}

// evaluate the message for prompt injection
const evalPrompt = await queryPromptEvaluationModel(message);
if (evalPrompt.isMalicious) {
defenceReport.triggeredDefences.push("LLM_EVALUATION");
if (isDefenceActive("LLM_EVALUATION", defences)) {
defenceReport.triggeredDefences.push(DEFENCE_TYPES.LLM_EVALUATION);
if (isDefenceActive(DEFENCE_TYPES.LLM_EVALUATION, defences)) {
console.debug("LLM evalutation defence active.");
defenceReport.isBlocked = true;
defenceReport.blockedReason =
Expand Down
4 changes: 2 additions & 2 deletions backend/src/email.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { DefenceInfo } from "./models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "./models/defence";
import { EmailInfo } from "./models/email";
import { getEmailWhitelistVar, isDefenceActive } from "./defence";
import { PHASE_NAMES } from "./models/phase";
Expand All @@ -11,7 +11,7 @@ function getEmailWhitelistValues(defences: DefenceInfo[]) {

// if defense active return the whitelist of emails and domains
function getEmailWhitelist(defences: DefenceInfo[]) {
if (!isDefenceActive("EMAIL_WHITELIST", defences)) {
if (!isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) {
return "As the email whitelist defence is not active, any email address can be emailed.";
} else {
return (
Expand Down
8 changes: 4 additions & 4 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
OpenAIApi,
} from "openai";
import { CHAT_MODELS, ChatDefenceReport } from "./models/chat";
import { DefenceInfo } from "./models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "./models/defence";
import { PHASE_NAMES } from "./models/phase";

// OpenAI config
Expand Down Expand Up @@ -165,8 +165,8 @@ async function chatGptCallFunction(
isAllowedToSendEmail = true;
} else {
// trigger email defence even if it is not active
defenceInfo.triggeredDefences.push("EMAIL_WHITELIST");
if (isDefenceActive("EMAIL_WHITELIST", defences)) {
defenceInfo.triggeredDefences.push(DEFENCE_TYPES.EMAIL_WHITELIST);
if (isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) {
// do not send email if defence is on and set to blocked
defenceInfo.isBlocked = true;
defenceInfo.blockedReason =
Expand Down Expand Up @@ -229,7 +229,7 @@ async function chatGptChatCompletion(
// system role is always active on phases
if (
currentPhase !== PHASE_NAMES.SANDBOX ||
isDefenceActive("SYSTEM_ROLE", defences)
isDefenceActive(DEFENCE_TYPES.SYSTEM_ROLE, defences)
) {
// check to see if there's already a system role
if (!chatHistory.find((message) => message.role === "system")) {
Expand Down
12 changes: 6 additions & 6 deletions backend/src/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from "./defence";
import { initQAModel } from "./langchain";
import { CHAT_MODELS, ChatHttpResponse } from "./models/chat";
import { DefenceConfig } from "./models/defence";
import { DEFENCE_TYPES, DefenceConfig } from "./models/defence";
import { chatGptSendMessage, setOpenAiApiKey, setGptModel } from "./openai";
import { retrievalQAPrePrompt } from "./promptTemplates";
import { PHASE_NAMES } from "./models/phase";
Expand All @@ -24,13 +24,13 @@ let prevPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX;
// Activate a defence
router.post("/defence/activate", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
if (defenceId) {
// activate the defence
req.session.defences = activateDefence(defenceId, req.session.defences);

// need to re-initialize QA model when turned on
if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) {
if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) {
console.debug(
"Activating qa llm instruction defence - reinitializing qa model"
);
Expand All @@ -47,12 +47,12 @@ router.post("/defence/activate", (req, res) => {
// Deactivate a defence
router.post("/defence/deactivate", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
if (defenceId) {
// deactivate the defence
req.session.defences = deactivateDefence(defenceId, req.session.defences);

if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) {
if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) {
console.debug("Resetting QA model with default prompt");
initQAModel(req.session.openAiApiKey, getQALLMprePrompt(req.session.defences));
}
Expand All @@ -66,7 +66,7 @@ router.post("/defence/deactivate", (req, res) => {
// Configure a defence
router.post("/defence/configure", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
const config: DefenceConfig[] = req.body?.config;
if (defenceId && config) {
// configure the defence
Expand Down
17 changes: 9 additions & 8 deletions backend/test/integration/defences.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { activateDefence, detectTriggeredDefences, getInitialDefences } from "../../src/defence";
import { initPromptEvaluationModel } from "../../src/langchain";
import { DEFENCE_TYPES } from "../../src/models/defence";

// Define a mock implementation for the createChatCompletion method
const mockCall = jest.fn();
Expand Down Expand Up @@ -31,14 +32,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt is malicious WHEN detect

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => {
Expand All @@ -53,14 +54,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice det

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt injection detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => {
Expand All @@ -75,14 +76,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt inj

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN detectTriggeredDefences is called THEN defence is not triggered AND defence is not blocked", async () => {
Expand All @@ -97,7 +98,7 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN de

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
Expand All @@ -124,5 +125,5 @@ test("GIVEN LLM_EVALUATION defence is not active AND prompt is malicious WHEN de
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(false);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});
16 changes: 8 additions & 8 deletions backend/test/integration/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatCompletionRequestMessage } from "openai";
import { CHAT_MODELS } from "../../src/models/chat";
import { chatGptSendMessage } from "../../src/openai";
import { DefenceInfo } from "../../src/models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "../../src/models/defence";
import { EmailInfo } from "../../src/models/email";
import { activateDefence, getInitialDefences } from "../../src/defence";

Expand Down Expand Up @@ -91,7 +91,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -154,7 +154,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role
const openAiApiKey = "sk-12345";

// activate the SYSTEM_ROLE defence
defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -292,7 +292,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -410,7 +410,7 @@ test(
expect(reply?.defenceInfo.isBlocked).toBe(false);
// EMAIL_WHITELIST defence is triggered
expect(reply?.defenceInfo.triggeredDefences.length).toBe(1);
expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST");
expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST);

// restore the mock
mockCreateChatCompletion.mockRestore();
Expand All @@ -432,7 +432,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("EMAIL_WHITELIST", defences);
defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion
Expand Down Expand Up @@ -486,7 +486,7 @@ test(
expect(reply?.defenceInfo.isBlocked).toBe(true);
// EMAIL_WHITELIST defence is triggered
expect(reply?.defenceInfo.triggeredDefences.length).toBe(1);
expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST");
expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST);

// restore the mock
mockCreateChatCompletion.mockRestore();
Expand All @@ -508,7 +508,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("EMAIL_WHITELIST", defences);
defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion
Expand Down
Loading