Skip to content

Commit

Permalink
dnsdist: Add TCP management options from rfc7766 section 10
Browse files Browse the repository at this point in the history
  • Loading branch information
rgacogne committed Dec 8, 2016
1 parent 439b085 commit 3990b7b
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 29 deletions.
3 changes: 3 additions & 0 deletions pdns/README-dnsdist.md
Expand Up @@ -1488,6 +1488,9 @@ instantiate a server with additional parameters
* member `add(DNSName)`: add this DNSName to the node
* Tuning related:
* `setMaxTCPClientThreads(n)`: set the maximum of TCP client threads, handling TCP connections
* `setMaxTCPConnectionDuration(n)`: set the maximum duration of an incoming TCP connection, in seconds. 0 (the default) means unlimited
* `setMaxTCPConnectionsPerClient(n)`: set the maximum number of TCP connections per client. 0 (the default) means unlimited
* `setMaxTCPQueriesPerConnection(n)`: set the maximum number of queries in an incoming TCP connection. 0 (the default) means unlimited
* `setMaxTCPQueuedConnections(n)`: set the maximum number of TCP connections queued (waiting to be picked up by a client thread), defaults to 1000. 0 means unlimited
* `setMaxUDPOutstanding(n)`: set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240
* `setCacheCleaningDelay(n)`: set the interval in seconds between two runs of the cache cleaning algorithm, removing expired entries
Expand Down
3 changes: 3 additions & 0 deletions pdns/dnsdist-console.cc
Expand Up @@ -332,6 +332,9 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
{ "setKey", true, "key", "set access key to that key" },
{ "setLocal", true, "netmask, [true], [false], [TCP Fast Open queue size]", "reset list of addresses we listen on to this address. Second optional parameter sets TCP or not. Third optional parameter sets SO_REUSEPORT when available. Last parameter sets the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0." },
{ "setMaxTCPClientThreads", true, "n", "set the maximum of TCP client threads, handling TCP connections" },
{ "setMaxTCPConnectionDuration", true, "n", "set the maximum duration of an incoming TCP connection, in seconds. 0 means unlimited" },
{ "setMaxTCPConnectionsPerClient", true, "n", "set the maximum number of TCP connections per client. 0 means unlimited" },
{ "setMaxTCPQueriesPerConnection", true, "n", "set the maximum number of queries in an incoming TCP connection. 0 means unlimited" },
{ "setMaxTCPQueuedConnections", true, "n", "set the maximum number of TCP connections queued (waiting to be picked up by a client thread)" },
{ "setMaxUDPOutstanding", true, "n", "set the maximum number of outstanding UDP queries to a given backend server. This can only be set at configuration time and defaults to 10240" },
{ "setQueryCount", true, "bool", "set whether queries should be counted" },
Expand Down
24 changes: 24 additions & 0 deletions pdns/dnsdist-lua.cc
Expand Up @@ -1520,6 +1520,30 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
}
});

g_lua.writeFunction("setMaxTCPQueriesPerConnection", [](size_t max) {
if (!g_configurationDone) {
g_maxTCPQueriesPerConn = max;
} else {
g_outputBuffer="The maximum number of queries per TCP connection cannot be altered at runtime!\n";
}
});

g_lua.writeFunction("setMaxTCPConnectionsPerClient", [](size_t max) {
if (!g_configurationDone) {
g_maxTCPConnectionsPerClient = max;
} else {
g_outputBuffer="The maximum number of TCP connection per client cannot be altered at runtime!\n";
}
});

g_lua.writeFunction("setMaxTCPConnectionDuration", [](size_t max) {
if (!g_configurationDone) {
g_maxTCPConnectionDuration = max;
} else {
g_outputBuffer="The maximum duration of a TCP connection cannot be altered at runtime!\n";
}
});

