Skip to content

Commit

Permalink
Merge pull request from GHSA-g96c-x7rh-99r3
Browse files Browse the repository at this point in the history
* Add support for randomizing DNS Lookup source port

* Clarify purpose of lease

* Skip initial refresh

Previously, the pool was being refreshed immediately upon initialization. Now, the refresh waits until the `poolRefreshSeconds` duration has elapsed.

* Ensure thread safety, skip unused poller refreshes

* Add change log

* Restore location of local flag
  • Loading branch information
Dan Torrey committed Jul 5, 2023
1 parent b3a967b commit a101f4f
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 75 deletions.
2 changes: 2 additions & 0 deletions changelog/unreleased/ghsa-g96c-x7rh-99r3.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type = "security"
message = "Fix insecure source port usage for DNS Lookup adapter queries. [GHSA-g96c-x7rh-99r3](https://github.com/Graylog2/graylog2-server/security/advisories/GHSA-g96c-x7rh-99r3)"
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.graylog2.indexer.retention.RetentionStrategyBindings;
import org.graylog2.indexer.rotation.RotationStrategyBindings;
import org.graylog2.inputs.transports.NettyTransportConfiguration;
import org.graylog2.lookup.adapters.dnslookup.DnsLookupAdapterConfiguration;
import org.graylog2.messageprocessors.MessageProcessorModule;
import org.graylog2.migrations.MigrationsModule;
import org.graylog2.notifications.Notification;
Expand Down Expand Up @@ -131,6 +132,7 @@ public class Server extends ServerBootstrap {
private final PrometheusExporterConfiguration prometheusExporterConfiguration = new PrometheusExporterConfiguration();
private final TLSProtocolsConfiguration tlsConfiguration = new TLSProtocolsConfiguration();
private final GeoIpProcessorConfig geoIpProcessorConfig = new GeoIpProcessorConfig();
private final DnsLookupAdapterConfiguration dnsLookupAdapterConfiguration = new DnsLookupAdapterConfiguration();

public Server() {
super("server", configuration);
Expand Down Expand Up @@ -211,7 +213,8 @@ protected List<Object> getCommandConfigurationBeans() {
jobSchedulerConfiguration,
prometheusExporterConfiguration,
tlsConfiguration,
geoIpProcessorConfig);
geoIpProcessorConfig,
dnsLookupAdapterConfiguration);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.graylog2.lookup.adapters.dnslookup.ADnsAnswer;
import org.graylog2.lookup.adapters.dnslookup.DnsAnswer;
import org.graylog2.lookup.adapters.dnslookup.DnsClient;
import org.graylog2.lookup.adapters.dnslookup.DnsLookupAdapterConfiguration;
import org.graylog2.lookup.adapters.dnslookup.DnsLookupType;
import org.graylog2.lookup.adapters.dnslookup.PtrDnsAnswer;
import org.graylog2.lookup.adapters.dnslookup.TxtDnsAnswer;
Expand Down Expand Up @@ -80,6 +81,7 @@ public class DnsLookupDataAdapter extends LookupDataAdapter {
private static final String TIMER_TEXT_LOOKUP = "textLookupTime";
private DnsClient dnsClient;
private final Config config;
private final DnsLookupAdapterConfiguration adapterConfiguration;

private final Counter errorCounter;

Expand All @@ -90,9 +92,11 @@ public class DnsLookupDataAdapter extends LookupDataAdapter {

@Inject
public DnsLookupDataAdapter(@Assisted("dto") DataAdapterDto dto,
MetricRegistry metricRegistry) {
MetricRegistry metricRegistry,
DnsLookupAdapterConfiguration adapterConfiguration) {
super(dto, metricRegistry);
this.config = (Config) dto.config();
this.adapterConfiguration = adapterConfiguration;
this.errorCounter = metricRegistry.counter(MetricRegistry.name(getClass(), dto.id(), ERROR_COUNTER));
this.resolveDomainNameTimer = metricRegistry.timer(MetricRegistry.name(getClass(), dto.id(), TIMER_RESOLVE_DOMAIN_NAME));
this.reverseLookupTimer = metricRegistry.timer(MetricRegistry.name(getClass(), dto.id(), TIMER_REVERSE_LOOKUP));
Expand All @@ -101,8 +105,8 @@ public DnsLookupDataAdapter(@Assisted("dto") DataAdapterDto dto,

@Override
protected void doStart() {

dnsClient = new DnsClient(config.requestTimeout());
dnsClient = new DnsClient(config.requestTimeout(), adapterConfiguration.getPoolSize(),
adapterConfiguration.getPoolRefreshInterval().toSeconds());
dnsClient.start(config.serverIps());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import com.google.common.net.InetAddresses;
import com.google.common.net.InternetDomainName;
import io.netty.buffer.ByteBuf;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.handler.codec.dns.DefaultDnsPtrRecord;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
Expand All @@ -35,20 +33,15 @@
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.DnsServerAddressStreamProvider;
import io.netty.resolver.dns.SequentialDnsServerAddressStreamProvider;
import io.netty.util.concurrent.Future;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.graylog2.lookup.adapters.dnslookup.DnsResolverPool.ResolverLease;
import org.graylog2.shared.utilities.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -59,10 +52,11 @@
import java.util.concurrent.TimeoutException;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

public class DnsClient {
import static org.graylog2.lookup.adapters.dnslookup.DnsLookupAdapterConfiguration.DEFAULT_POOL_SIZE;
import static org.graylog2.lookup.adapters.dnslookup.DnsLookupAdapterConfiguration.DEFAULT_REFRESH_INTERVAL_SECONDS;

public class DnsClient {
private static final Logger LOG = LoggerFactory.getLogger(DnsClient.class);

private static final int DEFAULT_DNS_PORT = 53;
Expand All @@ -80,9 +74,9 @@ public class DnsClient {
private static final char[] HEX_CHARS_ARRAY = "0123456789ABCDEF".toCharArray();
private final long queryTimeout;
private final long requestTimeout;

private NioEventLoopGroup nettyEventLoop;
private DnsNameResolver resolver;
private final long resolverPoolSize;
private final long resolverPoolRefreshSeconds;
private DnsResolverPool resolverPool;

/**
* Creates a new DNS client with the given query timeout. The request timeout will be the query timeout plus
Expand All @@ -108,66 +102,33 @@ public DnsClient(long queryTimeout) {
* @param requestTimeout the request timeout
*/
public DnsClient(long queryTimeout, long requestTimeout) {
this(queryTimeout, requestTimeout, DEFAULT_POOL_SIZE, DEFAULT_REFRESH_INTERVAL_SECONDS);
}

public DnsClient(long queryTimeout, int resolverPoolSize, long resolverPoolRefreshSeconds) {
this(queryTimeout, queryTimeout + DEFAULT_REQUEST_TIMEOUT_INCREMENT, resolverPoolSize, resolverPoolRefreshSeconds);
}

private DnsClient(long queryTimeout, long requestTimeout, int resolverPoolSize, long resolverPoolRefreshSeconds) {
this.queryTimeout = queryTimeout;
this.requestTimeout = requestTimeout;
this.resolverPoolSize = resolverPoolSize;
this.resolverPoolRefreshSeconds = resolverPoolRefreshSeconds;
}

public void start(String dnsServerIps) {

LOG.debug("Attempting to start DNS client");
final List<InetSocketAddress> iNetDnsServerIps = parseServerIpAddresses(dnsServerIps);

nettyEventLoop = new NioEventLoopGroup();

final DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(nettyEventLoop.next());
dnsNameResolverBuilder.channelType(NioDatagramChannel.class).queryTimeoutMillis(queryTimeout);

// Specify custom DNS servers if provided. If not, use those specified in local network adapter settings.
if (CollectionUtils.isNotEmpty(iNetDnsServerIps)) {

LOG.debug("Attempting to start DNS client with server IPs [{}] on port [{}] with timeout [{}]",
dnsServerIps, DEFAULT_DNS_PORT, requestTimeout);

final DnsServerAddressStreamProvider dnsServer = new SequentialDnsServerAddressStreamProvider(iNetDnsServerIps);
dnsNameResolverBuilder.nameServerProvider(dnsServer);
} else {
LOG.debug("Attempting to start DNS client with the local network adapter DNS server address on port [{}] with timeout [{}]",
DEFAULT_DNS_PORT, requestTimeout);
}

resolver = dnsNameResolverBuilder.build();

LOG.debug("DNS client startup successful");
}

private List<InetSocketAddress> parseServerIpAddresses(String dnsServerIps) {

// Parse and prepare DNS server IP addresses for Netty.
return StreamSupport
// Split comma-separated sever IP:port combos.
.stream(Splitter.on(",").trimResults().omitEmptyStrings().split(dnsServerIps).spliterator(), false)
// Parse as HostAndPort objects (allows convenient handling of port provided after colon).
.map(hostAndPort -> HostAndPort.fromString(hostAndPort).withDefaultPort(DnsClient.DEFAULT_DNS_PORT))
// Convert HostAndPort > InetSocketAddress as required by Netty.
.map(hostAndPort -> new InetSocketAddress(hostAndPort.getHost(), hostAndPort.getPort()))
.collect(Collectors.toList());
this.resolverPool = new DnsResolverPool(dnsServerIps, queryTimeout, resolverPoolSize, resolverPoolRefreshSeconds);
this.resolverPool.initialize();
}

public void stop() {

LOG.debug("Attempting to stop DNS client");

if (nettyEventLoop == null) {
LOG.error("DNS resolution event loop not initialized");
if (resolverPool == null) {
LOG.error("DNS resolution pool is not initialized.");
return;
}

// Make sure to close the resolver before shutting down the event loop
resolver.close();

// Shutdown event loop (required by Netty).
final Future<?> shutdownFuture = nettyEventLoop.shutdownGracefully();
shutdownFuture.addListener(future -> LOG.debug("DNS client shutdown successful"));
resolverPool.stop();
}

public List<ADnsAnswer> resolveIPv4AddressForHostname(String hostName, boolean includeIpVersion)
Expand All @@ -187,24 +148,28 @@ private List<ADnsAnswer> resolveIpAddresses(String hostName, DnsRecordType dnsRe

LOG.debug("Attempting to resolve [{}] records for [{}]", dnsRecordType, hostName);

if (isShutdown()) {
if (resolverPool.isStopped()) {
throw new DnsClientNotRunningException();
}

validateHostName(hostName);

final DefaultDnsQuestion aRecordDnsQuestion = new DefaultDnsQuestion(hostName, dnsRecordType);

final ResolverLease resolverLease = resolverPool.takeLease();
/* The DnsNameResolver.resolveAll(DnsQuestion) method handles all redirects through CNAME records to
* ultimately resolve a list of IP addresses with TTL values. */
try {
return resolver.resolveAll(aRecordDnsQuestion).get(requestTimeout, TimeUnit.MILLISECONDS).stream()
return resolverLease.getResolver().resolveAll(aRecordDnsQuestion).get(requestTimeout, TimeUnit.MILLISECONDS).stream()
.map(dnsRecord -> decodeDnsRecord(dnsRecord, includeIpVersion))
.filter(Objects::nonNull) // Removes any entries which the IP address could not be extracted for.
.collect(Collectors.toList());
} catch (TimeoutException e) {
throw new ExecutionException("Resolver future didn't return a result in " + requestTimeout + " ms", e);
}
finally {
resolverPool.returnLease(resolverLease);
}
}

/**
Expand Down Expand Up @@ -262,7 +227,7 @@ public PtrDnsAnswer reverseLookup(String ipAddress) throws InterruptedException,

LOG.debug("Attempting to perform reverse lookup for IP address [{}]", ipAddress);

if (isShutdown()) {
if (resolverPool.isStopped()) {
throw new DnsClientNotRunningException();
}

Expand All @@ -271,8 +236,9 @@ public PtrDnsAnswer reverseLookup(String ipAddress) throws InterruptedException,
final String inverseAddressFormat = getInverseAddressFormat(ipAddress);

DnsResponse content = null;
final ResolverLease resolverLease = resolverPool.takeLease();
try {
content = resolver.query(new DefaultDnsQuestion(inverseAddressFormat, DnsRecordType.PTR)).get(requestTimeout, TimeUnit.MILLISECONDS).content();
content = resolverLease.getResolver().query(new DefaultDnsQuestion(inverseAddressFormat, DnsRecordType.PTR)).get(requestTimeout, TimeUnit.MILLISECONDS).content();
for (int i = 0; i < content.count(DnsSection.ANSWER); i++) {

// Return the first PTR record, because there should be only one as per
Expand Down Expand Up @@ -306,6 +272,7 @@ public PtrDnsAnswer reverseLookup(String ipAddress) throws InterruptedException,
// Must manually release references on content object since the DnsResponse class extends ReferenceCounted
content.release();
}
resolverPool.returnLease(resolverLease);
}

return null;
Expand Down Expand Up @@ -348,7 +315,7 @@ public static void parseReverseLookupDomain(PtrDnsAnswer.Builder dnsAnswerBuilde

public List<TxtDnsAnswer> txtLookup(String hostName) throws InterruptedException, ExecutionException {

if (isShutdown()) {
if (resolverPool.isStopped()) {
throw new DnsClientNotRunningException();
}

Expand All @@ -357,8 +324,9 @@ public List<TxtDnsAnswer> txtLookup(String hostName) throws InterruptedException
validateHostName(hostName);

DnsResponse content = null;
final ResolverLease resolverLease = resolverPool.takeLease();
try {
content = resolver.query(new DefaultDnsQuestion(hostName, DnsRecordType.TXT)).get(requestTimeout, TimeUnit.MILLISECONDS).content();
content = resolverLease.getResolver().query(new DefaultDnsQuestion(hostName, DnsRecordType.TXT)).get(requestTimeout, TimeUnit.MILLISECONDS).content();
int count = content.count(DnsSection.ANSWER);
final ArrayList<TxtDnsAnswer> txtRecords = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Expand Down Expand Up @@ -389,13 +357,10 @@ public List<TxtDnsAnswer> txtLookup(String hostName) throws InterruptedException
// Must manually release references on content object since the DnsResponse class extends ReferenceCounted
content.release();
}
resolverPool.returnLease(resolverLease);
}
}

private boolean isShutdown() {
return nettyEventLoop == null || nettyEventLoop.isShutdown();
}

private static String decodeTxtRecord(DefaultDnsRawRecord record) {

LOG.debug("Attempting to read TXT value from DNS record [{}]", record);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (C) 2020 Graylog, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*/
package org.graylog2.lookup.adapters.dnslookup;

import com.github.joschi.jadconfig.Parameter;
import com.github.joschi.jadconfig.util.Duration;
import com.github.joschi.jadconfig.validators.PositiveDurationValidator;
import com.github.joschi.jadconfig.validators.PositiveIntegerValidator;
import org.graylog2.plugin.PluginConfigBean;

public class DnsLookupAdapterConfiguration implements PluginConfigBean {
private static final String PREFIX = "dns_lookup_adapter_";
protected static final String RESOLVER_POOL_SIZE = PREFIX + "resolver_pool_size";
protected static final String RESOLVER_POOL_REFRESH_INTERVAL = PREFIX + "resolver_pool_refresh_interval";

protected static final int DEFAULT_POOL_SIZE = 10;
protected static final int DEFAULT_REFRESH_INTERVAL_SECONDS = 300;

@Parameter(value = RESOLVER_POOL_SIZE, validators = PositiveIntegerValidator.class)
private int poolSize = DEFAULT_POOL_SIZE;

@Parameter(value = RESOLVER_POOL_REFRESH_INTERVAL, validators = PositiveDurationValidator.class)
private Duration poolRefreshInterval = Duration.seconds(DEFAULT_REFRESH_INTERVAL_SECONDS);

public int getPoolSize() {
return poolSize;
}

public Duration getPoolRefreshInterval() {
return poolRefreshInterval;
}
}
Loading

0 comments on commit a101f4f

Please sign in to comment.