Skip to content

Commit

Permalink
add PyAnyMethods for binary operators
Browse files Browse the repository at this point in the history
also pow

fixes PyO3#3709
  • Loading branch information
alex committed Dec 29, 2023
1 parent 6776b90 commit 339660c
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 0 deletions.
1 change: 1 addition & 0 deletions newsfragments/3712.added.md
@@ -0,0 +1 @@
Added methods to `PyAnyMethods` for binary operators (`add`, `sub`, etc.)
7 changes: 7 additions & 0 deletions src/tests/common.rs
Expand Up @@ -23,6 +23,13 @@ mod inner {
};
}

#[macro_export]
macro_rules! assert_py_eq {
($val:expr, $expected:expr) => {
assert!($val.eq($expected).unwrap());
};
}

#[macro_export]
macro_rules! py_expect_exception {
// Case1: idents & no err_msg
Expand Down
108 changes: 108 additions & 0 deletions src/types/any.rs
Expand Up @@ -1208,6 +1208,58 @@ pub trait PyAnyMethods<'py> {
where
O: ToPyObject;

/// Computes `self + other`.
fn add<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self - other`.
fn sub<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self * other`.
fn mul<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self / other`.
fn div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self << other`.
fn lshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self >> other`.
fn rshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject;

/// Computes `self & other`.
fn bitand<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self | other`.
fn bitor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self ^ other`.
fn bitxor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Determines whether this object appears callable.
///
/// This is equivalent to Python's [`callable()`][1] function.
Expand Down Expand Up @@ -1680,6 +1732,26 @@ pub trait PyAnyMethods<'py> {
fn py_super(&self) -> PyResult<Bound<'py, PySuper>>;
}

macro_rules! implement_binop {
($name:ident, $c_api:ident, $op:expr) => {
#[doc = concat!("Computes `self ", $op, " other`.")]
fn $name<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) }
}

let py = self.py();
inner(self, other.to_object(py).into_bound(py))
}
};
}

impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
#[inline]
fn is<T: AsPyPointer>(&self, other: &T) -> bool {
Expand Down Expand Up @@ -1855,6 +1927,42 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
.and_then(|any| any.is_truthy())
}

implement_binop!(add, PyNumber_Add, "+");
implement_binop!(sub, PyNumber_Subtract, "-");
implement_binop!(mul, PyNumber_Multiply, "*");
implement_binop!(div, PyNumber_TrueDivide, "/");
implement_binop!(lshift, PyNumber_Lshift, "<<");
implement_binop!(rshift, PyNumber_Rshift, ">>");
implement_binop!(bitand, PyNumber_And, "&");
implement_binop!(bitor, PyNumber_Or, "|");
implement_binop!(bitxor, PyNumber_Xor, "^");

/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
modulus: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe {
ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr())
.assume_owned_or_err(any.py())
}
}

let py = self.py();
inner(
self,
other.to_object(py).into_bound(py),
modulus.to_object(py).into_bound(py),
)
}

fn is_callable(&self) -> bool {
unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 }
}
Expand Down
16 changes: 16 additions & 0 deletions tests/test_arithmetics.rs
Expand Up @@ -178,6 +178,10 @@ impl BinaryArithmetic {
format!("BA * {:?}", rhs)
}

fn __truediv__(&self, rhs: &PyAny) -> String {
format!("BA / {:?}", rhs)
}

fn __lshift__(&self, rhs: &PyAny) -> String {
format!("BA << {:?}", rhs)
}
Expand Down Expand Up @@ -233,6 +237,18 @@ fn binary_arithmetic() {
py_expect_exception!(py, c, "1 ** c", PyTypeError);

py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");

let c: Bound<'_, PyAny> = c.extract().unwrap();
assert_py_eq!(c.add(&c).unwrap(), "BA + BA");
assert_py_eq!(c.sub(&c).unwrap(), "BA - BA");
assert_py_eq!(c.mul(&c).unwrap(), "BA * BA");
assert_py_eq!(c.div(&c).unwrap(), "BA / BA");
assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA");
assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA");
assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");
assert_py_eq!(c.bitor(&c).unwrap(), "BA | BA");
assert_py_eq!(c.bitxor(&c).unwrap(), "BA ^ BA");
assert_py_eq!(c.pow(&c, py.None()).unwrap(), "BA ** BA (mod: None)");
});
}

Expand Down

0 comments on commit 339660c

Please sign in to comment.