diff --git a/plugin_examples.py b/plugin_examples.py index 051c611176..cdd1f66946 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -55,10 +55,11 @@ def before_upstream_connection(self, request: HttpParser) -> Optional[HttpParser def handle_client_request(self, request: HttpParser) -> Optional[HttpParser]: if request.host and proxy.DOT not in request.host: if request.host in self.SHORT_LINKS: + path = proxy.SLASH if not request.path else request.path self.client.queue(proxy.build_http_response( proxy.httpStatusCodes.SEE_OTHER, reason=b'See Other', headers={ - b'Location': b'http://' + self.SHORT_LINKS[request.host] + request.path, + b'Location': b'http://' + self.SHORT_LINKS[request.host] + path, b'Content-Length': b'0', b'Connection': b'close', } diff --git a/tests.py b/tests.py index b430c95748..dddcd2a2f9 100644 --- a/tests.py +++ b/tests.py @@ -294,6 +294,7 @@ def setUp( mock_protocol_handler, config=self.protocol_config) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -301,7 +302,8 @@ def test_continues_when_no_events( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -317,6 +319,7 @@ def test_continues_when_no_events( sock.accept.assert_not_called() self.mock_protocol_handler.assert_not_called() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -324,7 +327,8 @@ def test_worker_doesnt_teardown_on_blocking_io_error( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -340,6 +344,7 @@ def test_worker_doesnt_teardown_on_blocking_io_error( self.mock_protocol_handler.assert_not_called() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -347,7 +352,8 @@ def test_accepts_client_from_server_socket( self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -967,9 +973,13 @@ def test_handshake(self, mock_connect: mock.Mock, mock_b64encode: mock.Mock) -> class TestHttpProtocolHandler(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + def setUp(self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -982,6 +992,7 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.mock_selector = mock_selector self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1095,10 +1106,14 @@ def test_proxy_connection_failed(self) -> None: self.proxy.run_once() self.assertEqual(self.proxy.client.buffer, proxy.ProxyConnectionFailed.RESPONSE_PKT) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_proxy_authentication_failed( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) config = proxy.ProtocolConfig( @@ -1108,6 +1123,7 @@ def test_proxy_authentication_failed( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', @@ -1119,13 +1135,15 @@ def test_proxy_authentication_failed( self.proxy.client.buffer, proxy.ProxyAuthenticationFailed.RESPONSE_PKT) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_get( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) @@ -1141,6 +1159,7 @@ def test_authenticated_proxy_http_get( self.proxy = proxy.ProtocolHandler( self.fileno, addr=self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1167,13 +1186,15 @@ def test_authenticated_proxy_http_get( ]) self.assert_data_queued(mock_server_connection, server) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_tunnel( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 @@ -1188,6 +1209,7 @@ def test_authenticated_proxy_http_tunnel( self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1281,9 +1303,10 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: class TestWebServerPlugin(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -1293,16 +1316,20 @@ def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_disk( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: pac_file = 'proxy.pac' self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) self.init_and_make_pac_file_request(pac_file) + mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1315,14 +1342,17 @@ def test_pac_file_served_from_disk( }, body=f.read() )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_buffer( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' self.init_and_make_pac_file_request(proxy.text_(pac_file_content)) + mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1334,10 +1364,12 @@ def test_pac_file_served_from_buffer( }, body=pac_file_content )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_default_web_server_returns_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [( selectors.SelectorKey( @@ -1350,6 +1382,7 @@ def test_default_web_server_returns_404( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET /hello HTTP/1.1', @@ -1363,10 +1396,12 @@ def test_default_web_server_returns_404( self.proxy.client.buffer, proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') @@ -1418,10 +1453,14 @@ def test_static_web_server_serves( body=html_file_content )) + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: + self, + mock_fromfd: mock.Mock, + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = proxy.build_http_request(b'GET', b'/not-found.html') @@ -1443,6 +1482,7 @@ def test_static_web_server_serves_404( self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.proxy.run_once() @@ -1453,15 +1493,17 @@ def test_static_web_server_serves_404( self.assertEqual(self._conn.send.call_args[0][0], proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) + @mock.patch('os.close') @mock.patch('socket.fromfd') def test_on_client_connection_called_on_teardown( - self, mock_fromfd: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_os_close: mock.Mock) -> None: config = proxy.ProtocolConfig() plugin = mock.MagicMock() config.plugins = {b'ProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() plugin.assert_called() with mock.patch.object(self.proxy, 'run_once') as mock_run_once: @@ -1493,11 +1535,13 @@ def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: class TestHttpProxyPlugin(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector @@ -1512,6 +1556,7 @@ def setUp(self, self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() def test_proxy_plugin_initialized(self) -> None: @@ -1566,11 +1611,13 @@ def test_proxy_plugin_before_upstream_connection_can_teardown( class TestHttpProxyPluginExamples(unittest.TestCase): + @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock) -> None: + mock_selector: mock.Mock, + mock_os_close: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.config = proxy.ProtocolConfig() @@ -1588,6 +1635,7 @@ def setUp(self, self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1788,6 +1836,7 @@ def closed() -> bool: class TestHttpProxyTlsInterception(unittest.TestCase): + @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -1801,7 +1850,8 @@ def test_e2e( mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock, + mock_os_close: mock.Mock) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) @@ -1841,6 +1891,7 @@ def mock_connection() -> Any: self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.plugin.assert_called() @@ -1919,6 +1970,7 @@ def mock_connection() -> Any: class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): + @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -1931,7 +1983,8 @@ def setUp(self, mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock, + mock_os_close: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_popen = mock_popen @@ -1957,6 +2010,7 @@ def setUp(self, mock_fromfd.return_value = self._conn self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) + mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.server = self.mock_server_conn.return_value