g_lua.writeFunction("showTCPStats", [] {
setLuaNoSideEffect();
boost::format fmt("%-10d %-10d %-10d %-10d\n");
Expand Down
74 changes: 70 additions & 4 deletions pdns/dnsdist-tcp.cc
Expand Up @@ -68,8 +68,25 @@ struct ConnectionInfo
};

uint64_t g_maxTCPQueuedConnections{1000};
size_t g_maxTCPQueriesPerConn{0};
size_t g_maxTCPConnectionDuration{0};
size_t g_maxTCPConnectionsPerClient{0};
static std::mutex tcpClientsCountMutex;
static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;

void* tcpClientThread(int pipefd);

static void decrementTCPClientCount(const ComboAddress& client)
{
if (g_maxTCPConnectionsPerClient) {
std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
tcpClientsCount[client]--;
if (tcpClientsCount[client] == 0) {
tcpClientsCount.erase(client);
}
}
}

// Should not be called simultaneously!
void TCPClientCollection::addTCPClientThread()
{
Expand Down Expand Up @@ -144,6 +161,18 @@ static bool sendResponseToClient(int fd, const char* response, uint16_t response
return true;
}

static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
{
if (maxConnectionDuration) {
time_t elapsed = time(NULL) - start;
if (elapsed >= maxConnectionDuration) {
return true;
}
remainingTime = maxConnectionDuration - elapsed;
}
return false;
}

std::shared_ptr<TCPClientCollection> g_tcpclientthreads;

void* tcpClientThread(int pipefd)
Expand Down Expand Up @@ -194,6 +223,9 @@ void* tcpClientThread(int pipefd)
memset(&dest, 0, sizeof(dest));
dest.sin4.sin_family = ci.remote.sin4.sin_family;
socklen_t len = dest.getSocklen();
size_t queriesCount = 0;
time_t connectionStartTime = time(NULL);

if (!setNonBlocking(ci.fd))
goto drop;

Expand All @@ -203,6 +235,7 @@ void* tcpClientThread(int pipefd)

try {
for(;;) {
unsigned int remainingTime = 0;
ds = nullptr;
outstanding = false;

Expand All @@ -212,6 +245,18 @@ void* tcpClientThread(int pipefd)
ci.cs->queries++;
g_stats.queries++;

queriesCount++;

if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
break;
}

if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
break;
}

if (qlen < sizeof(dnsheader)) {
g_stats.nonCompliantQueries++;
break;
Expand All @@ -225,7 +270,7 @@ void* tcpClientThread(int pipefd)
size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
char queryBuffer[querySize];
const char* query = queryBuffer;
readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime);

#ifdef HAVE_DNSCRYPT
std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
Expand Down Expand Up @@ -519,18 +564,18 @@ void* tcpClientThread(int pipefd)
outstanding = false;
--ds->outstanding;
}
decrementTCPClientCount(ci.remote);
}
return 0;
}


/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
they will hand off to worker threads & spawn more of them if required
*/
void* tcpAcceptorThread(void* p)
{
ClientState* cs = (ClientState*) p;

bool tcpClientCountIncremented = false;
ComboAddress remote;
remote.sin4.sin_family = cs->local.sin4.sin_family;

Expand All @@ -540,6 +585,7 @@ void* tcpAcceptorThread(void* p)
for(;;) {
bool queuedCounterIncremented = false;
ConnectionInfo* ci = nullptr;
tcpClientCountIncremented = false;
try {
ci = new ConnectionInfo;
ci->cs = cs;
Expand All @@ -563,8 +609,22 @@ void* tcpAcceptorThread(void* p)
continue;
}

if (g_maxTCPConnectionsPerClient) {
std::lock_guard<std::mutex> lock(tcpClientsCountMutex);

if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
close(ci->fd);
delete ci;
ci=nullptr;
vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
continue;
}
tcpClientsCount[remote]++;
tcpClientCountIncremented = true;
}

vinfolog("Got TCP connection from %s", remote.toStringWithPort());

ci->remote = remote;
int pipe = g_tcpclientthreads->getThread();
if (pipe >= 0) {
Expand All @@ -577,12 +637,18 @@ void* tcpAcceptorThread(void* p)
close(ci->fd);
delete ci;
ci=nullptr;
if(tcpClientCountIncremented) {
decrementTCPClientCount(remote);
}
}
}
catch(std::exception& e) {
errlog("While reading a TCP question: %s", e.what());
if(ci && ci->fd >= 0)
close(ci->fd);
if(tcpClientCountIncremented) {
decrementTCPClientCount(remote);
}
delete ci;
ci = nullptr;
if (queuedCounterIncremented) {
Expand Down
3 changes: 3 additions & 0 deletions pdns/dnsdist.hh
Expand Up @@ -688,6 +688,9 @@ extern uint16_t g_maxOutstanding;
extern std::atomic<bool> g_configurationDone;
extern uint64_t g_maxTCPClientThreads;
extern uint64_t g_maxTCPQueuedConnections;
extern size_t g_maxTCPQueriesPerConn;
extern size_t g_maxTCPConnectionDuration;
extern size_t g_maxTCPConnectionsPerClient;
extern std::atomic<uint16_t> g_cacheCleaningDelay;
extern bool g_verboseHealthChecks;
extern uint32_t g_staleCacheEntriesTTL;
Expand Down
20 changes: 18 additions & 2 deletions pdns/misc.cc
Expand Up @@ -106,9 +106,15 @@ size_t readn2(int fd, void* buffer, size_t len)
return len;
}

size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout)
{
size_t pos = 0;
time_t start = 0;
int remainingTime = totalTimeout;
if (totalTimeout) {
start = time(NULL);
}

do {
ssize_t got = read(fd, (char *)buffer + pos, len - pos);
if (got > 0) {
Expand All @@ -119,7 +125,7 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
}
else {
if (errno == EAGAIN) {
int res = waitForData(fd, timeout);
int res = waitForData(fd, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime);
if (res > 0) {
/* there is data available */
}
Expand All @@ -133,6 +139,16 @@ size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout)
unixDie("failed in readn2WithTimeout");
}
}

if (totalTimeout) {
time_t now = time(NULL);
int elapsed = now - start;
if (elapsed >= remainingTime) {
throw runtime_error("Timeout while reading data");
}
start = now;
remainingTime -= elapsed;
}
}
while (pos < len);

