Skip to content

Commit

Permalink
Recursion guard (pydantic#134)
Browse files Browse the repository at this point in the history
* adding recursion_guard argument

* fix linting, start on logic

* basic recursion implementation working

* make recursion guard option-al

* more tests

* move RecursionGuard, optimise recursion check

* tests for recursion across a wrap validator

* bump

* tweaks
  • Loading branch information
samuelcolvin committed Jun 29, 2022
1 parent 771c928 commit b95b3d2
Show file tree
Hide file tree
Showing 36 changed files with 565 additions and 73 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -19,6 +19,7 @@ indexmap = "1.8.1"
mimalloc = { version = "0.1.29", default-features = false, optional = true }
speedate = "0.4.1"
ahash = "0.7.6"
nohash-hasher = "0.2.0"

[lib]
name = "_pydantic_core"
Expand Down
4 changes: 4 additions & 0 deletions src/errors/kinds.rs
Expand Up @@ -8,6 +8,10 @@ pub enum ErrorKind {
#[strum(message = "Invalid JSON: {parser_error}")]
InvalidJson,
// ---------------------
// recursion error
#[strum(message = "Recursion error - cyclic reference detected")]
RecursionLoop,
// ---------------------
// typed dict specific errors
#[strum(message = "Value must be a valid dictionary or instance to extract fields from")]
DictAttributesType,
Expand Down
25 changes: 22 additions & 3 deletions src/errors/line_error.rs
Expand Up @@ -17,14 +17,27 @@ pub enum ValError<'a> {
InternalErr(PyErr),
}

impl<'a> From<PyErr> for ValError<'a> {
fn from(py_err: PyErr) -> Self {
Self::InternalErr(py_err)
}
}

impl<'a> From<Vec<ValLineError<'a>>> for ValError<'a> {
fn from(line_errors: Vec<ValLineError<'a>>) -> Self {
Self::LineErrors(line_errors)
}
}

// ValError used to implement Error, see #78 for removed code

// TODO, remove and replace with just .into()
pub fn as_internal<'a>(err: PyErr) -> ValError<'a> {
ValError::InternalErr(err)
err.into()
}

pub fn pretty_line_errors(py: Python, line_errors: Vec<ValLineError>) -> String {
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| e.into_py(py)).collect();
pretty_py_line_errors(Some(py), py_line_errors.iter())
}

Expand Down Expand Up @@ -58,7 +71,7 @@ impl<'a> ValLineError<'a> {
ValLineError {
kind: self.kind,
reverse_location: self.reverse_location,
input_value: InputValue::PyObject(self.input_value.to_object(py)),
input_value: self.input_value.to_object(py).into(),
context: self.context,
}
}
Expand All @@ -79,6 +92,12 @@ impl Default for InputValue<'_> {
}
}

impl<'a> From<PyObject> for InputValue<'a> {
fn from(py_object: PyObject) -> Self {
Self::PyObject(py_object)
}
}

