Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/guide/python/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
def validate_class(self, cls, is_local, **kwargs):
if cls.__module__ in dangerous_modules:
raise ValueError(f"Blocked dangerous class: {cls.__module__}.{cls.__name__}")
return None

def intercept_reduce_call(self, callable_obj, args, **kwargs):
if getattr(callable_obj, "__name__", "") == "Popen":
Expand All @@ -229,11 +228,15 @@ fory = pyfory.Fory(xlang=False, ref=True, strict=False, policy=policy)

Available policy hooks include:

Reference validation hooks reject by raising exceptions and otherwise leave deserialized references
unchanged.

| Hook | Description |
| -------------------------------------------- | --------------------------------------------------- |
| `validate_class(cls, is_local)` | Validate or block class types |
| `validate_module(module, is_local)` | Validate or block module imports |
| `validate_module(module_name, is_local)` | Validate or block module imports |
| `validate_function(func, is_local)` | Validate or block function references |
| `validate_method(method, is_local)` | Validate or block method references |
| `intercept_reduce_call(callable_obj, args)` | Intercept `__reduce__` invocations |
| `inspect_reduced_object(obj)` | Inspect or replace objects created via `__reduce__` |
| `intercept_setstate(obj, state)` | Sanitize state before `__setstate__` |
Expand Down
5 changes: 3 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,6 @@ class SafeDeserializationPolicy(DeserializationPolicy):
# Block dangerous modules
if cls.__module__ in dangerous_modules:
raise ValueError(f"Blocked dangerous class: {cls.__module__}.{cls.__name__}")
return None

def intercept_reduce_call(self, callable_obj, args, **kwargs):
# Block specific callable invocations during __reduce__
Expand All @@ -1144,9 +1143,11 @@ result = fory.deserialize(data) # Policy hooks will be invoked

**Available Policy Hooks:**

- Reference validation hooks reject by raising exceptions and otherwise leave deserialized references unchanged.
- `validate_class(cls, is_local)` - Validate/block class types during deserialization
- `validate_module(module, is_local)` - Validate/block module imports
- `validate_module(module_name, is_local)` - Validate/block module imports
- `validate_function(func, is_local)` - Validate/block function references
- `validate_method(method, is_local)` - Validate/block method references
- `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__` invocations
- `inspect_reduced_object(obj)` - Inspect/replace objects created via `__reduce__`
- `intercept_setstate(obj, state)` - Sanitize state before `__setstate__`
Expand Down
4 changes: 1 addition & 3 deletions python/pyfory/meta/typedef_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef:
type_cls = make_dataclass(class_name, field_definitions)
policy = getattr(resolver, "policy", None)
if policy is not None:
result = policy.validate_class(type_cls, is_local=True)
if result is not None:
type_cls = result
policy.validate_class(type_cls, is_local=True)
elif type_cls is None:
raise ValueError(f"TypeDef {name} is not registered")

Expand Down
71 changes: 22 additions & 49 deletions python/pyfory/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DeserializationPolicy:
| __reduce__ interception | no | intercept_reduce_call() |
| Post-reduce inspection | no | inspect_reduced_object() |
| __setstate__ interception | no | intercept_setstate() |
| Object replacement | no | return from validators |
| Object replacement | no | inspect_reduced_object() |
| State sanitization | no | modify in-place |
| Local class/function | no | is_local flag |
+---------------------------+----------------------+----------------------------+
Expand Down Expand Up @@ -87,8 +87,8 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs):

This DeserializationPolicy interface allows users to implement custom security policies by
subclassing and overriding specific hook methods. Each hook is called at a critical
point during deserialization, allowing inspection, replacement, or rejection of
dangerous constructs.
point during deserialization, allowing validation hooks to inspect or reject
dangerous constructs and interceptor hooks to control protocol operations.

Hook Categories
---------------
Expand All @@ -98,7 +98,7 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs):

2. **Reference Validation Hooks** (Validators)
- Validate deserialized type/function/module references
- Return None to accept original, return object to replace, raise exception to block,
- Raise exception to block, otherwise return normally

3. **Protocol Interception Hooks** (Interceptors)
- Intercept pickle protocol operations (__reduce__, __setstate__)
Expand All @@ -109,17 +109,15 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs):
>>> class SafeDeserializationPolicy(DeserializationPolicy):
... ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'}
...
... def validate_module(self, module_name, **kwargs):
... def validate_module(self, module_name, is_local, **kwargs):
... # Reject imports from disallowed modules
... if module_name.split('.')[0] not in self.ALLOWED_MODULES:
... raise ValueError(f"Module {module_name} is not allowed")
... return None # Accept
...
... def validate_class(self, cls, is_local, **kwargs):
... # Reject dangerous built-in classes
... if cls.__name__ in ('eval', 'exec', 'compile'):
... raise ValueError(f"Class {cls} is forbidden")
... return None # Accept
...
... def intercept_reduce_call(self, callable_obj, args, **kwargs):
... # Log all __reduce__ callables for audit
Expand Down Expand Up @@ -201,7 +199,7 @@ def validate_class(self, cls, *, is_local: bool, **kwargs):

This hook is called after a class reference has been deserialized (either by
importing from a module or reconstructing a local class), but before it is used.
It allows inspection, replacement, or rejection of class references.
It allows inspection or rejection of class references.

