Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import org.apache.doris.common.Config;
import org.apache.doris.service.FrontendOptions;
import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator;
import org.apache.doris.service.arrowflight.auth2.FlightRemoteIpServerStreamTracer;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl;

import io.grpc.ServerBuilder;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
Expand All @@ -33,12 +35,14 @@
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.util.function.Consumer;

/**
* flight sql protocol implementation based on nio.
*/
public class DorisFlightSqlService {
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class);
private static final String GRPC_BUILDER_CONSUMER = "grpc.builderConsumer";
private final FlightServer flightServer;
private final FlightTokenManager flightTokenManager;
private final FlightSessionsManager flightSessionsManager;
Expand All @@ -56,6 +60,8 @@ public DorisFlightSqlService(int port) {
DorisFlightSqlProducer producer = new DorisFlightSqlProducer(
Location.forGrpcInsecure(FrontendOptions.getLocalHostAddress(), port), flightSessionsManager);
flightServer = FlightServer.builder(allocator, Location.forGrpcInsecure("0.0.0.0", port), producer)
.transportHint(GRPC_BUILDER_CONSUMER, (Consumer<ServerBuilder<?>>) builder ->
builder.addStreamTracerFactory(new FlightRemoteIpServerStreamTracer.Factory()))
.headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build();
LOG.info("Arrow Flight SQL service is created, port: {}, arrow_flight_max_connections: {},"
+ "arrow_flight_token_alive_time_second: {}", port, Config.arrow_flight_max_connections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ public FlightCredentialValidator(FlightTokenManager flightTokenManager) {
*/
@Override
public AuthResult validate(String username, String password) {
// TODO Add ClientAddress information while creating a Token
String remoteIp = "0.0.0.0";
String remoteIp = FlightRemoteIpServerStreamTracer.getRemoteIp();
FlightAuthResult flightAuthResult = FlightAuthUtils.authenticateCredentials(username, password, remoteIp, LOG);
return getAuthResultWithBearerToken(flightAuthResult);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.service.arrowflight.auth2;

import io.grpc.Attributes;
import io.grpc.Context;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;

/**
* Captures the gRPC peer address before Arrow Flight header authentication runs.
* Arrow registers header authentication ahead of user interceptors, so use ServerStreamTracer to
* seed the remote IP into the gRPC Context for Basic credential validation.
*/
public class FlightRemoteIpServerStreamTracer extends ServerStreamTracer {
static final String UNKNOWN_REMOTE_IP = "0.0.0.0";
private static final Context.Key<RemoteIpHolder> REMOTE_IP_CONTEXT_KEY =
Context.key("doris.arrow.flight.remote_ip");

@Override
public Context filterContext(Context context) {
return context.withValue(REMOTE_IP_CONTEXT_KEY, new RemoteIpHolder());
}

@Override
public void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get();
if (holder == null) {
return;
}

Attributes attributes = callInfo.getAttributes();
SocketAddress remoteAddress = attributes == null ? null : attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
holder.setRemoteIp(extractRemoteIp(remoteAddress));
}

public static String getRemoteIp() {
RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get();
if (holder == null) {
return UNKNOWN_REMOTE_IP;
}
return holder.getRemoteIp();
}

static String extractRemoteIp(SocketAddress remoteAddress) {
if (!(remoteAddress instanceof InetSocketAddress)) {
return UNKNOWN_REMOTE_IP;
}

InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress;
InetAddress address = inetSocketAddress.getAddress();
if (address != null && isNotEmpty(address.getHostAddress())) {
return address.getHostAddress();
}
if (isNotEmpty(inetSocketAddress.getHostString())) {
return inetSocketAddress.getHostString();
}
return UNKNOWN_REMOTE_IP;
}

private static boolean isNotEmpty(String value) {
return value != null && !value.isEmpty();
}

public static class Factory extends ServerStreamTracer.Factory {
@Override
public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
return new FlightRemoteIpServerStreamTracer();
}
}

private static class RemoteIpHolder {
private volatile String remoteIp = UNKNOWN_REMOTE_IP;

String getRemoteIp() {
return remoteIp;
}

void setRemoteIp(String remoteIp) {
this.remoteIp = isNotEmpty(remoteIp) ? remoteIp : UNKNOWN_REMOTE_IP;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.service.arrowflight.auth2;

import io.grpc.Attributes;
import io.grpc.Context;
import io.grpc.Grpc;
import io.grpc.MethodDescriptor;
import io.grpc.ServerStreamTracer;
import org.junit.Assert;
import org.junit.Test;

import java.net.InetSocketAddress;
import java.net.SocketAddress;

public class FlightRemoteIpServerStreamTracerTest {

@Test
public void testGetRemoteIpFromServerCallAttributes() {
FlightRemoteIpServerStreamTracer tracer = new FlightRemoteIpServerStreamTracer();
Context context = tracer.filterContext(Context.current());
Context previous = context.attach();
try {
tracer.serverCallStarted(new TestServerCallInfo(new InetSocketAddress("10.26.20.3", 12345)));

Assert.assertEquals("10.26.20.3", FlightRemoteIpServerStreamTracer.getRemoteIp());
} finally {
context.detach(previous);
}
}

@Test
public void testFallbackRemoteIpWithoutServerCallAttributes() {
FlightRemoteIpServerStreamTracer tracer = new FlightRemoteIpServerStreamTracer();
Context context = tracer.filterContext(Context.current());
Context previous = context.attach();
try {
tracer.serverCallStarted(new TestServerCallInfo(null));

Assert.assertEquals("0.0.0.0", FlightRemoteIpServerStreamTracer.getRemoteIp());
} finally {
context.detach(previous);
}
}

@Test
public void testFallbackRemoteIpWithoutFlightContext() {
Assert.assertEquals("0.0.0.0", FlightRemoteIpServerStreamTracer.getRemoteIp());
}

private static class TestServerCallInfo extends ServerStreamTracer.ServerCallInfo<Object, Object> {
private final Attributes attributes;

TestServerCallInfo(SocketAddress remoteAddress) {
Attributes.Builder builder = Attributes.newBuilder();
if (remoteAddress != null) {
builder.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress);
}
this.attributes = builder.build();
}

@Override
public MethodDescriptor<Object, Object> getMethodDescriptor() {
return null;
}

@Override
public Attributes getAttributes() {
return attributes;
}

@Override
public String getAuthority() {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import java.sql.Connection
import java.sql.DriverManager

suite("test_auth_remote_ip", "arrow_flight_sql") {
String user = "flight_auth_remote_ip_user"
String password = "flight_auth_remote_ip_pwd"
String wrongPassword = "wrong_flight_auth_remote_ip_pwd"
List<String> remoteIpHosts = [
"127.%",
"10.%",
"172.%",
"192.168.%",
"::1",
"0:0:0:0:0:0:0:1"
]
List<String> allHosts = ["0.0.0.0"] + remoteIpHosts

try {
allHosts.each { host ->
jdbc_sql """DROP USER IF EXISTS '${user}'@'${host}'"""
}

jdbc_sql """CREATE USER '${user}'@'0.0.0.0' IDENTIFIED BY '${wrongPassword}'"""
String validComputeGroup = null
if (isCloudMode()) {
def computeGroups = sql """SHOW COMPUTE GROUPS"""
assertTrue(!computeGroups.isEmpty())
validComputeGroup = computeGroups[0][0]
}
remoteIpHosts.each { host ->
jdbc_sql """CREATE USER '${user}'@'${host}' IDENTIFIED BY '${password}'"""
jdbc_sql """GRANT SELECT_PRIV ON *.* TO '${user}'@'${host}'"""
if (validComputeGroup != null) {
jdbc_sql """GRANT USAGE_PRIV ON COMPUTE GROUP '${validComputeGroup}' TO '${user}'@'${host}'"""
}
}

Class.forName("org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver")
String arrowFlightSqlHost = context.config.otherConfigs.get("extArrowFlightSqlHost")
String arrowFlightSqlPort = context.config.otherConfigs.get("extArrowFlightSqlPort")
String arrowFlightSqlUrl = "jdbc:arrow-flight-sql://${arrowFlightSqlHost}:${arrowFlightSqlPort}" +
"/?useServerPrepStmts=false&useSSL=false&useEncryption=false"

Connection conn = DriverManager.getConnection(arrowFlightSqlUrl, user, password)
try {
List<List<Object>> result = sql_impl(conn, "SELECT 1")
assertEquals(1, result.size())
assertEquals(1, (result[0][0] as Number).intValue())
} finally {
conn.close()
}
} finally {
allHosts.each { host ->
try {
jdbc_sql """DROP USER IF EXISTS '${user}'@'${host}'"""
} catch (Throwable t) {
logger.warn("Failed to drop test user '${user}'@'${host}'", t)
}
}
}
}
Loading