Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proxy/Client] Fix DNS server denial-of-service issue when DNS entry expires #15403

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import com.google.common.collect.Lists;
import io.netty.channel.EventLoopGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import org.apache.pulsar.broker.auth.MockedPulsarServiceBaseTest;
import org.apache.pulsar.client.impl.conf.ClientConfigurationData;
import org.apache.pulsar.common.util.netty.EventLoopUtil;
Expand All @@ -31,22 +35,18 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;

@Test(groups = "broker-impl")
public class ConnectionPoolTest extends MockedPulsarServiceBaseTest {

String serviceUrl;
int brokerPort;

@BeforeClass
@Override
protected void setup() throws Exception {
super.internalSetup();
serviceUrl = "pulsar://non-existing-dns-name:" + pulsar.getBrokerListenPort().get();
brokerPort = pulsar.getBrokerListenPort().get();
serviceUrl = "pulsar://non-existing-dns-name:" + brokerPort;
}

@AfterClass(alwaysRun = true)
Expand All @@ -63,9 +63,11 @@ public void testSingleIpAddress() throws Exception {
conf.setServiceUrl(serviceUrl);
PulsarClientImpl client = new PulsarClientImpl(conf, eventLoop, pool);

List<InetAddress> result = Lists.newArrayList();
result.add(InetAddress.getByName("127.0.0.1"));
Mockito.when(pool.resolveName("non-existing-dns-name")).thenReturn(CompletableFuture.completedFuture(result));
List<InetSocketAddress> result = Lists.newArrayList();
result.add(new InetSocketAddress("127.0.0.1", brokerPort));
Mockito.when(pool.resolveName(InetSocketAddress.createUnresolved("non-existing-dns-name",
brokerPort)))
.thenReturn(CompletableFuture.completedFuture(result));

client.newProducer().topic("persistent://sample/standalone/ns/my-topic").create();

Expand All @@ -75,20 +77,20 @@ public void testSingleIpAddress() throws Exception {

@Test
public void testDoubleIpAddress() throws Exception {
String serviceUrl = "pulsar://non-existing-dns-name:" + pulsar.getBrokerListenPort().get();

ClientConfigurationData conf = new ClientConfigurationData();
EventLoopGroup eventLoop = EventLoopUtil.newEventLoopGroup(1, false, new DefaultThreadFactory("test"));
ConnectionPool pool = spyWithClassAndConstructorArgs(ConnectionPool.class, conf, eventLoop);
conf.setServiceUrl(serviceUrl);
PulsarClientImpl client = new PulsarClientImpl(conf, eventLoop, pool);

List<InetAddress> result = Lists.newArrayList();
List<InetSocketAddress> result = Lists.newArrayList();

// Add a non existent IP to the response to check that we're trying the 2nd address as well
result.add(InetAddress.getByName("127.0.0.99"));
result.add(InetAddress.getByName("127.0.0.1"));
Mockito.when(pool.resolveName("non-existing-dns-name")).thenReturn(CompletableFuture.completedFuture(result));
result.add(new InetSocketAddress("127.0.0.99", brokerPort));
result.add(new InetSocketAddress("127.0.0.1", brokerPort));
Mockito.when(pool.resolveName(InetSocketAddress.createUnresolved("non-existing-dns-name",
brokerPort)))
.thenReturn(CompletableFuture.completedFuture(result));

// Create producer should succeed by trying the 2nd IP
client.newProducer().topic("persistent://sample/standalone/ns/my-topic").create();
Expand All @@ -105,7 +107,7 @@ public void testNoConnectionPool() throws Exception {
ConnectionPool pool = spyWithClassAndConstructorArgs(ConnectionPool.class, conf, eventLoop);

InetSocketAddress brokerAddress =
InetSocketAddress.createUnresolved("127.0.0.1", pulsar.getBrokerListenPort().get());
InetSocketAddress.createUnresolved("127.0.0.1", brokerPort);
IntStream.range(1, 5).forEach(i -> {
pool.getConnection(brokerAddress).thenAccept(cnx -> {
Assert.assertTrue(cnx.channel().isActive());
Expand All @@ -127,7 +129,7 @@ public void testEnableConnectionPool() throws Exception {
ConnectionPool pool = spyWithClassAndConstructorArgs(ConnectionPool.class, conf, eventLoop);

InetSocketAddress brokerAddress =
InetSocketAddress.createUnresolved("127.0.0.1", pulsar.getBrokerListenPort().get());
InetSocketAddress.createUnresolved("127.0.0.1", brokerPort);
IntStream.range(1, 10).forEach(i -> {
pool.getConnection(brokerAddress).thenAccept(cnx -> {
Assert.assertTrue(cnx.channel().isActive());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.AddressResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.Future;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -63,7 +63,7 @@ public class ConnectionPool implements AutoCloseable {
private final int maxConnectionsPerHosts;
private final boolean isSniProxy;

protected final DnsNameResolver dnsResolver;
protected final AddressResolver<InetSocketAddress> addressResolver;
private final boolean shouldCloseDnsResolver;

public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGroup) throws PulsarClientException {
Expand All @@ -76,7 +76,8 @@ public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGrou
}

public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGroup,
Supplier<ClientCnx> clientCnxSupplier, Optional<DnsNameResolver> dnsNameResolver)
Supplier<ClientCnx> clientCnxSupplier,
Optional<AddressResolver<InetSocketAddress>> addressResolver)
throws PulsarClientException {
this.eventLoopGroup = eventLoopGroup;
this.clientConfig = conf;
Expand All @@ -101,20 +102,24 @@ public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGrou
throw new PulsarClientException(e);
}

this.shouldCloseDnsResolver = !dnsNameResolver.isPresent();
this.dnsResolver = dnsNameResolver.orElseGet(() -> createDnsNameResolver(conf, eventLoopGroup));
this.shouldCloseDnsResolver = !addressResolver.isPresent();
this.addressResolver = addressResolver.orElseGet(() -> createAddressResolver(conf, eventLoopGroup));
}

private static DnsNameResolver createDnsNameResolver(ClientConfigurationData conf, EventLoopGroup eventLoopGroup) {
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(eventLoopGroup.next())
private static AddressResolver<InetSocketAddress> createAddressResolver(ClientConfigurationData conf,
EventLoopGroup eventLoopGroup) {
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder()
.traceEnabled(true).channelType(EventLoopUtil.getDatagramChannelClass(eventLoopGroup));
if (conf.getDnsLookupBindAddress() != null) {
InetSocketAddress addr = new InetSocketAddress(conf.getDnsLookupBindAddress(),
conf.getDnsLookupBindPort());
dnsNameResolverBuilder.localAddress(addr);
}
DnsResolverUtil.applyJdkDnsCacheSettings(dnsNameResolverBuilder);
return dnsNameResolverBuilder.build();
// use DnsAddressResolverGroup to create the AddressResolver since it contains a solution
// to prevent cache stampede / thundering herds problem when a DNS entry expires while the system
// is under high load
return new DnsAddressResolverGroup(dnsNameResolverBuilder).getResolver(eventLoopGroup.next());
}

private static final Random random = new Random();
Expand Down Expand Up @@ -239,19 +244,17 @@ private CompletableFuture<ClientCnx> createConnection(InetSocketAddress logicalA
* Resolve DNS asynchronously and attempt to connect to any IP address returned by DNS server.
*/
private CompletableFuture<Channel> createConnection(InetSocketAddress unresolvedAddress) {
int port;
CompletableFuture<List<InetAddress>> resolvedAddress;
CompletableFuture<List<InetSocketAddress>> resolvedAddress;
try {
if (isSniProxy) {
URI proxyURI = new URI(clientConfig.getProxyServiceUrl());
port = proxyURI.getPort();
resolvedAddress = resolveName(proxyURI.getHost());
resolvedAddress =
resolveName(InetSocketAddress.createUnresolved(proxyURI.getHost(), proxyURI.getPort()));
} else {
port = unresolvedAddress.getPort();
resolvedAddress = resolveName(unresolvedAddress.getHostString());
resolvedAddress = resolveName(unresolvedAddress);
}
return resolvedAddress.thenCompose(
inetAddresses -> connectToResolvedAddresses(inetAddresses.iterator(), port,
inetAddresses -> connectToResolvedAddresses(inetAddresses.iterator(),
isSniProxy ? unresolvedAddress : null));
} catch (URISyntaxException e) {
log.error("Invalid Proxy url {}", clientConfig.getProxyServiceUrl(), e);
Expand All @@ -264,18 +267,17 @@ private CompletableFuture<Channel> createConnection(InetSocketAddress unresolved
* Try to connect to a sequence of IP addresses until a successful connection can be made, or fail if no
* address is working.
*/
private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetAddress> unresolvedAddresses,
int port,
private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetSocketAddress> unresolvedAddresses,
InetSocketAddress sniHost) {
CompletableFuture<Channel> future = new CompletableFuture<>();

// Successfully connected to server
connectToAddress(unresolvedAddresses.next(), port, sniHost)
connectToAddress(unresolvedAddresses.next(), sniHost)
.thenAccept(future::complete)
.exceptionally(exception -> {
if (unresolvedAddresses.hasNext()) {
// Try next IP address
connectToResolvedAddresses(unresolvedAddresses, port, sniHost).thenAccept(future::complete)
connectToResolvedAddresses(unresolvedAddresses, sniHost).thenAccept(future::complete)
.exceptionally(ex -> {
// This is already unwinding the recursive call
future.completeExceptionally(ex);
Expand All @@ -291,9 +293,9 @@ private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetAddre
return future;
}

CompletableFuture<List<InetAddress>> resolveName(String hostname) {
CompletableFuture<List<InetAddress>> future = new CompletableFuture<>();
dnsResolver.resolveAll(hostname).addListener((Future<List<InetAddress>> resolveFuture) -> {
CompletableFuture<List<InetSocketAddress>> resolveName(InetSocketAddress unresolvedAddress) {
CompletableFuture<List<InetSocketAddress>> future = new CompletableFuture<>();
addressResolver.resolveAll(unresolvedAddress).addListener((Future<List<InetSocketAddress>> resolveFuture) -> {
if (resolveFuture.isSuccess()) {
future.complete(resolveFuture.get());
} else {
Expand All @@ -306,8 +308,7 @@ CompletableFuture<List<InetAddress>> resolveName(String hostname) {
/**
* Attempt to establish a TCP connection to an already resolved single IP address.
*/
private CompletableFuture<Channel> connectToAddress(InetAddress ipAddress, int port, InetSocketAddress sniHost) {
InetSocketAddress remoteAddress = new InetSocketAddress(ipAddress, port);
private CompletableFuture<Channel> connectToAddress(InetSocketAddress remoteAddress, InetSocketAddress sniHost) {
if (clientConfig.isUseTls()) {
return toCompletableFuture(bootstrap.register())
.thenCompose(channel -> channelInitializerHandler
Expand Down Expand Up @@ -337,7 +338,7 @@ public void releaseConnection(ClientCnx cnx) {
public void close() throws Exception {
closeAllConnections();
if (shouldCloseDnsResolver) {
dnsResolver.close();
addressResolver.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.ssl.SslHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collections;
Expand Down Expand Up @@ -77,7 +77,7 @@ public class ProxyConnection extends PulsarHandler {
private final AtomicLong requestIdGenerator =
new AtomicLong(ThreadLocalRandom.current().nextLong(0, Long.MAX_VALUE / 2));
private final ProxyService service;
private final DnsNameResolver dnsNameResolver;
private final DnsAddressResolverGroup dnsAddressResolverGroup;
AuthenticationDataSource authenticationData;
private State state;
private final Supplier<SslHandler> sslHandlerSupplier;
Expand Down Expand Up @@ -130,10 +130,10 @@ ConnectionPool getConnectionPool() {
}

public ProxyConnection(ProxyService proxyService, Supplier<SslHandler> sslHandlerSupplier,
DnsNameResolver dnsNameResolver) {
DnsAddressResolverGroup dnsAddressResolverGroup) {
super(30, TimeUnit.SECONDS);
this.service = proxyService;
this.dnsNameResolver = dnsNameResolver;
this.dnsAddressResolverGroup = dnsAddressResolverGroup;
this.state = State.Init;
this.sslHandlerSupplier = sslHandlerSupplier;
this.brokerProxyValidator = service.getBrokerProxyValidator();
Expand Down Expand Up @@ -276,7 +276,8 @@ private synchronized void completeConnect(AuthData clientData) throws PulsarClie

if (this.connectionPool == null) {
this.connectionPool = new ConnectionPool(clientConf, service.getWorkerGroup(),
clientCnxSupplier, Optional.of(dnsNameResolver));
clientCnxSupplier,
Optional.of(dnsAddressResolverGroup.getResolver(service.getWorkerGroup().next())));
} else {
LOG.error("BUG! Connection Pool has already been created for proxy connection to {} state {} role {}",
remoteAddress, state, clientAuthRole);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.prometheus.client.Counter;
Expand Down Expand Up @@ -80,7 +80,7 @@ public class ProxyService implements Closeable {
private final ProxyConfiguration proxyConfig;
private final Authentication proxyClientAuthentication;
@Getter
private final DnsNameResolver dnsNameResolver;
private final DnsAddressResolverGroup dnsAddressResolverGroup;
@Getter
private final BrokerProxyValidator brokerProxyValidator;
private String serviceUrl;
Expand Down Expand Up @@ -162,13 +162,13 @@ public ProxyService(ProxyConfiguration proxyConfig,
false, workersThreadFactory);
this.authenticationService = authenticationService;

DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(workerGroup.next())
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder()
.channelType(EventLoopUtil.getDatagramChannelClass(workerGroup));
DnsResolverUtil.applyJdkDnsCacheSettings(dnsNameResolverBuilder);

dnsNameResolver = dnsNameResolverBuilder.build();
dnsAddressResolverGroup = new DnsAddressResolverGroup(dnsNameResolverBuilder);

brokerProxyValidator = new BrokerProxyValidator(dnsNameResolver.asAddressResolver(),
brokerProxyValidator = new BrokerProxyValidator(dnsAddressResolverGroup.getResolver(workerGroup.next()),
proxyConfig.getBrokerProxyAllowedHostNames(),
proxyConfig.getBrokerProxyAllowedIPAddresses(),
proxyConfig.getBrokerProxyAllowedTargetPorts());
Expand Down Expand Up @@ -331,7 +331,7 @@ public BrokerDiscoveryProvider getDiscoveryProvider() {
}

public void close() throws IOException {
dnsNameResolver.close();
dnsAddressResolverGroup.close();

if (discoveryProvider != null) {
discoveryProvider.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public SslHandler get() {
}

ch.pipeline().addLast("handler",
new ProxyConnection(proxyService, sslHandlerSupplier, proxyService.getDnsNameResolver()));
new ProxyConnection(proxyService, sslHandlerSupplier, proxyService.getDnsAddressResolverGroup()));

}
}