diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index a902192382..7a16bbee0f 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -211,7 +211,6 @@ class SafeDeserializationPolicy(DeserializationPolicy): def validate_class(self, cls, is_local, **kwargs): if cls.__module__ in dangerous_modules: raise ValueError(f"Blocked dangerous class: {cls.__module__}.{cls.__name__}") - return None def intercept_reduce_call(self, callable_obj, args, **kwargs): if getattr(callable_obj, "__name__", "") == "Popen": @@ -229,11 +228,15 @@ fory = pyfory.Fory(xlang=False, ref=True, strict=False, policy=policy) Available policy hooks include: +Reference validation hooks reject by raising exceptions and otherwise leave deserialized references +unchanged. + | Hook | Description | | -------------------------------------------- | --------------------------------------------------- | | `validate_class(cls, is_local)` | Validate or block class types | -| `validate_module(module, is_local)` | Validate or block module imports | +| `validate_module(module_name, is_local)` | Validate or block module imports | | `validate_function(func, is_local)` | Validate or block function references | +| `validate_method(method, is_local)` | Validate or block method references | | `intercept_reduce_call(callable_obj, args)` | Intercept `__reduce__` invocations | | `inspect_reduced_object(obj)` | Inspect or replace objects created via `__reduce__` | | `intercept_setstate(obj, state)` | Sanitize state before `__setstate__` | diff --git a/python/README.md b/python/README.md index d9ba9c7077..011e78b0cf 100644 --- a/python/README.md +++ b/python/README.md @@ -1119,7 +1119,6 @@ class SafeDeserializationPolicy(DeserializationPolicy): # Block dangerous modules if cls.__module__ in dangerous_modules: raise ValueError(f"Blocked dangerous class: {cls.__module__}.{cls.__name__}") - return None def intercept_reduce_call(self, callable_obj, args, **kwargs): # Block specific callable invocations during __reduce__ @@ -1144,9 +1143,11 @@ result = fory.deserialize(data) # Policy hooks will be invoked **Available Policy Hooks:** +- Reference validation hooks reject by raising exceptions and otherwise leave deserialized references unchanged. - `validate_class(cls, is_local)` - Validate/block class types during deserialization -- `validate_module(module, is_local)` - Validate/block module imports +- `validate_module(module_name, is_local)` - Validate/block module imports - `validate_function(func, is_local)` - Validate/block function references +- `validate_method(method, is_local)` - Validate/block method references - `intercept_reduce_call(callable_obj, args)` - Intercept `__reduce__` invocations - `inspect_reduced_object(obj)` - Inspect/replace objects created via `__reduce__` - `intercept_setstate(obj, state)` - Sanitize state before `__setstate__` diff --git a/python/pyfory/meta/typedef_decoder.py b/python/pyfory/meta/typedef_decoder.py index 3f2b224c1c..2c6a89d607 100644 --- a/python/pyfory/meta/typedef_decoder.py +++ b/python/pyfory/meta/typedef_decoder.py @@ -184,9 +184,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: type_cls = make_dataclass(class_name, field_definitions) policy = getattr(resolver, "policy", None) if policy is not None: - result = policy.validate_class(type_cls, is_local=True) - if result is not None: - type_cls = result + policy.validate_class(type_cls, is_local=True) elif type_cls is None: raise ValueError(f"TypeDef {name} is not registered") diff --git a/python/pyfory/policy.py b/python/pyfory/policy.py index 5070821c0b..b8f8ea2e44 100644 --- a/python/pyfory/policy.py +++ b/python/pyfory/policy.py @@ -48,7 +48,7 @@ class DeserializationPolicy: | __reduce__ interception | no | intercept_reduce_call() | | Post-reduce inspection | no | inspect_reduced_object() | | __setstate__ interception | no | intercept_setstate() | - | Object replacement | no | return from validators | + | Object replacement | no | inspect_reduced_object() | | State sanitization | no | modify in-place | | Local class/function | no | is_local flag | +---------------------------+----------------------+----------------------------+ @@ -87,8 +87,8 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs): This DeserializationPolicy interface allows users to implement custom security policies by subclassing and overriding specific hook methods. Each hook is called at a critical - point during deserialization, allowing inspection, replacement, or rejection of - dangerous constructs. + point during deserialization, allowing validation hooks to inspect or reject + dangerous constructs and interceptor hooks to control protocol operations. Hook Categories --------------- @@ -98,7 +98,7 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs): 2. **Reference Validation Hooks** (Validators) - Validate deserialized type/function/module references - - Return None to accept original, return object to replace, raise exception to block, + - Raise exception to block, otherwise return normally 3. **Protocol Interception Hooks** (Interceptors) - Intercept pickle protocol operations (__reduce__, __setstate__) @@ -109,17 +109,15 @@ def intercept_reduce_call(self, callable_obj, args, **kwargs): >>> class SafeDeserializationPolicy(DeserializationPolicy): ... ALLOWED_MODULES = {'builtins', 'datetime', 'decimal'} ... - ... def validate_module(self, module_name, **kwargs): + ... def validate_module(self, module_name, is_local, **kwargs): ... # Reject imports from disallowed modules ... if module_name.split('.')[0] not in self.ALLOWED_MODULES: ... raise ValueError(f"Module {module_name} is not allowed") - ... return None # Accept ... ... def validate_class(self, cls, is_local, **kwargs): ... # Reject dangerous built-in classes ... if cls.__name__ in ('eval', 'exec', 'compile'): ... raise ValueError(f"Class {cls} is forbidden") - ... return None # Accept ... ... def intercept_reduce_call(self, callable_obj, args, **kwargs): ... # Log all __reduce__ callables for audit @@ -201,7 +199,7 @@ def validate_class(self, cls, *, is_local: bool, **kwargs): This hook is called after a class reference has been deserialized (either by importing from a module or reconstructing a local class), but before it is used. - It allows inspection, replacement, or rejection of class references. + It allows inspection or rejection of class references. When Called ----------- @@ -212,9 +210,8 @@ def validate_class(self, cls, *, is_local: bool, **kwargs): Security Use Cases ------------------ - Block dangerous classes (subprocess.Popen, os.system, etc.) - - Replace untrusted classes with safe alternatives - Validate that local classes match expected signatures - - Implement class versioning or adaptation logic + - Audit class imports for security logging Args: cls (type): The deserialized class object. @@ -223,29 +220,20 @@ def validate_class(self, cls, *, is_local: bool, **kwargs): class from an importable module. **kwargs: Reserved for future extensions. - Returns: - None: Return None to accept the class as-is. - type: Return a different class to replace the original. The replacement - class will be used instead for deserialization. - Raises: Exception: Raise any exception to reject the class and abort deserialization. Example: - >>> class ClassAdapter(DeserializationPolicy): + >>> class ClassChecker(DeserializationPolicy): ... def validate_class(self, cls, is_local, **kwargs): - ... # Map a serialized class name to the current class. - ... if cls.__name__ == 'ArchivedUserClass': - ... return NewUserClass ... # Block dangerous classes ... if cls.__module__ == 'subprocess': ... raise ValueError("subprocess classes not allowed") - ... return None # Accept Note: `check_class` is an alias for this hook. """ - pass + return None def validate_function(self, func, is_local: bool, **kwargs): """Validate a deserialized function reference. @@ -263,7 +251,6 @@ def validate_function(self, func, is_local: bool, **kwargs): ------------------ - Block dangerous built-in functions (eval, exec, compile, __import__) - Validate that reconstructed functions have expected signatures - - Replace untrusted functions with safe alternatives - Audit function imports for security logging Args: @@ -272,10 +259,6 @@ def validate_function(self, func, is_local: bool, **kwargs): within a function scope), False if it's a global function. **kwargs: Reserved for future extensions. - Returns: - None: Return None to accept the function as-is. - function: Return a different function to replace the original. - Raises: Exception: Raise any exception to reject the function. @@ -286,12 +269,11 @@ def validate_function(self, func, is_local: bool, **kwargs): ... def validate_function(self, func, is_local, **kwargs): ... if func.__name__ in self.BLOCKED: ... raise ValueError(f"Function {func.__name__} is forbidden") - ... return None Note: `check_function` is an alias for this hook. """ - pass + return None def validate_method(self, method, is_local: bool, **kwargs): """Validate a deserialized method reference. @@ -309,17 +291,13 @@ def validate_method(self, method, is_local: bool, **kwargs): ------------------ - Validate that methods belong to expected classes - Block methods that could perform dangerous operations - - Replace methods with safer alternatives + - Audit method references for security logging Args: method (method): The deserialized bound method object. is_local (bool): True if the method's class is local, False if global. **kwargs: Reserved for future extensions. - Returns: - None: Return None to accept the method as-is. - method: Return a different method to replace the original. - Raises: Exception: Raise any exception to reject the method. @@ -329,39 +307,35 @@ def validate_method(self, method, is_local: bool, **kwargs): ... # Block methods from dangerous classes ... if method.__self__.__class__.__name__ == 'FileRemover': ... raise ValueError("FileRemover methods not allowed") - ... return None Note: `check_method` is an alias for this hook. """ - pass + return None - def validate_module(self, module_name: str, **kwargs): + def validate_module(self, module_name: str, *, is_local: bool, **kwargs): """Validate a deserialized module reference. - This hook is called after a module has been imported during deserialization, - but before it is used. + This hook is called before a module is imported during deserialization. When Called ----------- - - After importing modules via importlib.import_module() - - Before the module is stored or its contents accessed + - Before importing modules via importlib.import_module() + - Before the module is stored or its contents are accessed Security Use Cases ------------------ - Whitelist/blacklist modules by name or prefix - Prevent imports of system modules (os, subprocess, sys, etc.) - - Replace modules with safe alternatives or mocks - Audit module imports for security logging Args: - module_name (str): The name of the imported module (e.g., 'os.path'). + module_name (str): The name of the module to import (e.g., 'os.path'). + is_local (bool): True if the reference being resolved is local (defined + in __main__ or within a function/method scope), False + otherwise. **kwargs: Reserved for future extensions. - Returns: - None: Return None to accept the module as-is. - module: Return a different module object to replace the original. - Raises: Exception: Raise any exception to reject the module import. @@ -369,16 +343,15 @@ def validate_module(self, module_name: str, **kwargs): >>> class ModuleWhitelistChecker(DeserializationPolicy): ... ALLOWED = {'builtins', 'datetime', 'decimal', 'collections'} ... - ... def validate_module(self, module_name, **kwargs): + ... def validate_module(self, module_name, is_local, **kwargs): ... root = module_name.split('.')[0] ... if root not in self.ALLOWED: ... raise ValueError(f"Module {module_name} not whitelisted") - ... return None Note: `check_module` is an alias for this hook. """ - pass + return None # ============================================================================ # Protocol Interception Hooks (Interceptors) diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 2736599b57..e2cdf7994f 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -46,23 +46,18 @@ from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION -def _import_validated_module(policy, module_name): - result = policy.validate_module(module_name) - if result is not None: - if isinstance(result, types.ModuleType): - return result - assert isinstance(result, str), f"validate_module must return module, str, or None, got {type(result)}" - module_name = result +def _import_validated_module(policy, module_name, is_local=False): + policy.validate_module(module_name, is_local=is_local) return importlib.import_module(module_name) -def _resolve_validated_module_attr(policy, module_name, attr_name): - module = _import_validated_module(policy, module_name) +def _resolve_validated_module_attr(policy, module_name, attr_name, is_local=False): + module = _import_validated_module(policy, module_name, is_local=is_local) return getattr(module, attr_name) def _resolve_validated_module_qualname(policy, module_name, qualname): - obj = _import_validated_module(policy, module_name) + obj = _import_validated_module(policy, module_name, is_local=_is_local_qualname(module_name, qualname)) for name in qualname.split("."): obj = getattr(obj, name) return obj @@ -111,21 +106,15 @@ def _is_bound_method_value(obj): def _validate_function_value(policy, func, is_local): if isinstance(func, type): - result = policy.validate_class(func, is_local=is_local) - if result is not None: - func = result + policy.validate_class(func, is_local=is_local) if isinstance(func, type): raise TypeError(f"Function serializer resolved class {func.__module__}.{func.__qualname__}") if _is_bound_method_value(func): - result = policy.validate_method(func, is_local=is_local) - if result is not None: - func = result + policy.validate_method(func, is_local=is_local) return func if not callable(func): raise TypeError(f"Function serializer resolved non-callable object {func!r}") - result = policy.validate_function(func, is_local=is_local) - if result is not None: - func = result + policy.validate_function(func, is_local=is_local) return func @@ -159,9 +148,7 @@ def _resolve_validated_bound_method(policy, obj, method_name, is_local): if policy is DEFAULT_POLICY: return getattr(obj, method_name) method = _bind_static_method(obj, method_name) - result = policy.validate_method(method, is_local=is_local) - if result is not None: - method = result + policy.validate_method(method, is_local=is_local) return method @@ -1214,15 +1201,12 @@ def __init__(self, type_resolver, cls): self._getnewargs = getattr(cls, "__getnewargs__", None) def _validate_global_object(self, policy, obj): - result = None if isinstance(obj, type): - result = policy.validate_class(obj, is_local=_is_local_class(obj)) + policy.validate_class(obj, is_local=_is_local_class(obj)) elif _is_bound_method_value(obj): - result = policy.validate_method(obj, is_local=_is_local_callable(obj)) + policy.validate_method(obj, is_local=_is_local_callable(obj)) elif isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)): - result = policy.validate_function(obj, is_local=_is_local_callable(obj)) - if result is not None: - obj = result + policy.validate_function(obj, is_local=_is_local_callable(obj)) return obj def _resolve_global_name(self, read_context, global_name): @@ -1232,7 +1216,12 @@ def _resolve_global_name(self, read_context, global_name): else: module_name, obj_name = "builtins", global_name try: - obj = _resolve_validated_module_attr(policy, module_name, obj_name) + obj = _resolve_validated_module_attr( + policy, + module_name, + obj_name, + is_local=_is_local_qualname(module_name, obj_name), + ) except AttributeError: raise ValueError(f"Cannot resolve global name: {global_name}") return self._validate_global_object(policy, obj) @@ -1370,9 +1359,7 @@ def read(self, read_context): module_name = read_context.read_string() qualname = read_context.read_string() cls = _resolve_validated_module_qualname(read_context.policy, module_name, qualname) - result = read_context.policy.validate_class(cls, is_local=_is_local_class(cls)) - if result is not None: - cls = result + read_context.policy.validate_class(cls, is_local=_is_local_class(cls)) return cls def _serialize_local_class(self, write_context, cls): @@ -1422,9 +1409,7 @@ def _deserialize_local_class(self, read_context): read_context.policy.authorize_instantiation(type, module=module, qualname=qualname, bases=bases) cls = type(name, bases, {}) read_context.set_read_ref(ref_id, cls) - result = read_context.policy.validate_class(cls, is_local=True) - if result is not None: - cls = result + read_context.policy.validate_class(cls, is_local=True) num_class_methods = read_context.read_var_uint32() _check_collection_size(read_context, num_class_methods, "local class method") @@ -1440,9 +1425,7 @@ def _deserialize_local_class(self, read_context): # Set module and qualname cls.__module__ = module cls.__qualname__ = qualname - result = read_context.policy.validate_class(cls, is_local=True) - if result is not None: - cls = result + read_context.policy.validate_class(cls, is_local=True) return cls @@ -1457,7 +1440,7 @@ def write(self, buffer, value): def read(self, read_context): mod_name = read_context.read_string() - return _import_validated_module(read_context.policy, mod_name) + return _import_validated_module(read_context.policy, mod_name, is_local=_is_local_qualname(mod_name, "")) class MappingProxySerializer(Serializer): @@ -1617,7 +1600,7 @@ def _deserialize_function(self, read_context): module = read_context.read_string() qualname = read_context.read_string() - mod = _import_validated_module(read_context.policy, module) + mod = _import_validated_module(read_context.policy, module, is_local=_is_local_qualname(module, qualname)) name = qualname.rsplit(".")[-1] marshalled_code = read_context.read_bytes_and_size() @@ -1699,7 +1682,12 @@ def read(self, read_context): name = read_context.read_string() if read_context.read_bool(): module = read_context.read_string() - func = _resolve_validated_module_attr(read_context.policy, module, name) + func = _resolve_validated_module_attr( + read_context.policy, + module, + name, + is_local=_is_local_qualname(module, name), + ) func = _validate_function_value(read_context.policy, func, is_local=_is_local_callable(func)) else: obj = read_context.read_ref() diff --git a/python/pyfory/tests/test_policy.py b/python/pyfory/tests/test_policy.py index 9eb7d55f0c..dc0737369d 100644 --- a/python/pyfory/tests/test_policy.py +++ b/python/pyfory/tests/test_policy.py @@ -30,6 +30,10 @@ def policy_global_function(): return "safe" +def policy_replacement_function(): + return "replacement" + + class PolicyMethodHolder: def run(self): return "safe" @@ -491,26 +495,26 @@ def test_validate_module(): import json import collections - # Test 1: Return module object directly class ReturnModulePolicy(DeserializationPolicy): - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): + assert not is_local return collections fory1 = Fory(xlang=False, ref=True, strict=False, policy=ReturnModulePolicy()) data = fory1.serialize(json) - assert fory1.deserialize(data) is collections + assert fory1.deserialize(data) is json - # Test 2: Return string to redirect import class RedirectPolicy(DeserializationPolicy): - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): + assert not is_local return "collections" if module_name == "json" else None fory2 = Fory(xlang=False, ref=True, strict=False, policy=RedirectPolicy()) - assert fory2.deserialize(fory2.serialize(json)).__name__ == "collections" + assert fory2.deserialize(fory2.serialize(json)).__name__ == "json" - # Test 3: Raise to block module class BlockPolicy(DeserializationPolicy): - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): + assert not is_local raise ValueError(f"Module {module_name} blocked") fory3 = Fory(xlang=False, ref=True, strict=False, policy=BlockPolicy()) @@ -518,6 +522,63 @@ def validate_module(self, module_name, **kwargs): fory3.deserialize(fory3.serialize(json)) +def test_validator_returns_ignored(): + import json + import collections + + class ReplacementClass: + pass + + class ReturnPolicy(DeserializationPolicy): + def validate_module(self, module_name, is_local, **kwargs): + assert not is_local + return collections + + def validate_class(self, cls, is_local, **kwargs): + return ReplacementClass + + def validate_function(self, func, is_local, **kwargs): + return policy_replacement_function + + def validate_method(self, method, is_local, **kwargs): + return policy_replacement_function + + policy = ReturnPolicy() + fory = Fory(xlang=False, ref=True, strict=False, policy=policy) + assert fory.deserialize(fory.serialize(json)) is json + assert fory.deserialize(fory.serialize(PolicyGlobalClass)) is PolicyGlobalClass + assert fory.deserialize(fory.serialize(policy_global_function)) is policy_global_function + + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, [1, __name__, "policy_global_bound_method"]) + assert serializer._deserialize_function(read_context) is policy_global_bound_method + + +def test_local_class_return_ignored(): + class SafeClass: + @classmethod + def run(cls): + return "safe" + + def make_payload_class(): + class PayloadClass: + @classmethod + def run(cls): + return "payload" + + return PayloadClass + + class ReturnClassPolicy(DeserializationPolicy): + def validate_class(self, cls, is_local, **kwargs): + return SafeClass if is_local else None + + fory = Fory(xlang=False, ref=True, strict=False, policy=ReturnClassPolicy()) + decoded = fory.deserialize(fory.serialize(make_payload_class())) + assert decoded is not SafeClass + assert decoded.run() == "payload" + assert SafeClass.run() == "safe" + + def test_type_deserialization_validates_module(): """Test validate_module policy hook for global class deserialization.""" import subprocess @@ -525,9 +586,11 @@ def test_type_deserialization_validates_module(): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == "subprocess": raise ValueError("subprocess blocked") return None @@ -537,6 +600,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="subprocess blocked"): fory.deserialize(fory.serialize(subprocess.Popen)) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [False] def test_native_bound_method_uses_validate_method(): @@ -818,9 +882,11 @@ def test_global_function_deserialization_validates_module(): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == policy_global_function.__module__: raise ValueError("function module blocked") return None @@ -830,6 +896,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="function module blocked"): fory.deserialize(fory.serialize(policy_global_function)) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [False] def test_local_function_deserialization_validates_module(): @@ -841,9 +908,11 @@ def local_function(): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == local_function.__module__: raise ValueError("local function module blocked") return None @@ -853,6 +922,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="local function module blocked"): fory.deserialize(fory.serialize(local_function)) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [True] def test_native_function_deserialization_validates_module(): @@ -862,9 +932,11 @@ def test_native_function_deserialization_validates_module(): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == "time": raise ValueError("time blocked") return None @@ -874,6 +946,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="time blocked"): fory.deserialize(fory.serialize(time.time)) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [False] def test_type_metadata_load_validates_module(): @@ -882,9 +955,11 @@ def test_type_metadata_load_validates_module(): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == "subprocess": raise ValueError("subprocess blocked") return None @@ -902,6 +977,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="subprocess blocked"): resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [False] def test_type_metadata_load_validates_class(): @@ -942,9 +1018,11 @@ def __reduce__(self): class BlockModulePolicy(DeserializationPolicy): def __init__(self): self.validate_module_calls = 0 + self.is_local_values = [] - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + self.is_local_values.append(is_local) if module_name == "subprocess": raise ValueError(f"Module {module_name} blocked") return None @@ -954,6 +1032,7 @@ def validate_module(self, module_name, **kwargs): with pytest.raises(ValueError, match="subprocess blocked"): fory.deserialize(fory.serialize(GlobalNamePayload())) assert policy.validate_module_calls == 1 + assert policy.is_local_values == [False] def test_reduce_global_name_validates_class(): @@ -968,8 +1047,9 @@ def __init__(self): self.validate_module_calls = 0 self.validate_class_calls = 0 - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + assert not is_local return None def validate_class(self, cls, is_local, **kwargs): @@ -998,8 +1078,9 @@ def __init__(self): self.validate_module_calls = 0 self.validate_function_calls = 0 - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + assert not is_local return None def validate_function(self, func, is_local, **kwargs): @@ -1029,8 +1110,9 @@ def __init__(self): self.validate_method_calls = 0 self.validate_function_calls = 0 - def validate_module(self, module_name, **kwargs): + def validate_module(self, module_name, is_local, **kwargs): self.validate_module_calls += 1 + assert not is_local return None def validate_method(self, method, is_local, **kwargs): diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py index d0e8b02a93..ca1a750dc9 100644 --- a/python/pyfory/type_util.py +++ b/python/pyfory/type_util.py @@ -18,7 +18,6 @@ import dataclasses import importlib import inspect -import types import typing from typing import TypeVar @@ -367,34 +366,20 @@ def qualified_class_name(cls): def load_class(classname: str, policy=None): mod_name, cls_name = classname.rsplit("#", 1) + is_local = mod_name == "__main__" or "" in cls_name if policy is not None: - result = policy.validate_module(mod_name) - if result is not None: - if isinstance(result, str): - mod_name = result - mod = None - else: - assert isinstance(result, types.ModuleType), f"validate_module must return module, str, or None, got {type(result)}" - mod = result - else: - mod = None - else: - mod = None - if mod is None: - try: - mod = importlib.import_module(mod_name) - except ImportError as ex: - raise Exception(f"Can't import module {mod_name}") from ex + policy.validate_module(mod_name, is_local=is_local) + try: + mod = importlib.import_module(mod_name) + except ImportError as ex: + raise Exception(f"Can't import module {mod_name}") from ex try: classes = cls_name.split(".") cls = getattr(mod, classes.pop(0)) while classes: cls = getattr(cls, classes.pop(0)) if policy is not None: - is_local = cls.__module__ == "__main__" or "" in cls.__qualname__ - result = policy.validate_class(cls, is_local=is_local) - if result is not None: - cls = result + policy.validate_class(cls, is_local=is_local) return cls except AttributeError as ex: raise Exception(f"Can't import class {cls_name} from module {mod_name}") from ex