Skip to content

Commit

Permalink
Consume connections better in socket-level tests (urllib3#1958)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Pradet <quentin.pradet@gmail.com>
  • Loading branch information
hodbn and pquentin committed Sep 18, 2020
1 parent 04c55ec commit 1f3ea3b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 23 deletions.
63 changes: 63 additions & 0 deletions dummyserver/testcase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import threading
from contextlib import contextmanager

import pytest
from tornado import ioloop, web

from urllib3.connection import HTTPConnection

from dummyserver.server import (
SocketServerThread,
run_tornado_app,
Expand Down Expand Up @@ -217,3 +220,63 @@ class IPv6HTTPDummyProxyTestCase(HTTPDummyProxyTestCase):

proxy_host = "::1"
proxy_host_alt = "127.0.0.1"


class ConnectionMarker(object):
"""
Marks an HTTP(S)Connection's socket after a request was made.
Helps a test server understand when a client finished a request,
without implementing a complete HTTP server.
"""

MARK_FORMAT = b"$#MARK%04x*!"

@classmethod
@contextmanager
def mark(cls, monkeypatch):
"""
Mark connections under in that context.
"""

orig_request = HTTPConnection.request
orig_request_chunked = HTTPConnection.request_chunked

def call_and_mark(target):
def part(self, *args, **kwargs):
result = target(self, *args, **kwargs)
self.sock.sendall(cls._get_socket_mark(self.sock, False))
return result

return part

with monkeypatch.context() as m:
m.setattr(HTTPConnection, "request", call_and_mark(orig_request))
m.setattr(
HTTPConnection, "request_chunked", call_and_mark(orig_request_chunked)
)
yield

@classmethod
def consume_request(cls, sock, chunks=65536):
"""
Consume a socket until after the HTTP request is sent.
"""
consumed = bytearray()
mark = cls._get_socket_mark(sock, True)
while True:
b = sock.recv(chunks)
if not b:
break
consumed += b
if consumed.endswith(mark):
break
return consumed

@classmethod
def _get_socket_mark(cls, sock, server):
if server:
port = sock.getpeername()[1]
else:
port = sock.getsockname()[1]
return cls.MARK_FORMAT % (port,)
49 changes: 26 additions & 23 deletions test/with_dummyserver/test_chunked_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from urllib3 import HTTPConnectionPool
from urllib3.util.retry import Retry
from urllib3.util import SUPPRESS_USER_AGENT
from dummyserver.testcase import SocketDummyServerTestCase, consume_socket
from test import notWindows
from dummyserver.testcase import (
SocketDummyServerTestCase,
consume_socket,
ConnectionMarker,
)

# Retry failed tests
pytestmark = pytest.mark.flaky
Expand Down Expand Up @@ -156,56 +159,56 @@ def socket_handler(listener):
sock.close()
assert self.chunked_requests == 2

@notWindows
def test_preserve_chunked_on_redirect(self):
def test_preserve_chunked_on_redirect(self, monkeypatch):
self.chunked_requests = 0

def socket_handler(listener):
for i in range(2):
sock = listener.accept()[0]
request = consume_socket(sock)
request = ConnectionMarker.consume_request(sock)
if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
self.chunked_requests += 1

if i == 0:
sock.send(
sock.sendall(
b"HTTP/1.1 301 Moved Permanently\r\n"
b"Location: /redirect\r\n\r\n"
)
else:
sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n")
sock.close()

self._start_server(socket_handler)
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(redirect=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
with ConnectionMarker.mark(monkeypatch):
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(redirect=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2

@notWindows
def test_preserve_chunked_on_broken_connection(self):
def test_preserve_chunked_on_broken_connection(self, monkeypatch):
self.chunked_requests = 0

def socket_handler(listener):
for i in range(2):
sock = listener.accept()[0]
request = consume_socket(sock)
request = ConnectionMarker.consume_request(sock)
if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
self.chunked_requests += 1

if i == 0:
# Bad HTTP version will trigger a connection close
sock.send(b"HTTP/0.5 200 OK\r\n\r\n")
sock.sendall(b"HTTP/0.5 200 OK\r\n\r\n")
else:
sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n")
sock.close()

self._start_server(socket_handler)
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(read=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2
with ConnectionMarker.mark(monkeypatch):
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(read=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2

0 comments on commit 1f3ea3b

Please sign in to comment.