Skip to content

Commit

Permalink
[FLINK-32030][sql-client] Add URLs support for SQL Client gateway mode (
Browse files Browse the repository at this point in the history
  • Loading branch information
afedulov committed May 16, 2023
1 parent ce286c9 commit 4bd51ce
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 18 deletions.
12 changes: 12 additions & 0 deletions flink-core/src/main/java/org/apache/flink/util/NetUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ private static URL validateHostPortString(String hostPort) {
}
}

/**
* Converts an InetSocketAddress to a URL. This method assigns the "http://" schema to the URL
* by default.
*
* @param socketAddress the InetSocketAddress to be converted
* @return a URL object representing the provided socket address with "http://" schema
*/
public static URL socketToUrl(InetSocketAddress socketAddress) {
String hostPort = socketAddress.getHostString() + ":" + socketAddress.getPort();
return validateHostPortString(hostPort);
}

/**
* Calls {@link ServerSocket#accept()} on the provided server socket, suppressing any thrown
* {@link SocketTimeoutException}s. This is a workaround for the underlying JDK-8237858 bug in
Expand Down
11 changes: 11 additions & 0 deletions flink-core/src/test/java/org/apache/flink/util/NetUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

package org.apache.flink.util;

import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
Expand All @@ -32,6 +34,7 @@
import java.util.Iterator;
import java.util.Set;

import static org.apache.flink.util.NetUtils.socketToUrl;
import static org.hamcrest.core.IsCollectionContaining.hasItems;
import static org.hamcrest.core.IsNot.not;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -343,4 +346,12 @@ public void testFormatAddress() {
NetUtils.unresolvedHostAndPortToNormalizedString(host, port));
}
}

@Test
public void testSocketToUrl() throws MalformedURLException {
InetSocketAddress socketAddress = new InetSocketAddress("foo.com", 8080);
URL expectedResult = new URL("http://foo.com:8080");

Assertions.assertThat(socketToUrl(socketAddress)).isEqualTo(expectedResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ public class RestClient implements AutoCloseableAsync {

private final AtomicBoolean isRunning = new AtomicBoolean(true);

public static final String VERSION_PLACEHOLDER = "{{VERSION}}";

@VisibleForTesting List<OutboundChannelHandlerFactory> outboundChannelHandlerFactories;

public RestClient(Configuration configuration, Executor executor)
Expand Down Expand Up @@ -353,7 +355,7 @@ CompletableFuture<P> sendRequest(
}

String versionedHandlerURL =
"/" + apiVersion.getURLVersionPrefix() + messageHeaders.getTargetRestEndpointURL();
constructVersionedHandlerUrl(messageHeaders, apiVersion.getURLVersionPrefix());
String targetUrl = MessageParameters.resolveUrl(versionedHandlerURL, messageParameters);

LOG.debug(
Expand Down Expand Up @@ -394,6 +396,16 @@ CompletableFuture<P> sendRequest(
return submitRequest(targetAddress, targetPort, httpRequest, responseType);
}

private static <M extends MessageHeaders<?, ?, ?>> String constructVersionedHandlerUrl(
M messageHeaders, String urlVersionPrefix) {
String targetUrl = messageHeaders.getTargetRestEndpointURL();
if (targetUrl.contains(VERSION_PLACEHOLDER)) {
return targetUrl.replace(VERSION_PLACEHOLDER, urlVersionPrefix);
} else {
return "/" + urlVersionPrefix + messageHeaders.getTargetRestEndpointURL();
}
}

private static Request createRequest(
String targetAddress,
String targetUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import javax.annotation.Nullable;

import java.net.InetSocketAddress;
import java.net.URL;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -135,7 +134,7 @@ public Configuration getPythonConfiguration() {
/** Command option lines to configure SQL Client in the gateway mode. */
public static class GatewayCliOptions extends CliOptions {

private final @Nullable InetSocketAddress gatewayAddress;
private final @Nullable URL gatewayAddress;

GatewayCliOptions(
boolean isPrintHelp,
Expand All @@ -144,7 +143,7 @@ public static class GatewayCliOptions extends CliOptions {
URL sqlFile,
String updateStatement,
String historyFilePath,
@Nullable InetSocketAddress gatewayAddress,
@Nullable URL gatewayAddress,
Properties sessionConfig) {
super(
isPrintHelp,
Expand All @@ -157,7 +156,7 @@ public static class GatewayCliOptions extends CliOptions {
this.gatewayAddress = gatewayAddress;
}

public Optional<InetSocketAddress> getGatewayAddress() {
public Optional<URL> getGatewayAddress() {
return Optional.ofNullable(gatewayAddress);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.PrintWriter;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
Expand All @@ -49,6 +52,8 @@
/** Parser for command line options. */
public class CliOptionsParser {

private static final Logger LOG = LoggerFactory.getLogger(CliOptionsParser.class);

public static final Option OPTION_HELP =
Option.builder("h")
.required(false)
Expand Down Expand Up @@ -288,7 +293,7 @@ public static CliOptions parseGatewayModeClient(String[] args) {
line.getOptionValue(CliOptionsParser.OPTION_UPDATE.getOpt()),
line.getOptionValue(CliOptionsParser.OPTION_HISTORY.getOpt()),
line.hasOption(CliOptionsParser.OPTION_ENDPOINT_ADDRESS.getOpt())
? NetUtils.parseHostPortAddress(
? parseGatewayAddress(
line.getOptionValue(
CliOptionsParser.OPTION_ENDPOINT_ADDRESS.getOpt()))
: null,
Expand All @@ -298,6 +303,30 @@ public static CliOptions parseGatewayModeClient(String[] args) {
}
}

private static URL parseGatewayAddress(String cliOptionAddress) {
URL url;
try {
url = new URL(cliOptionAddress);
if (!NetUtils.isValidHostPort(url.getPort())) {
url =
new URL(
url.getProtocol(),
url.getHost(),
url.getDefaultPort(),
url.getPath());
}

} catch (MalformedURLException e) {
// Required for backwards compatibility
LOG.warn(
"The gateway address should be specified as a URL, i.e. https://hostname:port/optional_path.");
LOG.warn(
"Trying to fallback to hostname:port (will use non-encrypted, http connection).");
url = NetUtils.getCorrectHostnamePort(cliOptionAddress);
}
return url;
}

// --------------------------------------------------------------------------------------------

private static URL checkUrl(CommandLine line, Option option) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.Closeable;
import java.net.InetSocketAddress;
import java.net.URL;
import java.util.List;

/** A gateway for communicating with Flink and other external systems. */
Expand All @@ -34,6 +35,10 @@ static Executor create(
return new ExecutorImpl(defaultContext, address, sessionId);
}

static Executor create(DefaultContext defaultContext, URL address, String sessionId) {
return new ExecutorImpl(defaultContext, address, sessionId);
}

/**
* Configures session with statement.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.flink.table.gateway.rest.header.statement.ExecuteStatementHeaders;
import org.apache.flink.table.gateway.rest.header.statement.FetchResultsHeaders;
import org.apache.flink.table.gateway.rest.header.util.GetApiVersionHeaders;
import org.apache.flink.table.gateway.rest.header.util.UrlPrefixDecorator;
import org.apache.flink.table.gateway.rest.message.operation.OperationMessageParameters;
import org.apache.flink.table.gateway.rest.message.operation.OperationStatusResponseBody;
import org.apache.flink.table.gateway.rest.message.session.CloseSessionResponseBody;
Expand All @@ -68,6 +69,7 @@
import org.apache.flink.table.gateway.rest.util.SqlGatewayRestEndpointUtils;
import org.apache.flink.table.gateway.service.context.DefaultContext;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.NetUtils;
import org.apache.flink.util.Preconditions;

import org.slf4j.Logger;
Expand All @@ -77,6 +79,7 @@

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URL;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
Expand All @@ -101,7 +104,8 @@ public class ExecutorImpl implements Executor {
private static final long HEARTBEAT_INTERVAL_MILLISECONDS = 60_000L;

private final AutoCloseableRegistry registry;
private final InetSocketAddress gatewayAddress;
private final URL gatewayUrl;

private final ExecutorService executorService;
private final RestClient restClient;

Expand All @@ -110,7 +114,15 @@ public class ExecutorImpl implements Executor {

public ExecutorImpl(
DefaultContext defaultContext, InetSocketAddress gatewayAddress, String sessionId) {
this(defaultContext, gatewayAddress, sessionId, HEARTBEAT_INTERVAL_MILLISECONDS);
this(
defaultContext,
NetUtils.socketToUrl(gatewayAddress),
sessionId,
HEARTBEAT_INTERVAL_MILLISECONDS);
}

public ExecutorImpl(DefaultContext defaultContext, URL gatewayUrl, String sessionId) {
this(defaultContext, gatewayUrl, sessionId, HEARTBEAT_INTERVAL_MILLISECONDS);
}

@VisibleForTesting
Expand All @@ -119,9 +131,18 @@ public ExecutorImpl(
InetSocketAddress gatewayAddress,
String sessionId,
long heartbeatInterval) {
this(defaultContext, NetUtils.socketToUrl(gatewayAddress), sessionId, heartbeatInterval);
}

@VisibleForTesting
ExecutorImpl(
DefaultContext defaultContext,
URL gatewayUrl,
String sessionId,
long heartbeatInterval) {
this.registry = new AutoCloseableRegistry();
this.gatewayUrl = gatewayUrl;
try {
this.gatewayAddress = gatewayAddress;
// register required resource
this.executorService = Executors.newCachedThreadPool();
registry.registerCloseable(executorService::shutdownNow);
Expand All @@ -134,7 +155,7 @@ public ExecutorImpl(
// register session
LOG.info(
"Open session to {} with connection version: {}.",
gatewayAddress,
gatewayUrl,
connectionVersion);
OpenSessionResponseBody response =
sendRequest(
Expand Down Expand Up @@ -180,7 +201,7 @@ public void configureSession(String statement) {
.get();
} catch (Exception e) {
throw new SqlExecutionException(
String.format("Failed to open session to %s", gatewayAddress), e);
String.format("Failed to open session to %s", gatewayUrl), e);
}
}

Expand Down Expand Up @@ -342,7 +363,11 @@ private FetchResultsResponseBody fetchResults(OperationHandle operationHandle, l
P extends ResponseBody>
CompletableFuture<P> sendRequest(M messageHeaders, U messageParameters, R request) {
Preconditions.checkNotNull(connectionVersion, "The connection version should not be null.");
return sendRequest(messageHeaders, messageParameters, request, connectionVersion);
return sendRequest(
new UrlPrefixDecorator<>(messageHeaders, gatewayUrl.getPath()),
messageParameters,
request,
connectionVersion);
}

private <
Expand All @@ -357,8 +382,8 @@ CompletableFuture<P> sendRequest(
SqlGatewayRestAPIVersion connectionVersion) {
try {
return restClient.sendRequest(
gatewayAddress.getHostName(),
gatewayAddress.getPort(),
gatewayUrl.getHost(),
gatewayUrl.getPort(),
messageHeaders,
messageParameters,
request,
Expand Down Expand Up @@ -454,9 +479,11 @@ private SqlGatewayRestAPIVersion negotiateVersion() throws Exception {
List<SqlGatewayRestAPIVersion> gatewayVersions =
getResponse(
restClient.sendRequest(
gatewayAddress.getHostName(),
gatewayAddress.getPort(),
GetApiVersionHeaders.getInstance(),
gatewayUrl.getHost(),
gatewayUrl.getPort(),
new UrlPrefixDecorator<>(
GetApiVersionHeaders.getInstance(),
gatewayUrl.getPath()),
EmptyMessageParameters.getInstance(),
EmptyRequestBody.getInstance(),
Collections.emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void testEmptyOptions() throws Exception {
}

@Test
void testGatewayMode() throws Exception {
void testGatewayModeHostnamePort() throws Exception {
String[] args =
new String[] {
"gateway",
Expand All @@ -156,6 +156,21 @@ void testGatewayMode() throws Exception {
assertThat(actual).contains("execution.target", "yarn-session");
}

@Test
void testGatewayModeUrl() throws Exception {
String[] args =
new String[] {
"gateway",
"-e",
String.format(
"http://%s:%d",
SQL_GATEWAY_REST_ENDPOINT_EXTENSION.getTargetAddress(),
SQL_GATEWAY_REST_ENDPOINT_EXTENSION.getTargetPort())
};
String actual = runSqlClient(args, String.join("\n", "SET;", "QUIT;"), false);
assertThat(actual).contains("execution.target", "yarn-session");
}

@Test
void testGatewayModeWithoutAddress() throws Exception {
String[] args = new String[] {"gateway"};
Expand Down
Loading

0 comments on commit 4bd51ce

Please sign in to comment.