Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/src/arrow/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,14 @@ Status CheckPyError(StatusCode code) {
return Status::OK();
}

Status PassPyError() {
if (PyErr_Occurred()) {
// Do not call PyErr_Clear, the assumption is that someone further
// up the call stack will want to deal with the Python error.
return Status(StatusCode::PythonError, "");
}
return Status::OK();
}

} // namespace py
} // namespace arrow
2 changes: 2 additions & 0 deletions cpp/src/arrow/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ struct ARROW_EXPORT PyObjectStringify {

Status CheckPyError(StatusCode code = StatusCode::UnknownError);

Status PassPyError();

// TODO(wesm): We can just let errors pass through. To be explored later
#define RETURN_IF_PYERROR() RETURN_NOT_OK(CheckPyError());

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* el
return Status::SerializationError(ss.str());
} else {
*result = PyObject_CallMethodObjArgs(context, method_name, elem, NULL);
RETURN_IF_PYERROR();
return PassPyError();
}
return Status::OK();
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ enum class StatusCode : char {
UnknownError = 9,
NotImplemented = 10,
SerializationError = 11,
PythonError = 12,
PlasmaObjectExists = 20,
PlasmaObjectNonexistent = 21,
PlasmaStoreFull = 22
Expand Down Expand Up @@ -154,6 +155,8 @@ class ARROW_EXPORT Status {
bool IsNotImplemented() const { return code() == StatusCode::NotImplemented; }
// An object could not be serialized or deserialized.
bool IsSerializationError() const { return code() == StatusCode::SerializationError; }
// An error is propagated from a nested Python function.
bool IsPythonError() const { return code() == StatusCode::PythonError; }
// An object with this object ID already exists in the plasma store.
bool IsPlasmaObjectExists() const { return code() == StatusCode::PlasmaObjectExists; }
// An object was requested that doesn't exist in the plasma store.
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/error.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ cdef int check_status(const CStatus& status) nogil except -1:
if status.ok():
return 0

if status.IsPythonError():
return -1

with gil:
message = frombytes(status.message())
if status.IsInvalid():
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
c_bool IsNotImplemented()
c_bool IsTypeError()
c_bool IsSerializationError()
c_bool IsPythonError()
c_bool IsPlasmaObjectExists()
c_bool IsPlasmaObjectNonexistent()
c_bool IsPlasmaStoreFull()
Expand Down
4 changes: 3 additions & 1 deletion python/pyarrow/serialization.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ cdef class SerializationContext:
else:
assert type_id not in self.types_to_pickle
if type_id not in self.whitelisted_types:
raise "error"
msg = "Type ID " + str(type_id) + " not registered in " \
"deserialization callback"
raise DeserializationCallbackError(msg, type_id)
type_ = self.whitelisted_types[type_id]
if type_id in self.custom_deserializers:
obj = self.custom_deserializers[type_id](
Expand Down
23 changes: 23 additions & 0 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,26 @@ def test_custom_serialization(large_memory_map):
with pa.memory_map(large_memory_map, mode="r+") as mmap:
for obj in CUSTOM_OBJECTS:
serialization_roundtrip(obj, mmap)

def test_serialization_callback_error():

class TempClass(object):
pass

# Pass a SerializationContext into serialize, but TempClass
# is not registered
serialization_context = pa.SerializationContext()
val = TempClass()
with pytest.raises(pa.SerializationCallbackError) as err:
serialized_object = pa.serialize(val, serialization_context)
assert err.value.example_object == val

serialization_context.register_type(TempClass, 20*b"\x00")
serialized_object = pa.serialize(TempClass(), serialization_context)
deserialization_context = pa.SerializationContext()

# Pass a Serialization Context into deserialize, but TempClass
# is not registered
with pytest.raises(pa.DeserializationCallbackError) as err:
serialized_object.deserialize(deserialization_context)
assert err.value.type_id == 20*b"\x00"