diff --git a/newsfragments/3379.changed.md b/newsfragments/3379.changed.md new file mode 100644 index 00000000000..b5a246eeabe --- /dev/null +++ b/newsfragments/3379.changed.md @@ -0,0 +1 @@ +Sped up FromPyObject::extract for BigInt and BigUint by up to 43% (although mileage may vary depending on int size and sign) diff --git a/src/conversions/num_bigint.rs b/src/conversions/num_bigint.rs index 0c61c2f07b5..d9428fea932 100644 --- a/src/conversions/num_bigint.rs +++ b/src/conversions/num_bigint.rs @@ -51,25 +51,37 @@ use crate::{ ffi, types::*, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; -use num_bigint::{BigInt, BigUint}; -use std::os::raw::{c_int, c_uchar}; +use num_bigint::{BigInt, BigUint, Sign}; +use std::os::raw::c_int; #[cfg(not(Py_LIMITED_API))] -unsafe fn extract(ob: &PyLong, buffer: &mut [c_uchar], is_signed: c_int) -> PyResult<()> { +use std::os::raw::c_uchar; + +#[cfg(Py_LIMITED_API)] +use std::slice; + +#[cfg(not(Py_LIMITED_API))] +#[inline] +unsafe fn extract(ob: &PyLong, length: usize, is_signed: c_int) -> PyResult> { + let mut buffer = Vec::::with_capacity(length); crate::err::error_on_minusone( ob.py(), ffi::_PyLong_AsByteArray( ob.as_ptr() as *mut ffi::PyLongObject, - buffer.as_mut_ptr(), - buffer.len(), + buffer.as_mut_ptr() as *mut u8, + length * 4, 1, is_signed, ), - ) + )?; + buffer.set_len(length); + + Ok(buffer) } #[cfg(Py_LIMITED_API)] -unsafe fn extract(ob: &PyLong, buffer: &mut [c_uchar], is_signed: c_int) -> PyResult<()> { +#[inline] +unsafe fn extract(ob: &PyLong, length: usize, is_signed: c_int) -> PyResult> { use crate::intern; let py = ob.py(); let kwargs = if is_signed != 0 { @@ -81,14 +93,16 @@ unsafe fn extract(ob: &PyLong, buffer: &mut [c_uchar], is_signed: c_int) -> PyRe }; let bytes_obj = ob .getattr(intern!(py, "to_bytes"))? - .call((buffer.len(), "little"), kwargs)?; + .call((length * 4, intern!(py, "little")), kwargs)?; let bytes: &PyBytes = bytes_obj.downcast_unchecked(); - buffer.copy_from_slice(bytes.as_bytes()); - Ok(()) + let bytes_u32 = slice::from_raw_parts(bytes.as_bytes().as_ptr().cast(), length); + + Ok(bytes_u32.to_vec()) } +// for identical functionality between BigInt and BigUint macro_rules! bigint_conversion { - ($rust_ty: ty, $is_signed: expr, $to_bytes: path, $from_bytes: path) => { + ($rust_ty: ty, $is_signed: expr, $to_bytes: path) => { #[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))] impl ToPyObject for $rust_ty { #[cfg(not(Py_LIMITED_API))] @@ -129,62 +143,92 @@ macro_rules! bigint_conversion { self.to_object(py) } } + }; +} - #[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))] - impl<'source> FromPyObject<'source> for $rust_ty { - fn extract(ob: &'source PyAny) -> PyResult<$rust_ty> { - let py = ob.py(); - unsafe { - let num: Py = - Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))?; +bigint_conversion!(BigUint, 0, BigUint::to_bytes_le); +bigint_conversion!(BigInt, 1, BigInt::to_signed_bytes_le); - let n_bytes = { - cfg_if::cfg_if! { - if #[cfg(not(Py_LIMITED_API))] { - // fast path - let n_bits = ffi::_PyLong_NumBits(num.as_ptr()); - if n_bits == (-1isize as usize) { - return Err(crate::PyErr::fetch(py)); - } else if n_bits == 0 { - 0 - } else { - (n_bits - 1 + $is_signed) / 8 + 1 - } - } else { - // slow path - let n_bits_obj = num.getattr(py, crate::intern!(py, "bit_length"))?.call0(py)?; - let n_bits_int: &PyLong = n_bits_obj.downcast_unchecked(py); - let n_bits = n_bits_int.extract::()?; - if n_bits == 0 { - 0 - } else { - (n_bits - 1 + $is_signed) / 8 + 1 - } - } +#[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))] +impl<'source> FromPyObject<'source> for BigInt { + fn extract(ob: &'source PyAny) -> PyResult { + let py = ob.py(); + unsafe { + let num: Py = Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))?; + let n_bits = { + cfg_if::cfg_if! { + if #[cfg(not(Py_LIMITED_API))] { + // fast path + let n_bits = ffi::_PyLong_NumBits(num.as_ptr()); + if n_bits == (-1isize as usize) { + return Err(crate::PyErr::fetch(py)); } - }; + n_bits + } else { + // slow path + let n_bits_obj = num.getattr(py, crate::intern!(py, "bit_length"))?.call0(py)?; + let n_bits_int: &PyLong = n_bits_obj.downcast_unchecked(py); + n_bits_int.extract::()? + } + } + }; + + if n_bits == 0 { + return Ok(BigInt::from(0isize)); + } + let n_digits = (n_bits + 32) / 32; + let mut buffer = extract(num.as_ref(py), n_digits, 1)?; + buffer + .iter_mut() + .for_each(|chunk| *chunk = u32::from_le(*chunk)); + + Ok(if buffer.last().unwrap() >> 31 != 0 { + buffer.iter_mut().for_each(|element| *element = !*element); + BigInt::new(Sign::Minus, buffer) - 1 + } else { + BigInt::new(Sign::Plus, buffer) + }) + } + } +} - if n_bytes <= 128 { - let mut buffer = [0; 128]; - extract(num.as_ref(py), &mut buffer[..n_bytes], $is_signed)?; - Ok($from_bytes(&buffer[..n_bytes])) +#[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))] +impl<'source> FromPyObject<'source> for BigUint { + fn extract(ob: &'source PyAny) -> PyResult { + let py = ob.py(); + unsafe { + let num: Py = Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))?; + let n_bits = { + cfg_if::cfg_if! { + if #[cfg(not(Py_LIMITED_API))] { + // fast path + let n_bits = ffi::_PyLong_NumBits(num.as_ptr()); + if n_bits == (-1isize as usize) { + return Err(crate::PyErr::fetch(py)); + } + n_bits } else { - let mut buffer = vec![0; n_bytes]; - extract(num.as_ref(py), &mut buffer, $is_signed)?; - Ok($from_bytes(&buffer)) + // slow path + let n_bits_obj = num.getattr(py, crate::intern!(py, "bit_length"))?.call0(py)?; + let n_bits_int: &PyLong = n_bits_obj.downcast_unchecked(py); + n_bits_int.extract::()? } } + }; + + if n_bits == 0 { + return Ok(BigUint::from(0usize)); } + let n_digits = (n_bits + 31) / 32; + let mut buffer = extract(num.as_ref(py), n_digits, 0)?; + buffer + .iter_mut() + .for_each(|chunk| *chunk = u32::from_le(*chunk)); + + Ok(BigUint::new(buffer)) } - }; + } } -bigint_conversion!(BigUint, 0, BigUint::to_bytes_le, BigUint::from_bytes_le); -bigint_conversion!( - BigInt, - 1, - BigInt::to_signed_bytes_le, - BigInt::from_signed_bytes_le -); #[cfg(test)] mod tests { @@ -312,7 +356,7 @@ mod tests { ($T:ty, $value:expr, $py:expr) => { let value = $value; println!("{}: {}", stringify!($T), value); - let python_value = value.clone().to_object(py); + let python_value = value.clone().into_py(py); let roundtrip_value = python_value.extract::<$T>(py).unwrap(); assert_eq!(value, roundtrip_value); };