Skip to content

Commit

Permalink
[FLINK-9261] [security] Fix SSL support for REST API and Web UI.
Browse files Browse the repository at this point in the history
- Remove wrong reuse of SSLEngine instances. SSLEngine must be re-created for
  every SocketChannel initialization.
- Add ChunkedWriteHandler to REST server pipeline because StaticFileServerHandler
  relies on it.
- Add integration tests to verify that SSL can be enabled.
  • Loading branch information
GJL authored and StephanEwen committed May 9, 2018
1 parent 7c87c1a commit 3afa5eb
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 80 deletions.
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.net;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;

import static java.util.Objects.requireNonNull;

/**
* Creates and configures {@link SSLEngine} instances.
*/
public class SSLEngineFactory {

private final SSLContext sslContext;

private final String[] enabledProtocols;

private final String[] enabledCipherSuites;

private final boolean clientMode;

public SSLEngineFactory(
final SSLContext sslContext,
final String[] enabledProtocols,
final String[] enabledCipherSuites,
final boolean clientMode) {
this.sslContext = requireNonNull(sslContext, "sslContext must not be null");
this.enabledProtocols = requireNonNull(enabledProtocols, "enabledProtocols must not be null");
this.enabledCipherSuites = requireNonNull(enabledCipherSuites, "cipherSuites must not be null");
this.clientMode = clientMode;
}

public SSLEngine createSSLEngine() {
final SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setEnabledProtocols(enabledProtocols);
sslEngine.setEnabledCipherSuites(enabledCipherSuites);
sslEngine.setUseClientMode(clientMode);
return sslEngine;
}
}
Expand Up @@ -25,6 +25,7 @@
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;


import javax.annotation.Nullable;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
Expand All @@ -38,6 +39,9 @@
import java.security.KeyStore; import java.security.KeyStore;
import java.util.Arrays; import java.util.Arrays;


import static java.util.Objects.requireNonNull;
import static org.apache.flink.util.Preconditions.checkState;

/** /**
* Common utilities to manage SSL transport settings. * Common utilities to manage SSL transport settings.
*/ */
Expand Down Expand Up @@ -81,16 +85,62 @@ public static void setSSLVerAndCipherSuites(ServerSocket socket, Configuration c
} }
} }


/**
* Creates a {@link SSLEngineFactory} to be used by the Server.
*
* @param config The application configuration.
*/
public static SSLEngineFactory createServerSSLEngineFactory(final Configuration config) throws Exception {
return createSSLEngineFactory(config, false);
}

/**
* Creates a {@link SSLEngineFactory} to be used by the Client.
* @param config The application configuration.
*/
public static SSLEngineFactory createClientSSLEngineFactory(final Configuration config) throws Exception {
return createSSLEngineFactory(config, true);
}

private static SSLEngineFactory createSSLEngineFactory(
final Configuration config,
final boolean clientMode) throws Exception {

final SSLContext sslContext = clientMode ?
createSSLClientContext(config) :
createSSLServerContext(config);

checkState(sslContext != null, "%s it not enabled", SecurityOptions.SSL_ENABLED.key());

return new SSLEngineFactory(
sslContext,
getEnabledProtocols(config),
getEnabledCipherSuites(config),
clientMode);
}

/** /**
* Sets SSL version and cipher suites for SSLEngine. * Sets SSL version and cipher suites for SSLEngine.
* @param engine *
* SSLEngine to be handled * @param engine SSLEngine to be handled
* @param config * @param config The application configuration
* The application configuration * @deprecated Use {@link #createClientSSLEngineFactory(Configuration)} or
* {@link #createServerSSLEngineFactory(Configuration)}.
*/ */
@Deprecated
public static void setSSLVerAndCipherSuites(SSLEngine engine, Configuration config) { public static void setSSLVerAndCipherSuites(SSLEngine engine, Configuration config) {
engine.setEnabledProtocols(config.getString(SecurityOptions.SSL_PROTOCOL).split(",")); engine.setEnabledProtocols(getEnabledProtocols(config));
engine.setEnabledCipherSuites(config.getString(SecurityOptions.SSL_ALGORITHMS).split(",")); engine.setEnabledCipherSuites(getEnabledCipherSuites(config));
}

private static String[] getEnabledProtocols(final Configuration config) {
requireNonNull(config, "config must not be null");
return config.getString(SecurityOptions.SSL_PROTOCOL).split(",");
}

private static String[] getEnabledCipherSuites(final Configuration config) {
requireNonNull(config, "config must not be null");
return config.getString(SecurityOptions.SSL_ALGORITHMS).split(",");
} }


