From 682905abc2bdafb2456f0745278136aadb778690 Mon Sep 17 00:00:00 2001 From: lucarlig Date: Fri, 8 May 2026 14:43:32 +0100 Subject: [PATCH] Fix PII result copy regression Signed-off-by: lucarlig --- Cargo.lock | 2 +- .../rust/python-package/pii_filter/Cargo.toml | 2 +- .../cpex_pii_filter/plugin-manifest.yaml | 2 +- .../python-package/pii_filter/src/detector.rs | 73 ++++++++++++++++--- plugins/tests/conftest.py | 9 +++ plugins/tests/pii_filter/test_integration.py | 55 ++++++++++++++ plugins/tests/plugin_hooks.py | 61 +++++++++++++++- 7 files changed, 191 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0989f67..472d2d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1031,7 +1031,7 @@ dependencies = [ [[package]] name = "pii_filter" -version = "0.3.1" +version = "0.3.2" dependencies = [ "cpex_framework_bridge", "criterion", diff --git a/plugins/rust/python-package/pii_filter/Cargo.toml b/plugins/rust/python-package/pii_filter/Cargo.toml index 2bd2e4c..b2cdedf 100644 --- a/plugins/rust/python-package/pii_filter/Cargo.toml +++ b/plugins/rust/python-package/pii_filter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pii_filter" -version = "0.3.1" +version = "0.3.2" edition.workspace = true authors.workspace = true license.workspace = true diff --git a/plugins/rust/python-package/pii_filter/cpex_pii_filter/plugin-manifest.yaml b/plugins/rust/python-package/pii_filter/cpex_pii_filter/plugin-manifest.yaml index 4498941..7e563e9 100644 --- a/plugins/rust/python-package/pii_filter/cpex_pii_filter/plugin-manifest.yaml +++ b/plugins/rust/python-package/pii_filter/cpex_pii_filter/plugin-manifest.yaml @@ -1,6 +1,6 @@ description: "Rust-backed PII detection and masking for prompt arguments, tool inputs, and tool outputs" author: "ContextForge Contributors" -version: "0.3.1" +version: "0.3.2" kind: "cpex_pii_filter.pii_filter.PIIFilterPlugin" available_hooks: - "prompt_pre_fetch" diff --git a/plugins/rust/python-package/pii_filter/src/detector.rs b/plugins/rust/python-package/pii_filter/src/detector.rs index c16eb08..1718688 100644 --- a/plugins/rust/python-package/pii_filter/src/detector.rs +++ b/plugins/rust/python-package/pii_filter/src/detector.rs @@ -5,7 +5,7 @@ use log::{debug, warn}; use pyo3::prelude::*; -use pyo3::types::{PyAny, PyDict, PyList, PySet, PyString, PyTuple}; +use pyo3::types::{PyAny, PyDict, PyList, PyMapping, PySet, PyString, PyTuple}; use pyo3_stub_gen::derive::*; use std::collections::HashMap; @@ -357,16 +357,17 @@ impl PIIDetectorRust { } } - // Handle dictionaries - if let Ok(dict) = data.cast::() { - let mut entries: Vec<(Py, Py)> = Vec::with_capacity(dict.len()); + // Handle mappings through the Python protocol. CPEX isolation wraps + // dicts in copy-on-write dict subclasses whose visible entries are not + // stored in the underlying PyDict table. + if let Ok(mapping) = data.cast::() { + let mapping_len = mapping.len()?; + let mut entries: Vec<(Py, Py)> = Vec::with_capacity(mapping_len); let mut all_detections = HashMap::new(); - if dict.len() > self.config.max_collection_items { + if mapping_len > self.config.max_collection_items { warn!( "Rejected nested mapping at path '{}' because size {} exceeds max {}", - path, - dict.len(), - self.config.max_collection_items + path, mapping_len, self.config.max_collection_items ); return Err(PyErr::new::(format!( "Nested mapping exceeds maximum size of {} items", @@ -374,7 +375,10 @@ impl PIIDetectorRust { ))); } - for (key, value) in dict.iter() { + for item in mapping.items()?.iter() { + let item = item.cast::()?; + let key = item.get_item(0)?; + let value = item.get_item(1)?; let key_str = key.str()?.to_string_lossy().into_owned(); let new_path = if path.is_empty() { key_str.clone() @@ -1659,6 +1663,57 @@ class ConfigModel: }); } + #[test] + fn test_process_nested_mapping_allows_collection_limit_boundary() { + Python::initialize(); + Python::attach(|py| { + let config = PyDict::new(py); + config.set_item("detect_email", true).unwrap(); + config.set_item("max_collection_items", 1).unwrap(); + + let detector = PIIDetectorRust::new(&config.into_any()).unwrap(); + let data = PyDict::new(py); + data.set_item("email", "alice@example.com").unwrap(); + + let (modified, new_data, _) = + detector.process_nested(py, &data.into_any(), "").unwrap(); + + assert!(modified); + assert_eq!( + new_data + .bind(py) + .cast::() + .unwrap() + .get_item("email") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + "[REDACTED]" + ); + }); + } + + #[test] + fn test_process_nested_mapping_rejects_over_collection_limit() { + Python::initialize(); + Python::attach(|py| { + let config = PyDict::new(py); + config.set_item("detect_email", true).unwrap(); + config.set_item("max_collection_items", 1).unwrap(); + + let detector = PIIDetectorRust::new(&config.into_any()).unwrap(); + let data = PyDict::new(py); + data.set_item("first", "alice@example.com").unwrap(); + data.set_item("second", "bob@example.com").unwrap(); + + let err = detector + .process_nested(py, &data.into_any(), "") + .unwrap_err(); + assert!(err.is_instance_of::(py)); + }); + } + #[test] fn test_detects_plus_prefixed_international_phone_number() { let config = PIIConfig { diff --git a/plugins/tests/conftest.py b/plugins/tests/conftest.py index b768482..5bafa69 100644 --- a/plugins/tests/conftest.py +++ b/plugins/tests/conftest.py @@ -40,9 +40,18 @@ cpex = types.ModuleType("cpex") framework = types.ModuleType("cpex.framework") +hooks = types.ModuleType("cpex.framework.hooks") +policies = types.ModuleType("cpex.framework.hooks.policies") +memory = types.ModuleType("cpex.framework.memory") framework.__dict__.update(plugin_hooks.__dict__) +policies.HookPayloadPolicy = plugin_hooks.HookPayloadPolicy +policies.apply_policy = plugin_hooks.apply_policy +memory.wrap_payload_for_isolation = plugin_hooks.wrap_payload_for_isolation sys.modules["cpex"] = cpex sys.modules["cpex.framework"] = framework +sys.modules["cpex.framework.hooks"] = hooks +sys.modules["cpex.framework.hooks.policies"] = policies +sys.modules["cpex.framework.memory"] = memory sys.modules["cpex.framework.models"] = plugin_hooks sys.modules["cpex.framework.settings"] = plugin_hooks diff --git a/plugins/tests/pii_filter/test_integration.py b/plugins/tests/pii_filter/test_integration.py index 8784e76..201974e 100644 --- a/plugins/tests/pii_filter/test_integration.py +++ b/plugins/tests/pii_filter/test_integration.py @@ -13,6 +13,8 @@ ToolPostInvokePayload, ToolPreInvokePayload, ) +from cpex.framework.hooks.policies import HookPayloadPolicy, apply_policy +from cpex.framework.memory import wrap_payload_for_isolation from cpex.framework.models import GlobalContext from cpex_pii_filter.pii_filter import PIIDetectorRust, PIIFilterPlugin @@ -314,6 +316,59 @@ async def test_tool_post_invoke_returns_copied_payload_for_frozen_models(): assert result.modified_payload.result["contact"] == "[REDACTED]" +@pytest.mark.asyncio +async def test_tool_post_invoke_returns_new_nested_result_for_mcp_content(): + plugin = PIIFilterPlugin(_make_config()) + payload = ToolPostInvokePayload( + name="search", + result={ + "content": [ + { + "type": "text", + "text": "Contact alice@example.com", + } + ], + "isError": False, + }, + ) + + result = await plugin.tool_post_invoke(payload, _make_context()) + + assert result.modified_payload is not None + assert result.modified_payload is not payload + assert result.modified_payload.result is not payload.result + assert result.modified_payload.result["content"] is not payload.result["content"] + assert result.modified_payload.result["content"][0] is not payload.result["content"][0] + assert payload.result["content"][0]["text"] == "Contact alice@example.com" + assert result.modified_payload.result["content"][0]["text"] == "Contact [REDACTED]" + + +@pytest.mark.asyncio +async def test_tool_post_invoke_survives_cpex_policy_with_isolated_payload(): + plugin = PIIFilterPlugin(_make_config()) + payload = ToolPostInvokePayload( + name="search", + result={ + "content": [{"type": "text", "text": "Contact alice@example.com"}], + "isError": False, + }, + ) + plugin_input = wrap_payload_for_isolation(payload) + + result = await plugin.tool_post_invoke(plugin_input, _make_context()) + + assert result.modified_payload is not None + filtered = apply_policy( + plugin_input, + result.modified_payload, + HookPayloadPolicy(writable_fields=frozenset({"result"})), + apply_to=payload, + ) + assert filtered is not None + assert payload.result["content"][0]["text"] == "Contact alice@example.com" + assert filtered.result["content"][0]["text"] == "Contact [REDACTED]" + + @pytest.mark.asyncio async def test_tool_post_invoke_blocks_when_configured(): plugin = PIIFilterPlugin(_make_config(block_on_detection=True)) diff --git a/plugins/tests/plugin_hooks.py b/plugins/tests/plugin_hooks.py index a078240..d0991ce 100644 --- a/plugins/tests/plugin_hooks.py +++ b/plugins/tests/plugin_hooks.py @@ -3,11 +3,70 @@ from __future__ import annotations import importlib -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields, is_dataclass from enum import Enum from typing import Any +@dataclass(frozen=True) +class HookPayloadPolicy: + writable_fields: frozenset[str] + + +class CopyOnWriteDict(dict): + def __init__(self, original: dict[str, Any]) -> None: + super().__init__() + self._original = original + + def __getitem__(self, key: Any) -> Any: + return super().__getitem__(key) if key in self else self._original[key] + + def __iter__(self): + return iter(self._original) + + def __len__(self) -> int: + return len(self._original) + + def items(self): + return ((key, self[key]) for key in self) + + def copy(self) -> dict: + return dict(self.items()) + + +def wrap_payload_for_isolation(payload: Any) -> Any: + if not is_dataclass(payload): + return payload + updates = {} + for item in fields(payload): + value = getattr(payload, item.name) + updates[item.name] = CopyOnWriteDict(value) if isinstance(value, dict) else value + return type(payload)(**updates) + + +def apply_policy( + original: Any, + modified: Any, + policy: HookPayloadPolicy, + *, + apply_to: Any | None = None, +) -> Any | None: + target = apply_to if apply_to is not None else original + updates = {} + for item in fields(modified): + old_value = getattr(original, item.name) + new_value = getattr(modified, item.name) + if new_value == old_value: + continue + if item.name in policy.writable_fields: + updates[item.name] = new_value + if not updates: + return None + values = {item.name: getattr(target, item.name) for item in fields(target)} + values.update(updates) + return type(target)(**values) + + class PromptHookType(str, Enum): PROMPT_PRE_FETCH = "prompt_pre_fetch" PROMPT_POST_FETCH = "prompt_post_fetch"