Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache ssl contexts and reuse them #12404

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@

import io.netty.handler.ssl.SslProvider;
import java.security.KeyStore;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.StringUtils;


/**
* Container object for TLS/SSL configuration of pinot clients and servers (netty, grizzly, etc.)
*/
@Getter
@Setter
@EqualsAndHashCode
public class TlsConfig {
private boolean _clientAuthEnabled;
private String _keyStoreType = KeyStore.getDefaultType();
Expand All @@ -35,6 +41,7 @@ public class TlsConfig {
private String _trustStorePath;
private String _trustStorePassword;
private String _sslProvider = SslProvider.JDK.toString();
// If true, the client will not verify the server's certificate
private boolean _insecure = false;

public TlsConfig() {
Expand All @@ -52,79 +59,7 @@ public TlsConfig(TlsConfig tlsConfig) {
_sslProvider = tlsConfig._sslProvider;
}

public boolean isClientAuthEnabled() {
return _clientAuthEnabled;
}

public void setClientAuthEnabled(boolean clientAuthEnabled) {
_clientAuthEnabled = clientAuthEnabled;
}

public String getKeyStoreType() {
return _keyStoreType;
}

public void setKeyStoreType(String keyStoreType) {
_keyStoreType = keyStoreType;
}

public String getKeyStorePath() {
return _keyStorePath;
}

public void setKeyStorePath(String keyStorePath) {
_keyStorePath = keyStorePath;
}

public String getKeyStorePassword() {
return _keyStorePassword;
}

public void setKeyStorePassword(String keyStorePassword) {
_keyStorePassword = keyStorePassword;
}

public String getTrustStoreType() {
return _trustStoreType;
}

public void setTrustStoreType(String trustStoreType) {
_trustStoreType = trustStoreType;
}

public String getTrustStorePath() {
return _trustStorePath;
}

public void setTrustStorePath(String trustStorePath) {
_trustStorePath = trustStorePath;
}

public String getTrustStorePassword() {
return _trustStorePassword;
}

public void setTrustStorePassword(String trustStorePassword) {
_trustStorePassword = trustStorePassword;
}

public String getSslProvider() {
return _sslProvider;
}

public void setSslProvider(String sslProvider) {
_sslProvider = sslProvider;
}

public boolean isCustomized() {
return StringUtils.isNoneBlank(_keyStorePath) || StringUtils.isNoneBlank(_trustStorePath);
}

public boolean isInsecure() {
return _insecure;
}

public void setInsecure(boolean insecure) {
_insecure = insecure;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import nl.altindag.ssl.SSLFactory;
Expand All @@ -41,6 +44,10 @@
public class GrpcQueryClient {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryClient.class);
private static final int DEFAULT_CHANNEL_SHUTDOWN_TIMEOUT_SECOND = 10;
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final ManagedChannel _managedChannel;
private final PinotQueryServerGrpc.PinotQueryServerBlockingStub _blockingStub;
Expand All @@ -55,8 +62,17 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.usePlaintext().build();
} else {
_managedChannel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(buildSslContext(config.getTlsConfig())).build();
}
_blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
}

private SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
TlsConfig tlsConfig = config.getTlsConfig();
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
Expand All @@ -71,14 +87,12 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
} else {
sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
}
_managedChannel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(sslContextBuilder.build()).build();
return sslContextBuilder.build();
} catch (SSLException e) {
throw new RuntimeException("Failed to create Netty gRPC channel with SSL Context", e);
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
}
_blockingStub = PinotQueryServerGrpc.newBlockingStub(_managedChannel);
});
return sslContext;
}

public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.ssl.SslContext;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.metrics.BrokerMetrics;
Expand All @@ -36,8 +38,15 @@
* The {@code ChannelHandlerFactory} provides all kinds of Netty ChannelHandlers
*/
public class ChannelHandlerFactory {

public static final String SSL = "ssl";
// The key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private ChannelHandlerFactory() {
}
Expand All @@ -61,14 +70,18 @@ public static ChannelHandler getLengthFieldPrepender() {
* The {@code getClientTlsHandler} return a Client side Tls handler that encrypt and decrypt everything.
*/
public static ChannelHandler getClientTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
return TlsUtils.buildClientContext(tlsConfig).newHandler(ch.alloc());
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE
.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildClientContext(tlsConfig));
return sslContext.newHandler(ch.alloc());
}

/**
* The {@code getServerTlsHandler} return a Server side Tls handler that encrypt and decrypt everything.
*/
public static ChannelHandler getServerTlsHandler(TlsConfig tlsConfig, SocketChannel ch) {
return TlsUtils.buildServerContext(tlsConfig).newHandler(ch.alloc());
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(
tlsConfig.hashCode(), tlsConfigHashCode -> TlsUtils.buildServerContext(tlsConfig));
return sslContext.newHandler(ch.alloc());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import nl.altindag.ssl.SSLFactory;
Expand Down Expand Up @@ -56,6 +58,10 @@
// TODO: Plug in QueryScheduler
public class GrpcQueryServer extends PinotQueryServerGrpc.PinotQueryServerImplBase {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcQueryServer.class);
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final QueryExecutor _queryExecutor;
private final ServerMetrics _serverMetrics;
Expand Down Expand Up @@ -85,23 +91,30 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws Exception {
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
}
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
}
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory = TlsUtils.createSSLFactory(tlsConfig);
if (TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getKeyStorePath())
&& TlsUtils.isKeyOrTrustStorePathNullOrHasFileScheme(tlsConfig.getTrustStorePath())) {
TlsUtils.enableAutoRenewalFromFileStoreForSSLFactory(sslFactory, tlsConfig);
}
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
} catch (Exception e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
});
return sslContext;
}

public void start() {
Expand Down
Loading