/** /**
Expand Down Expand Up @@ -122,6 +172,7 @@ public static void setSSLVerifyHostname(Configuration sslConfig, SSLParameters s
* @throws Exception * @throws Exception
* Thrown if there is any misconfiguration * Thrown if there is any misconfiguration
*/ */
@Nullable
public static SSLContext createSSLClientContext(Configuration sslConfig) throws Exception { public static SSLContext createSSLClientContext(Configuration sslConfig) throws Exception {


Preconditions.checkNotNull(sslConfig); Preconditions.checkNotNull(sslConfig);
Expand Down Expand Up @@ -170,6 +221,7 @@ public static SSLContext createSSLClientContext(Configuration sslConfig) throws
* @throws Exception * @throws Exception
* Thrown if there is any misconfiguration * Thrown if there is any misconfiguration
*/ */
@Nullable
public static SSLContext createSSLServerContext(Configuration sslConfig) throws Exception { public static SSLContext createSSLServerContext(Configuration sslConfig) throws Exception {


Preconditions.checkNotNull(sslConfig); Preconditions.checkNotNull(sslConfig);
Expand All @@ -191,14 +243,8 @@ public static SSLContext createSSLServerContext(Configuration sslConfig) throws
Preconditions.checkNotNull(certPassword, SecurityOptions.SSL_KEY_PASSWORD.key() + " was not configured."); Preconditions.checkNotNull(certPassword, SecurityOptions.SSL_KEY_PASSWORD.key() + " was not configured.");


KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
FileInputStream keyStoreFile = null; try (FileInputStream keyStoreFile = new FileInputStream(new File(keystoreFilePath))) {
try {
keyStoreFile = new FileInputStream(new File(keystoreFilePath));
ks.load(keyStoreFile, keystorePassword.toCharArray()); ks.load(keyStoreFile, keystorePassword.toCharArray());
} finally {
if (keyStoreFile != null) {
keyStoreFile.close();
}
} }


// Set up key manager factory to use the server key store // Set up key manager factory to use the server key store
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.time.Time; import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.RestOptions; import org.apache.flink.configuration.RestOptions;
import org.apache.flink.runtime.net.SSLEngineFactory;
import org.apache.flink.runtime.rest.messages.ErrorResponseBody; import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
import org.apache.flink.runtime.rest.messages.MessageHeaders; import org.apache.flink.runtime.rest.messages.MessageHeaders;
import org.apache.flink.runtime.rest.messages.MessageParameters; import org.apache.flink.runtime.rest.messages.MessageParameters;
Expand Down Expand Up @@ -66,8 +67,6 @@
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;


import javax.net.ssl.SSLEngine;

import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.StringWriter; import java.io.StringWriter;
Expand All @@ -93,13 +92,13 @@ public RestClient(RestClientConfiguration configuration, Executor executor) {
Preconditions.checkNotNull(configuration); Preconditions.checkNotNull(configuration);
this.executor = Preconditions.checkNotNull(executor); this.executor = Preconditions.checkNotNull(executor);


SSLEngine sslEngine = configuration.getSslEngine(); final SSLEngineFactory sslEngineFactory = configuration.getSslEngineFactory();
ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() { ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() {
@Override @Override
protected void initChannel(SocketChannel socketChannel) throws Exception { protected void initChannel(SocketChannel socketChannel) {
// SSL should be the first handler in the pipeline // SSL should be the first handler in the pipeline
if (sslEngine != null) { if (sslEngineFactory != null) {
socketChannel.pipeline().addLast("ssl", new SslHandler(sslEngine)); socketChannel.pipeline().addLast("ssl", new SslHandler(sslEngineFactory.createSSLEngine()));
} }


socketChannel.pipeline() socketChannel.pipeline()
Expand Down
Expand Up @@ -21,12 +21,12 @@
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.RestOptions; import org.apache.flink.configuration.RestOptions;
import org.apache.flink.configuration.SecurityOptions; import org.apache.flink.configuration.SecurityOptions;
import org.apache.flink.runtime.net.SSLEngineFactory;
import org.apache.flink.runtime.net.SSLUtils; import org.apache.flink.runtime.net.SSLUtils;
import org.apache.flink.util.ConfigurationException; import org.apache.flink.util.ConfigurationException;
import org.apache.flink.util.Preconditions; import org.apache.flink.util.Preconditions;


import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;


import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkArgument;
Expand All @@ -37,18 +37,18 @@
public final class RestClientConfiguration { public final class RestClientConfiguration {


@Nullable @Nullable
private final SSLEngine sslEngine; private final SSLEngineFactory sslEngineFactory;


private final long connectionTimeout; private final long connectionTimeout;


private final int maxContentLength; private final int maxContentLength;


private RestClientConfiguration( private RestClientConfiguration(
@Nullable final SSLEngine sslEngine, @Nullable final SSLEngineFactory sslEngineFactory,
final long connectionTimeout, final long connectionTimeout,
final int maxContentLength) { final int maxContentLength) {
checkArgument(maxContentLength > 0, "maxContentLength must be positive, was: %d", maxContentLength); checkArgument(maxContentLength > 0, "maxContentLength must be positive, was: %d", maxContentLength);
this.sslEngine = sslEngine; this.sslEngineFactory = sslEngineFactory;
this.connectionTimeout = connectionTimeout; this.connectionTimeout = connectionTimeout;
this.maxContentLength = maxContentLength; this.maxContentLength = maxContentLength;
} }
Expand All @@ -58,9 +58,9 @@ private RestClientConfiguration(
* *
* @return SSLEngine that the REST client endpoint should use, or null if SSL was disabled * @return SSLEngine that the REST client endpoint should use, or null if SSL was disabled
*/ */

@Nullable
public SSLEngine getSslEngine() { public SSLEngineFactory getSslEngineFactory() {
return sslEngine; return sslEngineFactory;
} }


/** /**
Expand Down Expand Up @@ -90,25 +90,22 @@ public int getMaxContentLength() {
public static RestClientConfiguration fromConfiguration(Configuration config) throws ConfigurationException { public static RestClientConfiguration fromConfiguration(Configuration config) throws ConfigurationException {
Preconditions.checkNotNull(config); Preconditions.checkNotNull(config);


SSLEngine sslEngine = null; final SSLEngineFactory sslEngineFactory;
boolean enableSSL = config.getBoolean(SecurityOptions.SSL_ENABLED); final boolean enableSSL = config.getBoolean(SecurityOptions.SSL_ENABLED);
if (enableSSL) { if (enableSSL) {
try { try {
SSLContext sslContext = SSLUtils.createSSLServerContext(config); sslEngineFactory = SSLUtils.createClientSSLEngineFactory(config);
if (sslContext != null) {
sslEngine = sslContext.createSSLEngine();
SSLUtils.setSSLVerAndCipherSuites(sslEngine, config);
sslEngine.setUseClientMode(false);
}
} catch (Exception e) { } catch (Exception e) {
throw new ConfigurationException("Failed to initialize SSLContext for the web frontend", e); throw new ConfigurationException("Failed to initialize SSLContext for the web frontend", e);
} }
} else {
sslEngineFactory = null;
} }


final long connectionTimeout = config.getLong(RestOptions.CONNECTION_TIMEOUT); final long connectionTimeout = config.getLong(RestOptions.CONNECTION_TIMEOUT);


int maxContentLength = config.getInteger(RestOptions.CLIENT_MAX_CONTENT_LENGTH); int maxContentLength = config.getInteger(RestOptions.CLIENT_MAX_CONTENT_LENGTH);


return new RestClientConfiguration(sslEngine, connectionTimeout, maxContentLength); return new RestClientConfiguration(sslEngineFactory, connectionTimeout, maxContentLength);
} }
} }
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.time.Time; import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.concurrent.FutureUtils;
import org.apache.flink.runtime.net.SSLEngineFactory;
import org.apache.flink.runtime.rest.handler.PipelineErrorHandler; import org.apache.flink.runtime.rest.handler.PipelineErrorHandler;
import org.apache.flink.runtime.rest.handler.RestHandlerSpecification; import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
import org.apache.flink.runtime.rest.handler.RouterHandler; import org.apache.flink.runtime.rest.handler.RouterHandler;
Expand All @@ -41,13 +42,13 @@
import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Handler;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router; import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.router.Router;
import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler;
import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler;
import org.apache.flink.shaded.netty4.io.netty.util.concurrent.DefaultThreadFactory; import org.apache.flink.shaded.netty4.io.netty.util.concurrent.DefaultThreadFactory;


import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;


import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.net.ssl.SSLEngine;


import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
Expand Down Expand Up @@ -76,7 +77,8 @@ public abstract class RestServerEndpoint implements AutoCloseableAsync {
private final String restAddress; private final String restAddress;
private final String restBindAddress; private final String restBindAddress;
private final int restBindPort; private final int restBindPort;
private final SSLEngine sslEngine; @Nullable
private final SSLEngineFactory sslEngineFactory;
private final int maxContentLength; private final int maxContentLength;


protected final Path uploadDir; protected final Path uploadDir;
Expand All @@ -96,7 +98,7 @@ public RestServerEndpoint(RestServerEndpointConfiguration configuration) throws
this.restAddress = configuration.getRestAddress(); this.restAddress = configuration.getRestAddress();
this.restBindAddress = configuration.getRestBindAddress(); this.restBindAddress = configuration.getRestBindAddress();
this.restBindPort = configuration.getRestBindPort(); this.restBindPort = configuration.getRestBindPort();
this.sslEngine = configuration.getSslEngine(); this.sslEngineFactory = configuration.getSslEngineFactory();


this.uploadDir = configuration.getUploadDir(); this.uploadDir = configuration.getUploadDir();
createUploadDir(uploadDir, log); createUploadDir(uploadDir, log);
Expand Down Expand Up @@ -155,14 +157,15 @@ protected void initChannel(SocketChannel ch) {
Handler handler = new RouterHandler(router, responseHeaders); Handler handler = new RouterHandler(router, responseHeaders);


// SSL should be the first handler in the pipeline // SSL should be the first handler in the pipeline
if (sslEngine != null) { if (sslEngineFactory != null) {
ch.pipeline().addLast("ssl", new SslHandler(sslEngine)); ch.pipeline().addLast("ssl", new SslHandler(sslEngineFactory.createSSLEngine()));
} }


ch.pipeline() ch.pipeline()
.addLast(new HttpServerCodec()) .addLast(new HttpServerCodec())
.addLast(new FileUploadHandler(uploadDir)) .addLast(new FileUploadHandler(uploadDir))
.addLast(new FlinkHttpObjectAggregator(maxContentLength, responseHeaders)) .addLast(new FlinkHttpObjectAggregator(maxContentLength, responseHeaders))
.addLast(new ChunkedWriteHandler())
.addLast(handler.name(), handler) .addLast(handler.name(), handler)
.addLast(new PipelineErrorHandler(log, responseHeaders)); .addLast(new PipelineErrorHandler(log, responseHeaders));
} }
Expand Down Expand Up @@ -198,7 +201,7 @@ protected void initChannel(SocketChannel ch) {


final String protocol; final String protocol;


if (sslEngine != null) { if (sslEngineFactory != null) {
protocol = "https://"; protocol = "https://";
} else { } else {
protocol = "http://"; protocol = "http://";
Expand Down

0 comments on commit 3afa5eb

Please sign in to comment.