Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore response body if status code doesn't allow a body #202

Merged
merged 6 commits into from Sep 6, 2018
Copy path View file
@@ -155,6 +155,7 @@ class Task(object):
content_length = None
content_bytes_written = 0
logged_write_excess = False
logged_write_no_body = False
complete = False
chunked_response = False
logger = logger
@@ -182,6 +183,13 @@ def service(self):
finally:
pass

@property
def has_body(self):
return not (self.status.startswith('1') or
self.status.startswith('204') or
self.status.startswith('304')
)

def cancel(self):
self.close_on_finish = True

@@ -192,30 +200,41 @@ def build_response_header(self):
version = self.version
# Figure out whether the connection should be closed.
connection = self.request.headers.get('CONNECTION', '').lower()
response_headers = self.response_headers
response_headers = []
content_length_header = None
date_header = None
server_header = None
connection_close_header = None

for i, (headername, headerval) in enumerate(response_headers):
for (headername, headerval) in self.response_headers:
headername = '-'.join(
[x.capitalize() for x in headername.split('-')]
)

if headername == 'Content-Length':
content_length_header = headerval
if self.has_body:
content_length_header = headerval
else:
continue # pragma: no cover

if headername == 'Date':
date_header = headerval

if headername == 'Server':
server_header = headerval

if headername == 'Connection':
connection_close_header = headerval.lower()
# replace with properly capitalized version
response_headers[i] = (headername, headerval)
response_headers.append((headername, headerval))

if content_length_header is None and self.content_length is not None:
if (
content_length_header is None and
self.content_length is not None and
self.has_body
):
content_length_header = str(self.content_length)
self.response_headers.append(
response_headers.append(
('Content-Length', content_length_header)
)

@@ -239,11 +258,12 @@ def close_on_finish():

if not content_length_header:
# RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
# for any response with a status code of 1xx or 204.
if not (self.status.startswith('1') or
self.status.startswith('204')):
# for any response with a status code of 1xx, 204 or 304.

if self.has_body:
response_headers.append(('Transfer-Encoding', 'chunked'))
self.chunked_response = True

if not self.close_on_finish:
close_on_finish()

@@ -254,6 +274,7 @@ def close_on_finish():
# Set the Server and Date field, if not yet specified. This is needed
# if the server is used as a proxy.
ident = self.channel.server.adj.ident

if not server_header:
if ident:
response_headers.append(('Server', ident))
@@ -263,20 +284,28 @@ def close_on_finish():
if not date_header:
response_headers.append(('Date', build_http_date(self.start_time)))

self.response_headers = response_headers

This comment has been minimized.

Copy link
@bertjwregeer

bertjwregeer Sep 6, 2018

Author Member

Not strictly required in this function, but required for testing.


first_line = 'HTTP/%s %s' % (self.version, self.status)
# NB: sorting headers needs to preserve same-named-header order
# as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
# rely on stable sort to keep relative position of same-named headers
next_lines = ['%s: %s' % hv for hv in sorted(
self.response_headers, key=lambda x: x[0])]
self.response_headers, key=lambda x: x[0])]
lines = [first_line] + next_lines
res = '%s\r\n\r\n' % '\r\n'.join(lines)

return tobytes(res)

def remove_content_length_header(self):
for i, (header_name, header_value) in enumerate(self.response_headers):
response_headers = []

for header_name, header_value in self.response_headers:
if header_name.lower() == 'content-length':
del self.response_headers[i]
continue # pragma: nocover
response_headers.append((header_name, header_value))

self.response_headers = response_headers

def start(self):
self.start_time = time.time()
@@ -297,7 +326,8 @@ def write(self, data):
rh = self.build_response_header()
channel.write_soon(rh)
self.wrote_header = True
if data:

if data and self.has_body:
towrite = data
cl = self.content_length
if self.chunked_response:
@@ -314,6 +344,18 @@ def write(self, data):
self.logged_write_excess = True
if towrite:
channel.write_soon(towrite)
else:
# Cheat, and tell the application we have written all of the bytes,
# even though the response shouldn't have a body and we are
# ignoring it entirely.
self.content_bytes_written += len(data)

if not self.logged_write_no_body:
self.logger.warning(
'application-written content was ignored due to HTTP '
'response that may not contain a message-body: (%s)' % self.status)
self.logged_write_no_body = True


class ErrorTask(Task):
""" An error task produces an error response
Copy path View file
@@ -245,6 +245,23 @@ def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(se
self.assertEqual(inst.close_on_finish, True)
self.assertTrue(('Connection', 'close') in inst.response_headers)

def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self):
# RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
# for any response with a status code of 1xx, 204 or 304.
inst = self._makeOne()
inst.request = DummyParser()
inst.version = '1.1'
inst.status = '304 Not Modified'
result = inst.build_response_header()
lines = filter_lines(result)
self.assertEqual(len(lines), 4)
self.assertEqual(lines[0], b'HTTP/1.1 304 Not Modified')
self.assertEqual(lines[1], b'Connection: close')
self.assertTrue(lines[2].startswith(b'Date:'))
self.assertEqual(lines[3], b'Server: waitress')
self.assertEqual(inst.close_on_finish, True)
self.assertTrue(('Connection', 'close') in inst.response_headers)

def test_build_response_header_via_added(self):
inst = self._makeOne()
inst.request = DummyParser()
@@ -291,6 +308,12 @@ def test_remove_content_length_header(self):
inst.remove_content_length_header()
self.assertEqual(inst.response_headers, [])

def test_remove_content_length_header_with_other(self):
inst = self._makeOne()
inst.response_headers = [('Content-Length', '70'), ('Content-Type', 'text/html')]
inst.remove_content_length_header()
self.assertEqual(inst.response_headers, [('Content-Type', 'text/html')])

def test_start(self):
inst = self._makeOne()
inst.start()
@@ -561,6 +584,34 @@ def app(environ, start_response):
self.assertEqual(inst.close_on_finish, True)
self.assertEqual(len(inst.logger.logged), 0)

def test_execute_app_without_body_204_logged(self):
def app(environ, start_response):
start_response('204 No Content', [('Content-Length', '3')])
return [b'abc']
inst = self._makeOne()
inst.channel.server.application = app
inst.logger = DummyLogger()
inst.execute()
self.assertEqual(inst.close_on_finish, True)
self.assertNotIn(b'abc', inst.channel.written)
self.assertNotIn(b'Content-Length', inst.channel.written)
self.assertNotIn(b'Transfer-Encoding', inst.channel.written)
self.assertEqual(len(inst.logger.logged), 1)

def test_execute_app_without_body_304_logged(self):
def app(environ, start_response):
start_response('304 Not Modified', [('Content-Length', '3')])
return [b'abc']
inst = self._makeOne()
inst.channel.server.application = app
inst.logger = DummyLogger()
inst.execute()
self.assertEqual(inst.close_on_finish, True)
self.assertNotIn(b'abc', inst.channel.written)
self.assertNotIn(b'Content-Length', inst.channel.written)
self.assertNotIn(b'Transfer-Encoding', inst.channel.written)
self.assertEqual(len(inst.logger.logged), 1)

def test_execute_app_returns_closeable(self):
class closeable(list):
def close(self):
@@ -915,9 +915,10 @@ def handle_close(self):
self.flag = True
self.close()

# def handle_expt(self):
# self.flag = True
# self.close()
def handle_expt(self): # pragma: no cover
# needs to exist for MacOS testing
self.flag = True
self.close()

class TestHandler(BaseTestHandler):

ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.