From 688f7a8306d9268a0e79f55c9a1c9a1c325fcbe6 Mon Sep 17 00:00:00 2001 From: Noam Levy Date: Tue, 3 Dec 2024 19:03:55 +0200 Subject: [PATCH 1/9] fix activefence rail docs --- docs/user-guides/community/active-fence.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guides/community/active-fence.md b/docs/user-guides/community/active-fence.md index 688216f29..0c02b0d88 100644 --- a/docs/user-guides/community/active-fence.md +++ b/docs/user-guides/community/active-fence.md @@ -31,7 +31,7 @@ ActiveFence’s ActiveScore API gives flexibility in controlling the behavior of ```colang define flow activefence input moderation detailed - $result = execute call activefence api(text=$user_message) + $result = execute call activefence api if $result.violations.get("abusive_or_harmful.hate_speech", 0) > 0.8 bot inform cannot engage in abusive or harmful behavior From 17abbd94cdcfd268a92fc409a4331588fd4cf976 Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Thu, 19 Dec 2024 22:16:40 +0200 Subject: [PATCH 2/9] add prompt security integration --- .gitignore | 2 + docs/user-guides/community/prompt-security.md | 38 +++++ docs/user-guides/guardrails-library.md | 22 +++ examples/configs/prompt_security/README.md | 5 + examples/configs/prompt_security/config.yml | 13 ++ .../library/prompt_security/__init__.py | 14 ++ .../library/prompt_security/actions.py | 81 ++++++++++ .../library/prompt_security/flows.co | 22 +++ .../library/prompt_security/flows.v1.co | 20 +++ tests/test_prompt_security.py | 141 ++++++++++++++++++ 10 files changed, 358 insertions(+) create mode 100644 docs/user-guides/community/prompt-security.md create mode 100644 examples/configs/prompt_security/README.md create mode 100644 examples/configs/prompt_security/config.yml create mode 100644 nemoguardrails/library/prompt_security/__init__.py create mode 100644 nemoguardrails/library/prompt_security/actions.py create mode 100644 nemoguardrails/library/prompt_security/flows.co create mode 100644 nemoguardrails/library/prompt_security/flows.v1.co create mode 100644 tests/test_prompt_security.py diff --git a/.gitignore b/.gitignore index a707cb164..11436c9a6 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,5 @@ docs/**/config # Ignoring log files generated by tests firebase.json scratch.py + +.env \ No newline at end of file diff --git a/docs/user-guides/community/prompt-security.md b/docs/user-guides/community/prompt-security.md new file mode 100644 index 000000000..d5264d7f5 --- /dev/null +++ b/docs/user-guides/community/prompt-security.md @@ -0,0 +1,38 @@ +# Prompt Security Integration + +[Prompt Security AI](https://prompt.security/?utm_medium=github&utm_campaign=nemo-guardrails) allows you to protect LLM interaction. This integration enables NeMo Guardrails to use Prompt Security to protect input and output flows. + +## Setup + +1. Ensure that you have access to Prompt Security API server (SaaS or on-prem). + +2. Update your `config.yml` file to include the Private AI settings: + +```yaml +rails: + input: + flows: + - protect prompt + output: + flows: + - protect response +``` + +Set the `PS_PROTECT_URL` and `PS_APP_ID` environment variables. + +## Usage + +Once configured, the Prompt Security integration will automatically: + +1. Protect prompts before they are processed by the LLM. +2. Protect LLM outputs before they are sent back to the user. + +The `protect_text` action in `nemoguardrails/library/prompt_security/actions.py` handles the protection process. + +## Error Handling + +If the Prompt Security API request fails, it's operating in a fail-open mode (not blocking the prompt/response). + +## Notes + +For more information on Prompt Security and capabilities, please refer to the [Prompt Security documentation](https://prompt.security/?utm_medium=github&utm_campaign=nemo-guardrails). diff --git a/docs/user-guides/guardrails-library.md b/docs/user-guides/guardrails-library.md index 14b84fe44..a2a2d9c34 100644 --- a/docs/user-guides/guardrails-library.md +++ b/docs/user-guides/guardrails-library.md @@ -23,6 +23,7 @@ NeMo Guardrails comes with a library of built-in guardrails that you can easily - [Cleanlab Trustworthiness Score](#cleanlab) - [GCP Text Moderation](#gcp-text-moderation) - [Private AI PII detection](#private-ai-pii-detection) + - [Prompt Security Protection](#prompt-security-protection) - OpenAI Moderation API - *[COMING SOON]* 4. Other @@ -772,6 +773,27 @@ rails: For more details, check out the [Private AI Integration](./community/privateai.md) page. + +### Prompt Security Protection + +NeMo Guardrails supports using [Prompt Security API](https://prompt.security/?utm_medium=github&utm_campaign=nemo-guardrails) for protecting input and output retrieval flows. + +To activate the protection, you need to set the `PS_PROTECT_URL` and `PS_APP_ID` environment variables. + +#### Example usage + +```yaml +rails: + input: + flows: + - protect prompt + output: + flows: + - protect response +``` + +For more details, check out the [Prompt Security Integration](./community/prompt_security.md) page. + ## Other ### Jailbreak Detection Heuristics diff --git a/examples/configs/prompt_security/README.md b/examples/configs/prompt_security/README.md new file mode 100644 index 000000000..d1e8ba0fc --- /dev/null +++ b/examples/configs/prompt_security/README.md @@ -0,0 +1,5 @@ +# Prompt Security Configuration Example + +This example contains configuration files for using Prompt Security in your NeMo Guardrails project. + +For more details on the Prompt Security integration, see [Prompt Security Integration User Guide](../../../docs/user-guides/community/prompt-security.md). diff --git a/examples/configs/prompt_security/config.yml b/examples/configs/prompt_security/config.yml new file mode 100644 index 000000000..b33707008 --- /dev/null +++ b/examples/configs/prompt_security/config.yml @@ -0,0 +1,13 @@ +models: + - type: main + engine: openai + model: gpt-4o + +rails: + input: + flows: + - protect prompt + + output: + flows: + - protect response diff --git a/nemoguardrails/library/prompt_security/__init__.py b/nemoguardrails/library/prompt_security/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/library/prompt_security/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py new file mode 100644 index 000000000..db997d625 --- /dev/null +++ b/nemoguardrails/library/prompt_security/actions.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prompt/Response protection using Prompt Security.""" + +import logging +import os +from typing import Optional + +import httpx + +from nemoguardrails.actions import action + +log = logging.getLogger(__name__) + + +async def ps_protect_api_async( + ps_protect_url: str, + ps_app_id: str, + prompt: Optional[str] = None, + system_prompt: Optional[str] = None, + response: Optional[str] = None, + user: Optional[str] = None, +): + headers = { + "APP-ID": ps_app_id, + "Content-Type": "application/json", + } + payload = { + "prompt": prompt, + "system_prompt": system_prompt, + "response": response, + "user": user, + } + async with httpx.AsyncClient() as client: + ret = await client.post(ps_protect_url, headers=headers, json=payload) + return ret.json() + + +@action(is_system_action=True) +async def protect_text(source: str, text: str): + """Protects the given text. + + Args + source: The source for the text, i.e. "input", "output". + text: The text to check. + + Returns + True if text should be blocked, False otherwise. + """ + + ps_protect_url = os.getenv("PS_PROTECT_URL") + if not ps_protect_url: + raise ValueError("PS_PROTECT_URL env variable required for Prompt Security.") + + ps_app_id = os.getenv("PS_APP_ID") + if not ps_app_id: + raise ValueError("PS_APP_ID env variable required for Prompt Security.") + + if source == "input": + response = await ps_protect_api_async(ps_protect_url, ps_app_id, text) + elif source == "output": + response = await ps_protect_api_async( + ps_protect_url, ps_app_id, None, None, text + ) + else: + raise ValueError(f"The flow, '{source}', is not supported by Prompt Security.") + + return response["result"]["action"] == "block" diff --git a/nemoguardrails/library/prompt_security/flows.co b/nemoguardrails/library/prompt_security/flows.co new file mode 100644 index 000000000..bee5b3ff7 --- /dev/null +++ b/nemoguardrails/library/prompt_security/flows.co @@ -0,0 +1,22 @@ +# INPUT RAILS + +@active +flow protect prompt + """Check if the prompt is valid according to Prompt Security.""" + $invalid = await protect_text(source="input", text=$user_message) + + if $invalid + bot inform answer unknown + abort + + +# OUTPUT RAILS + +@active +flow protect response + """Check if the response is valid according to Prompt Security.""" + $invalid = await protect_text(source="output", text=$bot_message) + + if $invalid + bot inform answer unknown + abort \ No newline at end of file diff --git a/nemoguardrails/library/prompt_security/flows.v1.co b/nemoguardrails/library/prompt_security/flows.v1.co new file mode 100644 index 000000000..361a41471 --- /dev/null +++ b/nemoguardrails/library/prompt_security/flows.v1.co @@ -0,0 +1,20 @@ +# INPUT RAILS + +define subflow protect prompt + """Check if the prompt is valid according to Prompt Security.""" + $invalid = execute protect_text(source="input", text=$user_message) + + if $invalid + bot inform answer unknown + stop + + +# OUTPUT RAILS + +define subflow protect response + """Check if the response is valid according to Prompt Security.""" + $invalid = execute protect_text(source="output", text=$bot_message) + + if $invalid + bot inform answer unknown + stop diff --git a/tests/test_prompt_security.py b/tests/test_prompt_security.py new file mode 100644 index 000000000..aba84b234 --- /dev/null +++ b/tests/test_prompt_security.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions.actions import ActionResult, action +from tests.utils import TestChat + + +@action() +def retrieve_relevant_chunks(): + context_updates = {"relevant_chunks": "Mock retrieved context."} + + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + +def mock_protect_text(return_value=True): + def mock_request(*args, **kwargs): + return return_value + + return mock_request + + +@pytest.mark.unit +def test_prompt_security_protection_disabled(): + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_protect_text(True), "protect_text") + chat >> "Hi! I am Mr. John! And my email is test@gmail.com" + chat << "Hi! My name is John as well." + + +@pytest.mark.unit +def test_prompt_security_protection_input(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + input: + flows: + - protect prompt + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_protect_text(True), "protect_text") + chat >> "Hi! I am Mr. John! And my email is test@gmail.com" + chat << "I can't answer that." + + +@pytest.mark.unit +def test_prompt_security_protection_output(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + output: + flows: + - protect response + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! My name is John as well."', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(mock_protect_text(True), "protect_text") + chat >> "Hi!" + chat << "I can't answer that." From 1a3aee725315d2c1dd0604b2874e8481b9b2b31d Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Sat, 21 Dec 2024 15:42:57 +0200 Subject: [PATCH 3/9] use context and try to modify user_message or bot_message when needed --- .../library/prompt_security/actions.py | 26 +++++++++++-------- .../library/prompt_security/flows.co | 4 +-- .../library/prompt_security/flows.v1.co | 4 +-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index db997d625..8921686fe 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -34,6 +34,8 @@ async def ps_protect_api_async( response: Optional[str] = None, user: Optional[str] = None, ): + """Calls Prompt Security Protect API asynchronously.""" + headers = { "APP-ID": ps_app_id, "Content-Type": "application/json", @@ -50,12 +52,8 @@ async def ps_protect_api_async( @action(is_system_action=True) -async def protect_text(source: str, text: str): - """Protects the given text. - - Args - source: The source for the text, i.e. "input", "output". - text: The text to check. +async def protect_text(context: Optional[dict] = None): + """Protects the given user_message or bot_message. Returns True if text should be blocked, False otherwise. @@ -69,13 +67,19 @@ async def protect_text(source: str, text: str): if not ps_app_id: raise ValueError("PS_APP_ID env variable required for Prompt Security.") - if source == "input": - response = await ps_protect_api_async(ps_protect_url, ps_app_id, text) - elif source == "output": + if context.get("bot_message"): + response = await ps_protect_api_async( + ps_protect_url, ps_app_id, None, None, context["bot_message"] + ) + if response["result"]["action"] == "modify": + context["bot_message"] = response["result"]["response"]["modified_text"] + elif context.get("user_message"): response = await ps_protect_api_async( - ps_protect_url, ps_app_id, None, None, text + ps_protect_url, ps_app_id, context["user_message"] ) + if response["result"]["action"] == "modify": + context["user_message"] = response["result"]["prompt"]["modified_text"] else: - raise ValueError(f"The flow, '{source}', is not supported by Prompt Security.") + raise ValueError(f"No user_message or bot_message in context: {context}") return response["result"]["action"] == "block" diff --git a/nemoguardrails/library/prompt_security/flows.co b/nemoguardrails/library/prompt_security/flows.co index bee5b3ff7..a8886a545 100644 --- a/nemoguardrails/library/prompt_security/flows.co +++ b/nemoguardrails/library/prompt_security/flows.co @@ -3,7 +3,7 @@ @active flow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $invalid = await protect_text(source="input", text=$user_message) + $invalid = await protect_text(context=$context) if $invalid bot inform answer unknown @@ -15,7 +15,7 @@ flow protect prompt @active flow protect response """Check if the response is valid according to Prompt Security.""" - $invalid = await protect_text(source="output", text=$bot_message) + $invalid = await protect_text(context=$context) if $invalid bot inform answer unknown diff --git a/nemoguardrails/library/prompt_security/flows.v1.co b/nemoguardrails/library/prompt_security/flows.v1.co index 361a41471..f56dc635d 100644 --- a/nemoguardrails/library/prompt_security/flows.v1.co +++ b/nemoguardrails/library/prompt_security/flows.v1.co @@ -2,7 +2,7 @@ define subflow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $invalid = execute protect_text(source="input", text=$user_message) + $invalid = execute protect_text if $invalid bot inform answer unknown @@ -13,7 +13,7 @@ define subflow protect prompt define subflow protect response """Check if the response is valid according to Prompt Security.""" - $invalid = execute protect_text(source="output", text=$bot_message) + $invalid = execute protect_text if $invalid bot inform answer unknown From 87d1a541b569a786a3f173e804f730ec48814398 Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Sat, 4 Jan 2025 15:36:41 +0200 Subject: [PATCH 4/9] option to modify user_message or bot_message --- .../library/prompt_security/actions.py | 12 +++++++++--- .../library/prompt_security/flows.co | 18 ++++++++++-------- .../library/prompt_security/flows.v1.co | 14 ++++++++------ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index 8921686fe..423991f3a 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -72,14 +72,20 @@ async def protect_text(context: Optional[dict] = None): ps_protect_url, ps_app_id, None, None, context["bot_message"] ) if response["result"]["action"] == "modify": - context["bot_message"] = response["result"]["response"]["modified_text"] + response["result"]["modified_text"] = response["result"]["response"][ + "modified_text" + ] elif context.get("user_message"): response = await ps_protect_api_async( ps_protect_url, ps_app_id, context["user_message"] ) if response["result"]["action"] == "modify": - context["user_message"] = response["result"]["prompt"]["modified_text"] + response["result"]["modified_text"] = response["result"]["prompt"][ + "modified_text" + ] else: raise ValueError(f"No user_message or bot_message in context: {context}") - return response["result"]["action"] == "block" + response["result"]["is_blocked"] = response["result"]["action"] == "block" + response["result"]["is_modified"] = response["result"]["action"] == "modify" + return response["result"] diff --git a/nemoguardrails/library/prompt_security/flows.co b/nemoguardrails/library/prompt_security/flows.co index a8886a545..62cd11449 100644 --- a/nemoguardrails/library/prompt_security/flows.co +++ b/nemoguardrails/library/prompt_security/flows.co @@ -3,11 +3,12 @@ @active flow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $invalid = await protect_text(context=$context) - - if $invalid + $result = await protect_text(context=$context) + if $result["is_blocked"] bot inform answer unknown - abort + stop + else if $result["is_modified"] + $user_message = $result["modified_text"] # OUTPUT RAILS @@ -15,8 +16,9 @@ flow protect prompt @active flow protect response """Check if the response is valid according to Prompt Security.""" - $invalid = await protect_text(context=$context) - - if $invalid + $result = await protect_text(context=$context) + if $result["is_blocked"] bot inform answer unknown - abort \ No newline at end of file + stop + else if $result["is_modified"] + $bot_message = $result["modified_text"] \ No newline at end of file diff --git a/nemoguardrails/library/prompt_security/flows.v1.co b/nemoguardrails/library/prompt_security/flows.v1.co index f56dc635d..484c8e62b 100644 --- a/nemoguardrails/library/prompt_security/flows.v1.co +++ b/nemoguardrails/library/prompt_security/flows.v1.co @@ -2,19 +2,21 @@ define subflow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $invalid = execute protect_text - - if $invalid + $result = execute protect_text + if $result["is_blocked"] bot inform answer unknown stop + else if $result["is_modified"] + $user_message = $result["modified_text"] # OUTPUT RAILS define subflow protect response """Check if the response is valid according to Prompt Security.""" - $invalid = execute protect_text - - if $invalid + $result = execute protect_text + if $result["is_blocked"] bot inform answer unknown stop + else if $result["is_modified"] + $bot_message = $result["modified_text"] From 49762fac33abbdd951ec0034c37a42e16d7039ec Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Thu, 9 Jan 2025 16:20:19 +0200 Subject: [PATCH 5/9] add : --- nemoguardrails/library/prompt_security/actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index 423991f3a..fc1afcaab 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -55,7 +55,7 @@ async def ps_protect_api_async( async def protect_text(context: Optional[dict] = None): """Protects the given user_message or bot_message. - Returns + Returns: True if text should be blocked, False otherwise. """ From 7bd343b4f929128c86f5f8c9a29afdf822bb9716 Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Sat, 18 Jan 2025 17:28:45 +0200 Subject: [PATCH 6/9] resolve pull request comments --- docs/user-guides/community/prompt-security.md | 8 +- .../library/prompt_security/actions.py | 95 +++++++++++++------ .../library/prompt_security/flows.co | 4 +- .../library/prompt_security/flows.v1.co | 4 +- 4 files changed, 75 insertions(+), 36 deletions(-) diff --git a/docs/user-guides/community/prompt-security.md b/docs/user-guides/community/prompt-security.md index d5264d7f5..53ee3b631 100644 --- a/docs/user-guides/community/prompt-security.md +++ b/docs/user-guides/community/prompt-security.md @@ -2,10 +2,14 @@ [Prompt Security AI](https://prompt.security/?utm_medium=github&utm_campaign=nemo-guardrails) allows you to protect LLM interaction. This integration enables NeMo Guardrails to use Prompt Security to protect input and output flows. +You'll need to set the following env variables to work with Prompt Security: + +1. PS_PROTECT_URL - This is the URL of the protect endpoint given by Prompt Security. This will look like https://[REGION].prompt.security/api/protect where REGION is eu, useast or apac +2. PS_APP_ID - This is the application ID given by Prompt Security (similar to an API key). You can get it from admin portal at https://[REGION].prompt.security/ where REGION is eu, useast or apac + ## Setup 1. Ensure that you have access to Prompt Security API server (SaaS or on-prem). - 2. Update your `config.yml` file to include the Private AI settings: ```yaml @@ -18,7 +22,7 @@ rails: - protect response ``` -Set the `PS_PROTECT_URL` and `PS_APP_ID` environment variables. +Don't forget to set the `PS_PROTECT_URL` and `PS_APP_ID` environment variables. ## Usage diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index fc1afcaab..ff663a68b 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -34,7 +34,29 @@ async def ps_protect_api_async( response: Optional[str] = None, user: Optional[str] = None, ): - """Calls Prompt Security Protect API asynchronously.""" + """Calls Prompt Security Protect API asynchronously. + + Args: + ps_protect_url: the URL of the protect endpoint given by Prompt Security. + URL is https://[REGION].prompt.security/api/protect where REGION is eu, useast or apac + + ps_app_id: the application ID given by Prompt Security (similar to an API key). + Get it from the admin portal at https://[REGION].prompt.security/ where REGION is eu, useast or apac + + prompt: the user message to protect. + + system_prompt: the system message for context. + + response: the bot message to protect. + + user: the user ID or username for context. + + Returns: + A dictionary with the following items: + - is_blocked: True if the text should be blocked, False otherwise. + - is_modified: True if the text should be modified, False otherwise. + - modified_text: The modified text if is_modified is True, None otherwise. + """ headers = { "APP-ID": ps_app_id, @@ -47,45 +69,58 @@ async def ps_protect_api_async( "user": user, } async with httpx.AsyncClient() as client: - ret = await client.post(ps_protect_url, headers=headers, json=payload) - return ret.json() + modified_text = None + ps_action = "log" + try: + ret = await client.post(ps_protect_url, headers=headers, json=payload) + res = ret.json() + ps_action = res.get("result", {}).get("action", "log") + if ps_action == "modify": + key = "response" if response else "prompt" + modified_text = res.get("result", {}).get(key, {}).get("modified_text") + except Exception as e: + log.error("Error calling Prompt Security Protect API: %s", e) + return { + "is_blocked": ps_action == "block", + "is_modified": ps_action == "modify", + "modified_text": modified_text, + } @action(is_system_action=True) -async def protect_text(context: Optional[dict] = None): - """Protects the given user_message or bot_message. - +async def protect_text( + user_prompt: Optional[str] = None, bot_response: Optional[str] = None +): + """Protects the given user_prompt or bot_response. + Args: + user_prompt: The user message to protect. + bot_response: The bot message to protect. Returns: - True if text should be blocked, False otherwise. + A dictionary with the following items: + - is_blocked: True if the text should be blocked, False otherwise. + - is_modified: True if the text should be modified, False otherwise. + - modified_text: The modified text if is_modified is True, None otherwise. + Raises: + ValueError is returned in one of the following cases: + 1. If PS_PROTECT_URL env variable is not set. + 2. If PS_APP_ID env variable is not set. + 3. If no user_prompt and no bot_response is provided. """ ps_protect_url = os.getenv("PS_PROTECT_URL") if not ps_protect_url: - raise ValueError("PS_PROTECT_URL env variable required for Prompt Security.") + raise ValueError("PS_PROTECT_URL env variable is required for Prompt Security.") ps_app_id = os.getenv("PS_APP_ID") if not ps_app_id: - raise ValueError("PS_APP_ID env variable required for Prompt Security.") + raise ValueError("PS_APP_ID env variable is required for Prompt Security.") - if context.get("bot_message"): - response = await ps_protect_api_async( - ps_protect_url, ps_app_id, None, None, context["bot_message"] + if bot_response: + return await ps_protect_api_async( + ps_protect_url, ps_app_id, None, None, bot_response ) - if response["result"]["action"] == "modify": - response["result"]["modified_text"] = response["result"]["response"][ - "modified_text" - ] - elif context.get("user_message"): - response = await ps_protect_api_async( - ps_protect_url, ps_app_id, context["user_message"] - ) - if response["result"]["action"] == "modify": - response["result"]["modified_text"] = response["result"]["prompt"][ - "modified_text" - ] - else: - raise ValueError(f"No user_message or bot_message in context: {context}") - - response["result"]["is_blocked"] = response["result"]["action"] == "block" - response["result"]["is_modified"] = response["result"]["action"] == "modify" - return response["result"] + + if user_prompt: + return await ps_protect_api_async(ps_protect_url, ps_app_id, user_prompt) + + raise ValueError("Nither user_message nor bot_message was provided") diff --git a/nemoguardrails/library/prompt_security/flows.co b/nemoguardrails/library/prompt_security/flows.co index 62cd11449..7917f3c74 100644 --- a/nemoguardrails/library/prompt_security/flows.co +++ b/nemoguardrails/library/prompt_security/flows.co @@ -3,7 +3,7 @@ @active flow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $result = await protect_text(context=$context) + $result = await protect_text(user_prompt=$user_message) if $result["is_blocked"] bot inform answer unknown stop @@ -16,7 +16,7 @@ flow protect prompt @active flow protect response """Check if the response is valid according to Prompt Security.""" - $result = await protect_text(context=$context) + $result = await protect_text(bot_response=$bot_message) if $result["is_blocked"] bot inform answer unknown stop diff --git a/nemoguardrails/library/prompt_security/flows.v1.co b/nemoguardrails/library/prompt_security/flows.v1.co index 484c8e62b..04b747d16 100644 --- a/nemoguardrails/library/prompt_security/flows.v1.co +++ b/nemoguardrails/library/prompt_security/flows.v1.co @@ -2,7 +2,7 @@ define subflow protect prompt """Check if the prompt is valid according to Prompt Security.""" - $result = execute protect_text + $result = execute protect_text(user_prompt=$user_message) if $result["is_blocked"] bot inform answer unknown stop @@ -14,7 +14,7 @@ define subflow protect prompt define subflow protect response """Check if the response is valid according to Prompt Security.""" - $result = execute protect_text + $result = execute protect_text(bot_response=$bot_message) if $result["is_blocked"] bot inform answer unknown stop From b908e7467cdfdb3e32c8d3271160ab078482567c Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Thu, 23 Jan 2025 10:05:53 +0200 Subject: [PATCH 7/9] typo --- nemoguardrails/library/prompt_security/actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/library/prompt_security/actions.py b/nemoguardrails/library/prompt_security/actions.py index ff663a68b..b37135eb3 100644 --- a/nemoguardrails/library/prompt_security/actions.py +++ b/nemoguardrails/library/prompt_security/actions.py @@ -123,4 +123,4 @@ async def protect_text( if user_prompt: return await ps_protect_api_async(ps_protect_url, ps_app_id, user_prompt) - raise ValueError("Nither user_message nor bot_message was provided") + raise ValueError("Neither user_message nor bot_message was provided") From 11ebd9d89468a8be08b0547bb9ba69e1f109b257 Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Sat, 25 Jan 2025 17:08:58 +0200 Subject: [PATCH 8/9] fix prompt security pytest --- tests/test_prompt_security.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/test_prompt_security.py b/tests/test_prompt_security.py index aba84b234..3676e4165 100644 --- a/tests/test_prompt_security.py +++ b/tests/test_prompt_security.py @@ -16,21 +16,10 @@ import pytest from nemoguardrails import RailsConfig -from nemoguardrails.actions.actions import ActionResult, action from tests.utils import TestChat -@action() -def retrieve_relevant_chunks(): - context_updates = {"relevant_chunks": "Mock retrieved context."} - - return ActionResult( - return_value=context_updates["relevant_chunks"], - context_updates=context_updates, - ) - - -def mock_protect_text(return_value=True): +def mock_protect_text(return_value): def mock_request(*args, **kwargs): return return_value @@ -61,8 +50,9 @@ def test_prompt_security_protection_disabled(): ], ) - chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_protect_text(True), "protect_text") + chat.app.register_action( + mock_protect_text({"is_blocked": True, "is_modified": False}), "protect_text" + ) chat >> "Hi! I am Mr. John! And my email is test@gmail.com" chat << "Hi! My name is John as well." @@ -98,8 +88,9 @@ def test_prompt_security_protection_input(): ], ) - chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_protect_text(True), "protect_text") + chat.app.register_action( + mock_protect_text({"is_blocked": True, "is_modified": False}), "protect_text" + ) chat >> "Hi! I am Mr. John! And my email is test@gmail.com" chat << "I can't answer that." @@ -135,7 +126,8 @@ def test_prompt_security_protection_output(): ], ) - chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_protect_text(True), "protect_text") + chat.app.register_action( + mock_protect_text({"is_blocked": True, "is_modified": False}), "protect_text" + ) chat >> "Hi!" chat << "I can't answer that." From 73846d472ec2a8b16cf04f6db008c931e0b54da2 Mon Sep 17 00:00:00 2001 From: Lior Drihem Date: Mon, 27 Jan 2025 16:23:12 +0200 Subject: [PATCH 9/9] fix issue found by pre-commit --- .gitignore | 2 +- nemoguardrails/library/prompt_security/flows.co | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 11436c9a6..560b6f5d4 100644 --- a/.gitignore +++ b/.gitignore @@ -64,4 +64,4 @@ docs/**/config firebase.json scratch.py -.env \ No newline at end of file +.env diff --git a/nemoguardrails/library/prompt_security/flows.co b/nemoguardrails/library/prompt_security/flows.co index 7917f3c74..6d5d691dc 100644 --- a/nemoguardrails/library/prompt_security/flows.co +++ b/nemoguardrails/library/prompt_security/flows.co @@ -21,4 +21,4 @@ flow protect response bot inform answer unknown stop else if $result["is_modified"] - $bot_message = $result["modified_text"] \ No newline at end of file + $bot_message = $result["modified_text"]