Skip to content

Commit

Permalink
Ensure no leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavsingh committed Oct 16, 2019
1 parent d4fe97f commit 0d3040f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 19 deletions.
3 changes: 2 additions & 1 deletion plugin_examples.py
Expand Up @@ -55,10 +55,11 @@ def before_upstream_connection(self, request: HttpParser) -> Optional[HttpParser
def handle_client_request(self, request: HttpParser) -> Optional[HttpParser]:
if request.host and proxy.DOT not in request.host:
if request.host in self.SHORT_LINKS:
path = proxy.SLASH if not request.path else request.path
self.client.queue(proxy.build_http_response(
proxy.httpStatusCodes.SEE_OTHER, reason=b'See Other',
headers={
b'Location': b'http://' + self.SHORT_LINKS[request.host] + request.path,
b'Location': b'http://' + self.SHORT_LINKS[request.host] + path,
b'Content-Length': b'0',
b'Connection': b'close',
}
Expand Down
90 changes: 72 additions & 18 deletions tests.py
Expand Up @@ -294,14 +294,16 @@ def setUp(
mock_protocol_handler,
config=self.protocol_config)

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.recv_handle')
def test_continues_when_no_events(
self,
mock_recv_handle: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
fileno = 10
conn = mock.MagicMock()
addr = mock.MagicMock()
Expand All @@ -317,14 +319,16 @@ def test_continues_when_no_events(
sock.accept.assert_not_called()
self.mock_protocol_handler.assert_not_called()

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.recv_handle')
def test_worker_doesnt_teardown_on_blocking_io_error(
self,
mock_recv_handle: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
fileno = 10
conn = mock.MagicMock()
addr = mock.MagicMock()
Expand All @@ -340,14 +344,16 @@ def test_worker_doesnt_teardown_on_blocking_io_error(

self.mock_protocol_handler.assert_not_called()

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.recv_handle')
def test_accepts_client_from_server_socket(
self,
mock_recv_handle: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
fileno = 10
conn = mock.MagicMock()
addr = mock.MagicMock()
Expand Down Expand Up @@ -967,9 +973,13 @@ def test_handshake(self, mock_connect: mock.Mock, mock_b64encode: mock.Mock) ->

class TestHttpProtocolHandler(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
def setUp(self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self.fileno = 10
self._addr = ('127.0.0.1', 54382)
self._conn = mock_fromfd.return_value
Expand All @@ -982,6 +992,7 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self.mock_selector = mock_selector
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

@mock.patch('proxy.TcpServerConnection')
Expand Down Expand Up @@ -1095,10 +1106,14 @@ def test_proxy_connection_failed(self) -> None:
self.proxy.run_once()
self.assertEqual(self.proxy.client.buffer, proxy.ProxyConnectionFailed.RESPONSE_PKT)

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_proxy_authentication_failed(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self._conn = mock_fromfd.return_value
self.mock_selector_for_client_read(mock_selector)
config = proxy.ProtocolConfig(
Expand All @@ -1108,6 +1123,7 @@ def test_proxy_authentication_failed(
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()
self._conn.recv.return_value = proxy.CRLF.join([
b'GET http://abhinavsingh.com HTTP/1.1',
Expand All @@ -1119,13 +1135,15 @@ def test_proxy_authentication_failed(
self.proxy.client.buffer,
proxy.ProxyAuthenticationFailed.RESPONSE_PKT)

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_get(
self, mock_server_connection: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self._conn = mock_fromfd.return_value
self.mock_selector_for_client_read(mock_selector)

Expand All @@ -1141,6 +1159,7 @@ def test_authenticated_proxy_http_get(

self.proxy = proxy.ProtocolHandler(
self.fileno, addr=self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()
assert self.http_server_port is not None

Expand All @@ -1167,13 +1186,15 @@ def test_authenticated_proxy_http_get(
])
self.assert_data_queued(mock_server_connection, server)

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
@mock.patch('proxy.TcpServerConnection')
def test_authenticated_proxy_http_tunnel(
self, mock_server_connection: mock.Mock,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
server = mock_server_connection.return_value
server.connect.return_value = True
server.buffer_size.return_value = 0
Expand All @@ -1188,6 +1209,7 @@ def test_authenticated_proxy_http_tunnel(

self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

assert self.http_server_port is not None
Expand Down Expand Up @@ -1281,9 +1303,10 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None:

class TestWebServerPlugin(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_os_close: mock.Mock) -> None:
self.fileno = 10
self._addr = ('127.0.0.1', 54382)
self._conn = mock_fromfd.return_value
Expand All @@ -1293,16 +1316,20 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_pac_file_served_from_disk(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
pac_file = 'proxy.pac'
self._conn = mock_fromfd.return_value
self.mock_selector_for_client_read(mock_selector)
self.init_and_make_pac_file_request(pac_file)
mock_os_close.assert_called_with(self.fileno)
self.proxy.run_once()
self.assertEqual(
self.proxy.request.state,
Expand All @@ -1315,14 +1342,17 @@ def test_pac_file_served_from_disk(
}, body=f.read()
))

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_pac_file_served_from_buffer(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self._conn = mock_fromfd.return_value
self.mock_selector_for_client_read(mock_selector)
pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }'
self.init_and_make_pac_file_request(proxy.text_(pac_file_content))
mock_os_close.assert_called_with(self.fileno)
self.proxy.run_once()
self.assertEqual(
self.proxy.request.state,
Expand All @@ -1334,10 +1364,12 @@ def test_pac_file_served_from_buffer(
}, body=pac_file_content
))

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_default_web_server_returns_404(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self._conn = mock_fromfd.return_value
mock_selector.return_value.select.return_value = [(
selectors.SelectorKey(
Expand All @@ -1350,6 +1382,7 @@ def test_default_web_server_returns_404(
b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin')
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()
self._conn.recv.return_value = proxy.CRLF.join([
b'GET /hello HTTP/1.1',
Expand All @@ -1363,10 +1396,12 @@ def test_default_web_server_returns_404(
self.proxy.client.buffer,
proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE)

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_static_web_server_serves(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
# Setup a static directory
static_server_dir = os.path.join(tempfile.gettempdir(), 'static')
index_file_path = os.path.join(static_server_dir, 'index.html')
Expand Down Expand Up @@ -1418,10 +1453,14 @@ def test_static_web_server_serves(
body=html_file_content
))

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def test_static_web_server_serves_404(
self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self._conn = mock_fromfd.return_value
self._conn.recv.return_value = proxy.build_http_request(b'GET', b'/not-found.html')

Expand All @@ -1443,6 +1482,7 @@ def test_static_web_server_serves_404(

self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

self.proxy.run_once()
Expand All @@ -1453,15 +1493,17 @@ def test_static_web_server_serves_404(
self.assertEqual(self._conn.send.call_args[0][0],
proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE)

@mock.patch('os.close')
@mock.patch('socket.fromfd')
def test_on_client_connection_called_on_teardown(
self, mock_fromfd: mock.Mock) -> None:
self, mock_fromfd: mock.Mock, mock_os_close: mock.Mock) -> None:
config = proxy.ProtocolConfig()
plugin = mock.MagicMock()
config.plugins = {b'ProtocolHandlerPlugin': [plugin]}
self._conn = mock_fromfd.return_value
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()
plugin.assert_called()
with mock.patch.object(self.proxy, 'run_once') as mock_run_once:
Expand Down Expand Up @@ -1493,11 +1535,13 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None:

class TestHttpProxyPlugin(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def setUp(self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self.mock_fromfd = mock_fromfd
self.mock_selector = mock_selector

Expand All @@ -1512,6 +1556,7 @@ def setUp(self,
self._conn = mock_fromfd.return_value
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

def test_proxy_plugin_initialized(self) -> None:
Expand Down Expand Up @@ -1566,11 +1611,13 @@ def test_proxy_plugin_before_upstream_connection_can_teardown(

class TestHttpProxyPluginExamples(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('selectors.DefaultSelector')
@mock.patch('socket.fromfd')
def setUp(self,
mock_fromfd: mock.Mock,
mock_selector: mock.Mock) -> None:
mock_selector: mock.Mock,
mock_os_close: mock.Mock) -> None:
self.fileno = 10
self._addr = ('127.0.0.1', 54382)
self.config = proxy.ProtocolConfig()
Expand All @@ -1588,6 +1635,7 @@ def setUp(self,
self._conn = mock_fromfd.return_value
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

@mock.patch('proxy.TcpServerConnection')
Expand Down Expand Up @@ -1788,6 +1836,7 @@ def closed() -> bool:

class TestHttpProxyTlsInterception(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('ssl.wrap_socket')
@mock.patch('ssl.create_default_context')
@mock.patch('proxy.TcpServerConnection')
Expand All @@ -1801,7 +1850,8 @@ def test_e2e(
mock_popen: mock.Mock,
mock_server_conn: mock.Mock,
mock_ssl_context: mock.Mock,
mock_ssl_wrap: mock.Mock) -> None:
mock_ssl_wrap: mock.Mock,
mock_os_close: mock.Mock) -> None:
host, port = uuid.uuid4().hex, 443
netloc = '{0}:{1}'.format(host, port)

Expand Down Expand Up @@ -1841,6 +1891,7 @@ def mock_connection() -> Any:
self._conn = mock_fromfd.return_value
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

self.plugin.assert_called()
Expand Down Expand Up @@ -1919,6 +1970,7 @@ def mock_connection() -> Any:

class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):

@mock.patch('os.close')
@mock.patch('ssl.wrap_socket')
@mock.patch('ssl.create_default_context')
@mock.patch('proxy.TcpServerConnection')
Expand All @@ -1931,7 +1983,8 @@ def setUp(self,
mock_popen: mock.Mock,
mock_server_conn: mock.Mock,
mock_ssl_context: mock.Mock,
mock_ssl_wrap: mock.Mock) -> None:
mock_ssl_wrap: mock.Mock,
mock_os_close: mock.Mock) -> None:
self.mock_fromfd = mock_fromfd
self.mock_selector = mock_selector
self.mock_popen = mock_popen
Expand All @@ -1957,6 +2010,7 @@ def setUp(self,
mock_fromfd.return_value = self._conn
self.proxy = proxy.ProtocolHandler(
self.fileno, self._addr, config=self.config)
mock_os_close.assert_called_with(self.fileno)
self.proxy.initialize()

self.server = self.mock_server_conn.return_value
Expand Down

0 comments on commit 0d3040f

Please sign in to comment.