Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TCP TLS server SNI server name leak #5100

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/java/io/vertx/core/net/impl/SSLHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ public SSLHelper(TCPSSLOptions options, List<String> applicationProtocols) {
this.applicationProtocols = applicationProtocols;
}

public synchronized int sniEntrySize() {
CachedProvider res = cachedProvider.result();
if (res != null) {
return res.sslChannelProvider.sniEntrySize();
}
return 0;
}

private static class CachedProvider {
final SSLOptions options;
final long id;
Expand Down
57 changes: 32 additions & 25 deletions src/main/java/io/vertx/core/net/impl/SslChannelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsyncMapping;
import io.netty.util.concurrent.ImmediateExecutor;
import io.vertx.core.VertxException;
import io.vertx.core.net.SocketAddress;

import javax.net.ssl.KeyManagerFactory;
Expand Down Expand Up @@ -64,6 +65,10 @@ public SslChannelProvider(SslContextProvider sslContextProvider,
this.sslContextProvider = sslContextProvider;
}

public int sniEntrySize() {
return sslContextMaps[0].size() + sslContextMaps[1].size();
}

public SslContextProvider sslContextProvider() {
return sslContextProvider;
}
Expand All @@ -73,20 +78,35 @@ public SslContext sslClientContext(String serverName, boolean useAlpn) {
}

public SslContext sslClientContext(String serverName, boolean useAlpn, boolean trustAll) {
try {
return sslContext(serverName, useAlpn, false, trustAll);
} catch (Exception e) {
throw new VertxException(e);
}
}

public SslContext sslContext(String serverName, boolean useAlpn, boolean server, boolean trustAll) throws Exception {
int idx = idx(useAlpn);
if (serverName != null) {
KeyManagerFactory kmf = sslContextProvider.resolveKeyManagerFactory(serverName);
TrustManager[] trustManagers = trustAll ? null : sslContextProvider.resolveTrustManagers(serverName);
if (kmf != null || trustManagers != null || !server) {
return sslContextMaps[idx].computeIfAbsent(serverName, s -> sslContextProvider.createContext(server, kmf, trustManagers, s, useAlpn, trustAll));
}
}
if (sslContexts[idx] == null) {
SslContext context = sslContextProvider.createClientContext(serverName, useAlpn, trustAll);
SslContext context = sslContextProvider.createContext(server, null, null, serverName, useAlpn, trustAll);
sslContexts[idx] = context;
}
return sslContexts[idx];
}

public SslContext sslServerContext(boolean useAlpn) {
int idx = idx(useAlpn);
if (sslContexts[idx] == null) {
sslContexts[idx] = sslContextProvider.createServerContext(useAlpn);
try {
return sslContext(null, useAlpn, true, false);
} catch (Exception e) {
throw new VertxException(e);
}
return sslContexts[idx];
}

/**
Expand All @@ -97,27 +117,14 @@ public SslContext sslServerContext(boolean useAlpn) {
public AsyncMapping<? super String, ? extends SslContext> serverNameMapping() {
return (AsyncMapping<String, SslContext>) (serverName, promise) -> {
workerPool.execute(() -> {
if (serverName == null) {
promise.setSuccess(sslServerContext(useAlpn));
} else {
KeyManagerFactory kmf;
try {
kmf = sslContextProvider.resolveKeyManagerFactory(serverName);
} catch (Exception e) {
promise.setFailure(e);
return;
}
TrustManager[] trustManagers;
try {
trustManagers = sslContextProvider.resolveTrustManagers(serverName);
} catch (Exception e) {
promise.setFailure(e);
return;
}
int idx = idx(useAlpn);
SslContext sslContext = sslContextMaps[idx].computeIfAbsent(serverName, s -> sslContextProvider.createServerContext(kmf, trustManagers, s, useAlpn));
promise.setSuccess(sslContext);
SslContext sslContext;
try {
sslContext = sslContext(serverName, useAlpn, true, false);
} catch (Exception e) {
promise.setFailure(e);
return;
}
promise.setSuccess(sslContext);
});
return promise;
};
Expand Down
111 changes: 63 additions & 48 deletions src/main/java/io/vertx/core/net/impl/SslContextProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,36 @@ public SslContextProvider(ClientAuth clientAuth,
this.crls = crls;
}

public VertxSslContext createClientContext(String serverName, boolean useAlpn, boolean trustAll) {
public VertxSslContext createContext(boolean server,
KeyManagerFactory keyManagerFactory,
TrustManager[] trustManagers,
String serverName,
boolean useAlpn,
boolean trustAll) {
if (keyManagerFactory == null) {
keyManagerFactory = defaultKeyManagerFactory();
}
if (trustAll) {
trustManagers = SslContextProvider.createTrustAllManager();
} else if (trustManagers == null) {
trustManagers = defaultTrustManagers();
}
if (server) {
return createServerContext(keyManagerFactory, trustManagers, serverName, useAlpn);
} else {
return createClientContext(keyManagerFactory, trustManagers, serverName, useAlpn);
}
}

public VertxSslContext createContext(boolean server, boolean useAlpn) {
return createContext(server, defaultKeyManagerFactory(), defaultTrustManagers(), null, useAlpn, false);
}

public VertxSslContext createClientContext(
KeyManagerFactory keyManagerFactory,
TrustManager[] trustManagers,
String serverName,
boolean useAlpn) {
try {
SslContextFactory factory = provider.get()
.useAlpn(useAlpn)
Expand All @@ -76,12 +105,6 @@ public VertxSslContext createClientContext(String serverName, boolean useAlpn, b
if (keyManagerFactory != null) {
factory.keyMananagerFactory(keyManagerFactory);
}
TrustManager[] trustManagers = null;
if (trustAll) {
trustManagers = new TrustManager[] { createTrustAllTrustManager() };
} else if (trustManagerFactory != null) {
trustManagers = trustManagerFactory.getTrustManagers();
}
if (trustManagers != null) {
TrustManagerFactory tmf = buildVertxTrustManagerFactory(trustManagers);
factory.trustManagerFactory(tmf);
Expand All @@ -98,10 +121,6 @@ protected void initEngine(SSLEngine engine) {
}
}

public VertxSslContext createServerContext(boolean useAlpn) {
return createServerContext(keyManagerFactory, trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null, null, useAlpn);
}

public VertxSslContext createServerContext(KeyManagerFactory keyManagerFactory,
TrustManager[] trustManagers,
String serverName,
Expand Down Expand Up @@ -135,16 +154,20 @@ protected void initEngine(SSLEngine engine) {
}
}

public KeyManagerFactory loadKeyManagerFactory(String serverName) throws Exception {
if (keyManagerFactoryMapper != null) {
return keyManagerFactoryMapper.apply(serverName);
}
return null;
public TrustManager[] defaultTrustManagers() {
return trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null;
}

public TrustManagerFactory defaultTrustManagerFactory() {
return trustManagerFactory;
}

public KeyManagerFactory defaultKeyManagerFactory() {
return keyManagerFactory;
}

/**
* Resolve the {@link KeyManagerFactory} for the {@code serverName}, when a factory cannot be resolved, the default
* factory is returned.
* Resolve the {@link KeyManagerFactory} for the {@code serverName}, when a factory cannot be resolved, {@code null} is returned.
* <br/>
* This can block and should be executed on the appropriate thread.
*
Expand All @@ -153,23 +176,14 @@ public KeyManagerFactory loadKeyManagerFactory(String serverName) throws Excepti
* @throws Exception anything that would prevent loading the factory
*/
public KeyManagerFactory resolveKeyManagerFactory(String serverName) throws Exception {
KeyManagerFactory kmf = loadKeyManagerFactory(serverName);
if (kmf == null) {
kmf = keyManagerFactory;
}
return kmf;
}

public TrustManager[] loadTrustManagers(String serverName) throws Exception {
if (trustManagerMapper != null) {
return trustManagerMapper.apply(serverName);
if (keyManagerFactoryMapper != null) {
return keyManagerFactoryMapper.apply(serverName);
}
return null;
}

/**
* Resolve the {@link TrustManager}[] for the {@code serverName}, when managers cannot be resolved, the default
* managers are returned.
* Resolve the {@link TrustManager}[] for the {@code serverName}, when managers cannot be resolved, {@code null} is returned.
* <br/>
* This can block and should be executed on the appropriate thread.
*
Expand All @@ -178,11 +192,10 @@ public TrustManager[] loadTrustManagers(String serverName) throws Exception {
* @throws Exception anything that would prevent loading the managers
*/
public TrustManager[] resolveTrustManagers(String serverName) throws Exception {
TrustManager[] trustManagers = loadTrustManagers(serverName);
if (trustManagers == null && trustManagerFactory != null) {
trustManagers = trustManagerFactory.getTrustManagers();
if (trustManagerMapper != null) {
return trustManagerMapper.apply(serverName);
}
return trustManagers;
return null;
}

private VertxTrustManagerFactory buildVertxTrustManagerFactory(TrustManager[] mgrs) {
Expand Down Expand Up @@ -232,22 +245,24 @@ public X509Certificate[] getAcceptedIssuers() {
return trustMgrs;
}

// Create a TrustManager which trusts everything
private static TrustManager createTrustAllTrustManager() {
return new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
}
private static final TrustManager TRUST_ALL_MANAGER = new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
}

@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException {
}

@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
};
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
};

// Create a TrustManager which trusts everything
private static TrustManager[] createTrustAllManager() {
return new TrustManager[] { TRUST_ALL_MANAGER };
}

public void configureEngine(SSLEngine engine, Set<String> enabledProtocols, String serverName, boolean client) {
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/io/vertx/core/net/impl/TCPServerBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ private GlobalTrafficShapingHandler createTrafficShapingHandler(EventLoopGroup e
return trafficShapingHandler;
}

public int sniEntrySize() {
return sslHelper.sniEntrySize();
}

public Future<Boolean> updateSSLOptions(SSLOptions options, boolean force) {
TCPServerBase server = actualServer;
if (server != null && server != this) {
Expand Down
37 changes: 27 additions & 10 deletions src/test/java/io/vertx/core/net/NetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.core.net.impl.HAProxyMessageCompletionHandler;
import io.vertx.core.net.impl.NetServerImpl;
import io.vertx.core.net.impl.NetSocketInternal;
import io.vertx.core.net.impl.VertxHandler;
import io.vertx.core.net.impl.*;
import io.vertx.core.spi.tls.SslContextFactory;
import io.vertx.core.streams.ReadStream;
import io.vertx.test.core.CheckingSender;
Expand Down Expand Up @@ -95,12 +92,7 @@
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -1532,6 +1524,31 @@ public void testSniOverrideServerName() throws Exception {
assertEquals("host2.com", cnOf(test.clientPeerCert()));
}

@Test
public void testClientSniMultipleServerName() throws Exception {
List<String> receivedServerNames = Collections.synchronizedList(new ArrayList<>());
server = vertx.createNetServer(new NetServerOptions()
.setSni(true)
.setSsl(true)
.setKeyCertOptions(Cert.SNI_JKS.get())
).connectHandler(so -> {
receivedServerNames.add(so.indicatedServerName());
});
startServer();
List<String> serverNames = Arrays.asList("host1", "host2.com", "fake");
List<String> cns = new ArrayList<>();
client = vertx.createNetClient(new NetClientOptions().setSsl(true).setTrustAll(true));
for (String serverName : serverNames) {
NetSocket so = client.connect(testAddress, serverName).toCompletionStage().toCompletableFuture().get();
String host = cnOf(so.peerCertificates().get(0));
cns.add(host);
}
assertEquals(Arrays.asList("host1", "host2.com", "localhost"), cns);
assertEquals(2, ((TCPServerBase)server).sniEntrySize());
assertWaitUntil(() -> receivedServerNames.size() == 3);
assertEquals(receivedServerNames, serverNames);
}

@Test
// SNI present an unknown server
public void testSniWithUnknownServer1() throws Exception {
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/io/vertx/core/net/SSLHelperTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void testUseJdkCiphersWhenNotSpecified() throws Exception {
helper
.buildContextProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_JKS.get()).setTrustOptions(Trust.SERVER_JKS.get()), (ContextInternal) vertx.getOrCreateContext())
.onComplete(onSuccess(provider -> {
SslContext ctx = provider.createClientContext(null, false, false);
SslContext ctx = provider.createContext(false, false);
assertEquals(new HashSet<>(Arrays.asList(expected)), new HashSet<>(ctx.cipherSuites()));
testComplete();
}));
Expand All @@ -60,7 +60,7 @@ public void testUseOpenSSLCiphersWhenNotSpecified() throws Exception {
new HttpClientOptions().setOpenSslEngineOptions(new OpenSSLEngineOptions()).setPemKeyCertOptions(Cert.CLIENT_PEM.get()).setTrustOptions(Trust.SERVER_PEM.get()),
null);
helper.buildContextProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_PEM.get()).setTrustOptions(Trust.SERVER_PEM.get()), (ContextInternal) vertx.getOrCreateContext()).onComplete(onSuccess(provider -> {
SslContext ctx = provider.createClientContext(null, false, false);
SslContext ctx = provider.createContext(false, false);
assertEquals(expected, new HashSet<>(ctx.cipherSuites()));
testComplete();
}));
Expand Down Expand Up @@ -90,7 +90,7 @@ private void testOpenSslServerSessionContext(boolean testDefault){
defaultHelper
.buildContextProvider(httpServerOptions.getSslOptions(), (ContextInternal) vertx.getOrCreateContext())
.onComplete(onSuccess(provider -> {
SslContext ctx = provider.createServerContext(false);
SslContext ctx = provider.createContext(true, false);

SSLSessionContext sslSessionContext = ctx.sessionContext();
assertTrue(sslSessionContext instanceof OpenSslServerSessionContext);
Expand Down Expand Up @@ -185,6 +185,6 @@ private void testTLSVersions(HttpServerOptions options, Consumer<SSLEngine> chec
}

public SSLEngine createEngine(SslContextProvider provider) {
return provider.createClientContext(null, false, false).newEngine(ByteBufAllocator.DEFAULT);
return provider.createContext(false, false).newEngine(ByteBufAllocator.DEFAULT);
}
}
2 changes: 1 addition & 1 deletion src/test/java/io/vertx/it/SSLEngineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private void doTest(SSLEngineOptions engine,
}
}
SslContextProvider provider = ((HttpServerImpl)server).sslContextProvider();
SslContext ctx = provider.createClientContext(null, false, false);
SslContext ctx = provider.createContext(false, false);
switch (expectedSslContext != null ? expectedSslContext : "jdk") {
case "jdk":
assertTrue(ctx.sessionContext().getClass().getName().equals("sun.security.ssl.SSLSessionContextImpl"));
Expand Down