diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 1275db6a8d417..f315e42a6a6a9 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -83,22 +83,60 @@ TEST(FlightTypes, LocationUnknownScheme) { } TEST(FlightTypes, RoundTripTypes) { + ActionType action_type{"action-type1", "action-type1-description"}; + ASSERT_OK_AND_ASSIGN(std::string action_type_serialized, + action_type.SerializeToString()); + ASSERT_OK_AND_ASSIGN(ActionType action_type_deserialized, + ActionType::Deserialize(action_type_serialized)); + ASSERT_EQ(action_type, action_type_deserialized); + + Criteria criteria{"criteria1"}; + ASSERT_OK_AND_ASSIGN(std::string criteria_serialized, criteria.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Criteria criteria_deserialized, + Criteria::Deserialize(criteria_serialized)); + ASSERT_EQ(criteria, criteria_deserialized); + + Action action{"action1", Buffer::FromString("action1-content")}; + ASSERT_OK_AND_ASSIGN(std::string action_serialized, action.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Action action_deserialized, + Action::Deserialize(action_serialized)); + ASSERT_EQ(action, action_deserialized); + + Result result{Buffer::FromString("result1-content")}; + ASSERT_OK_AND_ASSIGN(std::string result_serialized, result.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Result result_deserialized, + Result::Deserialize(result_serialized)); + ASSERT_EQ(result, result_deserialized); + + BasicAuth basic_auth{"username1", "password1"}; + ASSERT_OK_AND_ASSIGN(std::string basic_auth_serialized, basic_auth.SerializeToString()); + ASSERT_OK_AND_ASSIGN(BasicAuth basic_auth_deserialized, + BasicAuth::Deserialize(basic_auth_serialized)); + ASSERT_EQ(basic_auth, basic_auth_deserialized); + + SchemaResult schema_result{"schema_result1"}; + ASSERT_OK_AND_ASSIGN(std::string schema_result_serialized, + schema_result.SerializeToString()); + ASSERT_OK_AND_ASSIGN(SchemaResult schema_result_deserialized, + SchemaResult::Deserialize(schema_result_serialized)); + ASSERT_EQ(schema_result, schema_result_deserialized); + Ticket ticket{"foo"}; ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString()); ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized, Ticket::Deserialize(ticket_serialized)); - ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket); + ASSERT_EQ(ticket, ticket_deserialized); FlightDescriptor desc = FlightDescriptor::Command("select * from foo;"); ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString()); ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); + ASSERT_EQ(desc, desc_deserialized); desc = FlightDescriptor::Path({"a", "b", "test.arrow"}); ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString()); ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); + ASSERT_EQ(desc, desc_deserialized); FlightInfo::Data data; std::shared_ptr schema = @@ -114,10 +152,17 @@ TEST(FlightTypes, RoundTripTypes) { ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString()); ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized, FlightInfo::Deserialize(info_serialized)); - ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor())); + ASSERT_EQ(info->descriptor(), info_deserialized->descriptor()); ASSERT_EQ(info->endpoints(), info_deserialized->endpoints()); ASSERT_EQ(info->total_records(), info_deserialized->total_records()); ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes()); + + FlightEndpoint flight_endpoint{ticket, {location1, location2}}; + ASSERT_OK_AND_ASSIGN(std::string flight_endpoint_serialized, + flight_endpoint.SerializeToString()); + ASSERT_OK_AND_ASSIGN(FlightEndpoint flight_endpoint_deserialized, + FlightEndpoint::Deserialize(flight_endpoint_serialized)); + ASSERT_EQ(flight_endpoint, flight_endpoint_deserialized); } TEST(FlightTypes, RoundtripStatus) { diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index c27bc79b315e1..0e1d7a6d843d9 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -60,6 +60,7 @@ Status FromProto(const pb::SchemaResult& pb_result, std::string* result); Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info); Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); +Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint); Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info); Status ToProto(const ActionType& type, pb::ActionType* pb_type); Status ToProto(const Action& action, pb::Action* pb_action); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 6e80f40cfbf38..a505e6d6e1ecf 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -162,13 +162,42 @@ Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, return GetSchema(dictionary_memo).Value(out); } +bool SchemaResult::Equals(const SchemaResult& other) const { + return raw_schema_ == other.raw_schema_; +} + +arrow::Result SchemaResult::SerializeToString() const { + pb::SchemaResult pb_schema_result; + RETURN_NOT_OK(internal::ToProto(*this, &pb_schema_result)); + + std::string out; + if (!pb_schema_result.SerializeToString(&out)) { + return Status::IOError("Serialized SchemaResult exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result SchemaResult::Deserialize( + arrow::util::string_view serialized) { + pb::SchemaResult pb_schema_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized SchemaResult size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_schema_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid SchemaResult"); + } + return SchemaResult{pb_schema_result.schema()}; +} + arrow::Result FlightDescriptor::SerializeToString() const { pb::FlightDescriptor pb_descriptor; RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor)); std::string out; if (!pb_descriptor.SerializeToString(&out)) { - return Status::IOError("Serialized descriptor exceeded 2 GiB limit"); + return Status::IOError("Serialized FlightDescriptor exceeded 2 GiB limit"); } return out; } @@ -186,7 +215,7 @@ arrow::Result FlightDescriptor::Deserialize( google::protobuf::io::ArrayInputStream input(serialized.data(), static_cast(serialized.size())); if (!pb_descriptor.ParseFromZeroCopyStream(&input)) { - return Status::Invalid("Not a valid descriptor"); + return Status::Invalid("Not a valid FlightDescriptor"); } FlightDescriptor out; RETURN_NOT_OK(internal::FromProto(pb_descriptor, &out)); @@ -206,7 +235,7 @@ arrow::Result Ticket::SerializeToString() const { std::string out; if (!pb_ticket.SerializeToString(&out)) { - return Status::IOError("Serialized ticket exceeded 2 GiB limit"); + return Status::IOError("Serialized Ticket exceeded 2 GiB limit"); } return out; } @@ -223,7 +252,7 @@ arrow::Result Ticket::Deserialize(arrow::util::string_view serialized) { google::protobuf::io::ArrayInputStream input(serialized.data(), static_cast(serialized.size())); if (!pb_ticket.ParseFromZeroCopyStream(&input)) { - return Status::Invalid("Not a valid ticket"); + return Status::Invalid("Not a valid Ticket"); } Ticket out; RETURN_NOT_OK(internal::FromProto(pb_ticket, &out)); @@ -370,10 +399,154 @@ bool FlightEndpoint::Equals(const FlightEndpoint& other) const { return ticket == other.ticket && locations == other.locations; } +arrow::Result FlightEndpoint::SerializeToString() const { + pb::FlightEndpoint pb_flight_endpoint; + RETURN_NOT_OK(internal::ToProto(*this, &pb_flight_endpoint)); + + std::string out; + if (!pb_flight_endpoint.SerializeToString(&out)) { + return Status::IOError("Serialized FlightEndpoint exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result FlightEndpoint::Deserialize( + arrow::util::string_view serialized) { + pb::FlightEndpoint pb_flight_endpoint; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized FlightEndpoint size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_flight_endpoint.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid FlightEndpoint"); + } + FlightEndpoint out; + RETURN_NOT_OK(internal::FromProto(pb_flight_endpoint, &out)); + return out; +} + bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; } +arrow::Result ActionType::SerializeToString() const { + pb::ActionType pb_action_type; + RETURN_NOT_OK(internal::ToProto(*this, &pb_action_type)); + + std::string out; + if (!pb_action_type.SerializeToString(&out)) { + return Status::IOError("Serialized ActionType exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result ActionType::Deserialize(arrow::util::string_view serialized) { + pb::ActionType pb_action_type; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized ActionType size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_action_type.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid ActionType"); + } + ActionType out; + RETURN_NOT_OK(internal::FromProto(pb_action_type, &out)); + return out; +} + +bool Criteria::Equals(const Criteria& other) const { + return expression == other.expression; +} + +arrow::Result Criteria::SerializeToString() const { + pb::Criteria pb_criteria; + RETURN_NOT_OK(internal::ToProto(*this, &pb_criteria)); + + std::string out; + if (!pb_criteria.SerializeToString(&out)) { + return Status::IOError("Serialized Criteria exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Criteria::Deserialize(arrow::util::string_view serialized) { + pb::Criteria pb_criteria; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Criteria size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_criteria.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Criteria"); + } + Criteria out; + RETURN_NOT_OK(internal::FromProto(pb_criteria, &out)); + return out; +} + +bool Action::Equals(const Action& other) const { + return (type == other.type) && + ((body == other.body) || (body && other.body && body->Equals(*other.body))); +} + +arrow::Result Action::SerializeToString() const { + pb::Action pb_action; + RETURN_NOT_OK(internal::ToProto(*this, &pb_action)); + + std::string out; + if (!pb_action.SerializeToString(&out)) { + return Status::IOError("Serialized Action exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Action::Deserialize(arrow::util::string_view serialized) { + pb::Action pb_action; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Action size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_action.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Action"); + } + Action out; + RETURN_NOT_OK(internal::FromProto(pb_action, &out)); + return out; +} + +bool Result::Equals(const Result& other) const { + return (body == other.body) || (body && other.body && body->Equals(*other.body)); +} + +arrow::Result Result::SerializeToString() const { + pb::Result pb_result; + RETURN_NOT_OK(internal::ToProto(*this, &pb_result)); + + std::string out; + if (!pb_result.SerializeToString(&out)) { + return Status::IOError("Serialized Result exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Result::Deserialize(arrow::util::string_view serialized) { + pb::Result pb_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Result size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Result"); + } + Result out; + RETURN_NOT_OK(internal::FromProto(pb_result, &out)); + return out; +} + Status ResultStream::Next(std::unique_ptr* info) { return Next().Value(info); } Status MetadataRecordBatchReader::Next(FlightStreamChunk* next) { @@ -468,6 +641,10 @@ arrow::Result> SimpleResultStream::Next() { return std::unique_ptr(new Result(std::move(results_[position_++]))); } +bool BasicAuth::Equals(const BasicAuth& other) const { + return (username == other.username) && (password == other.password); +} + arrow::Result BasicAuth::Deserialize(arrow::util::string_view serialized) { pb::BasicAuth pb_result; if (serialized.size() > static_cast(std::numeric_limits::max())) { diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 2ec24ff586851..ae9867e44a1f5 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -148,12 +148,33 @@ struct ARROW_FLIGHT_EXPORT ActionType { friend bool operator!=(const ActionType& left, const ActionType& right) { return !(left == right); } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Opaque selection criteria for ListFlights RPC struct ARROW_FLIGHT_EXPORT Criteria { /// Opaque criteria expression, dependent on server implementation std::string expression; + + bool Equals(const Criteria& other) const; + + friend bool operator==(const Criteria& left, const Criteria& right) { + return left.Equals(right); + } + friend bool operator!=(const Criteria& left, const Criteria& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief An action to perform with the DoAction RPC @@ -163,11 +184,41 @@ struct ARROW_FLIGHT_EXPORT Action { /// The action content as a Buffer std::shared_ptr body; + + bool Equals(const Action& other) const; + + friend bool operator==(const Action& left, const Action& right) { + return left.Equals(right); + } + friend bool operator!=(const Action& left, const Action& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Opaque result returned after executing an action struct ARROW_FLIGHT_EXPORT Result { std::shared_ptr body; + + bool Equals(const Result& other) const; + + friend bool operator==(const Result& left, const Result& right) { + return left.Equals(right); + } + friend bool operator!=(const Result& left, const Result& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief message for simple auth @@ -175,6 +226,15 @@ struct ARROW_FLIGHT_EXPORT BasicAuth { std::string username; std::string password; + bool Equals(const BasicAuth& other) const; + + friend bool operator==(const BasicAuth& left, const BasicAuth& right) { + return left.Equals(right); + } + friend bool operator!=(const BasicAuth& left, const BasicAuth& right) { + return !(left == right); + } + /// \brief Deserialize this message from its wire-format representation. static arrow::Result Deserialize(arrow::util::string_view serialized); /// \brief Serialize this message to its wire-format representation. @@ -377,6 +437,12 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { friend bool operator!=(const FlightEndpoint& left, const FlightEndpoint& right) { return !(left == right); } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Staging data structure for messages about to be put on the wire @@ -394,6 +460,7 @@ struct ARROW_FLIGHT_EXPORT FlightPayload { /// \brief Schema result returned after a schema request RPC struct ARROW_FLIGHT_EXPORT SchemaResult { public: + SchemaResult() = default; explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} /// \brief Factory method to construct a SchemaResult. @@ -412,6 +479,21 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { const std::string& serialized_schema() const { return raw_schema_; } + bool Equals(const SchemaResult& other) const; + + friend bool operator==(const SchemaResult& left, const SchemaResult& right) { + return left.Equals(right); + } + friend bool operator!=(const SchemaResult& left, const SchemaResult& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); + private: std::string raw_schema_; }; diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 2ad3f7128c414..16e4aad5a00c5 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -289,6 +289,31 @@ cdef class Action(_Weakrefable): type(action))) return ( action).action + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.action.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Action action = Action.__new__(Action) + action.action = GetResultValue( + CAction.Deserialize(tobytes(serialized))) + return action + + def __eq__(self, Action other): + return self.action == other.action + _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) @@ -327,6 +352,31 @@ cdef class Result(_Weakrefable): """Get the Buffer containing the result.""" return pyarrow_wrap_buffer(self.result.get().body) + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Result result = Result.__new__(Result) + result.result.reset(new CFlightResult(GetResultValue( + CFlightResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, Result other): + return deref(self.result.get()) == deref(other.result.get()) + cdef class BasicAuth(_Weakrefable): """A container for basic auth.""" @@ -360,13 +410,16 @@ cdef class BasicAuth(_Weakrefable): @staticmethod def deserialize(serialized): auth = BasicAuth() - check_flight_status( - CBasicAuth.Deserialize(serialized).Value(auth.basic_auth.get())) + auth.basic_auth.reset(new CBasicAuth(GetResultValue( + CBasicAuth.Deserialize(tobytes(serialized))))) return auth def serialize(self): return GetResultValue(self.basic_auth.get().SerializeToString()) + def __eq__(self, BasicAuth other): + return deref(self.basic_auth.get()) == deref(other.basic_auth.get()) + class DescriptorType(enum.Enum): """ @@ -686,6 +739,28 @@ cdef class FlightEndpoint(_Weakrefable): return [Location.wrap(location) for location in self.endpoint.locations] + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.endpoint.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightEndpoint endpoint = FlightEndpoint.__new__(FlightEndpoint) + endpoint.endpoint = GetResultValue( + CFlightEndpoint.Deserialize(tobytes(serialized))) + return endpoint + def __repr__(self): return "".format( self.ticket, self.locations) @@ -721,6 +796,31 @@ cdef class SchemaResult(_Weakrefable): check_flight_status(self.result.get().GetSchema(&dummy_memo).Value(&schema)) return pyarrow_wrap_schema(schema) + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef SchemaResult result = SchemaResult.__new__(SchemaResult) + result.result.reset(new CSchemaResult(GetResultValue( + CSchemaResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, SchemaResult other): + return deref(self.result.get()) == deref(other.result.get()) + cdef class FlightInfo(_Weakrefable): """A description of a Flight stream.""" diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 3698292b5a03a..3b9ac54fe9dee 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -28,15 +28,30 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CActionType" arrow::flight::ActionType": c_string type c_string description + bint operator==(CActionType) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CActionType] Deserialize(const c_string& serialized) cdef cppclass CAction" arrow::flight::Action": c_string type shared_ptr[CBuffer] body + bint operator==(CAction) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CAction] Deserialize(const c_string& serialized) cdef cppclass CFlightResult" arrow::flight::Result": CFlightResult() CFlightResult(CFlightResult) shared_ptr[CBuffer] body + bint operator==(CFlightResult) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CFlightResult] Deserialize(const c_string& serialized) cdef cppclass CBasicAuth" arrow::flight::BasicAuth": CBasicAuth() @@ -44,7 +59,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CBasicAuth(CBasicAuth) c_string username c_string password - + bint operator==(CBasicAuth) CResult[c_string] SerializeToString() @staticmethod @@ -68,11 +83,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CDescriptorType type c_string cmd vector[c_string] path + bint operator==(CFlightDescriptor) CResult[c_string] SerializeToString() @staticmethod CResult[CFlightDescriptor] Deserialize(const c_string& serialized) - bint operator==(CFlightDescriptor) cdef cppclass CTicket" arrow::flight::Ticket": CTicket() @@ -86,6 +101,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CCriteria" arrow::flight::Criteria": CCriteria() c_string expression + bint operator==(CCriteria) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CCriteria] Deserialize(const c_string& serialized) cdef cppclass CLocation" arrow::flight::Location": CLocation() @@ -111,6 +131,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: vector[CLocation] locations bint operator==(CFlightEndpoint) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CFlightEndpoint] Deserialize(const c_string& serialized) cdef cppclass CFlightInfo" arrow::flight::FlightInfo": CFlightInfo(CFlightInfo info) @@ -126,8 +150,14 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: const c_string& serialized) cdef cppclass CSchemaResult" arrow::flight::SchemaResult": + CSchemaResult() CSchemaResult(CSchemaResult result) CResult[shared_ptr[CSchema]] GetSchema(CDictionaryMemo* memo) + bint operator==(CSchemaResult) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CSchemaResult] Deserialize(const c_string& serialized) cdef cppclass CFlightListing" arrow::flight::FlightListing": CResult[unique_ptr[CFlightInfo]] Next() diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 905efa564b0c7..72d1fa5ec3359 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -1560,9 +1560,22 @@ def block_read(): def test_roundtrip_types(): """Make sure serializable types round-trip.""" + action = flight.Action("action1", b"action1-body") + assert action == flight.Action.deserialize(action.serialize()) + ticket = flight.Ticket("foo") assert ticket == flight.Ticket.deserialize(ticket.serialize()) + result = flight.Result(b"result1") + assert result == flight.Result.deserialize(result.serialize()) + + basic_auth = flight.BasicAuth("username1", "password1") + assert basic_auth == flight.BasicAuth.deserialize(basic_auth.serialize()) + + schema_result = flight.SchemaResult(pa.schema([('a', pa.int32())])) + assert schema_result == flight.SchemaResult.deserialize( + schema_result.serialize()) + desc = flight.FlightDescriptor.for_command("test") assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) @@ -1589,6 +1602,12 @@ def test_roundtrip_types(): assert info.total_records == info2.total_records assert info.endpoints == info2.endpoints + endpoint = flight.FlightEndpoint( + ticket, + ['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)] + ) + assert endpoint == flight.FlightEndpoint.deserialize(endpoint.serialize()) + def test_roundtrip_errors(): """Ensure that Flight errors propagate from server to client."""