From 5ca24410658b6d8825ae75f943b472ab6d04e134 Mon Sep 17 00:00:00 2001 From: Daniel Nelson Date: Wed, 4 May 2016 15:20:04 -0700 Subject: [PATCH] Fix leak of connection slot during connection error --- aiohttp/connector.py | 63 +++++++++++++++++++++++------------------ tests/test_connector.py | 21 ++++++++++++-- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index c261037f7db..d578c9827c7 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -297,27 +297,33 @@ def connect(self, req): # This connection will now count towards the limit. waiters.append(fut) - yield from fut - - transport, proto = self._get(key) - if transport is None: - try: - if self._conn_timeout: - transport, proto = yield from asyncio.wait_for( - self._create_connection(req), - self._conn_timeout, loop=self._loop) - else: - transport, proto = yield from self._create_connection(req) - - except asyncio.TimeoutError as exc: - raise ClientTimeoutError( - 'Connection timeout to host {0[0]}:{0[1]} ssl:{0[2]}' - .format(key)) from exc - except OSError as exc: - raise ClientOSError( - exc.errno, - 'Cannot connect to host {0[0]}:{0[1]} ssl:{0[2]} [{1}]' - .format(key, exc.strerror)) from exc + try: + if limit is not None: + yield from fut + + transport, proto = self._get(key) + if transport is None: + try: + if self._conn_timeout: + transport, proto = yield from asyncio.wait_for( + self._create_connection(req), + self._conn_timeout, loop=self._loop) + else: + transport, proto = \ + yield from self._create_connection(req) + + except asyncio.TimeoutError as exc: + raise ClientTimeoutError( + 'Connection timeout to host {0[0]}:{0[1]} ssl:{0[2]}' + .format(key)) from exc + except OSError as exc: + raise ClientOSError( + exc.errno, + 'Cannot connect to host {0[0]}:{0[1]} ssl:{0[2]} [{1}]' + .format(key, exc.strerror)) from exc + except: + self._release_waiter(key) + raise self._acquired[key].add(transport) conn = Connection(self, key, req, transport, proto, self._loop) @@ -344,6 +350,14 @@ def _get(self, key): del self._conns[key] return None, None + def _release_waiter(self, key): + waiters = self._waiters[key] + while waiters: + waiter = waiters.pop(0) + if not waiter.done(): + waiter.set_result(None) + break + def _release(self, key, req, transport, protocol, *, should_close=False): if self._closed: # acquired connection is already released on connector closing @@ -358,12 +372,7 @@ def _release(self, key, req, transport, protocol, *, should_close=False): pass else: if self._limit is not None and len(acquired) < self._limit: - waiters = self._waiters[key] - while waiters: - waiter = waiters.pop(0) - if not waiter.done(): - waiter.set_result(None) - break + self._release_waiter(key) resp = req.response diff --git a/tests/test_connector.py b/tests/test_connector.py index b56f95b39c7..564f42ebe1c 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -560,11 +560,28 @@ class Req: # limit exhausted yield from asyncio.wait_for(conn.connect(Req), 0.01, loop=self.loop) - connection.close() - self.loop.run_until_complete(go()) + def test_connect_with_limit_release_waiters(self): + + def check_with_exc(err): + conn = aiohttp.BaseConnector(limit=1, loop=self.loop) + conn._create_connection = unittest.mock.Mock() + conn._create_connection.return_value = \ + asyncio.Future(loop=self.loop) + conn._create_connection.return_value.set_exception(err) + + with self.assertRaises(Exception): + req = unittest.mock.Mock() + self.loop.run_until_complete(conn.connect(req)) + key = (req.host, req.port, req.ssl) + self.assertFalse(conn._waiters[key]) + + check_with_exc(OSError(1, 'permission error')) + check_with_exc(RuntimeError()) + check_with_exc(asyncio.TimeoutError()) + def test_connect_with_limit_concurrent(self): @asyncio.coroutine