impl<'a> ToPyObject for InputValue<'a> {
fn to_object(&self, py: Python) -> PyObject {
match self {
Expand Down
49 changes: 39 additions & 10 deletions src/errors/validation_exception.rs
Expand Up @@ -16,7 +16,7 @@ use super::location::Location;
use super::ValError;

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ValidationError {
line_errors: Vec<PyLineError>,
title: PyObject,
Expand All @@ -33,7 +33,7 @@ impl ValidationError {
pub fn from_val_error(py: Python, title: PyObject, error: ValError) -> PyErr {
match error {
ValError::LineErrors(raw_errors) => {
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| e.into_py(py)).collect();
PyErr::new::<ValidationError, _>((line_errors, title))
}
ValError::InternalErr(err) => err,
Expand Down Expand Up @@ -61,6 +61,18 @@ impl Error for ValidationError {
}
}

// used to convert a validation error back to ValError for wrap functions
impl<'a> From<ValidationError> for ValError<'a> {
fn from(val_error: ValidationError) -> Self {
val_error
.line_errors
.into_iter()
.map(|e| e.into())
.collect::<Vec<_>>()
.into()
}
}

#[pymethods]
impl ValidationError {
#[new]
Expand Down Expand Up @@ -131,19 +143,36 @@ pub struct PyLineError {
context: Context,
}

impl PyLineError {
pub fn new(py: Python, raw_error: ValLineError) -> Self {
impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
fn into_py(self, py: Python<'_>) -> PyLineError {
PyLineError {
kind: self.kind,
location: match self.reverse_location.len() {
0..=1 => self.reverse_location,
_ => self.reverse_location.into_iter().rev().collect(),
},
input_value: self.input_value.to_object(py),
context: self.context,
}
}
}

/// opposite of above, used to extract line errors from a validation error for wrap functions
impl<'a> From<PyLineError> for ValLineError<'a> {
fn from(py_line_error: PyLineError) -> Self {
Self {
kind: raw_error.kind,
location: match raw_error.reverse_location.len() {
0..=1 => raw_error.reverse_location,
_ => raw_error.reverse_location.into_iter().rev().collect(),
kind: py_line_error.kind,
reverse_location: match py_line_error.location.len() {
0..=1 => py_line_error.location,
_ => py_line_error.location.into_iter().rev().collect(),
},
input_value: raw_error.input_value.to_object(py),
context: raw_error.context,
input_value: py_line_error.input_value.into(),
context: py_line_error.context,
}
}
}

impl PyLineError {
pub fn as_dict(&self, py: Python) -> PyResult<PyObject> {
let dict = PyDict::new(py);
dict.set_item("kind", self.kind())?;
Expand Down
4 changes: 4 additions & 0 deletions src/input/input_abstract.rs
Expand Up @@ -16,6 +16,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn as_error_value(&'a self) -> InputValue<'a>;

fn identity(&'a self) -> Option<usize> {
None
}

fn is_none(&self) -> bool;

fn strict_str<'data>(&'data self) -> ValResult<EitherString<'data>>;
Expand Down
6 changes: 5 additions & 1 deletion src/input/input_python.rs
@@ -1,12 +1,12 @@
use std::str::from_utf8;

use pyo3::exceptions::{PyAttributeError, PyTypeError};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString,
PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

use crate::errors::location::LocItem;
use crate::errors::{as_internal, context, err_val_error, py_err_string, ErrorKind, InputValue, ValResult};
Expand Down Expand Up @@ -36,6 +36,10 @@ impl<'a> Input<'a> for PyAny {
InputValue::PyAny(self)
}

fn identity(&'a self) -> Option<usize> {
Some(self.as_ptr() as usize)
}

fn is_none(&self) -> bool {
self.is_none()
}
Expand Down
23 changes: 17 additions & 6 deletions src/input/return_enums.rs
Expand Up @@ -4,6 +4,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple};

use crate::errors::{ValError, ValLineError, ValResult};
use crate::recursion_guard::RecursionGuard;
use crate::validators::{CombinedValidator, Extra, Validator};

use super::parse_json::{JsonArray, JsonObject};
Expand Down Expand Up @@ -41,11 +42,12 @@ macro_rules! build_validate_to_vec {
validator: &'s CombinedValidator,
extra: &Extra,
slots: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(length);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item) in sequence.iter().enumerate() {
match validator.validate(py, item, extra, slots) {
match validator.validate(py, item, extra, slots, recursion_guard) {
Ok(item) => output.push(item),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
Expand Down Expand Up @@ -90,13 +92,22 @@ impl<'a> GenericSequence<'a> {
validator: &'s CombinedValidator,
extra: &Extra,
slots: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
match self {
Self::List(sequence) => validate_to_vec_list(py, sequence, length, validator, extra, slots),
Self::Tuple(sequence) => validate_to_vec_tuple(py, sequence, length, validator, extra, slots),
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots),
Self::FrozenSet(sequence) => validate_to_vec_frozenset(py, sequence, length, validator, extra, slots),
Self::JsonArray(sequence) => validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots),
Self::List(sequence) => {
validate_to_vec_list(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::Tuple(sequence) => {
validate_to_vec_tuple(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots, recursion_guard),
Self::FrozenSet(sequence) => {
validate_to_vec_frozenset(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::JsonArray(sequence) => {
validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots, recursion_guard)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Expand Up @@ -10,6 +10,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
mod build_tools;
mod errors;
mod input;
mod recursion_guard;
mod validators;

// required for benchmarks
Expand Down
36 changes: 36 additions & 0 deletions src/recursion_guard.rs
@@ -0,0 +1,36 @@
use std::collections::HashSet;
use std::hash::BuildHasherDefault;

use nohash_hasher::NoHashHasher;

/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard(Option<HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>>>);

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, id: usize) -> bool {
match self.0 {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert(id),
None => {
let mut set: HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>> =
HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default());
set.insert(id);
self.0 = Some(set);
false
}
}
}

pub fn remove(&mut self, id: &usize) {
match self.0 {
Some(ref mut set) => {
set.remove(id);
}
None => unreachable!(),
};
}
}
2 changes: 2 additions & 0 deletions src/validators/any.rs
Expand Up @@ -3,6 +3,7 @@ use pyo3::types::PyDict;

use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

Expand All @@ -29,6 +30,7 @@ impl Validator for AnyValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// Ok(input.clone().into_py(py))
Ok(input.to_object(py))
Expand Down
4 changes: 4 additions & 0 deletions src/validators/bool.rs
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::PyDict;
use crate::build_tools::is_strict;
use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

Expand Down Expand Up @@ -33,6 +34,7 @@ impl Validator for BoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
// and back again, might be worth profiling?
Expand All @@ -45,6 +47,7 @@ impl Validator for BoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.strict_bool()?.into_py(py))
}
Expand All @@ -70,6 +73,7 @@ impl Validator for StrictBoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.strict_bool()?.into_py(py))
}
Expand Down

0 comments on commit b95b3d2

Please sign in to comment.