Expand Down
2 changes: 1 addition & 1 deletion pdns/misc.hh
Expand Up @@ -147,7 +147,7 @@ vstringtok (Container &container, string const &in,
size_t writen2(int fd, const void *buf, size_t count);
inline size_t writen2(int fd, const std::string &s) { return writen2(fd, s.data(), s.size()); }
size_t readn2(int fd, void* buffer, size_t len);
size_t readn2WithTimeout(int fd, void* buffer, size_t len, int timeout);
size_t readn2WithTimeout(int fd, void* buffer, size_t len, int idleTimeout, int totalTimeout=0);
size_t writen2WithTimeout(int fd, const void * buffer, size_t len, int timeout);

const string toLower(const string &upper);
Expand Down
51 changes: 33 additions & 18 deletions regression-tests.dnsdist/dnsdisttests.py
Expand Up @@ -235,42 +235,57 @@ def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=Fals
return (receivedQuery, message)

@classmethod
def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
if useQueue:
cls._toResponderQueue.put(response, True, timeout)
def openTCPConnection(cls, timeout=None):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if timeout:
sock.settimeout(timeout)

sock.connect(("127.0.0.1", cls._dnsDistPort))
return sock

try:
if not rawQuery:
wire = query.to_wire()
else:
wire = query
@classmethod
def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False):
if not rawQuery:
wire = query.to_wire()
else:
wire = query

sock.send(struct.pack("!H", len(wire)))
sock.send(wire)
data = sock.recv(2)
sock.send(struct.pack("!H", len(wire)))
sock.send(wire)

@classmethod
def recvTCPResponseOverConnection(cls, sock):
message = None
data = sock.recv(2)
if data:
(datalen,) = struct.unpack("!H", data)
data = sock.recv(datalen)
if data:
(datalen,) = struct.unpack("!H", data)
data = sock.recv(datalen)
message = dns.message.from_wire(data)
return message

@classmethod
def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
message = None
if useQueue:
cls._toResponderQueue.put(response, True, timeout)

sock = cls.openTCPConnection(timeout)

try:
cls.sendTCPQueryOverConnection(sock, query, rawQuery)
message = cls.recvTCPResponseOverConnection(sock)
except socket.timeout as e:
print("Timeout: %s" % (str(e)))
data = None
except socket.error as e:
print("Network error: %s" % (str(e)))
data = None
finally:
sock.close()

receivedQuery = None
message = None
if useQueue and not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)
if data:
message = dns.message.from_wire(data)

return (receivedQuery, message)

@classmethod
Expand Down
4 changes: 0 additions & 4 deletions regression-tests.dnsdist/test_AXFR.py
Expand Up @@ -24,10 +24,6 @@ def startResponders(cls):
cls._TCPResponder.setDaemon(True)
cls._TCPResponder.start()

_config_template = """
newServer{address="127.0.0.1:%s"}
"""

def testOneMessageAXFR(self):
"""
AXFR: One message
Expand Down

0 comments on commit 3990b7b

Please sign in to comment.