diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index b9521c2984ee..99e23fcb9022 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -1488,6 +1488,9 @@ instantiate a server with additional parameters * member `add(DNSName)`: add this DNSName to the node * Tuning related: * `setMaxTCPClientThreads(n)`: set the maximum of TCP client threads, handling TCP connections + * `setMaxTCPConnectionDuration(n)`: set the maximum duration of an incoming TCP connection, in seconds. 0 (the default) means unlimited + * `setMaxTCPConnectionsPerClient(n)`: set the maximum number of TCP connections per client. 0 (the default) means unlimited + * `setMaxTCPQueriesPerConnection(n)`: set the maximum number of queries in an incoming TCP connection. 0 (the default) means unlimited * `setMaxTCPQueuedConnections(n)`: set the maximum number of TCP connections queued (waiting to be picked up by a client thread), defaults to 1000. 0 means unlimited * `setMaxUDPOutstanding(n)`: set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240 * `setCacheCleaningDelay(n)`: set the interval in seconds between two runs of the cache cleaning algorithm, removing expired entries diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index 45052f08db49..6a1c7fc39c1a 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -332,6 +332,9 @@ const std::vector g_consoleKeywords{ { "setKey", true, "key", "set access key to that key" }, { "setLocal", true, "netmask, [true], [false], [TCP Fast Open queue size]", "reset list of addresses we listen on to this address. Second optional parameter sets TCP or not. Third optional parameter sets SO_REUSEPORT when available. Last parameter sets the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0." }, { "setMaxTCPClientThreads", true, "n", "set the maximum of TCP client threads, handling TCP connections" }, + { "setMaxTCPConnectionDuration", true, "n", "set the maximum duration of an incoming TCP connection, in seconds. 0 means unlimited" }, + { "setMaxTCPConnectionsPerClient", true, "n", "set the maximum number of TCP connections per client. 0 means unlimited" }, + { "setMaxTCPQueriesPerConnection", true, "n", "set the maximum number of queries in an incoming TCP connection. 0 means unlimited" }, { "setMaxTCPQueuedConnections", true, "n", "set the maximum number of TCP connections queued (waiting to be picked up by a client thread)" }, { "setMaxUDPOutstanding", true, "n", "set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240" }, { "setQueryCount", true, "bool", "set whether queries should be counted" }, diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 3fc64c9b5018..38a075796c37 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1520,6 +1520,30 @@ vector> setupLua(bool client, const std::string& confi } }); + g_lua.writeFunction("setMaxTCPQueriesPerConnection", [](size_t max) { + if (!g_configurationDone) { + g_maxTCPQueriesPerConn = max; + } else { + g_outputBuffer="The maximum number of queries per TCP connection cannot be altered at runtime!\n"; + } + }); + + g_lua.writeFunction("setMaxTCPConnectionsPerClient", [](size_t max) { + if (!g_configurationDone) { + g_maxTCPConnectionsPerClient = max; + } else { + g_outputBuffer="The maximum number of TCP connection per client cannot be altered at runtime!\n"; + } + }); + + g_lua.writeFunction("setMaxTCPConnectionDuration", [](size_t max) { + if (!g_configurationDone) { + g_maxTCPConnectionDuration = max; + } else { + g_outputBuffer="The maximum duration of a TCP connection cannot be altered at runtime!\n"; + } + }); + g_lua.writeFunction("showTCPStats", [] { setLuaNoSideEffect(); boost::format fmt("%-10d %-10d %-10d %-10d\n"); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 4b2223680119..15c400d4f4df 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -68,8 +68,25 @@ struct ConnectionInfo }; uint64_t g_maxTCPQueuedConnections{1000}; +size_t g_maxTCPQueriesPerConn{0}; +size_t g_maxTCPConnectionDuration{0}; +size_t g_maxTCPConnectionsPerClient{0}; +static std::mutex tcpClientsCountMutex; +static std::map tcpClientsCount; + void* tcpClientThread(int pipefd); +static void decrementTCPClientCount(const ComboAddress& client) +{ + if (g_maxTCPConnectionsPerClient) { + std::lock_guard lock(tcpClientsCountMutex); + tcpClientsCount[client]--; + if (tcpClientsCount[client] == 0) { + tcpClientsCount.erase(client); + } + } +} + // Should not be called simultaneously! void TCPClientCollection::addTCPClientThread() { @@ -144,6 +161,18 @@ static bool sendResponseToClient(int fd, const char* response, uint16_t response return true; } +static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime) +{ + if (maxConnectionDuration) { + time_t elapsed = time(NULL) - start; + if (elapsed >= maxConnectionDuration) { + return true; + } + remainingTime = maxConnectionDuration - elapsed; + } + return false; +} + std::shared_ptr g_tcpclientthreads; void* tcpClientThread(int pipefd) @@ -194,6 +223,9 @@ void* tcpClientThread(int pipefd) memset(&dest, 0, sizeof(dest)); dest.sin4.sin_family = ci.remote.sin4.sin_family; socklen_t len = dest.getSocklen(); + size_t queriesCount = 0; + time_t connectionStartTime = time(NULL); + if (!setNonBlocking(ci.fd)) goto drop; @@ -203,6 +235,7 @@ void* tcpClientThread(int pipefd) try { for(;;) { + unsigned int remainingTime = 0; ds = nullptr; outstanding = false; @@ -212,6 +245,18 @@ void* tcpClientThread(int pipefd) ci.cs->queries++; g_stats.queries++; + queriesCount++; + + if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) { + vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn); + break; + } + + if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort()); + break; + } + if (qlen < sizeof(dnsheader)) { g_stats.nonCompliantQueries++; break; @@ -225,7 +270,7 @@ void* tcpClientThread(int pipefd) size_t querySize = qlen <= 4096 ? qlen + 512 : qlen; char queryBuffer[querySize]; const char* query = queryBuffer; - readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout); + readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime); #ifdef HAVE_DNSCRYPT std::shared_ptr dnsCryptQuery = 0; @@ -519,18 +564,18 @@ void* tcpClientThread(int pipefd) outstanding = false; --ds->outstanding; } + decrementTCPClientCount(ci.remote); } return 0; } - /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and they will hand off to worker threads & spawn more of them if required */ void* tcpAcceptorThread(void* p) { ClientState* cs = (ClientState*) p; - + bool tcpClientCountIncremented = false; ComboAddress remote; remote.sin4.sin_family = cs->local.sin4.sin_family; @@ -540,6 +585,7 @@ void* tcpAcceptorThread(void* p) for(;;) { bool queuedCounterIncremented = false; ConnectionInfo* ci = nullptr; + tcpClientCountIncremented = false; try { ci = new ConnectionInfo; ci->cs = cs; @@ -563,8 +609,22 @@ void* tcpAcceptorThread(void* p) continue; } + if (g_maxTCPConnectionsPerClient) { + std::lock_guard lock(tcpClientsCountMutex); + + if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) { + close(ci->fd); + delete ci; + ci=nullptr; + vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort()); + continue; + } + tcpClientsCount[remote]++; + tcpClientCountIncremented = true; + } + vinfolog("Got TCP connection from %s", remote.toStringWithPort()); - + ci->remote = remote; int pipe = g_tcpclientthreads->getThread(); if (pipe >= 0) { @@ -577,12 +637,18 @@ void* tcpAcceptorThread(void* p) close(ci->fd); delete ci; ci=nullptr; + if(tcpClientCountIncremented) { + decrementTCPClientCount(remote); + } } } catch(std::exception& e) { errlog("While reading a TCP question: %s", e.what()); if(ci && ci->fd >= 0) close(ci->fd); + if(tcpClientCountIncremented) { + decrementTCPClientCount(remote); + } delete ci; ci = nullptr; if (queuedCounterIncremented) { diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index a8d877fbab9b..f4a4fcad2163 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -688,6 +688,9 @@ extern uint16_t g_maxOutstanding; extern std::atomic g_configurationDone; extern uint64_t g_maxTCPClientThreads; extern uint64_t g_maxTCPQueuedConnections; +extern size_t g_maxTCPQueriesPerConn; +extern size_t g_maxTCPConnectionDuration; +extern size_t g_maxTCPConnectionsPerClient; extern std::atomic g_cacheCleaningDelay; extern bool g_verboseHealthChecks; extern uint32_t g_staleCacheEntriesTTL; diff --git a/pdns/misc.cc b/pdns/misc.cc index 038500276cc0..2eca97d7d5e1 100644 --- a/pdns/misc.cc +++ b/pdns/misc.cc @@ -106,9 +106,15 @@ size_t readn2(int fd, void* buffer, size_t len) return len; } -size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout) +size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout) { size_t pos = 0; + time_t start = 0; + int remainingTime = totalTimeout; + if (totalTimeout) { + start = time(NULL); + } + do { ssize_t got = read(fd, (char *)buffer + pos, len - pos); if (got > 0) { @@ -119,7 +125,7 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout) } else { if (errno == EAGAIN) { - int res = waitForData(fd, timeout); + int res = waitForData(fd, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime); if (res > 0) { /* there is data available */ } @@ -133,6 +139,16 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout) unixDie("failed in readn2WithTimeout"); } } + + if (totalTimeout) { + time_t now = time(NULL); + int elapsed = now - start; + if (elapsed >= remainingTime) { + throw runtime_error("Timeout while reading data"); + } + start = now; + remainingTime -= elapsed; + } } while (pos < len); diff --git a/pdns/misc.hh b/pdns/misc.hh index fa3104d55dd4..50caa8bf80c9 100644 --- a/pdns/misc.hh +++ b/pdns/misc.hh @@ -147,7 +147,7 @@ vstringtok (Container &container, string const &in, size_t writen2(int fd, const void *buf, size_t count); inline size_t writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); } size_t readn2(int fd, void* buffer, size_t len); -size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout); +size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout=0); size_t writen2WithTimeout(int fd, const void * buffer, size_t len, int timeout); const string toLower(const string &upper); diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index d04c247803db..f62aed9baaa3 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -235,42 +235,57 @@ def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=Fals return (receivedQuery, message) @classmethod - def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): - if useQueue: - cls._toResponderQueue.put(response, True, timeout) + def openTCPConnection(cls, timeout=None): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if timeout: sock.settimeout(timeout) sock.connect(("127.0.0.1", cls._dnsDistPort)) + return sock - try: - if not rawQuery: - wire = query.to_wire() - else: - wire = query + @classmethod + def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False): + if not rawQuery: + wire = query.to_wire() + else: + wire = query - sock.send(struct.pack("!H", len(wire))) - sock.send(wire) - data = sock.recv(2) + sock.send(struct.pack("!H", len(wire))) + sock.send(wire) + + @classmethod + def recvTCPResponseOverConnection(cls, sock): + message = None + data = sock.recv(2) + if data: + (datalen,) = struct.unpack("!H", data) + data = sock.recv(datalen) if data: - (datalen,) = struct.unpack("!H", data) - data = sock.recv(datalen) + message = dns.message.from_wire(data) + return message + + @classmethod + def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): + message = None + if useQueue: + cls._toResponderQueue.put(response, True, timeout) + + sock = cls.openTCPConnection(timeout) + + try: + cls.sendTCPQueryOverConnection(sock, query, rawQuery) + message = cls.recvTCPResponseOverConnection(sock) except socket.timeout as e: print("Timeout: %s" % (str(e))) - data = None except socket.error as e: print("Network error: %s" % (str(e))) - data = None finally: sock.close() receivedQuery = None - message = None if useQueue and not cls._fromResponderQueue.empty(): receivedQuery = cls._fromResponderQueue.get(True, timeout) - if data: - message = dns.message.from_wire(data) + return (receivedQuery, message) @classmethod diff --git a/regression-tests.dnsdist/test_AXFR.py b/regression-tests.dnsdist/test_AXFR.py index 6f54dc30218e..9fed01195846 100644 --- a/regression-tests.dnsdist/test_AXFR.py +++ b/regression-tests.dnsdist/test_AXFR.py @@ -24,10 +24,6 @@ def startResponders(cls): cls._TCPResponder.setDaemon(True) cls._TCPResponder.start() - _config_template = """ - newServer{address="127.0.0.1:%s"} - """ - def testOneMessageAXFR(self): """ AXFR: One message diff --git a/regression-tests.dnsdist/test_TCPLimits.py b/regression-tests.dnsdist/test_TCPLimits.py new file mode 100644 index 000000000000..fb9dc03a6774 --- /dev/null +++ b/regression-tests.dnsdist/test_TCPLimits.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +import struct +import time +import dns +from dnsdisttests import DNSDistTest + +class TestTCPLimits(DNSDistTest): + + _tcpIdleTimeout = 2 + _maxTCPQueriesPerConn = 5 + _maxTCPConnsPerClient = 3 + _maxTCPConnDuration = 5 + _config_template = """ + newServer{address="127.0.0.1:%s"} + setTCPRecvTimeout(%s) + setMaxTCPQueriesPerConnection(%s) + setMaxTCPConnectionsPerClient(%s) + setMaxTCPConnectionDuration(%s) + """ + _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxTCPQueriesPerConn', '_maxTCPConnsPerClient', '_maxTCPConnDuration'] + + def testTCPQueriesPerConn(self): + """ + TCP Limits: Maximum number of queries + """ + name = 'maxqueriesperconn.tcp.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + conn = self.openTCPConnection() + + count = 0 + for idx in xrange(self._maxTCPQueriesPerConn): + try: + self.sendTCPQueryOverConnection(conn, query) + response = self.recvTCPResponseOverConnection(conn) + self.assertTrue(response) + count = count + 1 + except: + pass + + # this one should fail + failed = False + try: + self.sendTCPQueryOverConnection(conn, query) + response = self.recvTCPResponseOverConnection(conn) + self.assertFalse(response) + if not response: + failed = True + else: + count = count + 1 + except: + failed = True + + conn.close() + self.assertTrue(failed) + self.assertEqual(count, self._maxTCPQueriesPerConn) + + def testTCPConnsPerClient(self): + """ + TCP Limits: Maximum number of conns per client + """ + name = 'maxconnsperclient.tcp.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + conns = [] + + for idx in xrange(self._maxTCPConnsPerClient + 1): + conns.append(self.openTCPConnection()) + + count = 0 + failed = 0 + for conn in conns: + try: + self.sendTCPQueryOverConnection(conn, query) + response = self.recvTCPResponseOverConnection(conn) + if response: + count = count + 1 + else: + failed = failed + 1 + except: + failed = failed + 1 + + for conn in conns: + conn.close() + + self.assertEqual(count, self._maxTCPConnsPerClient) + self.assertEqual(failed, 1) + + def testTCPDuration(self): + """ + TCP Limits: Maximum duration + """ + name = 'duration.tcp.tests.powerdns.com.' + + start = time.time() + conn = self.openTCPConnection() + # immediately send the maximum size + conn.send(struct.pack("!H", 65535)) + + count = 0 + while count < (self._maxTCPConnDuration * 2): + try: + # sleeping for only one second keeps us below the + # idle timeout (setTCPRecvTimeout()) + time.sleep(1) + conn.send('A') + count = count + 1 + except: + break + + end = time.time() + + self.assertAlmostEquals(count, self._maxTCPConnDuration, delta=2) + self.assertAlmostEquals(end - start, self._maxTCPConnDuration, delta=2) + + conn.close()