Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/rust/python-package/pii_filter/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pii_filter"
version = "0.3.1"
version = "0.3.2"
edition.workspace = true
authors.workspace = true
license.workspace = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
description: "Rust-backed PII detection and masking for prompt arguments, tool inputs, and tool outputs"
author: "ContextForge Contributors"
version: "0.3.1"
version: "0.3.2"
kind: "cpex_pii_filter.pii_filter.PIIFilterPlugin"
available_hooks:
- "prompt_pre_fetch"
Expand Down
73 changes: 64 additions & 9 deletions plugins/rust/python-package/pii_filter/src/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use log::{debug, warn};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyList, PySet, PyString, PyTuple};
use pyo3::types::{PyAny, PyDict, PyList, PyMapping, PySet, PyString, PyTuple};
use pyo3_stub_gen::derive::*;
use std::collections::HashMap;

Expand Down Expand Up @@ -357,24 +357,28 @@ impl PIIDetectorRust {
}
}

// Handle dictionaries
if let Ok(dict) = data.cast::<PyDict>() {
let mut entries: Vec<(Py<PyAny>, Py<PyAny>)> = Vec::with_capacity(dict.len());
// Handle mappings through the Python protocol. CPEX isolation wraps
// dicts in copy-on-write dict subclasses whose visible entries are not
// stored in the underlying PyDict table.
if let Ok(mapping) = data.cast::<PyMapping>() {
let mapping_len = mapping.len()?;
let mut entries: Vec<(Py<PyAny>, Py<PyAny>)> = Vec::with_capacity(mapping_len);
let mut all_detections = HashMap::new();
if dict.len() > self.config.max_collection_items {
if mapping_len > self.config.max_collection_items {
warn!(
"Rejected nested mapping at path '{}' because size {} exceeds max {}",
path,
dict.len(),
self.config.max_collection_items
path, mapping_len, self.config.max_collection_items
);
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Nested mapping exceeds maximum size of {} items",
self.config.max_collection_items
)));
}

for (key, value) in dict.iter() {
for item in mapping.items()?.iter() {
let item = item.cast::<PyTuple>()?;
let key = item.get_item(0)?;
let value = item.get_item(1)?;
let key_str = key.str()?.to_string_lossy().into_owned();
let new_path = if path.is_empty() {
key_str.clone()
Expand Down Expand Up @@ -1659,6 +1663,57 @@ class ConfigModel:
});
}

#[test]
fn test_process_nested_mapping_allows_collection_limit_boundary() {
Python::initialize();
Python::attach(|py| {
let config = PyDict::new(py);
config.set_item("detect_email", true).unwrap();
config.set_item("max_collection_items", 1).unwrap();

let detector = PIIDetectorRust::new(&config.into_any()).unwrap();
let data = PyDict::new(py);
data.set_item("email", "alice@example.com").unwrap();

let (modified, new_data, _) =
detector.process_nested(py, &data.into_any(), "").unwrap();

assert!(modified);
assert_eq!(
new_data
.bind(py)
.cast::<PyDict>()
.unwrap()
.get_item("email")
.unwrap()
.unwrap()
.extract::<String>()
.unwrap(),
"[REDACTED]"
);
});
}

#[test]
fn test_process_nested_mapping_rejects_over_collection_limit() {
Python::initialize();
Python::attach(|py| {
let config = PyDict::new(py);
config.set_item("detect_email", true).unwrap();
config.set_item("max_collection_items", 1).unwrap();

let detector = PIIDetectorRust::new(&config.into_any()).unwrap();
let data = PyDict::new(py);
data.set_item("first", "alice@example.com").unwrap();
data.set_item("second", "bob@example.com").unwrap();

let err = detector
.process_nested(py, &data.into_any(), "")
.unwrap_err();
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
});
}

