From 7da0bfff0a710e5ec7b9e44f5d060c781e233d92 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Wed, 23 Aug 2023 10:12:49 +0900 Subject: [PATCH] GH-37257: [Ruby][FlightSQL] Use the same options for auto prepared statement close request (#37258) ### Rationale for this change If we don't pass the same options for auto prepared statement close request, the close request may be failed. For example, it's caused when authentication information in the options. ### What changes are included in this PR? `ArrowFlightSQL::Client#prepare` change is the main change to use the given options for auto prepared statement close request too. Other changes (`gaflight_server_call_context_foreach_incoming_header()` and related changes) are for testing the above change. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * Closes: #37257 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-flight-glib/client.cpp | 2 +- c_glib/arrow-flight-glib/server.cpp | 63 ++++++++++++++++--- c_glib/arrow-flight-glib/server.h | 19 +++--- .../arrow-flight-glib-docs.xml | 4 ++ .../lib/arrow-flight-sql/client.rb | 6 +- .../test/helper/server.rb | 7 ++- ruby/red-arrow-flight-sql/test/test-client.rb | 6 +- .../lib/arrow-flight/loader.rb | 1 + .../lib/arrow-flight/server-call-context.rb | 31 +++++++++ 9 files changed, 115 insertions(+), 24 deletions(-) create mode 100644 ruby/red-arrow-flight/lib/arrow-flight/server-call-context.rb diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp index 65bc2d56a51f6..60dec29dbbdfb 100644 --- a/c_glib/arrow-flight-glib/client.cpp +++ b/c_glib/arrow-flight-glib/client.cpp @@ -148,7 +148,7 @@ gaflight_call_options_clear_headers(GAFlightCallOptions *options) * @func: (scope call): The user's callback function. * @user_data: (closure): Data for @func. * - * Iterates over all header in the options. + * Iterates over all headers in the options. * * Since: 9.0.0 */ diff --git a/c_glib/arrow-flight-glib/server.cpp b/c_glib/arrow-flight-glib/server.cpp index 53ecbf6e08db5..9deb1623b16cd 100644 --- a/c_glib/arrow-flight-glib/server.cpp +++ b/c_glib/arrow-flight-glib/server.cpp @@ -293,9 +293,11 @@ gaflight_message_reader_get_descriptor(GAFlightMessageReader *reader) } -typedef struct GAFlightServerCallContextPrivate_ { +struct GAFlightServerCallContextPrivate { arrow::flight::ServerCallContext *call_context; -} GAFlightServerCallContextPrivate; + std::string current_incoming_header_key; + std::string current_incoming_header_value; +}; enum { PROP_CALL_CONTEXT = 1, @@ -310,6 +312,15 @@ G_DEFINE_TYPE_WITH_PRIVATE(GAFlightServerCallContext, gaflight_server_call_context_get_instance_private( \ GAFLIGHT_SERVER_CALL_CONTEXT(obj))) +static void +gaflight_server_call_context_finalize(GObject *object) +{ + auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(object); + priv->current_incoming_header_key.~basic_string(); + priv->current_incoming_header_value.~basic_string(); + G_OBJECT_CLASS(gaflight_server_call_context_parent_class)->finalize(object); +} + static void gaflight_server_call_context_set_property(GObject *object, guint prop_id, @@ -333,6 +344,9 @@ gaflight_server_call_context_set_property(GObject *object, static void gaflight_server_call_context_init(GAFlightServerCallContext *object) { + auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(object); + new(&(priv->current_incoming_header_key)) std::string; + new(&(priv->current_incoming_header_value)) std::string; } static void @@ -340,6 +354,7 @@ gaflight_server_call_context_class_init(GAFlightServerCallContextClass *klass) { auto gobject_class = G_OBJECT_CLASS(klass); + gobject_class->finalize = gaflight_server_call_context_finalize; gobject_class->set_property = gaflight_server_call_context_set_property; GParamSpec *spec; @@ -351,6 +366,33 @@ gaflight_server_call_context_class_init(GAFlightServerCallContextClass *klass) g_object_class_install_property(gobject_class, PROP_CALL_CONTEXT, spec); } +/** + * gaflight_server_call_context_foreach_incoming_header: + * @context: A #GAFlightServerCallContext. + * @func: (scope call): The user's callback function. + * @user_data: (closure): Data for @func. + * + * Iterates over all incoming headers. + * + * Since: 14.0.0 + */ +void +gaflight_server_call_context_foreach_incoming_header( + GAFlightServerCallContext *context, + GAFlightHeaderFunc func, + gpointer user_data) +{ + auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(context); + auto flight_context = gaflight_server_call_context_get_raw(context); + for (const auto &header : flight_context->incoming_headers()) { + priv->current_incoming_header_key = std::string(header.first); + priv->current_incoming_header_value = std::string(header.second); + func(priv->current_incoming_header_key.c_str(), + priv->current_incoming_header_value.c_str(), + user_data); + } +} + struct GAFlightServerAuthSenderPrivate { arrow::flight::ServerAuthSender *sender; @@ -630,9 +672,8 @@ namespace gaflight { auto klass = GAFLIGHT_SERVER_CUSTOM_AUTH_HANDLER_GET_CLASS(handler_); auto gacontext = gaflight_server_call_context_new_raw(&context); auto gtoken = g_bytes_new_static(token.data(), token.size()); - GBytes *gpeer_identity = nullptr; GError *error = nullptr; - klass->is_valid(handler_, gacontext, gtoken, &gpeer_identity, &error); + auto gpeer_identity = klass->is_valid(handler_, gacontext, gtoken, &error); g_bytes_unref(gtoken); g_object_unref(gacontext); if (gpeer_identity) { @@ -718,20 +759,20 @@ gaflight_server_custom_auth_handler_authenticate( * @context: A #GAFlightServerCallContext. * @token: The client token. May be the empty string if the client does not * provide a token. - * @peer_identity: (out): The identity of the peer, if this authentication - * method supports it. * @error: (nullable): Return location for a #GError or %NULL. * * Validates a per-call client token. * + * Returns: (nullable) (transfer full): The identity of the peer, if + * this authentication method supports it. + * * Since: 12.0.0 */ -void +GBytes * gaflight_server_custom_auth_handler_is_valid( GAFlightServerCustomAuthHandler *handler, GAFlightServerCallContext *context, GBytes *token, - GBytes **peer_identity, GError **error) { auto flight_handler = @@ -749,8 +790,10 @@ gaflight_server_custom_auth_handler_is_valid( status, "[flight-server-custom-auth-handler]" "[is-valid]")) { - *peer_identity = g_bytes_new(flight_peer_identity.data(), - flight_peer_identity.size()); + return g_bytes_new(flight_peer_identity.data(), + flight_peer_identity.size()); + } else { + return nullptr; } } diff --git a/c_glib/arrow-flight-glib/server.h b/c_glib/arrow-flight-glib/server.h index 7fa0dcf878000..77ecf36fd5221 100644 --- a/c_glib/arrow-flight-glib/server.h +++ b/c_glib/arrow-flight-glib/server.h @@ -84,6 +84,13 @@ struct _GAFlightServerCallContextClass GObjectClass parent_class; }; +GARROW_AVAILABLE_IN_14_0 +void +gaflight_server_call_context_foreach_incoming_header( + GAFlightServerCallContext *context, + GAFlightHeaderFunc func, + gpointer user_data); + #define GAFLIGHT_TYPE_SERVER_AUTH_SENDER \ (gaflight_server_auth_sender_get_type()) @@ -158,11 +165,10 @@ struct _GAFlightServerCustomAuthHandlerClass GAFlightServerAuthSender *sender, GAFlightServerAuthReader *reader, GError **error); - void (*is_valid)(GAFlightServerCustomAuthHandler *handler, - GAFlightServerCallContext *context, - GBytes *token, - GBytes **peer_identity, - GError **error); + GBytes *(*is_valid)(GAFlightServerCustomAuthHandler *handler, + GAFlightServerCallContext *context, + GBytes *token, + GError **error); }; GARROW_AVAILABLE_IN_12_0 @@ -175,12 +181,11 @@ gaflight_server_custom_auth_handler_authenticate( GError **error); GARROW_AVAILABLE_IN_12_0 -void +GBytes * gaflight_server_custom_auth_handler_is_valid( GAFlightServerCustomAuthHandler *handler, GAFlightServerCallContext *context, GBytes *token, - GBytes **peer_identity, GError **error); diff --git a/c_glib/doc/arrow-flight-glib/arrow-flight-glib-docs.xml b/c_glib/doc/arrow-flight-glib/arrow-flight-glib-docs.xml index e078b1c037b5f..e58ff375c5964 100644 --- a/c_glib/doc/arrow-flight-glib/arrow-flight-glib-docs.xml +++ b/c_glib/doc/arrow-flight-glib/arrow-flight-glib-docs.xml @@ -55,6 +55,10 @@ Index of deprecated API + + Index of new symbols in 14.0.0 + + Index of new symbols in 12.0.0 diff --git a/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/client.rb b/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/client.rb index ff3169d5621b2..dc8815fced6e9 100644 --- a/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/client.rb +++ b/ruby/red-arrow-flight-sql/lib/arrow-flight-sql/client.rb @@ -18,13 +18,13 @@ module ArrowFlightSQL class Client alias_method :prepare_raw, :prepare - def prepare(*args) - statement = prepare_raw(*args) + def prepare(query, options=nil) + statement = prepare_raw(query, options) if block_given? begin yield(statement) ensure - statement.close unless statement.closed? + statement.close(options) unless statement.closed? end else statement diff --git a/ruby/red-arrow-flight-sql/test/helper/server.rb b/ruby/red-arrow-flight-sql/test/helper/server.rb index f7c935fab91fa..6426e11f2984d 100644 --- a/ruby/red-arrow-flight-sql/test/helper/server.rb +++ b/ruby/red-arrow-flight-sql/test/helper/server.rb @@ -39,7 +39,7 @@ def virtual_do_do_get_statement(context, command) end def virtual_do_create_prepared_statement(context, request) - unless request.query == "INSERT INTO page_view_table VALUES (?, true)" + unless request.query == "INSERT INTO page_view_table VALUES ($1, true)" raise Arrow::Error::Invalid.new("invalid SQL") end result = ArrowFlightSQL::CreatePreparedStatementResult.new @@ -62,6 +62,11 @@ def virtual_do_close_prepared_statement(context, request) unless request.handle.to_s == "valid-handle" raise Arrow::Error::Invalid.new("invalid handle") end + access_key = context.incoming_headers.assoc("x-access-key") + unless access_key == ["x-access-key", "secret"] + message = "invalid access key: #{access_key.inspect}" + raise Arrow::Error::Invalid.new(message) + end end end end diff --git a/ruby/red-arrow-flight-sql/test/test-client.rb b/ruby/red-arrow-flight-sql/test/test-client.rb index 21554c1bdab84..1fff21da0193c 100644 --- a/ruby/red-arrow-flight-sql/test/test-client.rb +++ b/ruby/red-arrow-flight-sql/test/test-client.rb @@ -41,9 +41,11 @@ def test_execute end def test_prepare - insert_sql = "INSERT INTO page_view_table VALUES (?, true)" + insert_sql = "INSERT INTO page_view_table VALUES ($1, true)" block_called = false - @sql_client.prepare(insert_sql) do |statement| + options = ArrowFlight::CallOptions.new + options.add_header("x-access-key", "secret") + @sql_client.prepare(insert_sql, options) do |statement| block_called = true assert_equal([ Arrow::Schema.new(count: :uint64, private: :boolean), diff --git a/ruby/red-arrow-flight/lib/arrow-flight/loader.rb b/ruby/red-arrow-flight/lib/arrow-flight/loader.rb index 042e15769368e..e9e9e08f32dd9 100644 --- a/ruby/red-arrow-flight/lib/arrow-flight/loader.rb +++ b/ruby/red-arrow-flight/lib/arrow-flight/loader.rb @@ -34,6 +34,7 @@ def require_libraries require "arrow-flight/client-options" require "arrow-flight/location" require "arrow-flight/record-batch-reader" + require "arrow-flight/server-call-context" require "arrow-flight/server-options" require "arrow-flight/ticket" end diff --git a/ruby/red-arrow-flight/lib/arrow-flight/server-call-context.rb b/ruby/red-arrow-flight/lib/arrow-flight/server-call-context.rb new file mode 100644 index 0000000000000..7cbad1fd079df --- /dev/null +++ b/ruby/red-arrow-flight/lib/arrow-flight/server-call-context.rb @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module ArrowFlight + class ServerCallContext + def each_incoming_header + return to_enum(__method__) unless block_given? + foreach_incoming_header do |key, value| + yield(key, value) + end + end + + def incoming_headers + each_incoming_header.to_a + end + end +end