From 245381f981ca3e2b146886e0ffc8011540ba7f6d Mon Sep 17 00:00:00 2001 From: zakstucke <44890343+zakstucke@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:23:55 +0200 Subject: [PATCH] Implementation of __cause__ for ValidationError using ExceptionGroups (#780) --- python/pydantic_core/core_schema.py | 3 + src/errors/validation_exception.rs | 101 ++++++++++++- src/validators/function.rs | 18 ++- src/validators/generator.rs | 29 +++- src/validators/mod.rs | 5 + tests/requirements.txt | 1 + tests/test_errors.py | 210 ++++++++++++++++++++++++++++ 7 files changed, 355 insertions(+), 12 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 74442b44ca..0b4348e81d 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -61,6 +61,8 @@ class CoreConfig(TypedDict, total=False): ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. hide_input_in_errors: Whether to hide input data from `ValidationError` representation. + validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. + Requires exceptiongroup backport pre Python 3.11. """ title: str @@ -92,6 +94,7 @@ class CoreConfig(TypedDict, total=False): ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8' # used to hide input data from ValidationError repr hide_input_in_errors: bool + validation_error_cause: bool # default: False IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 9dde6551c3..e6563f5974 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -3,10 +3,11 @@ use std::fmt::{Display, Write}; use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; -use pyo3::intern; +use pyo3::ffi; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; +use pyo3::{intern, AsPyPointer}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -51,6 +52,7 @@ impl ValidationError { error: ValError, outer_location: Option, hide_input: bool, + validation_error_cause: bool, ) -> PyErr { match error { ValError::LineErrors(raw_errors) => { @@ -61,9 +63,19 @@ impl ValidationError { .collect(), None => raw_errors.into_iter().map(|e| e.into_py(py)).collect(), }; + let validation_error = Self::new(line_errors, title, error_mode, hide_input); + match Py::new(py, validation_error) { - Ok(err) => PyErr::from_value(err.into_ref(py)), + Ok(err) => { + if validation_error_cause { + // Will return an import error if the backport was needed and not installed: + if let Some(cause_problem) = ValidationError::maybe_add_cause(err.borrow(py), py) { + return cause_problem; + } + } + PyErr::from_value(err.as_ref(py)) + } Err(err) => err, } } @@ -93,6 +105,91 @@ impl ValidationError { pub fn use_default_error() -> PyErr { py_schema_error_type!("Uncaught UseDefault error, please check your usage of `default` validators.") } + + fn maybe_add_cause(self_: PyRef<'_, Self>, py: Python) -> Option { + let mut user_py_errs = vec![]; + for line_error in &self_.line_errors { + if let ErrorType::AssertionError { + error: Some(err), + context: _, + } + | ErrorType::ValueError { + error: Some(err), + context: _, + } = &line_error.error_type + { + let note: PyObject = if let Location::Empty = &line_error.location { + "Pydantic: cause of loc: root".into_py(py) + } else { + format!( + "Pydantic: cause of loc: {}", + // Location formats with a newline at the end, hence the trim() + line_error.location.to_string().trim() + ) + .into_py(py) + }; + + // Notes only support 3.11 upwards: + #[cfg(Py_3_11)] + { + // Add the location context as a note, no direct c api for this, + // fine performance wise, add_note() goes directly to C: "(PyCFunction)BaseException_add_note": + // https://github.com/python/cpython/blob/main/Objects/exceptions.c + if err.call_method1(py, "add_note", (format!("\n{note}"),)).is_ok() { + user_py_errs.push(err.clone_ref(py)); + } + } + + // Pre 3.11 notes support, use a UserWarning exception instead: + #[cfg(not(Py_3_11))] + { + use pyo3::exceptions::PyUserWarning; + + let wrapped = PyUserWarning::new_err((note,)); + wrapped.set_cause(py, Some(PyErr::from_value(err.as_ref(py)))); + user_py_errs.push(wrapped); + } + } + } + + // Only add the cause if there are actually python user exceptions to show: + if !user_py_errs.is_empty() { + let title = "Pydantic User Code Exceptions"; + + // Native ExceptionGroup(s) only supported 3.11 and later: + #[cfg(Py_3_11)] + let cause = { + use pyo3::exceptions::PyBaseExceptionGroup; + Some(PyBaseExceptionGroup::new_err((title, user_py_errs)).into_py(py)) + }; + + // Pre 3.11 ExceptionGroup support, use the python backport instead: + // If something's gone wrong with the backport, just don't add the cause: + #[cfg(not(Py_3_11))] + let cause = { + use pyo3::exceptions::PyImportError; + match py.import("exceptiongroup") { + Ok(py_mod) => match py_mod.getattr("ExceptionGroup") { + Ok(group_cls) => match group_cls.call1((title, user_py_errs)) { + Ok(group_instance) => Some(group_instance.into_py(py)), + Err(_) => None, + }, + Err(_) => None, + }, + Err(_) => return Some(PyImportError::new_err("validation_error_cause flag requires the exceptiongroup module backport to be installed when used on Python <3.11.")), + } + }; + + // Set the cause to the ValidationError: + if let Some(cause) = cause { + unsafe { + // PyException_SetCause _steals_ a reference to cause, so must use .into_ptr() + ffi::PyException_SetCause(self_.as_ptr(), cause.into_ptr()); + } + } + } + None + } } static URL_ENV_VAR: GILOnceCell = GILOnceCell::new(); diff --git a/src/validators/function.rs b/src/validators/function.rs index fa8a0673b2..206e10a0c3 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -288,6 +288,7 @@ pub struct FunctionWrapValidator { field_name: Option>, info_arg: bool, hide_input_in_errors: bool, + validation_error_cause: bool, } impl BuildValidator for FunctionWrapValidator { @@ -302,6 +303,7 @@ impl BuildValidator for FunctionWrapValidator { let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, definitions)?; let function_info = destructure_function_schema(schema)?; let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); + let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { validator: Box::new(validator), func: function_info.function.clone(), @@ -313,6 +315,7 @@ impl BuildValidator for FunctionWrapValidator { field_name: function_info.field_name.clone(), info_arg: function_info.info_arg, hide_input_in_errors, + validation_error_cause, } .into()) } @@ -356,6 +359,7 @@ impl Validator for FunctionWrapValidator { &self.validator, state, self.hide_input_in_errors, + self.validation_error_cause, ), }; self._validate( @@ -381,6 +385,7 @@ impl Validator for FunctionWrapValidator { &self.validator, state, self.hide_input_in_errors, + self.validation_error_cause, ), updated_field_name: field_name.to_string(), updated_field_value: field_value.to_object(py), @@ -478,12 +483,12 @@ impl AssignmentValidatorCallable { } macro_rules! py_err_string { - ($error_value:expr, $type_member:ident, $input:ident) => { + ($py:expr, $py_err:expr, $error_value:expr, $type_member:ident, $input:ident) => { match $error_value.str() { Ok(py_string) => match py_string.to_str() { Ok(_) => ValError::new( ErrorType::$type_member { - error: Some($error_value.into()), + error: Some($py_err.into_py($py)), context: None, }, $input, @@ -499,17 +504,18 @@ macro_rules! py_err_string { /// as validation errors, `TypeError` is now considered as a runtime error to catch errors in function signatures pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> ValError<'a> { if err.is_instance_of::(py) { - if let Ok(pydantic_value_error) = err.value(py).extract::() { + let error_value = err.value(py); + if let Ok(pydantic_value_error) = error_value.extract::() { pydantic_value_error.into_val_error(input) - } else if let Ok(pydantic_error_type) = err.value(py).extract::() { + } else if let Ok(pydantic_error_type) = error_value.extract::() { pydantic_error_type.into_val_error(input) } else if let Ok(validation_error) = err.value(py).extract::() { validation_error.into_val_error(py) } else { - py_err_string!(err.value(py), ValueError, input) + py_err_string!(py, err, error_value, ValueError, input) } } else if err.is_instance_of::(py) { - py_err_string!(err.value(py), AssertionError, input) + py_err_string!(py, err, err.value(py), AssertionError, input) } else if err.is_instance_of::(py) { ValError::Omit } else if err.is_instance_of::(py) { diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 1047e31bdd..c52910500d 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -19,6 +19,7 @@ pub struct GeneratorValidator { max_length: Option, name: String, hide_input_in_errors: bool, + validation_error_cause: bool, } impl BuildValidator for GeneratorValidator { @@ -37,12 +38,16 @@ impl BuildValidator for GeneratorValidator { let hide_input_in_errors: bool = config .get_as(pyo3::intern!(schema.py(), "hide_input_in_errors"))? .unwrap_or(false); + let validation_error_cause: bool = config + .get_as(pyo3::intern!(schema.py(), "validation_error_cause"))? + .unwrap_or(false); Ok(Self { item_validator, name, min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?, max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?, hide_input_in_errors, + validation_error_cause, } .into()) } @@ -58,10 +63,16 @@ impl Validator for GeneratorValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let iterator = input.validate_iter()?; - let validator = self - .item_validator - .as_ref() - .map(|v| InternalValidator::new(py, "ValidatorIterator", v, state, self.hide_input_in_errors)); + let validator = self.item_validator.as_ref().map(|v| { + InternalValidator::new( + py, + "ValidatorIterator", + v, + state, + self.hide_input_in_errors, + self.validation_error_cause, + ) + }); let v_iterator = ValidatorIterator { iterator, @@ -69,6 +80,7 @@ impl Validator for GeneratorValidator { min_length: self.min_length, max_length: self.max_length, hide_input_in_errors: self.hide_input_in_errors, + validation_error_cause: self.validation_error_cause, }; Ok(v_iterator.into_py(py)) } @@ -105,6 +117,7 @@ struct ValidatorIterator { min_length: Option, max_length: Option, hide_input_in_errors: bool, + validation_error_cause: bool, } #[pymethods] @@ -117,6 +130,7 @@ impl ValidatorIterator { let min_length = slf.min_length; let max_length = slf.max_length; let hide_input_in_errors = slf.hide_input_in_errors; + let validation_error_cause = slf.validation_error_cause; let Self { validator, iterator, .. } = &mut *slf; @@ -143,6 +157,7 @@ impl ValidatorIterator { val_error, None, hide_input_in_errors, + validation_error_cause, )); } } @@ -169,6 +184,7 @@ impl ValidatorIterator { val_error, None, hide_input_in_errors, + validation_error_cause, )); } } @@ -217,6 +233,7 @@ pub struct InternalValidator { recursion_guard: RecursionGuard, validation_mode: InputType, hide_input_in_errors: bool, + validation_error_cause: bool, } impl fmt::Debug for InternalValidator { @@ -232,6 +249,7 @@ impl InternalValidator { validator: &CombinedValidator, state: &ValidationState, hide_input_in_errors: bool, + validation_error_cause: bool, ) -> Self { let extra = state.extra(); Self { @@ -246,6 +264,7 @@ impl InternalValidator { recursion_guard: state.recursion_guard.clone(), validation_mode: extra.mode, hide_input_in_errors, + validation_error_cause, } } @@ -277,6 +296,7 @@ impl InternalValidator { e, outer_location, self.hide_input_in_errors, + self.validation_error_cause, ) }) } @@ -305,6 +325,7 @@ impl InternalValidator { e, outer_location, self.hide_input_in_errors, + self.validation_error_cause, ) }) } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5824888ec2..9a9c6a185c 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -107,6 +107,7 @@ pub struct SchemaValidator { #[pyo3(get)] title: PyObject, hide_input_in_errors: bool, + validation_error_cause: bool, } #[pymethods] @@ -133,12 +134,14 @@ impl SchemaValidator { None => validator.get_name().into_py(py), }; let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); + let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { validator, definitions, schema: schema.into_py(py), title, hide_input_in_errors, + validation_error_cause, }) } @@ -329,6 +332,7 @@ impl SchemaValidator { error, None, self.hide_input_in_errors, + self.validation_error_cause, ) } } @@ -385,6 +389,7 @@ impl<'py> SelfValidator<'py> { schema: py.None(), title: "Self Schema".into_py(py), hide_input_in_errors: false, + validation_error_cause: false, }) } } diff --git a/tests/requirements.txt b/tests/requirements.txt index ce5cc9f237..ae8b9e50d8 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -14,3 +14,4 @@ pytest-timeout==2.1.0 pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' +exceptiongroup==1.1; python_version < "3.11" diff --git a/tests/test_errors.py b/tests/test_errors.py index 881abd5081..fe71b98604 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,11 +1,15 @@ +import enum import re +import sys from decimal import Decimal from typing import Any, Optional +from unittest.mock import patch import pytest from dirty_equals import HasRepr, IsInstance, IsJson, IsStr from pydantic_core import ( + CoreConfig, PydanticCustomError, PydanticKnownError, PydanticOmit, @@ -517,6 +521,212 @@ def test_all_errors(): pytest.fail('core_schema.ErrorType needs to be updated') +@pytest.mark.skipif(sys.version_info < (3, 11), reason='This is the modern version used post 3.10.') +def test_validation_error_cause_contents(): + enabled_config: CoreConfig = {'validation_error_cause': True} + + def multi_raise_py_error(v: Any) -> Any: + try: + raise AssertionError('Wrong') + except AssertionError as e: + raise ValueError('Oh no!') from e + + s2 = SchemaValidator(core_schema.no_info_plain_validator_function(multi_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s2.validate_python('anything') + + cause_group = exc_info.value.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert cause.__notes__ + assert cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + # Edge case: make sure a deep inner ValidationError(s) causing a validator failure doesn't cause any problems: + def outer_raise_py_error(v: Any) -> Any: + try: + s2.validate_python('anything') + except ValidationError as e: + raise ValueError('Sub val failure') from e + + s3 = SchemaValidator(core_schema.no_info_plain_validator_function(outer_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s3.validate_python('anything') + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + assert len(exc_info.value.__cause__.exceptions) == 1 + cause = exc_info.value.__cause__.exceptions[0] + assert cause.__notes__ and cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Sub val failure')) + subcause = cause.__cause__ + assert isinstance(subcause, ValidationError) + + cause_group = subcause.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert cause.__notes__ + assert cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + +@pytest.mark.skipif(sys.version_info >= (3, 11), reason='This is the backport/legacy version used pre 3.11 only.') +def test_validation_error_cause_contents_legacy(): + from exceptiongroup import BaseExceptionGroup + + enabled_config: CoreConfig = {'validation_error_cause': True} + + def multi_raise_py_error(v: Any) -> Any: + try: + raise AssertionError('Wrong') + except AssertionError as e: + raise ValueError('Oh no!') from e + + s2 = SchemaValidator(core_schema.no_info_plain_validator_function(multi_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s2.validate_python('anything') + + cause_group = exc_info.value.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + # Make sure a deep inner ValidationError(s) causing a validator failure doesn't cause any problems: + def outer_raise_py_error(v: Any) -> Any: + try: + s2.validate_python('anything') + except ValidationError as e: + raise ValueError('Sub val failure') from e + + s3 = SchemaValidator(core_schema.no_info_plain_validator_function(outer_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s3.validate_python('anything') + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + assert len(exc_info.value.__cause__.exceptions) == 1 + cause = exc_info.value.__cause__.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Sub val failure')) + subcause = cause.__cause__ + assert isinstance(subcause, ValidationError) + + cause_group = subcause.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + +class CauseResult(enum.Enum): + CAUSE = enum.auto() + NO_CAUSE = enum.auto() + IMPORT_ERROR = enum.auto() + + +@pytest.mark.parametrize( + 'desc,config,expected_result', + [ # Without the backport should still work after 3.10 as not needed: + ( + 'Enabled', + {'validation_error_cause': True}, + CauseResult.CAUSE if sys.version_info >= (3, 11) else CauseResult.IMPORT_ERROR, + ), + ('Disabled specifically', {'validation_error_cause': False}, CauseResult.NO_CAUSE), + ('Disabled implicitly', {}, CauseResult.NO_CAUSE), + ], +) +def test_validation_error_cause_config_variants(desc: str, config: CoreConfig, expected_result: CauseResult): + # Simulate the package being missing: + with patch.dict('sys.modules', {'exceptiongroup': None}): + + def singular_raise_py_error(v: Any) -> Any: + raise ValueError('Oh no!') + + s = SchemaValidator(core_schema.no_info_plain_validator_function(singular_raise_py_error), config=config) + + if expected_result is CauseResult.IMPORT_ERROR: + # Confirm error message contains "requires the exceptiongroup module" in the middle of the string: + with pytest.raises(ImportError, match='requires the exceptiongroup module'): + s.validate_python('anything') + elif expected_result is CauseResult.CAUSE: + with pytest.raises(ValidationError) as exc_info: + s.validate_python('anything') + assert exc_info.value.__cause__ is not None + assert hasattr(exc_info.value.__cause__, 'exceptions') + assert len(exc_info.value.__cause__.exceptions) == 1 + assert repr(exc_info.value.__cause__.exceptions[0]) == repr(ValueError('Oh no!')) + elif expected_result is CauseResult.NO_CAUSE: + with pytest.raises(ValidationError) as exc_info: + s.validate_python('anything') + assert exc_info.value.__cause__ is None + else: + raise AssertionError('Unhandled result: {}'.format(expected_result)) + + +def test_validation_error_cause_traceback_preserved(): + """Makes sure historic bug of traceback being lost is fixed.""" + + enabled_config: CoreConfig = {'validation_error_cause': True} + + def singular_raise_py_error(v: Any) -> Any: + raise ValueError('Oh no!') + + s1 = SchemaValidator(core_schema.no_info_plain_validator_function(singular_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s1.validate_python('anything') + + base_errs = getattr(exc_info.value.__cause__, 'exceptions', []) + assert len(base_errs) == 1 + base_err = base_errs[0] + + # Get to the root error: + cause = base_err + while cause.__cause__ is not None: + cause = cause.__cause__ + + # Should still have a traceback: + assert cause.__traceback__ is not None + + class BadRepr: def __repr__(self): raise RuntimeError('bad repr')