#[test]
fn test_detects_plus_prefixed_international_phone_number() {
let config = PIIConfig {
Expand Down
9 changes: 9 additions & 0 deletions plugins/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,18 @@

cpex = types.ModuleType("cpex")
framework = types.ModuleType("cpex.framework")
hooks = types.ModuleType("cpex.framework.hooks")
policies = types.ModuleType("cpex.framework.hooks.policies")
memory = types.ModuleType("cpex.framework.memory")

framework.__dict__.update(plugin_hooks.__dict__)
policies.HookPayloadPolicy = plugin_hooks.HookPayloadPolicy
policies.apply_policy = plugin_hooks.apply_policy
memory.wrap_payload_for_isolation = plugin_hooks.wrap_payload_for_isolation
sys.modules["cpex"] = cpex
sys.modules["cpex.framework"] = framework
sys.modules["cpex.framework.hooks"] = hooks
sys.modules["cpex.framework.hooks.policies"] = policies
sys.modules["cpex.framework.memory"] = memory
sys.modules["cpex.framework.models"] = plugin_hooks
sys.modules["cpex.framework.settings"] = plugin_hooks
55 changes: 55 additions & 0 deletions plugins/tests/pii_filter/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
ToolPostInvokePayload,
ToolPreInvokePayload,
)
from cpex.framework.hooks.policies import HookPayloadPolicy, apply_policy
from cpex.framework.memory import wrap_payload_for_isolation
from cpex.framework.models import GlobalContext

from cpex_pii_filter.pii_filter import PIIDetectorRust, PIIFilterPlugin
Expand Down Expand Up @@ -314,6 +316,59 @@ async def test_tool_post_invoke_returns_copied_payload_for_frozen_models():
assert result.modified_payload.result["contact"] == "[REDACTED]"


@pytest.mark.asyncio
async def test_tool_post_invoke_returns_new_nested_result_for_mcp_content():
plugin = PIIFilterPlugin(_make_config())
payload = ToolPostInvokePayload(
name="search",
result={
"content": [
{
"type": "text",
"text": "Contact alice@example.com",
}
],
"isError": False,
},
)

result = await plugin.tool_post_invoke(payload, _make_context())

assert result.modified_payload is not None
assert result.modified_payload is not payload
assert result.modified_payload.result is not payload.result
assert result.modified_payload.result["content"] is not payload.result["content"]
assert result.modified_payload.result["content"][0] is not payload.result["content"][0]
assert payload.result["content"][0]["text"] == "Contact alice@example.com"
assert result.modified_payload.result["content"][0]["text"] == "Contact [REDACTED]"


@pytest.mark.asyncio
async def test_tool_post_invoke_survives_cpex_policy_with_isolated_payload():
plugin = PIIFilterPlugin(_make_config())
payload = ToolPostInvokePayload(
name="search",
result={
"content": [{"type": "text", "text": "Contact alice@example.com"}],
"isError": False,
},
)
plugin_input = wrap_payload_for_isolation(payload)

result = await plugin.tool_post_invoke(plugin_input, _make_context())

assert result.modified_payload is not None
filtered = apply_policy(
plugin_input,
result.modified_payload,
HookPayloadPolicy(writable_fields=frozenset({"result"})),
apply_to=payload,
)
assert filtered is not None
assert payload.result["content"][0]["text"] == "Contact alice@example.com"
assert filtered.result["content"][0]["text"] == "Contact [REDACTED]"


@pytest.mark.asyncio
async def test_tool_post_invoke_blocks_when_configured():
plugin = PIIFilterPlugin(_make_config(block_on_detection=True))
Expand Down
61 changes: 60 additions & 1 deletion plugins/tests/plugin_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,70 @@
from __future__ import annotations

import importlib
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Any


@dataclass(frozen=True)
class HookPayloadPolicy:
writable_fields: frozenset[str]


class CopyOnWriteDict(dict):
def __init__(self, original: dict[str, Any]) -> None:
super().__init__()
self._original = original

def __getitem__(self, key: Any) -> Any:
return super().__getitem__(key) if key in self else self._original[key]

def __iter__(self):
return iter(self._original)

def __len__(self) -> int:
return len(self._original)

def items(self):
return ((key, self[key]) for key in self)

def copy(self) -> dict:
return dict(self.items())


def wrap_payload_for_isolation(payload: Any) -> Any:
if not is_dataclass(payload):
return payload
updates = {}
for item in fields(payload):
value = getattr(payload, item.name)
updates[item.name] = CopyOnWriteDict(value) if isinstance(value, dict) else value
return type(payload)(**updates)


def apply_policy(
original: Any,
modified: Any,
policy: HookPayloadPolicy,
*,
apply_to: Any | None = None,
) -> Any | None:
target = apply_to if apply_to is not None else original
updates = {}
for item in fields(modified):
old_value = getattr(original, item.name)
new_value = getattr(modified, item.name)
if new_value == old_value:
continue
if item.name in policy.writable_fields:
updates[item.name] = new_value
if not updates:
return None
values = {item.name: getattr(target, item.name) for item in fields(target)}
values.update(updates)
return type(target)(**values)


class PromptHookType(str, Enum):
PROMPT_PRE_FETCH = "prompt_pre_fetch"
PROMPT_POST_FETCH = "prompt_post_fetch"
Expand Down
Loading