When Called
-----------
Expand All @@ -212,9 +210,8 @@ def validate_class(self, cls, *, is_local: bool, **kwargs):
Security Use Cases
------------------
- Block dangerous classes (subprocess.Popen, os.system, etc.)
- Replace untrusted classes with safe alternatives
- Validate that local classes match expected signatures
- Implement class versioning or adaptation logic
- Audit class imports for security logging

Args:
cls (type): The deserialized class object.
Expand All @@ -223,29 +220,20 @@ def validate_class(self, cls, *, is_local: bool, **kwargs):
class from an importable module.
**kwargs: Reserved for future extensions.

Returns:
None: Return None to accept the class as-is.
type: Return a different class to replace the original. The replacement
class will be used instead for deserialization.

Raises:
Exception: Raise any exception to reject the class and abort deserialization.

Example:
>>> class ClassAdapter(DeserializationPolicy):
>>> class ClassChecker(DeserializationPolicy):
... def validate_class(self, cls, is_local, **kwargs):
... # Map a serialized class name to the current class.
... if cls.__name__ == 'ArchivedUserClass':
... return NewUserClass
... # Block dangerous classes
... if cls.__module__ == 'subprocess':
... raise ValueError("subprocess classes not allowed")
... return None # Accept

Note:
`check_class` is an alias for this hook.
"""
pass
return None

def validate_function(self, func, is_local: bool, **kwargs):
"""Validate a deserialized function reference.
Expand All @@ -263,7 +251,6 @@ def validate_function(self, func, is_local: bool, **kwargs):
------------------
- Block dangerous built-in functions (eval, exec, compile, __import__)
- Validate that reconstructed functions have expected signatures
- Replace untrusted functions with safe alternatives
- Audit function imports for security logging

Args:
Expand All @@ -272,10 +259,6 @@ def validate_function(self, func, is_local: bool, **kwargs):
within a function scope), False if it's a global function.
**kwargs: Reserved for future extensions.

Returns:
None: Return None to accept the function as-is.
function: Return a different function to replace the original.

Raises:
Exception: Raise any exception to reject the function.

Expand All @@ -286,12 +269,11 @@ def validate_function(self, func, is_local: bool, **kwargs):
... def validate_function(self, func, is_local, **kwargs):
... if func.__name__ in self.BLOCKED:
... raise ValueError(f"Function {func.__name__} is forbidden")
... return None

Note:
`check_function` is an alias for this hook.
"""
pass
return None

def validate_method(self, method, is_local: bool, **kwargs):
"""Validate a deserialized method reference.
Expand All @@ -309,17 +291,13 @@ def validate_method(self, method, is_local: bool, **kwargs):
------------------
- Validate that methods belong to expected classes
- Block methods that could perform dangerous operations
- Replace methods with safer alternatives
- Audit method references for security logging

Args:
method (method): The deserialized bound method object.
is_local (bool): True if the method's class is local, False if global.
**kwargs: Reserved for future extensions.

Returns:
None: Return None to accept the method as-is.
method: Return a different method to replace the original.

Raises:
Exception: Raise any exception to reject the method.

Expand All @@ -329,56 +307,51 @@ def validate_method(self, method, is_local: bool, **kwargs):
... # Block methods from dangerous classes
... if method.__self__.__class__.__name__ == 'FileRemover':
... raise ValueError("FileRemover methods not allowed")
... return None

Note:
`check_method` is an alias for this hook.
"""
pass
return None

def validate_module(self, module_name: str, **kwargs):
def validate_module(self, module_name: str, *, is_local: bool, **kwargs):
"""Validate a deserialized module reference.

This hook is called after a module has been imported during deserialization,
but before it is used.
This hook is called before a module is imported during deserialization.

When Called
-----------
- After importing modules via importlib.import_module()
- Before the module is stored or its contents accessed
- Before importing modules via importlib.import_module()
- Before the module is stored or its contents are accessed

Security Use Cases
------------------
- Whitelist/blacklist modules by name or prefix
- Prevent imports of system modules (os, subprocess, sys, etc.)
- Replace modules with safe alternatives or mocks
- Audit module imports for security logging

Args:
module_name (str): The name of the imported module (e.g., 'os.path').
module_name (str): The name of the module to import (e.g., 'os.path').
is_local (bool): True if the reference being resolved is local (defined
in __main__ or within a function/method scope), False
otherwise.
**kwargs: Reserved for future extensions.

Returns:
None: Return None to accept the module as-is.
module: Return a different module object to replace the original.

Raises:
Exception: Raise any exception to reject the module import.

Example:
>>> class ModuleWhitelistChecker(DeserializationPolicy):
... ALLOWED = {'builtins', 'datetime', 'decimal', 'collections'}
...
... def validate_module(self, module_name, **kwargs):
... def validate_module(self, module_name, is_local, **kwargs):
... root = module_name.split('.')[0]
... if root not in self.ALLOWED:
... raise ValueError(f"Module {module_name} not whitelisted")
... return None

Note:
`check_module` is an alias for this hook.
"""
pass
return None

# ============================================================================
# Protocol Interception Hooks (Interceptors)
Expand Down
Loading
Loading