diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java index 5c47941b291415..22f9466e7db4f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlService.java @@ -20,11 +20,13 @@ import org.apache.doris.common.Config; import org.apache.doris.service.FrontendOptions; import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator; +import org.apache.doris.service.arrowflight.auth2.FlightRemoteIpServerStreamTracer; import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager; import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager; import org.apache.doris.service.arrowflight.tokens.FlightTokenManager; import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl; +import io.grpc.ServerBuilder; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; @@ -33,12 +35,14 @@ import org.apache.logging.log4j.Logger; import java.io.IOException; +import java.util.function.Consumer; /** * flight sql protocol implementation based on nio. */ public class DorisFlightSqlService { private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class); + private static final String GRPC_BUILDER_CONSUMER = "grpc.builderConsumer"; private final FlightServer flightServer; private final FlightTokenManager flightTokenManager; private final FlightSessionsManager flightSessionsManager; @@ -56,6 +60,8 @@ public DorisFlightSqlService(int port) { DorisFlightSqlProducer producer = new DorisFlightSqlProducer( Location.forGrpcInsecure(FrontendOptions.getLocalHostAddress(), port), flightSessionsManager); flightServer = FlightServer.builder(allocator, Location.forGrpcInsecure("0.0.0.0", port), producer) + .transportHint(GRPC_BUILDER_CONSUMER, (Consumer>) builder -> + builder.addStreamTracerFactory(new FlightRemoteIpServerStreamTracer.Factory())) .headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build(); LOG.info("Arrow Flight SQL service is created, port: {}, arrow_flight_max_connections: {}," + "arrow_flight_token_alive_time_second: {}", port, Config.arrow_flight_max_connections, diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java index 6676e8526ef0f5..259d5083446e24 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightCredentialValidator.java @@ -48,8 +48,7 @@ public FlightCredentialValidator(FlightTokenManager flightTokenManager) { */ @Override public AuthResult validate(String username, String password) { - // TODO Add ClientAddress information while creating a Token - String remoteIp = "0.0.0.0"; + String remoteIp = FlightRemoteIpServerStreamTracer.getRemoteIp(); FlightAuthResult flightAuthResult = FlightAuthUtils.authenticateCredentials(username, password, remoteIp, LOG); return getAuthResultWithBearerToken(flightAuthResult); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracer.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracer.java new file mode 100644 index 00000000000000..5f5deee49bca4a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracer.java @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.auth2; + +import io.grpc.Attributes; +import io.grpc.Context; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; + +/** + * Captures the gRPC peer address before Arrow Flight header authentication runs. + * Arrow registers header authentication ahead of user interceptors, so use ServerStreamTracer to + * seed the remote IP into the gRPC Context for Basic credential validation. + */ +public class FlightRemoteIpServerStreamTracer extends ServerStreamTracer { + static final String UNKNOWN_REMOTE_IP = "0.0.0.0"; + private static final Context.Key REMOTE_IP_CONTEXT_KEY = + Context.key("doris.arrow.flight.remote_ip"); + + @Override + public Context filterContext(Context context) { + return context.withValue(REMOTE_IP_CONTEXT_KEY, new RemoteIpHolder()); + } + + @Override + public void serverCallStarted(ServerCallInfo callInfo) { + RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get(); + if (holder == null) { + return; + } + + Attributes attributes = callInfo.getAttributes(); + SocketAddress remoteAddress = attributes == null ? null : attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + holder.setRemoteIp(extractRemoteIp(remoteAddress)); + } + + public static String getRemoteIp() { + RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get(); + if (holder == null) { + return UNKNOWN_REMOTE_IP; + } + return holder.getRemoteIp(); + } + + static String extractRemoteIp(SocketAddress remoteAddress) { + if (!(remoteAddress instanceof InetSocketAddress)) { + return UNKNOWN_REMOTE_IP; + } + + InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress; + InetAddress address = inetSocketAddress.getAddress(); + if (address != null && isNotEmpty(address.getHostAddress())) { + return address.getHostAddress(); + } + if (isNotEmpty(inetSocketAddress.getHostString())) { + return inetSocketAddress.getHostString(); + } + return UNKNOWN_REMOTE_IP; + } + + private static boolean isNotEmpty(String value) { + return value != null && !value.isEmpty(); + } + + public static class Factory extends ServerStreamTracer.Factory { + @Override + public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { + return new FlightRemoteIpServerStreamTracer(); + } + } + + private static class RemoteIpHolder { + private volatile String remoteIp = UNKNOWN_REMOTE_IP; + + String getRemoteIp() { + return remoteIp; + } + + void setRemoteIp(String remoteIp) { + this.remoteIp = isNotEmpty(remoteIp) ? remoteIp : UNKNOWN_REMOTE_IP; + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracerTest.java b/fe/fe-core/src/test/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracerTest.java new file mode 100644 index 00000000000000..250cd9f26ecb06 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/service/arrowflight/auth2/FlightRemoteIpServerStreamTracerTest.java @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.auth2; + +import io.grpc.Attributes; +import io.grpc.Context; +import io.grpc.Grpc; +import io.grpc.MethodDescriptor; +import io.grpc.ServerStreamTracer; +import org.junit.Assert; +import org.junit.Test; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; + +public class FlightRemoteIpServerStreamTracerTest { + + @Test + public void testGetRemoteIpFromServerCallAttributes() { + FlightRemoteIpServerStreamTracer tracer = new FlightRemoteIpServerStreamTracer(); + Context context = tracer.filterContext(Context.current()); + Context previous = context.attach(); + try { + tracer.serverCallStarted(new TestServerCallInfo(new InetSocketAddress("10.26.20.3", 12345))); + + Assert.assertEquals("10.26.20.3", FlightRemoteIpServerStreamTracer.getRemoteIp()); + } finally { + context.detach(previous); + } + } + + @Test + public void testFallbackRemoteIpWithoutServerCallAttributes() { + FlightRemoteIpServerStreamTracer tracer = new FlightRemoteIpServerStreamTracer(); + Context context = tracer.filterContext(Context.current()); + Context previous = context.attach(); + try { + tracer.serverCallStarted(new TestServerCallInfo(null)); + + Assert.assertEquals("0.0.0.0", FlightRemoteIpServerStreamTracer.getRemoteIp()); + } finally { + context.detach(previous); + } + } + + @Test + public void testFallbackRemoteIpWithoutFlightContext() { + Assert.assertEquals("0.0.0.0", FlightRemoteIpServerStreamTracer.getRemoteIp()); + } + + private static class TestServerCallInfo extends ServerStreamTracer.ServerCallInfo { + private final Attributes attributes; + + TestServerCallInfo(SocketAddress remoteAddress) { + Attributes.Builder builder = Attributes.newBuilder(); + if (remoteAddress != null) { + builder.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress); + } + this.attributes = builder.build(); + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return null; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public String getAuthority() { + return null; + } + } +} diff --git a/regression-test/suites/arrow_flight_sql_p0/test_auth_remote_ip.groovy b/regression-test/suites/arrow_flight_sql_p0/test_auth_remote_ip.groovy new file mode 100644 index 00000000000000..1b58a31fb63810 --- /dev/null +++ b/regression-test/suites/arrow_flight_sql_p0/test_auth_remote_ip.groovy @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import java.sql.Connection +import java.sql.DriverManager + +suite("test_auth_remote_ip", "arrow_flight_sql") { + String user = "flight_auth_remote_ip_user" + String password = "flight_auth_remote_ip_pwd" + String wrongPassword = "wrong_flight_auth_remote_ip_pwd" + List remoteIpHosts = [ + "127.%", + "10.%", + "172.%", + "192.168.%", + "::1", + "0:0:0:0:0:0:0:1" + ] + List allHosts = ["0.0.0.0"] + remoteIpHosts + + try { + allHosts.each { host -> + jdbc_sql """DROP USER IF EXISTS '${user}'@'${host}'""" + } + + jdbc_sql """CREATE USER '${user}'@'0.0.0.0' IDENTIFIED BY '${wrongPassword}'""" + String validComputeGroup = null + if (isCloudMode()) { + def computeGroups = sql """SHOW COMPUTE GROUPS""" + assertTrue(!computeGroups.isEmpty()) + validComputeGroup = computeGroups[0][0] + } + remoteIpHosts.each { host -> + jdbc_sql """CREATE USER '${user}'@'${host}' IDENTIFIED BY '${password}'""" + jdbc_sql """GRANT SELECT_PRIV ON *.* TO '${user}'@'${host}'""" + if (validComputeGroup != null) { + jdbc_sql """GRANT USAGE_PRIV ON COMPUTE GROUP '${validComputeGroup}' TO '${user}'@'${host}'""" + } + } + + Class.forName("org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver") + String arrowFlightSqlHost = context.config.otherConfigs.get("extArrowFlightSqlHost") + String arrowFlightSqlPort = context.config.otherConfigs.get("extArrowFlightSqlPort") + String arrowFlightSqlUrl = "jdbc:arrow-flight-sql://${arrowFlightSqlHost}:${arrowFlightSqlPort}" + + "/?useServerPrepStmts=false&useSSL=false&useEncryption=false" + + Connection conn = DriverManager.getConnection(arrowFlightSqlUrl, user, password) + try { + List> result = sql_impl(conn, "SELECT 1") + assertEquals(1, result.size()) + assertEquals(1, (result[0][0] as Number).intValue()) + } finally { + conn.close() + } + } finally { + allHosts.each { host -> + try { + jdbc_sql """DROP USER IF EXISTS '${user}'@'${host}'""" + } catch (Throwable t) { + logger.warn("Failed to drop test user '${user}'@'${host}'", t) + } + } + } +}