New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dnsdist: Expose trailing data #6967
Changes from 12 commits
4aa08b6
157445b
64cda3d
93446f2
8ca2f50
f641008
06f6491
d2c336c
7d243a5
6b32cb3
bf0ff88
9e2119f
bf11f6f
3ef7ab0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,42 +142,54 @@ def _ResponderIncrementCounter(cls): | |
cls._responsesCounter[threading.currentThread().name] = 1 | ||
|
||
@classmethod | ||
def _getResponse(cls, request, fromQueue, toQueue): | ||
def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): | ||
response = None | ||
if len(request.question) != 1: | ||
print("Skipping query with question count %d" % (len(request.question))) | ||
return None | ||
healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) | ||
if healthCheck: | ||
cls._healthCheckCounter += 1 | ||
response = dns.message.make_response(request) | ||
else: | ||
cls._ResponderIncrementCounter() | ||
if not fromQueue.empty(): | ||
response = fromQueue.get(True, cls._queueTimeout) | ||
if response: | ||
response = copy.copy(response) | ||
response.id = request.id | ||
toQueue.put(request, True, cls._queueTimeout) | ||
toQueue.put(request, True, cls._queueTimeout) | ||
if synthesize is None: | ||
response = fromQueue.get(True, cls._queueTimeout) | ||
if response: | ||
response = copy.copy(response) | ||
response.id = request.id | ||
|
||
if not response: | ||
if healthCheck: | ||
if synthesize is not None: | ||
response = dns.message.make_response(request) | ||
response.set_rcode(synthesize) | ||
elif cls._answerUnexpected: | ||
response = dns.message.make_response(request) | ||
response.set_rcode(dns.rcode.SERVFAIL) | ||
|
||
return response | ||
|
||
@classmethod | ||
def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False): | ||
def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False): | ||
ignoreTrailing = trailingDataResponse is True | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | ||
sock.bind(("127.0.0.1", port)) | ||
while True: | ||
data, addr = sock.recvfrom(4096) | ||
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | ||
response = cls._getResponse(request, fromQueue, toQueue) | ||
|
||
forceRcode = None | ||
try: | ||
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | ||
except dns.message.TrailingJunk as e: | ||
if trailingDataResponse is False: | ||
raise | ||
print("UDP query with trailing data, synthesizing response") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks quite confusing to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing structure was already a bit confusing, and I was trying to preserve it as much as possible (which unfortunately didn't help). But the idea is to extend the existing
We will never pass I have added comments and safety-enforcing code. |
||
request = dns.message.from_wire(data, ignore_trailing=True) | ||
forceRcode = trailingDataResponse | ||
|
||
response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | ||
if not response: | ||
continue | ||
|
||
|
@@ -187,7 +199,8 @@ def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False): | |
sock.close() | ||
|
||
@classmethod | ||
def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False): | ||
def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False): | ||
ignoreTrailing = trailingDataResponse is True | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | ||
try: | ||
|
@@ -207,9 +220,17 @@ def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleRe | |
|
||
(datalen,) = struct.unpack("!H", data) | ||
data = conn.recv(datalen) | ||
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | ||
response = cls._getResponse(request, fromQueue, toQueue) | ||
|
||
forceRcode = None | ||
try: | ||
request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | ||
except dns.message.TrailingJunk as e: | ||
if trailingDataResponse is False: | ||
raise | ||
print("TCP query with trailing data, synthesizing response") | ||
request = dns.message.from_wire(data, ignore_trailing=True) | ||
forceRcode = trailingDataResponse | ||
|
||
response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | ||
if not response: | ||
conn.close() | ||
continue | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a big deal but it looks like
messageLen
+tailLen
could overflow before being promoted. Perhaps something like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! Updated.