Skip to content

Commit

Permalink
Implementation of __cause__ for ValidationError using ExceptionGroups (
Browse files Browse the repository at this point in the history
  • Loading branch information
zakstucke committed Sep 18, 2023
1 parent e801531 commit 245381f
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 12 deletions.
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Expand Up @@ -61,6 +61,8 @@ class CoreConfig(TypedDict, total=False):
ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'.
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
"""

title: str
Expand Down Expand Up @@ -92,6 +94,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False


IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None'
Expand Down
101 changes: 99 additions & 2 deletions src/errors/validation_exception.rs
Expand Up @@ -3,10 +3,11 @@ use std::fmt::{Display, Write};
use std::str::from_utf8;

use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::ffi;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::{intern, AsPyPointer};
use serde::ser::{Error, SerializeMap, SerializeSeq};
use serde::{Serialize, Serializer};

Expand Down Expand Up @@ -51,6 +52,7 @@ impl ValidationError {
error: ValError,
outer_location: Option<LocItem>,
hide_input: bool,
validation_error_cause: bool,
) -> PyErr {
match error {
ValError::LineErrors(raw_errors) => {
Expand All @@ -61,9 +63,19 @@ impl ValidationError {
.collect(),
None => raw_errors.into_iter().map(|e| e.into_py(py)).collect(),
};

let validation_error = Self::new(line_errors, title, error_mode, hide_input);

match Py::new(py, validation_error) {
Ok(err) => PyErr::from_value(err.into_ref(py)),
Ok(err) => {
if validation_error_cause {
// Will return an import error if the backport was needed and not installed:
if let Some(cause_problem) = ValidationError::maybe_add_cause(err.borrow(py), py) {
return cause_problem;
}
}
PyErr::from_value(err.as_ref(py))
}
Err(err) => err,
}
}
Expand Down Expand Up @@ -93,6 +105,91 @@ impl ValidationError {
pub fn use_default_error() -> PyErr {
py_schema_error_type!("Uncaught UseDefault error, please check your usage of `default` validators.")
}

fn maybe_add_cause(self_: PyRef<'_, Self>, py: Python) -> Option<PyErr> {
let mut user_py_errs = vec![];
for line_error in &self_.line_errors {
if let ErrorType::AssertionError {
error: Some(err),
context: _,
}
| ErrorType::ValueError {
error: Some(err),
context: _,
} = &line_error.error_type
{
let note: PyObject = if let Location::Empty = &line_error.location {
"Pydantic: cause of loc: root".into_py(py)
} else {
format!(
"Pydantic: cause of loc: {}",
// Location formats with a newline at the end, hence the trim()
line_error.location.to_string().trim()
)
.into_py(py)
};

// Notes only support 3.11 upwards:
#[cfg(Py_3_11)]
{
// Add the location context as a note, no direct c api for this,
// fine performance wise, add_note() goes directly to C: "(PyCFunction)BaseException_add_note":
// https://github.com/python/cpython/blob/main/Objects/exceptions.c
if err.call_method1(py, "add_note", (format!("\n{note}"),)).is_ok() {
user_py_errs.push(err.clone_ref(py));
}
}

// Pre 3.11 notes support, use a UserWarning exception instead:
#[cfg(not(Py_3_11))]
{
use pyo3::exceptions::PyUserWarning;

let wrapped = PyUserWarning::new_err((note,));
wrapped.set_cause(py, Some(PyErr::from_value(err.as_ref(py))));
user_py_errs.push(wrapped);
}
}
}

// Only add the cause if there are actually python user exceptions to show:
if !user_py_errs.is_empty() {
let title = "Pydantic User Code Exceptions";

// Native ExceptionGroup(s) only supported 3.11 and later:
#[cfg(Py_3_11)]
let cause = {
use pyo3::exceptions::PyBaseExceptionGroup;
Some(PyBaseExceptionGroup::new_err((title, user_py_errs)).into_py(py))
};

// Pre 3.11 ExceptionGroup support, use the python backport instead:
// If something's gone wrong with the backport, just don't add the cause:
#[cfg(not(Py_3_11))]
let cause = {
use pyo3::exceptions::PyImportError;
match py.import("exceptiongroup") {
Ok(py_mod) => match py_mod.getattr("ExceptionGroup") {
Ok(group_cls) => match group_cls.call1((title, user_py_errs)) {
Ok(group_instance) => Some(group_instance.into_py(py)),
Err(_) => None,
},
Err(_) => None,
},
Err(_) => return Some(PyImportError::new_err("validation_error_cause flag requires the exceptiongroup module backport to be installed when used on Python <3.11.")),
}
};

// Set the cause to the ValidationError:
if let Some(cause) = cause {
unsafe {
// PyException_SetCause _steals_ a reference to cause, so must use .into_ptr()
ffi::PyException_SetCause(self_.as_ptr(), cause.into_ptr());
}
}
}
None
}
}

static URL_ENV_VAR: GILOnceCell<bool> = GILOnceCell::new();
Expand Down
18 changes: 12 additions & 6 deletions src/validators/function.rs
Expand Up @@ -288,6 +288,7 @@ pub struct FunctionWrapValidator {
field_name: Option<Py<PyString>>,
info_arg: bool,
hide_input_in_errors: bool,
validation_error_cause: bool,
}

impl BuildValidator for FunctionWrapValidator {
Expand All @@ -302,6 +303,7 @@ impl BuildValidator for FunctionWrapValidator {
let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, definitions)?;
let function_info = destructure_function_schema(schema)?;
let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false);
let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false);
Ok(Self {
validator: Box::new(validator),
func: function_info.function.clone(),
Expand All @@ -313,6 +315,7 @@ impl BuildValidator for FunctionWrapValidator {
field_name: function_info.field_name.clone(),
info_arg: function_info.info_arg,
hide_input_in_errors,
validation_error_cause,
}
.into())
}
Expand Down Expand Up @@ -356,6 +359,7 @@ impl Validator for FunctionWrapValidator {
&self.validator,
state,
self.hide_input_in_errors,
self.validation_error_cause,
),
};
self._validate(
Expand All @@ -381,6 +385,7 @@ impl Validator for FunctionWrapValidator {
&self.validator,
state,
self.hide_input_in_errors,
self.validation_error_cause,
),
updated_field_name: field_name.to_string(),
updated_field_value: field_value.to_object(py),
Expand Down Expand Up @@ -478,12 +483,12 @@ impl AssignmentValidatorCallable {
}

macro_rules! py_err_string {
($error_value:expr, $type_member:ident, $input:ident) => {
($py:expr, $py_err:expr, $error_value:expr, $type_member:ident, $input:ident) => {
match $error_value.str() {
Ok(py_string) => match py_string.to_str() {
Ok(_) => ValError::new(
ErrorType::$type_member {
error: Some($error_value.into()),
error: Some($py_err.into_py($py)),
context: None,
},
$input,
Expand All @@ -499,17 +504,18 @@ macro_rules! py_err_string {
/// as validation errors, `TypeError` is now considered as a runtime error to catch errors in function signatures
pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> ValError<'a> {
if err.is_instance_of::<PyValueError>(py) {
if let Ok(pydantic_value_error) = err.value(py).extract::<PydanticCustomError>() {
let error_value = err.value(py);
if let Ok(pydantic_value_error) = error_value.extract::<PydanticCustomError>() {
pydantic_value_error.into_val_error(input)
} else if let Ok(pydantic_error_type) = err.value(py).extract::<PydanticKnownError>() {
} else if let Ok(pydantic_error_type) = error_value.extract::<PydanticKnownError>() {
pydantic_error_type.into_val_error(input)
} else if let Ok(validation_error) = err.value(py).extract::<ValidationError>() {
validation_error.into_val_error(py)
} else {
py_err_string!(err.value(py), ValueError, input)
py_err_string!(py, err, error_value, ValueError, input)
}
} else if err.is_instance_of::<PyAssertionError>(py) {
py_err_string!(err.value(py), AssertionError, input)
py_err_string!(py, err, err.value(py), AssertionError, input)
} else if err.is_instance_of::<PydanticOmit>(py) {
ValError::Omit
} else if err.is_instance_of::<PydanticUseDefault>(py) {
Expand Down
29 changes: 25 additions & 4 deletions src/validators/generator.rs
Expand Up @@ -19,6 +19,7 @@ pub struct GeneratorValidator {
max_length: Option<usize>,
name: String,
hide_input_in_errors: bool,
validation_error_cause: bool,
}

impl BuildValidator for GeneratorValidator {
Expand All @@ -37,12 +38,16 @@ impl BuildValidator for GeneratorValidator {
let hide_input_in_errors: bool = config
.get_as(pyo3::intern!(schema.py(), "hide_input_in_errors"))?
.unwrap_or(false);
let validation_error_cause: bool = config
.get_as(pyo3::intern!(schema.py(), "validation_error_cause"))?
.unwrap_or(false);
Ok(Self {
item_validator,
name,
min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?,
max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?,
hide_input_in_errors,
validation_error_cause,
}
.into())
}
Expand All @@ -58,17 +63,24 @@ impl Validator for GeneratorValidator {
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let iterator = input.validate_iter()?;
let validator = self
.item_validator
.as_ref()
.map(|v| InternalValidator::new(py, "ValidatorIterator", v, state, self.hide_input_in_errors));
let validator = self.item_validator.as_ref().map(|v| {
InternalValidator::new(
py,
"ValidatorIterator",
v,
state,
self.hide_input_in_errors,
self.validation_error_cause,
)
});

let v_iterator = ValidatorIterator {
iterator,
validator,
min_length: self.min_length,
max_length: self.max_length,
hide_input_in_errors: self.hide_input_in_errors,
validation_error_cause: self.validation_error_cause,
};
Ok(v_iterator.into_py(py))
}
Expand Down Expand Up @@ -105,6 +117,7 @@ struct ValidatorIterator {
min_length: Option<usize>,
max_length: Option<usize>,
hide_input_in_errors: bool,
validation_error_cause: bool,
}

#[pymethods]
Expand All @@ -117,6 +130,7 @@ impl ValidatorIterator {
let min_length = slf.min_length;
let max_length = slf.max_length;
let hide_input_in_errors = slf.hide_input_in_errors;
let validation_error_cause = slf.validation_error_cause;
let Self {
validator, iterator, ..
} = &mut *slf;
Expand All @@ -143,6 +157,7 @@ impl ValidatorIterator {
val_error,
None,
hide_input_in_errors,
validation_error_cause,
));
}
}
Expand All @@ -169,6 +184,7 @@ impl ValidatorIterator {
val_error,
None,
hide_input_in_errors,
validation_error_cause,
));
}
}
Expand Down Expand Up @@ -217,6 +233,7 @@ pub struct InternalValidator {
recursion_guard: RecursionGuard,
validation_mode: InputType,
hide_input_in_errors: bool,
validation_error_cause: bool,
}

impl fmt::Debug for InternalValidator {
Expand All @@ -232,6 +249,7 @@ impl InternalValidator {
validator: &CombinedValidator,
state: &ValidationState,
hide_input_in_errors: bool,
validation_error_cause: bool,
) -> Self {
let extra = state.extra();
Self {
Expand All @@ -246,6 +264,7 @@ impl InternalValidator {
recursion_guard: state.recursion_guard.clone(),
validation_mode: extra.mode,
hide_input_in_errors,
validation_error_cause,
}
}

Expand Down Expand Up @@ -277,6 +296,7 @@ impl InternalValidator {
e,
outer_location,
self.hide_input_in_errors,
self.validation_error_cause,
)
})
}
Expand Down Expand Up @@ -305,6 +325,7 @@ impl InternalValidator {
e,
outer_location,
self.hide_input_in_errors,
self.validation_error_cause,
)
})
}
Expand Down
5 changes: 5 additions & 0 deletions src/validators/mod.rs
Expand Up @@ -107,6 +107,7 @@ pub struct SchemaValidator {
#[pyo3(get)]
title: PyObject,
hide_input_in_errors: bool,
validation_error_cause: bool,
}

#[pymethods]
Expand All @@ -133,12 +134,14 @@ impl SchemaValidator {
None => validator.get_name().into_py(py),
};
let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false);
let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false);
Ok(Self {
validator,
definitions,
schema: schema.into_py(py),
title,
hide_input_in_errors,
validation_error_cause,
})
}

Expand Down Expand Up @@ -329,6 +332,7 @@ impl SchemaValidator {
error,
None,
self.hide_input_in_errors,
self.validation_error_cause,
)
}
}
Expand Down Expand Up @@ -385,6 +389,7 @@ impl<'py> SelfValidator<'py> {
schema: py.None(),
title: "Self Schema".into_py(py),
hide_input_in_errors: false,
validation_error_cause: false,
})
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Expand Up @@ -14,3 +14,4 @@ pytest-timeout==2.1.0
pytz==2023.3.post1
# numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux
numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64'
exceptiongroup==1.1; python_version < "3.11"

0 comments on commit 245381f

Please sign in to comment.