Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnsdist: Expose trailing data #6967

Merged
merged 14 commits into from Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions pdns/dnsdist-lua-bindings-dnsquestion.cc
Expand Up @@ -63,6 +63,27 @@ void setupLuaBindingsDNSQuestion()

return *dq.ednsOptions;
});
g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) {
const char* message = reinterpret_cast<const char*>(dq.dh);
const uint16_t messageLen = getDNSPacketLength(message, dq.len);
const std::string tail = std::string(message + messageLen, dq.len - messageLen);
return tail;
});
g_lua.registerFunction<bool(DNSQuestion::*)(std::string)>("setTrailingData", [](DNSQuestion& dq, const std::string& tail) {
char* message = reinterpret_cast<char*>(dq.dh);
const uint16_t messageLen = getDNSPacketLength(message, dq.len);
const uint16_t tailLen = tail.size();
if(messageLen + tailLen > dq.size) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal but it looks like messageLen + tailLen could overflow before being promoted. Perhaps something like:

Suggested change
if(messageLen + tailLen > dq.size) {
if(tailLen > (dq.size - messageLen)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Updated.

return false;
}

/* Update length and copy data from the Lua string. */
dq.len = messageLen + tailLen;
if(tailLen > 0) {
tail.copy(message + messageLen, tailLen);
}
return true;
});

g_lua.registerFunction<void(DNSQuestion::*)(std::string)>("sendTrap", [](const DNSQuestion& dq, boost::optional<std::string> reason) {
#ifdef HAVE_NET_SNMP
Expand Down Expand Up @@ -123,6 +144,27 @@ void setupLuaBindingsDNSQuestion()
g_lua.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](const DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
editDNSPacketTTL((char*) dr.dh, dr.len, editFunc);
});
g_lua.registerFunction<std::string(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) {
const char* message = reinterpret_cast<const char*>(dq.dh);
const uint16_t messageLen = getDNSPacketLength(message, dq.len);
const std::string tail = std::string(message + messageLen, dq.len - messageLen);
return tail;
});
g_lua.registerFunction<bool(DNSResponse::*)(std::string)>("setTrailingData", [](DNSResponse& dq, const std::string& tail) {
char* message = reinterpret_cast<char*>(dq.dh);
const uint16_t messageLen = getDNSPacketLength(message, dq.len);
const uint16_t tailLen = tail.size();
if(messageLen + tailLen > dq.size) {
return false;
}

/* Update length and copy data from the Lua string. */
dq.len = messageLen + tailLen;
if(tailLen > 0) {
tail.copy(message + messageLen, tailLen);
}
return true;
});
g_lua.registerFunction<void(DNSResponse::*)(std::string)>("sendTrap", [](const DNSResponse& dr, boost::optional<std::string> reason) {
#ifdef HAVE_NET_SNMP
if (g_snmpAgent && g_snmpTrapsEnabled) {
Expand Down
17 changes: 17 additions & 0 deletions pdns/dnsdistdist/docs/reference/dq.rst
Expand Up @@ -109,6 +109,14 @@ This state can be modified from the various hooks.

:returns: A table of tags, using strings as keys and values

.. method:: DNSQuestion:getTrailingData() -> string

.. versionadded:: 1.4.0

Get all data following the DNS message.

:returns: The trailing data as a null-safe string

.. method:: DNSQuestion:sendTrap(reason)

.. versionadded:: 1.2.0
Expand All @@ -134,6 +142,15 @@ This state can be modified from the various hooks.

:param table tags: A table of tags, using strings as keys and values

.. method:: DNSQuestion:setTrailingData(tail) -> bool

.. versionadded:: 1.4.0

Set the data following the DNS message, overwriting anything already present.

:param string tail: The new data
:returns: true if the operation succeeded, false otherwise

.. _DNSResponse:

DNSResponse object
Expand Down
51 changes: 36 additions & 15 deletions regression-tests.dnsdist/dnsdisttests.py
Expand Up @@ -142,42 +142,54 @@ def _ResponderIncrementCounter(cls):
cls._responsesCounter[threading.currentThread().name] = 1

@classmethod
def _getResponse(cls, request, fromQueue, toQueue):
def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
response = None
if len(request.question) != 1:
print("Skipping query with question count %d" % (len(request.question)))
return None
healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
if healthCheck:
cls._healthCheckCounter += 1
response = dns.message.make_response(request)
else:
cls._ResponderIncrementCounter()
if not fromQueue.empty():
response = fromQueue.get(True, cls._queueTimeout)
if response:
response = copy.copy(response)
response.id = request.id
toQueue.put(request, True, cls._queueTimeout)
toQueue.put(request, True, cls._queueTimeout)
if synthesize is None:
response = fromQueue.get(True, cls._queueTimeout)
if response:
response = copy.copy(response)
response.id = request.id

if not response:
if healthCheck:
if synthesize is not None:
response = dns.message.make_response(request)
response.set_rcode(synthesize)
elif cls._answerUnexpected:
response = dns.message.make_response(request)
response.set_rcode(dns.rcode.SERVFAIL)

return response

@classmethod
def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("127.0.0.1", port))
while True:
data, addr = sock.recvfrom(4096)
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
response = cls._getResponse(request, fromQueue, toQueue)

forceRcode = None
try:
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
except dns.message.TrailingJunk as e:
if trailingDataResponse is False:
raise
print("UDP query with trailing data, synthesizing response")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks quite confusing to me. ignoreTrailing will only be True if trailingDataResponse is True, not if it is a numerical value like a response code, which means we would then pass True in the synthesize parameter to _getResponse(), which feels wrong.
Since as far as I can tell we never pass set trailingDataResponse to True, perhaps we could just remove ignoreTrailing altogether?

Copy link
Contributor Author

@gibson042 gibson042 Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing structure was already a bit confusing, and I was trying to preserve it as much as possible (which unfortunately didn't help). But the idea is to extend the existing ignoreTrailing parameter into trailingDataResponse as follows:

  • False: reject queries with trailing data (preexisting behavior)
  • True: ignore trailing data after queries (preexisting behavior)
  • anything else: interpret as RCODE to be sent in response to queries with trailing data (new behavior, used by TestTrailingDataToBackend in regression-tests.dnsdist/test_Trailing.py)

We will never pass synthesize=True to _getResponse, because forceRcode is only set to trailingDataResponse when dns.message.from_wire raises dns.message.TrailingJunk, which cannot happen when it is ignoring trailing data.

I have added comments and safety-enforcing code.

request = dns.message.from_wire(data, ignore_trailing=True)
forceRcode = trailingDataResponse

response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
continue

Expand All @@ -187,7 +199,8 @@ def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
sock.close()

@classmethod
def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
try:
Expand All @@ -207,9 +220,17 @@ def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleRe

(datalen,) = struct.unpack("!H", data)
data = conn.recv(datalen)
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
response = cls._getResponse(request, fromQueue, toQueue)

forceRcode = None
try:
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
except dns.message.TrailingJunk as e:
if trailingDataResponse is False:
raise
print("TCP query with trailing data, synthesizing response")
request = dns.message.from_wire(data, ignore_trailing=True)
forceRcode = trailingDataResponse

response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
conn.close()
continue
Expand Down