Skip to content

Commit

Permalink
check tensor numel in PyObject_CheckLongOrToLong
Browse files Browse the repository at this point in the history
  • Loading branch information
RedContritio committed Jan 31, 2023
1 parent 23d20e3 commit ed96bae
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/phi/common/complex.h"

Expand Down Expand Up @@ -70,7 +71,8 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT
(PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true;
}

Expand Down

0 comments on commit ed96bae

Please sign in to comment.