diff --git a/Cargo.lock b/Cargo.lock index 18e99107f4..af7d64e63b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "once_cell" version = "1.10.0" @@ -203,6 +209,7 @@ dependencies = [ "enum_dispatch", "indexmap", "mimalloc", + "nohash-hasher", "pyo3", "regex", "serde", diff --git a/Cargo.toml b/Cargo.toml index adb8bf1f93..24338b51a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ indexmap = "1.8.1" mimalloc = { version = "0.1.29", default-features = false, optional = true } speedate = "0.4.1" ahash = "0.7.6" +nohash-hasher = "0.2.0" [lib] name = "_pydantic_core" diff --git a/src/errors/kinds.rs b/src/errors/kinds.rs index b155391d80..77c9197a30 100644 --- a/src/errors/kinds.rs +++ b/src/errors/kinds.rs @@ -8,6 +8,10 @@ pub enum ErrorKind { #[strum(message = "Invalid JSON: {parser_error}")] InvalidJson, // --------------------- + // recursion error + #[strum(message = "Recursion error - cyclic reference detected")] + RecursionLoop, + // --------------------- // typed dict specific errors #[strum(message = "Value must be a valid dictionary or instance to extract fields from")] DictAttributesType, diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 9594c03f9b..8b61ae2c50 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -17,14 +17,27 @@ pub enum ValError<'a> { InternalErr(PyErr), } +impl<'a> From for ValError<'a> { + fn from(py_err: PyErr) -> Self { + Self::InternalErr(py_err) + } +} + +impl<'a> From>> for ValError<'a> { + fn from(line_errors: Vec>) -> Self { + Self::LineErrors(line_errors) + } +} + // ValError used to implement Error, see #78 for removed code +// TODO, remove and replace with just .into() pub fn as_internal<'a>(err: PyErr) -> ValError<'a> { - ValError::InternalErr(err) + err.into() } pub fn pretty_line_errors(py: Python, line_errors: Vec) -> String { - let py_line_errors: Vec = line_errors.into_iter().map(|e| PyLineError::new(py, e)).collect(); + let py_line_errors: Vec = line_errors.into_iter().map(|e| e.into_py(py)).collect(); pretty_py_line_errors(Some(py), py_line_errors.iter()) } @@ -58,7 +71,7 @@ impl<'a> ValLineError<'a> { ValLineError { kind: self.kind, reverse_location: self.reverse_location, - input_value: InputValue::PyObject(self.input_value.to_object(py)), + input_value: self.input_value.to_object(py).into(), context: self.context, } } @@ -79,6 +92,12 @@ impl Default for InputValue<'_> { } } +impl<'a> From for InputValue<'a> { + fn from(py_object: PyObject) -> Self { + Self::PyObject(py_object) + } +} + impl<'a> ToPyObject for InputValue<'a> { fn to_object(&self, py: Python) -> PyObject { match self { diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index d0285ca917..621b4c2825 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -16,7 +16,7 @@ use super::location::Location; use super::ValError; #[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ValidationError { line_errors: Vec, title: PyObject, @@ -33,7 +33,7 @@ impl ValidationError { pub fn from_val_error(py: Python, title: PyObject, error: ValError) -> PyErr { match error { ValError::LineErrors(raw_errors) => { - let line_errors: Vec = raw_errors.into_iter().map(|e| PyLineError::new(py, e)).collect(); + let line_errors: Vec = raw_errors.into_iter().map(|e| e.into_py(py)).collect(); PyErr::new::((line_errors, title)) } ValError::InternalErr(err) => err, @@ -61,6 +61,18 @@ impl Error for ValidationError { } } +// used to convert a validation error back to ValError for wrap functions +impl<'a> From for ValError<'a> { + fn from(val_error: ValidationError) -> Self { + val_error + .line_errors + .into_iter() + .map(|e| e.into()) + .collect::>() + .into() + } +} + #[pymethods] impl ValidationError { #[new] @@ -131,19 +143,36 @@ pub struct PyLineError { context: Context, } -impl PyLineError { - pub fn new(py: Python, raw_error: ValLineError) -> Self { +impl<'a> IntoPy for ValLineError<'a> { + fn into_py(self, py: Python<'_>) -> PyLineError { + PyLineError { + kind: self.kind, + location: match self.reverse_location.len() { + 0..=1 => self.reverse_location, + _ => self.reverse_location.into_iter().rev().collect(), + }, + input_value: self.input_value.to_object(py), + context: self.context, + } + } +} + +/// opposite of above, used to extract line errors from a validation error for wrap functions +impl<'a> From for ValLineError<'a> { + fn from(py_line_error: PyLineError) -> Self { Self { - kind: raw_error.kind, - location: match raw_error.reverse_location.len() { - 0..=1 => raw_error.reverse_location, - _ => raw_error.reverse_location.into_iter().rev().collect(), + kind: py_line_error.kind, + reverse_location: match py_line_error.location.len() { + 0..=1 => py_line_error.location, + _ => py_line_error.location.into_iter().rev().collect(), }, - input_value: raw_error.input_value.to_object(py), - context: raw_error.context, + input_value: py_line_error.input_value.into(), + context: py_line_error.context, } } +} +impl PyLineError { pub fn as_dict(&self, py: Python) -> PyResult { let dict = PyDict::new(py); dict.set_item("kind", self.kind())?; diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 9eeae3a896..ecfa77a52d 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -16,6 +16,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn as_error_value(&'a self) -> InputValue<'a>; + fn identity(&'a self) -> Option { + None + } + fn is_none(&self) -> bool; fn strict_str<'data>(&'data self) -> ValResult>; diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 2e9bbf5084..a4cabe4e35 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -1,12 +1,12 @@ use std::str::from_utf8; use pyo3::exceptions::{PyAttributeError, PyTypeError}; -use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{ PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType, }; +use pyo3::{intern, AsPyPointer}; use crate::errors::location::LocItem; use crate::errors::{as_internal, context, err_val_error, py_err_string, ErrorKind, InputValue, ValResult}; @@ -36,6 +36,10 @@ impl<'a> Input<'a> for PyAny { InputValue::PyAny(self) } + fn identity(&'a self) -> Option { + Some(self.as_ptr() as usize) + } + fn is_none(&self) -> bool { self.is_none() } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index cf8902f83e..e1f37ae16a 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,6 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple}; use crate::errors::{ValError, ValLineError, ValResult}; +use crate::recursion_guard::RecursionGuard; use crate::validators::{CombinedValidator, Extra, Validator}; use super::parse_json::{JsonArray, JsonObject}; @@ -41,11 +42,12 @@ macro_rules! build_validate_to_vec { validator: &'s CombinedValidator, extra: &Extra, slots: &'a [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'a, Vec> { let mut output: Vec = Vec::with_capacity(length); let mut errors: Vec = Vec::new(); for (index, item) in sequence.iter().enumerate() { - match validator.validate(py, item, extra, slots) { + match validator.validate(py, item, extra, slots, recursion_guard) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend( @@ -90,13 +92,22 @@ impl<'a> GenericSequence<'a> { validator: &'s CombinedValidator, extra: &Extra, slots: &'a [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'a, Vec> { match self { - Self::List(sequence) => validate_to_vec_list(py, sequence, length, validator, extra, slots), - Self::Tuple(sequence) => validate_to_vec_tuple(py, sequence, length, validator, extra, slots), - Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots), - Self::FrozenSet(sequence) => validate_to_vec_frozenset(py, sequence, length, validator, extra, slots), - Self::JsonArray(sequence) => validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots), + Self::List(sequence) => { + validate_to_vec_list(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::Tuple(sequence) => { + validate_to_vec_tuple(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots, recursion_guard), + Self::FrozenSet(sequence) => { + validate_to_vec_frozenset(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::JsonArray(sequence) => { + validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots, recursion_guard) + } } } } diff --git a/src/lib.rs b/src/lib.rs index 3e36b6088d..2ac4b9c745 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; mod build_tools; mod errors; mod input; +mod recursion_guard; mod validators; // required for benchmarks diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs new file mode 100644 index 0000000000..3a6b9dd9a3 --- /dev/null +++ b/src/recursion_guard.rs @@ -0,0 +1,36 @@ +use std::collections::HashSet; +use std::hash::BuildHasherDefault; + +use nohash_hasher::NoHashHasher; + +/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault. +/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself. +#[derive(Debug, Clone, Default)] +pub struct RecursionGuard(Option>>>); + +impl RecursionGuard { + // insert a new id into the set, return whether the set already had the id in it + pub fn contains_or_insert(&mut self, id: usize) -> bool { + match self.0 { + // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert + // "If the set did not have this value present, `true` is returned." + Some(ref mut set) => !set.insert(id), + None => { + let mut set: HashSet>> = + HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default()); + set.insert(id); + self.0 = Some(set); + false + } + } + } + + pub fn remove(&mut self, id: &usize) { + match self.0 { + Some(ref mut set) => { + set.remove(id); + } + None => unreachable!(), + }; + } +} diff --git a/src/validators/any.rs b/src/validators/any.rs index bedb6bd7bc..38df421977 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -3,6 +3,7 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -29,6 +30,7 @@ impl Validator for AnyValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // Ok(input.clone().into_py(py)) Ok(input.to_object(py)) diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 9b888aad15..b425b3cb2f 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::is_strict; use crate::errors::ValResult; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -33,6 +34,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? @@ -45,6 +47,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } @@ -70,6 +73,7 @@ impl Validator for StrictBoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index c9b3510e75..8c4f45337d 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherBytes, Input}; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -36,6 +37,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.lax_bytes()?; Ok(either_bytes.into_py(py)) @@ -47,6 +49,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -73,6 +76,7 @@ impl Validator for StrictBytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -97,6 +101,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let bytes = match self.strict { true => input.strict_bytes()?, @@ -111,6 +116,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_bytes()?) } diff --git a/src/validators/date.rs b/src/validators/date.rs index 39eb312e8d..927532e1bb 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -5,6 +5,7 @@ use speedate::{Date, Time}; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::{EitherDate, Input}; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -58,6 +59,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_date()?, @@ -80,6 +82,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_date()?) } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 6f3b5465b5..e628071802 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -5,6 +5,7 @@ use speedate::DateTime; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherDateTime, Input}; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -58,6 +59,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_datetime()?, @@ -72,6 +74,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_datetime()?) } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 473035965b..edabb2a510 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input, JsonObject}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -49,12 +50,13 @@ impl Validator for DictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let dict = match self.strict { true => input.strict_dict()?, false => input.lax_dict()?, }; - self._validation_logic(py, input, dict, extra, slots) + self._validation_logic(py, input, dict, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -63,8 +65,9 @@ impl Validator for DictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_dict()?, extra, slots) + self._validation_logic(py, input, input.strict_dict()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -86,6 +89,7 @@ macro_rules! build_validate { dict: &'data $dict_type, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if let Some(min_length) = self.min_items { if dict.len() < min_length { @@ -112,7 +116,7 @@ macro_rules! build_validate { let value_validator = self.value_validator.as_ref(); for (key, value) in dict.iter() { - let output_key = match key_validator.validate(py, key, extra, slots) { + let output_key = match key_validator.validate(py, key, extra, slots, recursion_guard) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -125,7 +129,7 @@ macro_rules! build_validate { } Err(err) => return Err(err), }; - let output_value = match value_validator.validate(py, value, extra, slots) { + let output_value = match value_validator.validate(py, value, extra, slots, recursion_guard) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -160,11 +164,14 @@ impl DictValidator { dict: GenericMapping<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match dict { - GenericMapping::PyDict(py_dict) => self.validate_dict(py, input, py_dict, extra, slots), + GenericMapping::PyDict(py_dict) => self.validate_dict(py, input, py_dict, extra, slots, recursion_guard), GenericMapping::PyGetAttr(_) => unreachable!(), - GenericMapping::JsonObject(json_object) => self.validate_json_object(py, input, json_object, extra, slots), + GenericMapping::JsonObject(json_object) => { + self.validate_json_object(py, input, json_object, extra, slots, recursion_guard) + } } } } diff --git a/src/validators/float.rs b/src/validators/float.rs index 13faeb15ab..37993c422e 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -40,6 +41,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_float()?.into_py(py)) } @@ -50,6 +52,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -75,6 +78,7 @@ impl Validator for StrictFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -101,6 +105,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let float = match self.strict { true => input.strict_float()?, @@ -115,6 +120,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_float()?) } diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index c977c03d33..b44bf8e108 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyFrozenSet}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; @@ -29,12 +30,13 @@ impl Validator for FrozenSetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let frozenset = match self.strict { true => input.strict_frozenset()?, false => input.lax_frozenset()?, }; - self._validation_logic(py, input, frozenset, extra, slots) + self._validation_logic(py, input, frozenset, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +45,9 @@ impl Validator for FrozenSetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_frozenset()?, extra, slots) + self._validation_logic(py, input, input.strict_frozenset()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +63,7 @@ impl FrozenSetValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +85,7 @@ impl FrozenSetValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PyFrozenSet::new(py, &output).map_err(as_internal)?.into_py(py)) } } diff --git a/src/validators/function.rs b/src/validators/function.rs index 094a27cdb1..8f743d927e 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -5,6 +5,7 @@ use pyo3::types::{PyAny, PyDict}; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{context, val_line_error, ErrorKind, ValError, ValResult, ValidationError}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -71,6 +72,7 @@ impl Validator for FunctionBeforeValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); let value = self @@ -79,7 +81,7 @@ impl Validator for FunctionBeforeValidator { .map_err(|e| convert_err(py, e, input))?; // maybe there's some way to get the PyAny here and explicitly tell rust it should have lifespan 'a? let new_input: &PyAny = value.as_ref(py); - match self.validator.validate(py, new_input, extra, slots) { + match self.validator.validate(py, new_input, extra, slots, recursion_guard) { Ok(v) => Ok(v), Err(ValError::InternalErr(err)) => Err(ValError::InternalErr(err)), Err(ValError::LineErrors(line_errors)) => { @@ -116,8 +118,9 @@ impl Validator for FunctionAfterValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let v = self.validator.validate(py, input, extra, slots)?; + let v = self.validator.validate(py, input, extra, slots, recursion_guard)?; let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); self.func.call(py, (v,), kwargs).map_err(|e| convert_err(py, e, input)) } @@ -154,6 +157,7 @@ impl Validator for FunctionPlainValidator { input: &'data impl Input<'data>, extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); self.func @@ -182,12 +186,14 @@ impl Validator for FunctionWrapValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let validator_kwarg = ValidatorCallable { validator: self.validator.clone(), slots: slots.to_vec(), data: extra.data.map(|d| d.into_py(py)), field: extra.field.map(|f| f.to_string()), + recursion_guard: recursion_guard.clone(), }; let kwargs = kwargs!( py, @@ -212,17 +218,18 @@ struct ValidatorCallable { slots: Vec, data: Option>, field: Option, + recursion_guard: RecursionGuard, } #[pymethods] impl ValidatorCallable { - fn __call__(&self, py: Python, arg: &PyAny) -> PyResult { + fn __call__(&mut self, py: Python, arg: &PyAny) -> PyResult { let extra = Extra { data: self.data.as_ref().map(|data| data.as_ref(py)), field: self.field.as_deref(), }; self.validator - .validate(py, arg, &extra, &self.slots) + .validate(py, arg, &extra, &self.slots, &mut self.recursion_guard) .map_err(|e| ValidationError::from_val_error(py, "Model".to_object(py), e)) } @@ -252,6 +259,9 @@ fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> Val // Only ValueError and AssertionError are considered as validation errors, // TypeError is now considered as a runtime error to catch errors in function signatures let kind = if err.is_instance_of::(py) { + if let Ok(validation_error) = err.value(py).extract::() { + return validation_error.into(); + } ErrorKind::ValueError } else if err.is_instance_of::(py) { ErrorKind::AssertionError diff --git a/src/validators/int.rs b/src/validators/int.rs index 21b0718f49..758d4f1f34 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -40,6 +41,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_int()?.into_py(py)) } @@ -50,6 +52,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -75,6 +78,7 @@ impl Validator for StrictIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -101,6 +105,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = match self.strict { true => input.strict_int()?, @@ -115,6 +120,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_int()?) } diff --git a/src/validators/list.rs b/src/validators/list.rs index c7bb225135..61f261240b 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; @@ -50,12 +51,13 @@ impl Validator for ListValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let list = match self.strict { true => input.strict_list()?, false => input.lax_list()?, }; - self._validation_logic(py, input, list, extra, slots) + self._validation_logic(py, input, list, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -64,8 +66,9 @@ impl Validator for ListValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_list()?, extra, slots) + self._validation_logic(py, input, input.strict_list()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -81,6 +84,7 @@ impl ListValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -102,7 +106,7 @@ impl ListValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(output.into_py(py)) } } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index aee203d055..534ddfc54f 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -6,6 +6,7 @@ use ahash::AHashSet; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -63,6 +64,7 @@ impl Validator for LiteralSingleStringValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if either_str.as_cow().as_ref() == self.expected.as_str() { @@ -99,6 +101,7 @@ impl Validator for LiteralSingleIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = input.strict_int()?; if str == self.expected { @@ -150,6 +153,7 @@ impl Validator for LiteralMultipleStringsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if self.expected.contains(either_str.as_cow().as_ref()) { @@ -201,6 +205,7 @@ impl Validator for LiteralMultipleIntsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = input.strict_int()?; if self.expected.contains(&int) { @@ -260,6 +265,7 @@ impl Validator for LiteralGeneralValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if !self.expected_int.is_empty() { if let Ok(int) = input.strict_int() { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 33ac898cbd..7799d0e9df 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -10,6 +10,7 @@ use serde_json::from_str as parse_json; use crate::build_tools::{py_error, SchemaDict, SchemaError}; use crate::errors::{context, val_line_error, ErrorKind, ValError, ValResult, ValidationError}; use crate::input::{Input, JsonInput}; +use crate::recursion_guard::RecursionGuard; mod any; mod bool; @@ -74,12 +75,24 @@ impl SchemaValidator { } pub fn validate_python(&self, py: Python, input: &PyAny) -> PyResult { - let r = self.validator.validate(py, input, &Extra::default(), &self.slots); + let r = self.validator.validate( + py, + input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ); r.map_err(|e| self.prepare_validation_err(py, e)) } pub fn isinstance_python(&self, py: Python, input: &PyAny) -> PyResult { - match self.validator.validate(py, input, &Extra::default(), &self.slots) { + match self.validator.validate( + py, + input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ) { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), _ => Ok(false), @@ -89,7 +102,13 @@ impl SchemaValidator { pub fn validate_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { Ok(input) => { - let r = self.validator.validate(py, &input, &Extra::default(), &self.slots); + let r = self.validator.validate( + py, + &input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ); r.map_err(|e| self.prepare_validation_err(py, e)) } Err(e) => { @@ -106,11 +125,19 @@ impl SchemaValidator { pub fn isinstance_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { - Ok(input) => match self.validator.validate(py, &input, &Extra::default(), &self.slots) { - Ok(_) => Ok(true), - Err(ValError::InternalErr(err)) => Err(err), - _ => Ok(false), - }, + Ok(input) => { + match self.validator.validate( + py, + &input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ) { + Ok(_) => Ok(true), + Err(ValError::InternalErr(err)) => Err(err), + _ => Ok(false), + } + } Err(_) => Ok(false), } } @@ -120,7 +147,9 @@ impl SchemaValidator { data: Some(data), field: Some(field.as_str()), }; - let r = self.validator.validate(py, input, &extra, &self.slots); + let r = self + .validator + .validate(py, input, &extra, &self.slots, &mut RecursionGuard::default()); r.map_err(|e| self.prepare_validation_err(py, e)) } @@ -344,6 +373,7 @@ pub trait Validator: Send + Sync + Clone + Debug { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject>; /// This is used in unions for the first pass to see if we have an "exact match", @@ -354,8 +384,9 @@ pub trait Validator: Send + Sync + Clone + Debug { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self.validate(py, input, extra, slots) + self.validate(py, input, extra, slots, recursion_guard) } /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator diff --git a/src/validators/model_class.rs b/src/validators/model_class.rs index f167b5edf1..68eb4bdac5 100644 --- a/src/validators/model_class.rs +++ b/src/validators/model_class.rs @@ -10,6 +10,7 @@ use pyo3::{ffi, intern}; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -59,6 +60,7 @@ impl Validator for ModelClassValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let class = self.class.as_ref(py); if input.strict_model_check(class)? { @@ -70,7 +72,7 @@ impl Validator for ModelClassValidator { context = context!("class_name" => self.get_name(py)) ) } else { - let output = self.validator.validate(py, input, extra, slots)?; + let output = self.validator.validate(py, input, extra, slots, recursion_guard)?; self.create_class(py, output).map_err(as_internal) } } @@ -81,6 +83,7 @@ impl Validator for ModelClassValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if input.strict_model_check(self.class.as_ref(py))? { Ok(input.to_object(py)) diff --git a/src/validators/none.rs b/src/validators/none.rs index 94b3116779..6e7aef25bc 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -3,6 +3,7 @@ use pyo3::types::PyDict; use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -28,6 +29,7 @@ impl Validator for NoneValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 47cef2a292..dd2f059833 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -3,6 +3,7 @@ use pyo3::types::PyDict; use crate::build_tools::SchemaDict; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; @@ -34,10 +35,11 @@ impl Validator for NullableValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), - false => self.validator.validate(py, input, extra, slots), + false => self.validator.validate(py, input, extra, slots, recursion_guard), } } @@ -47,10 +49,11 @@ impl Validator for NullableValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), - false => self.validator.validate_strict(py, input, extra, slots), + false => self.validator.validate_strict(py, input, extra, slots, recursion_guard), } } diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 5de7c26075..ba4f3a0eb8 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -2,8 +2,9 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::SchemaDict; -use crate::errors::ValResult; +use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -25,9 +26,9 @@ impl Validator for RecursiveContainerValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots) + validate(self.validator_id, py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { @@ -61,12 +62,36 @@ impl Validator for RecursiveRefValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots) + validate(self.validator_id, py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { Self::EXPECTED_TYPE.to_string() } } + +fn validate<'s, 'data>( + validator_id: usize, + py: Python<'data>, + input: &'data impl Input<'data>, + extra: &Extra, + slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, +) -> ValResult<'data, PyObject> { + if let Some(id) = input.identity() { + if recursion_guard.contains_or_insert(id) { + // remove ID in case we use recursion_guard again + recursion_guard.remove(&id); + return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value()); + } + let validator = unsafe { slots.get_unchecked(validator_id) }; + let output = validator.validate(py, input, extra, slots, recursion_guard); + recursion_guard.remove(&id); + output + } else { + let validator = unsafe { slots.get_unchecked(validator_id) }; + validator.validate(py, input, extra, slots, recursion_guard) + } +} diff --git a/src/validators/set.rs b/src/validators/set.rs index fd40b71d69..60203d027c 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PySet}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; @@ -29,12 +30,13 @@ impl Validator for SetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let set = match self.strict { true => input.strict_set()?, false => input.lax_set()?, }; - self._validation_logic(py, input, set, extra, slots) + self._validation_logic(py, input, set, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +45,9 @@ impl Validator for SetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_set()?, extra, slots) + self._validation_logic(py, input, input.strict_set()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +63,7 @@ impl SetValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +85,7 @@ impl SetValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PySet::new(py, &output).map_err(as_internal)?.into_py(py)) } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 678f015b7a..330be9d53c 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -5,6 +5,7 @@ use regex::Regex; use crate::build_tools::{is_strict, py_error, schema_or_config}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherString, Input}; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -53,6 +54,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_str()?.into_py(py)) } @@ -63,6 +65,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -88,6 +91,7 @@ impl Validator for StrictStrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -115,6 +119,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = match self.strict { true => input.strict_str()?, @@ -129,6 +134,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_str()?) } diff --git a/src/validators/time.rs b/src/validators/time.rs index 871d432c96..350927a5cc 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -5,6 +5,7 @@ use speedate::Time; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherTime, Input}; +use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -58,6 +59,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let time = match self.strict { true => input.strict_time()?, @@ -72,6 +74,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_time()?) } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index f5f7e329c2..62f563dd07 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use crate::build_tools::{is_strict, py_error, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValError, ValLineError}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; @@ -29,12 +30,13 @@ impl Validator for TupleVarLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let tuple = match self.strict { true => input.strict_tuple()?, false => input.lax_tuple()?, }; - self._validation_logic(py, input, tuple, extra, slots) + self._validation_logic(py, input, tuple, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +45,9 @@ impl Validator for TupleVarLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_tuple()?, extra, slots) + self._validation_logic(py, input, input.strict_tuple()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +63,7 @@ impl TupleVarLenValidator { tuple: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = tuple.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +85,7 @@ impl TupleVarLenValidator { } } - let output = tuple.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = tuple.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PyTuple::new(py, &output).into_py(py)) } } @@ -124,12 +128,13 @@ impl Validator for TupleFixLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let tuple = match self.strict { true => input.strict_tuple()?, false => input.lax_tuple()?, }; - self._validation_logic(py, input, tuple, extra, slots) + self._validation_logic(py, input, tuple, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -138,8 +143,9 @@ impl Validator for TupleFixLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_tuple()?, extra, slots) + self._validation_logic(py, input, input.strict_tuple()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -161,6 +167,7 @@ impl TupleFixLenValidator { tuple: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let expected_length = self.items_validators.len(); @@ -181,7 +188,7 @@ impl TupleFixLenValidator { macro_rules! iter { ($sequence:expr) => { for (validator, (index, item)) in self.items_validators.iter().zip($sequence.iter().enumerate()) { - match validator.validate(py, item, extra, slots) { + match validator.validate(py, item, extra, slots, recursion_guard) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend( diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 6783bba1a9..6183e7db6d 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -10,6 +10,7 @@ use crate::errors::{ as_internal, context, err_val_error, py_err_string, val_line_error, ErrorKind, ValError, ValLineError, ValResult, }; use crate::input::{GenericMapping, Input, JsonInput, JsonObject}; +use crate::recursion_guard::RecursionGuard; use crate::SchemaError; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -127,10 +128,11 @@ impl Validator for TypedDictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if let Some(field) = extra.field { // we're validating assignment, completely different logic - return self.validate_assignment(py, field, input, extra, slots); + return self.validate_assignment(py, field, input, extra, slots, recursion_guard); } let dict = input.typed_dict(self.from_attributes, !self.strict)?; @@ -176,7 +178,7 @@ impl Validator for TypedDictValidator { // extra logic either way used_keys.insert(used_key); } - match field.validator.validate(py, value, &extra, slots) { + match field.validator.validate(py, value, &extra, slots, recursion_guard) { Ok(value) => { output_dict .set_item(&field.name_pystring, value) @@ -243,7 +245,7 @@ impl Validator for TypedDictValidator { } if let Some(ref validator) = self.extra_validator { - match validator.validate(py, value, &extra, slots) { + match validator.validate(py, value, &extra, slots, recursion_guard) { Ok(value) => { output_dict.set_item(py_key, value).map_err(as_internal)?; if let Some(ref mut fs) = fields_set_vec { @@ -298,6 +300,7 @@ impl TypedDictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> where 'data: 's, @@ -331,11 +334,11 @@ impl TypedDictValidator { }; if let Some(field) = self.fields.iter().find(|f| f.name == field) { - prepare_result(field.validator.validate(py, input, extra, slots)) + prepare_result(field.validator.validate(py, input, extra, slots, recursion_guard)) } else if self.check_extra && !self.forbid_extra { // this is the "allow" case of extra_behavior match self.extra_validator { - Some(ref validator) => prepare_result(validator.validate(py, input, extra, slots)), + Some(ref validator) => prepare_result(validator.validate(py, input, extra, slots, recursion_guard)), None => prepare_tuple(input.to_object(py)), } } else { diff --git a/src/validators/union.rs b/src/validators/union.rs index 2d4f05ebd1..581a571d93 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{ValError, ValLineError}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; @@ -38,12 +39,13 @@ impl Validator for UnionValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if self.strict { let mut errors: Vec = Vec::with_capacity(self.choices.len()); for validator in &self.choices { - let line_errors = match validator.validate_strict(py, input, extra, slots) { + let line_errors = match validator.validate_strict(py, input, extra, slots, recursion_guard) { Err(ValError::LineErrors(line_errors)) => line_errors, otherwise => return otherwise, }; @@ -61,7 +63,7 @@ impl Validator for UnionValidator { if let Some(res) = self .choices .iter() - .map(|validator| validator.validate_strict(py, input, extra, slots)) + .map(|validator| validator.validate_strict(py, input, extra, slots, recursion_guard)) .find(ValResult::is_ok) { return res; @@ -71,7 +73,7 @@ impl Validator for UnionValidator { // 2nd pass: check if the value can be coerced into one of the Union types for validator in &self.choices { - let line_errors = match validator.validate(py, input, extra, slots) { + let line_errors = match validator.validate(py, input, extra, slots, recursion_guard) { Err(ValError::LineErrors(line_errors)) => line_errors, otherwise => return otherwise, }; diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 57cafcddf0..c4f4739546 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -103,9 +103,7 @@ def f(input_value, *, validator, **kwargs): v = SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'function': f, 'schema': 'str'}) - # with pytest.raises(ValidationError) as exc_info: assert v.validate_python('input value') == 'input value Changed' - # print(exc_info.value) def test_function_wrap_repr(): @@ -134,6 +132,25 @@ def test_function_wrap_not_callable(): SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'schema': 'str'}) +def test_wrap_error(): + def f(input_value, *, validator, **kwargs): + return validator(input_value) * 2 + + v = SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'function': f, 'schema': 'int'}) + + assert v.validate_python('42') == 84 + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + assert exc_info.value.errors() == [ + { + 'kind': 'int_parsing', + 'loc': [], + 'message': 'Value must be a valid integer, unable to parse string as an integer', + 'input_value': 'wrong', + } + ] + + def test_wrong_mode(): with pytest.raises(SchemaError, match='SchemaError: Unexpected function mode "foobar"'): SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'foobar', 'schema': 'str'}) diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index a4e39f8103..38541b01b1 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -1,9 +1,13 @@ from typing import Optional import pytest +from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict from pydantic_core import SchemaError, SchemaValidator, ValidationError +from ..conftest import Err +from .test_typed_dict import Cls + def test_branch_nullable(): v = SchemaValidator( @@ -231,3 +235,206 @@ def test_outside_parent(): 'tuple1': (1, 1, 'frog'), 'tuple2': (2, 2, 'toad'), } + + +def test_recursion_branch(): + v = SchemaValidator( + { + 'type': 'typed-dict', + 'ref': 'Branch', + 'config': {'from_attributes': True}, + 'fields': { + 'name': {'schema': {'type': 'str'}}, + 'branch': { + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, + 'default': None, + }, + }, + } + ) + assert v.validate_python({'name': 'root'}) == {'name': 'root', 'branch': None} + assert v.validate_python({'name': 'root', 'branch': {'name': 'b1', 'branch': None}}) == { + 'name': 'root', + 'branch': {'name': 'b1', 'branch': None}, + } + + data = Cls(name='root') + data.branch = Cls(name='b1', branch=None) + assert v.validate_python(data) == {'name': 'root', 'branch': {'name': 'b1', 'branch': None}} + + b = {'name': 'recursive'} + b['branch'] = b + with pytest.raises(ValidationError) as exc_info: + assert v.validate_python(b) + assert exc_info.value.title == 'recursive-container' + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['branch'], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': {'name': 'recursive', 'branch': IsPartialDict(name='recursive')}, + } + ] + + data = Cls(name='root') + data.branch = data + with pytest.raises(ValidationError) as exc_info: + v.validate_python(data) + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['branch'], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': HasAttributes(name='root', branch=AnyThing()), + } + ] + + +def test_recursive_list(): + v = SchemaValidator( + {'type': 'list', 'ref': 'the-list', 'items_schema': {'type': 'recursive-ref', 'schema_ref': 'the-list'}} + ) + assert v.validate_python([]) == [] + assert v.validate_python([[]]) == [[]] + + data = list() + data.append(data) + with pytest.raises(ValidationError, match='Recursion error - cyclic reference detected'): + assert v.validate_python(data) + + +@pytest.fixture(scope='module') +def multiple_tuple_schema(): + return SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'f1': { + 'schema': { + 'type': 'tuple-fix-len', + 'ref': 't', + 'items_schema': [ + {'type': 'int'}, + {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 't'}}, + ], + } + }, + 'f2': { + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 't'}}, + 'default': None, + }, + }, + } + ) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + ({'f1': [1, None]}, {'f1': (1, None), 'f2': None}), + ({'f1': [1, None], 'f2': [2, None]}, {'f1': (1, None), 'f2': (2, None)}), + ( + {'f1': [1, (3, None)], 'f2': [2, (4, (4, (5, None)))]}, + {'f1': (1, (3, None)), 'f2': (2, (4, (4, (5, None))))}, + ), + ({'f1': [1, 2]}, Err(r'f1 -> 1\s+Value must be a valid tuple')), + ( + {'f1': [1, (3, None)], 'f2': [2, (4, (4, (5, 6)))]}, + Err(r'f2 -> 1 -> 1 -> 1 -> 1\s+Value must be a valid tuple'), + ), + ], +) +def test_multiple_tuple_param(multiple_tuple_schema, input_value, expected): + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=expected.message): + multiple_tuple_schema.validate_python(input_value) + # debug(repr(exc_info.value)) + else: + assert multiple_tuple_schema.validate_python(input_value) == expected + + +def test_multiple_tuple_repeat(multiple_tuple_schema): + t = (42, None) + assert multiple_tuple_schema.validate_python({'f1': (1, t), 'f2': (2, t)}) == { + 'f1': (1, (42, None)), + 'f2': (2, (42, None)), + } + + +def test_multiple_tuple_recursion(multiple_tuple_schema): + data = [1] + data.append(data) + with pytest.raises(ValidationError) as exc_info: + multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) + + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['f1', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + { + 'kind': 'recursion_loop', + 'loc': ['f2', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + ] + + +def test_multiple_tuple_recursion_once(multiple_tuple_schema): + data = [1] + data.append(data) + with pytest.raises(ValidationError) as exc_info: + multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) + + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['f1', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + { + 'kind': 'recursion_loop', + 'loc': ['f2', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + ] + + +def test_recursive_wrap(): + def wrap_func(input_value, *, validator, **kwargs): + return validator(input_value) + (42,) + + v = SchemaValidator( + { + 'type': 'function', + 'ref': 'wrapper', + 'mode': 'wrap', + 'function': wrap_func, + 'schema': { + 'type': 'tuple-fix-len', + 'items_schema': [ + {'type': 'int'}, + {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'wrapper'}}, + ], + }, + } + ) + assert v.validate_python((1, None)) == (1, None, 42) + assert v.validate_python((1, (2, (3, None)))) == (1, (2, (3, None, 42), 42), 42) + t = [1] + t.append(t) + with pytest.raises(ValidationError) as exc_info: + v.validate_python(t) + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': [1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': IsList(positions={0: 1}, length=2), + } + ] diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index a3576835d4..3999d3bf5f 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -678,6 +678,8 @@ class MyDataclass: (Cls(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), (dict(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), (Map(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), + # using type gives `__module__ == 'builtins'` + (type('Testing', (), {}), Err('[kind=dict_attributes_type,')), ('123', Err('Value must be a valid dictionary or instance to extract fields from [kind=dict_attributes_type,')), ([(1, 2)], Err('kind=dict_attributes_type,')), (((1, 2),), Err('kind=dict_attributes_type,')),