diff --git a/docs/user-guides/community/privateai.md b/docs/user-guides/community/privateai.md index b305d7d53..226457d00 100644 --- a/docs/user-guides/community/privateai.md +++ b/docs/user-guides/community/privateai.md @@ -1,6 +1,6 @@ # Private AI Integration -[Private AI](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) allows you to detect and mask Personally Identifiable Information (PII) in your data. This integration enables NeMo Guardrails to use Private AI for PII detection in input, output and retrieval flows. +[Private AI](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) allows you to detect and mask Personally Identifiable Information (PII) in your data. This integration enables NeMo Guardrails to use Private AI for PII detection and masking in input, output, and retrieval flows. ## Setup @@ -8,6 +8,8 @@ 2. Update your `config.yml` file to include the Private AI settings: +**PII detection config** + ```yaml rails: config: @@ -31,19 +33,48 @@ rails: - detect pii on output ``` +The detection flow will not let the input/output/retrieval text pass if PII is detected. + +**PII masking config** + +```yaml +rails: + config: + privateai: + server_endpoint: http://your-privateai-api-endpoint/process/text # Replace this with your Private AI process text endpoint + input: + entities: # If no entity is specified here, all supported entities will be detected by default. + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + output: + entities: + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + input: + flows: + - mask pii on input + output: + flows: + - mask pii on output +``` + +The masking flow will mask the PII in the input/output/retrieval text before they are sent to the LLM/user. For example, `Hi John Doe, my email is john.doe@example.com` will be converted to `Hi [NAME], my email is [EMAIL_ADDRESS]`. + Replace `http://your-privateai-api-endpoint/process/text` with your actual Private AI process text endpoint and set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. 3. You can customize the `entities` list under both `input` and `output` to include the PII types you want to detect. A full list of supported entities can be found [here](https://docs.private-ai.com/entities/?utm_medium=github&utm_campaign=nemo-guardrails). ## Usage -Once configured, the Private AI integration will automatically: +Once configured, the Private AI integration can automatically: -1. Detect PII in user inputs before they are processed by the LLM. -2. Detect PII in LLM outputs before they are sent back to the user. -3. Detect PII in retrieved chunks before they are sent to the LLM. +1. Detect or mask PII in user inputs before they are processed by the LLM. +2. Detect or mask PII in LLM outputs before they are sent back to the user. +3. Detect or mask PII in retrieved chunks before they are sent to the LLM. -The `detect_pii` action in `nemoguardrails/library/privateai/actions.py` handles the PII detection process. +The `detect_pii` and `mask_pii` actions in `nemoguardrails/library/privateai/actions.py` handle the PII detection and masking processes, respectively. ## Customization @@ -56,6 +87,6 @@ If the Private AI detection API request fails, the system will assume PII is pre ## Notes - Ensure that your Private AI process text endpoint is properly set up and accessible from your NeMo Guardrails environment. -- The integration currently supports PII detection only. +- The integration currently supports PII detection and masking. For more information on Private AI and its capabilities, please refer to the [Private AI documentation](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails). diff --git a/docs/user-guides/guardrails-library.md b/docs/user-guides/guardrails-library.md index 14b84fe44..051c2bafc 100644 --- a/docs/user-guides/guardrails-library.md +++ b/docs/user-guides/guardrails-library.md @@ -694,9 +694,9 @@ For more details, check out the [GCP Text Moderation](https://github.com/NVIDIA/ ### Private AI PII Detection -NeMo Guardrails supports using [Private AI API](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) for PII detection in input, output and retrieval flows. +NeMo Guardrails supports using [Private AI API](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) for PII detection and masking input, output and retrieval flows. -To activate the PII detection, you need specify `server_endpoint`, and the entities that you want to detect. You'll also need to set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. +To activate the PII detection or masking, you need specify `server_endpoint`, and the entities that you want to detect or mask. You'll also need to set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. ```yaml rails: @@ -717,6 +717,8 @@ rails: #### Example usage +**PII detection** + ```yaml rails: input: @@ -730,44 +732,19 @@ rails: - detect pii on retrieval ``` -For more details, check out the [Private AI Integration](https://github.com/NVIDIA/NeMo-Guardrails/blob/develop/docs/user-guides/community/privateai.md) page. - -### Private AI PII Detection - -NeMo Guardrails supports using [Private AI API](https://docs.private-ai.com/?utm_medium=github&utm_campaign=nemo-guardrails) for PII detection in input, output and retrieval flows. - -To activate the PII detection, you need specify `server_endpoint`, and the entities that you want to detect. You'll also need to set the `PAI_API_KEY` environment variable if you're using the Private AI cloud API. - -```yaml -rails: - config: - privateai: - server_endpoint: http://your-privateai-api-endpoint/process/text # Replace this with your Private AI process text endpoint - input: - entities: # If no entity is specified here, all supported entities will be detected by default. - - NAME_FAMILY - - EMAIL_ADDRESS - ... - output: - entities: - - NAME_FAMILY - - EMAIL_ADDRESS - ... -``` - -#### Example usage +**PII masking** ```yaml rails: input: flows: - - detect pii on input + - mask pii on input output: flows: - - detect pii on output + - mask pii on output retrieval: flows: - - detect pii on retrieval + - mask pii on retrieval ``` For more details, check out the [Private AI Integration](./community/privateai.md) page. diff --git a/examples/configs/privateai/pii_masking/config.yml b/examples/configs/privateai/pii_masking/config.yml new file mode 100644 index 000000000..9c8b76c6d --- /dev/null +++ b/examples/configs/privateai/pii_masking/config.yml @@ -0,0 +1,26 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + +rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + input: + entities: + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + output: + entities: # If no entity is specified here, all supported entities will be masked by default. + - NAME_FAMILY + - LOCATION_ADDRESS_STREET + - EMAIL_ADDRESS + input: + flows: + - mask pii on input + + output: + flows: + - mask pii on output diff --git a/examples/notebooks/privateai_pii_detection.ipynb b/examples/notebooks/privateai_pii_detection.ipynb index 146541a80..5f2b5e412 100644 --- a/examples/notebooks/privateai_pii_detection.ipynb +++ b/examples/notebooks/privateai_pii_detection.ipynb @@ -6,19 +6,26 @@ "source": [ "# Private AI PII detection example\n", "\n", - "This notebook shows how to use Private AI for PII detection in NeMo Guardrails." + "This notebook shows how to use Private AI for PII detection and PII masking in NeMo Guardrails." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Import libraries" + "## PII Detection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import libraries" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +49,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create rails with Private AI PII detection\n", + "### Create rails with Private AI PII detection\n", "\n", "For this step you'll need your OpenAI API key & Private AI API key.\n", "\n", @@ -98,7 +105,123 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Input rails" + "### Input rails" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = rails.generate(messages=[{\"role\": \"user\", \"content\": \"Hello! I'm John. My email id is text@gmail.com. I live in California, USA.\"}])\n", + "\n", + "info = rails.explain()\n", + "\n", + "print(\"Response\")\n", + "print(\"----------------------------------------\")\n", + "print(response[\"content\"])\n", + "\n", + "\n", + "print(\"\\n\\nColang history\")\n", + "print(\"----------------------------------------\")\n", + "print(info.colang_history)\n", + "\n", + "print(\"\\n\\nLLM calls summary\")\n", + "print(\"----------------------------------------\")\n", + "info.print_llm_calls_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Output rails" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = rails.generate(messages=[{\"role\": \"user\", \"content\": \"give me a sample email id\"}])\n", + "\n", + "info = rails.explain()\n", + "\n", + "print(\"Response\")\n", + "print(\"----------------------------------------\\n\\n\")\n", + "print(response[\"content\"])\n", + "\n", + "\n", + "print(\"\\n\\nColang history\")\n", + "print(\"----------------------------------------\")\n", + "print(info.colang_history)\n", + "\n", + "print(\"\\n\\nLLM calls summary\")\n", + "print(\"----------------------------------------\")\n", + "info.print_llm_calls_summary()\n", + "\n", + "\n", + "print(\"\\n\\nCompletions where PII was detected!\")\n", + "print(\"----------------------------------------\")\n", + "print(info.llm_calls[0].completion)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## PII Masking\n", + "\n", + "Note: This example uses ollama model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Input rails" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"PAI_API_KEY\"] = \"YOUR PRIVATE AI API KEY\" # Visit https://portal.private-ai.com to get your API key\n", + "\n", + "YAML_CONFIG = \"\"\"\n", + "\n", + "\n", + "models:\n", + " - type: main\n", + " engine: ollama\n", + " model: llama3.2\n", + " parameters:\n", + " base_url: http://localhost:11434\n", + "\n", + "rails:\n", + " config:\n", + " privateai:\n", + " server_endpoint: https://api.private-ai.com/cloud/v3/process/text\n", + " input:\n", + " entities:\n", + " - LOCATION\n", + " - EMAIL_ADDRESS\n", + " input:\n", + " flows:\n", + " - mask pii on input\n", + "\"\"\"\n", + "\n", + "\n", + "\n", + "config = RailsConfig.from_content(yaml_content=YAML_CONFIG)\n", + "rails = LLMRails(config)" ] }, { @@ -129,7 +252,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Output rails" + "### Output rails" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"PAI_API_KEY\"] = \"YOUR PRIVATE AI API KEY\" # Visit https://portal.private-ai.com to get your API key\n", + "\n", + "YAML_CONFIG = \"\"\"\n", + "\n", + "\n", + "models:\n", + " - type: main\n", + " engine: ollama\n", + " model: llama3.2\n", + " parameters:\n", + " base_url: http://localhost:11434\n", + "\n", + "rails:\n", + " config:\n", + " privateai:\n", + " server_endpoint: https://api.private-ai.com/cloud/v3/process/text\n", + " output:\n", + " entities:\n", + " - LOCATION\n", + " - EMAIL_ADDRESS\n", + " output:\n", + " flows:\n", + " - mask pii on output\n", + "\"\"\"\n", + "\n", + "\n", + "\n", + "config = RailsConfig.from_content(yaml_content=YAML_CONFIG)\n", + "rails = LLMRails(config)" ] }, { @@ -160,6 +323,13 @@ "print(\"----------------------------------------\")\n", "print(info.llm_calls[0].completion)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py index 1c31efc2a..bb021357e 100644 --- a/nemoguardrails/library/privateai/actions.py +++ b/nemoguardrails/library/privateai/actions.py @@ -21,7 +21,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.actions import action -from nemoguardrails.library.privateai.request import private_ai_detection_request +from nemoguardrails.library.privateai.request import private_ai_request from nemoguardrails.rails.llm.config import PrivateAIDetection log = logging.getLogger(__name__) @@ -38,8 +38,10 @@ async def detect_pii(source: str, text: str, config: RailsConfig): Returns True if PII is detected, False otherwise. - """ + Raises: + ValueError: If PAI_API_KEY is missing when using cloud API or if the response is invalid. + """ pai_config: PrivateAIDetection = getattr(config.rails.config, "privateai") pai_api_key = os.environ.get("PAI_API_KEY") server_endpoint = pai_config.server_endpoint @@ -58,11 +60,66 @@ async def detect_pii(source: str, text: str, config: RailsConfig): f"The current flow, '{source}', is not allowed." ) - entity_detected = await private_ai_detection_request( + private_ai_response = await private_ai_request( text, enabled_entities, server_endpoint, pai_api_key, ) + try: + entity_detected = any(res["entities_present"] for res in private_ai_response) + except (KeyError, TypeError) as e: + raise ValueError(f"Invalid response from Private AI service: {str(e)}") return entity_detected + + +@action(is_system_action=True) +async def mask_pii(source: str, text: str, config: RailsConfig): + """Masks any detected PII in the provided text. + + Args: + source (str): The source for the text, i.e. "input", "output", "retrieval". + text (str): The text to check. + config (RailsConfig): The rails configuration object. + + Returns: + str: The altered text with PII masked. + + Raises: + ValueError: If PAI_API_KEY is missing when using cloud API or if the response is invalid. + """ + pai_config: PrivateAIDetection = getattr(config.rails.config, "privateai") + pai_api_key = os.environ.get("PAI_API_KEY") + server_endpoint = pai_config.server_endpoint + enabled_entities = getattr(pai_config, source).entities + + parsed_url = urlparse(server_endpoint) + if parsed_url.hostname == "api.private-ai.com" and not pai_api_key: + raise ValueError( + "PAI_API_KEY environment variable required for Private AI cloud API." + ) + + valid_sources = ["input", "output", "retrieval"] + if source not in valid_sources: + raise ValueError( + f"Private AI can only be defined in the following flows: {valid_sources}. " + f"The current flow, '{source}', is not allowed." + ) + + private_ai_response = await private_ai_request( + text, + enabled_entities, + server_endpoint, + pai_api_key, + ) + + if not private_ai_response or not isinstance(private_ai_response, list): + raise ValueError( + "Invalid response received from Private AI service. The response is not a list." + ) + + try: + return private_ai_response[0]["processed_text"] + except (IndexError, KeyError) as e: + raise ValueError(f"Invalid response from Private AI service: {str(e)}") diff --git a/nemoguardrails/library/privateai/flows.co b/nemoguardrails/library/privateai/flows.co index 04465deba..c3cf1148f 100644 --- a/nemoguardrails/library/privateai/flows.co +++ b/nemoguardrails/library/privateai/flows.co @@ -1,3 +1,5 @@ +#### PII DETECTION RAILS #### + # INPUT RAILS @active @@ -32,3 +34,35 @@ flow detect pii on retrieval if $has_pii bot inform answer unknown abort + + +####################################################### + + +#### PII MASKING RAILS #### + +# INPUT RAILS + +@active +flow mask pii on input + """Mask any detected PII in the user input.""" + $masked_input = await MaskPiiAction(source="input", text=$user_message) + + global $user_message + $user_message = $masked_input + + +# OUTPUT RAILS + +@active +flow mask pii on output + """Mask any detected PII in the bot output.""" + $bot_message = await MaskPiiAction(source="output", text=$bot_message) + + +# RETRIVAL RAILS + +@active +flow mask pii on retrieval + """Mask any detected PII in the relevant chunks from the knowledge base.""" + $relevant_chunks = await MaskPiiAction(source="retrieval", text=$relevant_chunks) diff --git a/nemoguardrails/library/privateai/flows.v1.co b/nemoguardrails/library/privateai/flows.v1.co index a7e4fca55..02884da1c 100644 --- a/nemoguardrails/library/privateai/flows.v1.co +++ b/nemoguardrails/library/privateai/flows.v1.co @@ -1,3 +1,5 @@ +#### PII DETECTION RAILS #### + # INPUT RAILS define subflow detect pii on input @@ -29,3 +31,31 @@ define subflow detect pii on retrieval if $has_pii bot inform answer unknown stop + + +####################################################### + + +#### PII MASKING RAILS #### + +# INPUT RAILS + +define subflow mask pii on input + """Mask any detected PII in the user input.""" + $masked_input = execute mask_pii(source="input", text=$user_message) + + $user_message = $masked_input + + +# OUTPUT RAILS + +define subflow mask pii on output + """Mask any detected PII in the bot output.""" + $bot_message = execute mask_pii(source="output", text=$bot_message) + + +# RETRIVAL RAILS + +define subflow mask pii on retrieval + """Mask any detected PII in the relevant chunks from the knowledge base.""" + $relevant_chunks = execute mask_pii(source="retrieval", text=$relevant_chunks) diff --git a/nemoguardrails/library/privateai/request.py b/nemoguardrails/library/privateai/request.py index dfa586bee..99571a7b6 100644 --- a/nemoguardrails/library/privateai/request.py +++ b/nemoguardrails/library/privateai/request.py @@ -15,7 +15,6 @@ """Module for handling Private AI detection requests.""" -import json import logging from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -25,14 +24,13 @@ log = logging.getLogger(__name__) -async def private_ai_detection_request( +async def private_ai_request( text: str, enabled_entities: List[str], server_endpoint: str, api_key: Optional[str] = None, ): - """ - Send a detection request to the Private AI API. + """Send a PII detection request to the Private AI API. Args: text: The text to analyze. @@ -41,7 +39,12 @@ async def private_ai_detection_request( api_key: The API key for the Private AI service. Returns: - True if PII is detected, False otherwise. + The response from the Private AI API. See Private AI API reference for more details: + https://docs.private-ai.com/reference/latest/operation/process_text_process_text_post/ + + Raises: + ValueError: If api_key is missing for cloud API, if the API call fails, + or if the response cannot be parsed as JSON. """ parsed_url = urlparse(server_endpoint) if parsed_url.hostname == "api.private-ai.com" and not api_key: @@ -73,6 +76,10 @@ async def private_ai_detection_request( f"Details: {await resp.text()}" ) - result = await resp.json() - - return any(res["entities_present"] for res in result) + try: + return await resp.json() + except aiohttp.ContentTypeError: + raise ValueError( + f"Failed to parse Private AI response as JSON. Status: {resp.status}, " + f"Content: {await resp.text()}" + ) diff --git a/tests/test_privateai.py b/tests/test_privateai.py index 4d127d6b3..4e147b91f 100644 --- a/tests/test_privateai.py +++ b/tests/test_privateai.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest from nemoguardrails import RailsConfig from nemoguardrails.actions.actions import ActionResult, action from tests.utils import TestChat +PAI_API_KEY_PRESENT = os.getenv("PAI_API_KEY") is not None + @action() def retrieve_relevant_chunks(): @@ -30,13 +34,9 @@ def retrieve_relevant_chunks(): ) -def mock_detect_pii(return_value=True): - def mock_request(*args, **kwargs): - return return_value - - return mock_request - - +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) @pytest.mark.unit def test_privateai_pii_detection_no_active_pii_detection(): config = RailsConfig.from_content( @@ -69,11 +69,13 @@ def test_privateai_pii_detection_no_active_pii_detection(): ) chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_detect_pii(True), "detect_pii") chat >> "Hi! I am Mr. John! And my email is test@gmail.com" chat << "Hi! My name is John as well." +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) @pytest.mark.unit def test_privateai_pii_detection_input(): config = RailsConfig.from_content( @@ -113,11 +115,13 @@ def test_privateai_pii_detection_input(): ) chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_detect_pii(True), "detect_pii") chat >> "Hi! I am Mr. John! And my email is test@gmail.com" chat << "I can't answer that." +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) @pytest.mark.unit def test_privateai_pii_detection_output(): config = RailsConfig.from_content( @@ -157,7 +161,6 @@ def test_privateai_pii_detection_output(): ) chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_detect_pii(True), "detect_pii") chat >> "Hi!" chat << "I can't answer that." @@ -203,17 +206,19 @@ def test_privateai_pii_detection_retrieval_with_pii(): ) chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_detect_pii(True), "detect_pii") # When the relevant_chunks has_pii, a bot intent will get invoked via (bot inform answer unknown), which in turn # will invoke retrieve_relevant_chunks action. - # With a mocked retrieve_relevant_chunks always returning something & mocked detect_pii always returning True, + # With a mocked retrieve_relevant_chunks always returning something & detect_pii always returning True, # the process goes in an infinite loop and raises an Exception: Too many events. with pytest.raises(Exception, match="Too many events."): chat >> "Hi!" chat << "I can't answer that." +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) @pytest.mark.unit def test_privateai_pii_detection_retrieval_with_no_pii(): config = RailsConfig.from_content( @@ -253,7 +258,178 @@ def test_privateai_pii_detection_retrieval_with_no_pii(): ) chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") - chat.app.register_action(mock_detect_pii(False), "detect_pii") chat >> "Hi!" chat << "Hi! My name is John as well." + + +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) +@pytest.mark.unit +def test_privateai_pii_masking_on_output(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + output: + entities: + - EMAIL_ADDRESS + - NAME + output: + flows: + - mask pii on output + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! I am John.', + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + chat >> "Hi!" + chat << "Hi! I am [NAME_1]." + + +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) +@pytest.mark.unit +def test_privateai_pii_masking_on_input(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + input: + entities: + - EMAIL_ADDRESS + - NAME + input: + flows: + - mask pii on input + - check user message + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + + define flow check user message + execute check_user_message(user_message=$user_message) + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ' "Hi! I am John.', + ], + ) + + @action() + def check_user_message(user_message: str): + """Check if the user message is converted to the expected message with masked PII.""" + assert user_message == "Hi there! Are you [NAME_1]?" + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + chat.app.register_action(check_user_message, "check_user_message") + + chat >> "Hi there! Are you John?" + chat << "Hi! I am John." + + +@pytest.mark.skipif( + not PAI_API_KEY_PRESENT, reason="Private AI API key is not present." +) +@pytest.mark.unit +def test_privateai_pii_masking_on_retrieval(): + config = RailsConfig.from_content( + yaml_content=""" + models: [] + rails: + config: + privateai: + server_endpoint: https://api.private-ai.com/cloud/v3/process/text + retrieval: + entities: + - EMAIL_ADDRESS + - NAME + retrieval: + flows: + - mask pii on retrieval + - check relevant chunks + """, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot inform answer unknown + "I can't answer that." + + define flow check relevant chunks + execute check_relevant_chunks(relevant_chunks=$relevant_chunks) + """, + ) + + chat = TestChat( + config, + llm_completions=[ + " express greeting", + " Sorry, I don't have that in my knowledge base.", + ], + ) + + @action() + def check_relevant_chunks(relevant_chunks: str): + """Check if the relevant chunks is converted to the expected message with masked PII.""" + assert relevant_chunks == "[NAME_1]'s Email: [EMAIL_ADDRESS_1]" + + @action() + def retrieve_relevant_chunk_for_masking(): + # Mock retrieval of relevant chunks with PII + context_updates = {"relevant_chunks": "John's Email: john@email.com"} + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + chat.app.register_action( + retrieve_relevant_chunk_for_masking, "retrieve_relevant_chunks" + ) + chat.app.register_action(check_relevant_chunks) + + chat >> "Hey! Can you help me get John's email?" + chat << "Sorry, I don't have that in my knowledge base."