Skip to content

Commit

Permalink
BUG: core: ensure cfloat and clongdouble scalars have a __complex__ m…
Browse files Browse the repository at this point in the history
…ethod, so that complex(...) cast works properly (fixes #1617)
  • Loading branch information
pv committed Sep 20, 2010
1 parent d82003b commit 14d8e20
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
34 changes: 34 additions & 0 deletions numpy/core/src/multiarray/scalartypes.c.src
Expand Up @@ -1468,6 +1468,22 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
return Py_None;
}

/* casting complex numbers (that don't inherit from Python complex)
* to Python complex */

/**begin repeat
* #name=cfloat,clongdouble#
* #Name=CFloat,CLongDouble#
*/
static PyObject *
@name@_complex(PyObject *self, PyObject *NPY_UNUSED(args),
PyObject *NPY_UNUSED(kwds))
{
return PyComplex_FromDoubles(PyArrayScalar_VAL(self, @Name@).real,
PyArrayScalar_VAL(self, @Name@).imag);
}
/**end repeat**/

/* need to fill in doc-strings for these methods on import -- copy from
array docstrings
*/
Expand Down Expand Up @@ -1687,6 +1703,17 @@ static PyMethodDef voidtype_methods[] = {
{NULL, NULL, 0, NULL}
};

/**begin repeat
* #name=cfloat,clongdouble#
*/
static PyMethodDef @name@type_methods[] = {
{"__complex__",
(PyCFunction)@name@_complex,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}
};
/**end repeat**/

/************* As_mapping functions for void array scalar ************/

static Py_ssize_t
Expand Down Expand Up @@ -3307,6 +3334,13 @@ initialize_numeric_types(void)
Py@NAME@ArrType_Type.tp_hash = @name@_arrtype_hash;
/**end repeat**/

/**begin repeat
* #name = cfloat, clongdouble#
* #NAME = CFloat, CLongDouble#
*/
Py@NAME@ArrType_Type.tp_methods = @name@type_methods;
/**end repeat**/

#if (SIZEOF_INT != SIZEOF_LONG) || defined(NPY_PY3K)
/* We won't be inheriting from Python Int type. */
PyIntArrType_Type.tp_hash = int_arrtype_hash;
Expand Down
5 changes: 5 additions & 0 deletions numpy/core/tests/test_regression.py
Expand Up @@ -1411,5 +1411,10 @@ def test_complex_scalar_warning(self):
assert_equal(float(x), float(x.real))
ctx.__exit__()

def test_complex_scalar_complex_cast(self):
for tp in [np.csingle, np.cdouble, np.clongdouble]:
x = tp(1+2j)
assert_equal(complex(x), 1+2j)

if __name__ == "__main__":
run_module_suite()

0 comments on commit 14d8e20

Please sign in to comment.