Skip to content

Commit

Permalink
Fix parsing int from large decimals (pydantic#948)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com>
  • Loading branch information
adriangb and davidhewitt committed Sep 7, 2023
1 parent 2d9df49 commit 6769140
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 46 deletions.
12 changes: 6 additions & 6 deletions src/input/input_abstract.rs
Expand Up @@ -152,17 +152,17 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_float()
}

fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
if strict {
self.strict_decimal(decimal_type)
self.strict_decimal(py)
} else {
self.lax_decimal(decimal_type)
self.lax_decimal(py)
}
}
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>;
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
self.strict_decimal(decimal_type)
fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
self.strict_decimal(py)
}

fn validate_dict(&'a self, strict: bool) -> ValResult<GenericMapping<'a>> {
Expand Down
14 changes: 6 additions & 8 deletions src/input/input_json.rs
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString, PyType};
use pyo3::types::{PyDict, PyString};
use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

Expand Down Expand Up @@ -178,13 +178,12 @@ impl<'a> Input<'a> for JsonInput {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
let py = decimal_type.py();
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
match self {
JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, decimal_type),
JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py),

JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => {
create_decimal(self.to_object(py).into_ref(py), self, decimal_type)
create_decimal(self.to_object(py).into_ref(py), self, py)
}
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
Expand Down Expand Up @@ -439,9 +438,8 @@ impl<'a> Input<'a> for String {
str_as_float(self, self)
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
let py = decimal_type.py();
create_decimal(self.to_object(py).into_ref(py), self, decimal_type)
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
create_decimal(self.to_object(py).into_ref(py), self, py)
}

#[cfg_attr(has_no_coverage, no_coverage)]
Expand Down
24 changes: 16 additions & 8 deletions src/input/input_python.rs
Expand Up @@ -13,15 +13,15 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::decimal::create_decimal;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
EitherTime,
};
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::{
py_string_str, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
Expand Down Expand Up @@ -324,6 +324,10 @@ impl<'a> Input<'a> for PyAny {
} else if PyInt::is_type_of(self) {
// force to an int to upcast to a pure python int to maintain current behaviour
EitherInt::upcast(self)
} else if PyFloat::is_exact_type_of(self) {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
decimal_as_int(self.py(), self, decimal)
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else {
Expand Down Expand Up @@ -367,15 +371,17 @@ impl<'a> Input<'a> for PyAny {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
let decimal_type = decimal_type_obj.as_ref(py);
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

// Try subclasses of decimals, they will be upcast to Decimal
if self.is_instance(decimal_type)? {
return create_decimal(self, self, decimal_type);
return create_decimal(self, self, py);
}

Err(ValError::new(
Expand All @@ -387,20 +393,22 @@ impl<'a> Input<'a> for PyAny {
))
}

fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> {
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
let decimal_type = decimal_type_obj.as_ref(py);
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
// checking isinstance for str / int / bool is fast compared to decimal / float
create_decimal(self, self, decimal_type)
create_decimal(self, self, py)
} else if self.is_instance(decimal_type)? {
// upcast subclasses to decimal
return create_decimal(self, self, decimal_type);
return create_decimal(self, self, py);
} else if self.is_instance_of::<PyFloat>() {
create_decimal(self.str()?, self, decimal_type)
create_decimal(self.str()?, self, py)
} else {
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
}
Expand Down
14 changes: 14 additions & 0 deletions src/input/shared.rs
@@ -1,4 +1,5 @@
use num_bigint::BigInt;
use pyo3::{intern, PyAny, Python};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::input::EitherInt;
Expand Down Expand Up @@ -136,3 +137,16 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a,
Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input))
}
}

pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<'a, EitherInt<'a>> {
if !decimal.call_method0(intern!(py, "is_finite"))?.extract::<bool>()? {
return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input));
}
let (numerator, denominator) = decimal
.call_method0(intern!(py, "as_integer_ratio"))?
.extract::<(&PyAny, &PyAny)>()?;
if denominator.extract::<i64>().map_or(true, |d| d != 1) {
return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input));
}
Ok(EitherInt::Py(numerator))
}
58 changes: 34 additions & 24 deletions src/validators/decimal.rs
@@ -1,4 +1,5 @@
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::sync::GILOnceCell;
use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType};
use pyo3::{intern, AsPyPointer};
use pyo3::{prelude::*, PyTypeInfo};
Expand All @@ -13,6 +14,21 @@ use crate::tools::SchemaDict;

