Skip to content

Commit

Permalink
cache ssl contexts and reuse them (#12404)
Browse files Browse the repository at this point in the history
* cache ssl contexts and reuse them

* address comments

* update java inline comment

* add a comment
  • Loading branch information
zhtaoxiang authored Feb 16, 2024
1 parent 38d86b0 commit 24f48e3
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@

import io.netty.handler.ssl.SslProvider;
import java.security.KeyStore;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.StringUtils;


/**
* Container object for TLS/SSL configuration of pinot clients and servers (netty, grizzly, etc.)
*/
@Getter
@Setter
@EqualsAndHashCode
public class TlsConfig {
private boolean _clientAuthEnabled;
private String _keyStoreType = KeyStore.getDefaultType();
Expand All @@ -35,6 +41,7 @@ public class TlsConfig {
private String _trustStorePath;
private String _trustStorePassword;
private String _sslProvider = SslProvider.JDK.toString();
// If true, the client will not verify the server's certificate
private boolean _insecure = false;

public TlsConfig() {
Expand All @@ -52,79 +59,7 @@ public TlsConfig(TlsConfig tlsConfig) {
_sslProvider = tlsConfig._sslProvider;
}

public boolean isClientAuthEnabled() {
return _clientAuthEnabled;
}

public void setClientAuthEnabled(boolean clientAuthEnabled) {
_clientAuthEnabled = clientAuthEnabled;
}

public String getKeyStoreType() {
return _keyStoreType;
}

public void setKeyStoreType(String keyStoreType) {
_keyStoreType = keyStoreType;
}

public String getKeyStorePath() {
return _keyStorePath;
}

public void setKeyStorePath(String keyStorePath) {
_keyStorePath = keyStorePath;
}

public String getKeyStorePassword() {
return _keyStorePassword;
}

public void setKeyStorePassword(String keyStorePassword) {
_keyStorePassword = keyStorePassword;
}

public String getTrustStoreType() {
return _trustStoreType;
}

public void setTrustStoreType(String trustStoreType) {
_trustStoreType = trustStoreType;
}

public String getTrustStorePath() {
return _trustStorePath;
}

public void setTrustStorePath(String trustStorePath) {
_trustStorePath = trustStorePath;
}

public String getTrustStorePassword() {
return _trustStorePassword;
}

public void setTrustStorePassword(String trustStorePassword) {
_trustStorePassword = trustStorePassword;
}

public String getSslProvider() {
return _sslProvider;
}

public void setSslProvider(String sslProvider) {
_sslProvider = sslProvider;
}

public boolean isCustomized() {
return StringUtils.isNoneBlank(_keyStorePath) || StringUtils.isNoneBlank(_trustStorePath);
}

public boolean isInsecure() {
return _insecure;
}

public void setInsecure(boolean insecure) {
_insecure = insecure;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import nl.altindag.ssl.SSLFactory;
Expand All @@ -41,6 +44,10 @@
public class GrpcQueryClient {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryClient.class);
private static final int DEFAULT_CHANNEL_SHUTDOWN_TIMEOUT_SECOND = 10;
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final ManagedChannel _managedChannel;
private final PinotQueryServerGrpc.PinotQueryServerBlockingStub _blockingStub;
Expand All @@ -55,8 +62,17 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.usePlaintext().build();
} else {
_managedChannel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(buildSslContext(config.getTlsConfig())).build();
}
_blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
}

private SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
TlsConfig tlsConfig = config.getTlsConfig();
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
Expand All @@ -71,14 +87,12 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
} else {
sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
}
_managedChannel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(sslContextBuilder.build()).build();
return sslContextBuilder.build();
} catch (SSLException e) {
throw new RuntimeException("Failed to create Netty gRPC channel with SSL Context", e);
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
}
_blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
});
return sslContext;
}

public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.ssl.SslContext;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.metrics.BrokerMetrics;
Expand All @@ -36,8 +38,15 @@
* The {@code ChannelHandlerFactory} provides all kinds of Netty ChannelHandlers
*/
public class ChannelHandlerFactory {

public static final String SSL = "ssl";
// The key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private ChannelHandlerFactory() {
}
Expand All @@ -61,14 +70,18 @@ public static ChannelHandler getLengthFieldPrepender() {
* The {@code getClientTlsHandler} return a Client side Tls handler that encrypt and decrypt everything.
*/
public static ChannelHandler getClientTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
return TlsUtils.buildClientContext(tlsConfig).newHandler(ch.alloc());
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE
.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildClientContext(tlsConfig));
return sslContext.newHandler(ch.alloc());
}

/**
* The {@code getServerTlsHandler} return a Server side Tls handler that encrypt and decrypt everything.
*/
public static ChannelHandler getServerTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
return TlsUtils.buildServerContext(tlsConfig).newHandler(ch.alloc());
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(
tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildServerContext(tlsConfig));
return sslContext.newHandler(ch.alloc());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import nl.altindag.ssl.SSLFactory;
Expand Down Expand Up @@ -56,6 +58,10 @@
// TODO: Plug in QueryScheduler
public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBase {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryServer.class);
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final QueryExecutor _queryExecutor;
private final ServerMetrics _serverMetrics;
Expand Down Expand Up @@ -85,23 +91,30 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws Exception {
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
}
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
}
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
}
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
} catch (Exception e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
});
return sslContext;
}

public void start() {
Expand Down

0 comments on commit 24f48e3

Please sign in to comment.