use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

static DECIMAL_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();

pub fn get_decimal_type(py: Python) -> Py<PyType> {
DECIMAL_TYPE
.get_or_init(py, || {
py.import("decimal")
.and_then(|decimal_module| decimal_module.getattr("Decimal"))
.unwrap()
.extract::<&PyType>()
.unwrap()
.into()
})
.clone()
}

#[derive(Debug, Clone)]
pub struct DecimalValidator {
strict: bool,
Expand All @@ -25,7 +41,6 @@ pub struct DecimalValidator {
gt: Option<Py<PyAny>>,
max_digits: Option<u64>,
decimal_places: Option<u64>,
decimal_type: Py<PyType>,
}

impl BuildValidator for DecimalValidator {
Expand Down Expand Up @@ -55,10 +70,6 @@ impl BuildValidator for DecimalValidator {
ge: schema.get_as(intern!(py, "ge"))?,
gt: schema.get_as(intern!(py, "gt"))?,
max_digits,
decimal_type: py
.import(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()?,
}
.into())
}
Expand All @@ -69,8 +80,7 @@ impl_py_gc_traverse!(DecimalValidator {
le,
lt,
ge,
gt,
decimal_type
gt
});

impl Validator for DecimalValidator {
Expand All @@ -80,11 +90,7 @@ impl Validator for DecimalValidator {
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let decimal = input.validate_decimal(
state.strict_or(self.strict),
// Safety: self and py both outlive this call
unsafe { py.from_borrowed_ptr(self.decimal_type.as_ptr()) },
)?;
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;

if !self.allow_inf_nan || self.check_digits {
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {
Expand Down Expand Up @@ -244,19 +250,23 @@ impl Validator for DecimalValidator {
pub(crate) fn create_decimal<'a>(
arg: &'a PyAny,
input: &'a impl Input<'a>,
decimal_type: &'a PyType,
py: Python<'a>,
) -> ValResult<'a, &'a PyAny> {
decimal_type.call1((arg,)).map_err(|e| {
let decimal_exception = match arg
.py()
.import("decimal")
.and_then(|decimal_module| decimal_module.getattr("DecimalException"))
{
Ok(decimal_exception) => decimal_exception,
Err(e) => return ValError::InternalErr(e),
};
handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception)
})
let decimal_type_obj: Py<PyType> = get_decimal_type(py);
decimal_type_obj
.call1(py, (arg,))
.map_err(|e| {
let decimal_exception = match arg
.py()
.import("decimal")
.and_then(|decimal_module| decimal_module.getattr("DecimalException"))
{
Ok(decimal_exception) => decimal_exception,
Err(e) => return ValError::InternalErr(e),
};
handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception)
})
.map(|v| v.into_ref(py))
}

fn handle_decimal_new_error<'a>(
Expand Down
21 changes: 21 additions & 0 deletions tests/validators/test_int.py
Expand Up @@ -58,14 +58,24 @@ def test_int_py_and_json(py_and_json: PyAndJson, input_value, expected):
'input_value,expected',
[
(Decimal('1'), 1),
(Decimal('1' + '0' * 1_000), int('1' + '0' * 1_000)), # a large decimal
(Decimal('1.0'), 1),
(1.0, 1),
(i64_max, i64_max),
(str(i64_max), i64_max),
(str(i64_max * 2), i64_max * 2),
(i64_max + 1, i64_max + 1),
(-i64_max + 1, -i64_max + 1),
(i64_max * 2, i64_max * 2),
(-i64_max * 2, -i64_max * 2),
pytest.param(
1.00000000001,
Err(
'Input should be a valid integer, got a number with a fractional part '
'[type=int_from_float, input_value=1.00000000001, input_type=float]'
),
id='decimal-remainder',
),
pytest.param(
Decimal('1.001'),
Err(
Expand Down Expand Up @@ -437,3 +447,14 @@ def test_int_subclass_constraint() -> None:

with pytest.raises(ValidationError, match='Input should be greater than 0'):
v.validate_python(IntSubclass(0))


class FloatSubclass(float):
pass


def test_float_subclass() -> None:
v = SchemaValidator({'type': 'int'})
v_lax = v.validate_python(FloatSubclass(1))
assert v_lax == 1
assert type(v_lax) == int

0 comments on commit 6769140

Please sign in to comment.