Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure websocket transport is closed when client does not close it (#8200) #8257

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES/8200.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Ensure websocket transport is closed when client does not close it
-- by :user:`bdraco`.

The transport could remain open if the client did not close it. This
change ensures the transport is closed when the client does not close
it.
21 changes: 16 additions & 5 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ async def close(
return True

if self._closing:
self._close_transport()
return True

reader = self._reader
Expand All @@ -418,9 +419,18 @@ async def close(
self._exception = asyncio.TimeoutError()
return True

def _set_closing(self, code: WSCloseCode) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
self._close_transport()

def _close_transport(self) -> None:
"""Close the transport."""
if self._req is not None and self._req.transport is not None:
self._req.transport.close()

Expand Down Expand Up @@ -465,14 +475,12 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)

if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
self._set_closing(msg.data)
# Could be closed while awaiting reader.
if not self._closed and self._autoclose:
# The client is likely going to close the
Expand All @@ -481,7 +489,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type == WSMsgType.CLOSING:
self._closing = True
self._set_closing(WSCloseCode.OK)
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
Expand Down Expand Up @@ -525,5 +533,8 @@ async def __anext__(self) -> WSMessage:
return msg

def _cancel(self, exc: BaseException) -> None:
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
if self._reader is not None:
set_exception(self._reader, exc)
89 changes: 88 additions & 1 deletion tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,94 @@ async def test_receive_eofstream_in_reader(make_request, loop) -> None:
assert ws.closed


async def test_receive_timeouterror(make_request, loop) -> None:
async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)

ws._reader = mock.Mock()
exc = Exception()
res = loop.create_future()
res.set_exception(exc)
ws._reader.read = make_mocked_coro(res)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.ERROR
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_close_but_left_open(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
close_message = WSMessage(WSMsgType.CLOSE, 1000, "close")

ws._reader = mock.Mock()
ws._reader.read = mock.AsyncMock(return_value=close_message)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_closing(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing")

ws._reader = mock.Mock()
read_mock = mock.AsyncMock(return_value=closing_message)
ws._reader.read = read_mock
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed

ws._cancel(ConnectionResetError("Connection lost"))

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING


async def test_close_after_closing(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing")

ws._reader = mock.Mock()
ws._reader.read = mock.AsyncMock(return_value=closing_message)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed
assert len(ws._req.transport.close.mock_calls) == 0

await ws.close()
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
Expand Down
33 changes: 31 additions & 2 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# HTTP websocket server functional tests

import asyncio
from typing import Any, Optional

import pytest

Expand Down Expand Up @@ -258,7 +259,7 @@ async def handler(request):
assert "reply" == (await ws.receive_str())

# The server closes here. Then the client sends bogus messages with an
# internval shorter than server-side close timeout, to make the server
# interval shorter than server-side close timeout, to make the server
# hanging indefinitely.
await asyncio.sleep(0.08)
msg = await ws._reader.read()
Expand Down Expand Up @@ -310,8 +311,36 @@ async def handler(request):
assert msg.type == WSMsgType.CLOSED


async def test_auto_pong_with_closing_by_peer(loop, aiohttp_client) -> None:
async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None:
srv_ws: Optional[web.WebSocketResponse] = None

async def handler(request):
nonlocal srv_ws
ws = srv_ws = web.WebSocketResponse(protocols=("foo", "bar"))
await ws.prepare(request)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
await asyncio.sleep(0)
return ws

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

ws: web.WebSocketResponse = await client.ws_connect("/", protocols=("eggs", "bar"))

await ws._writer._send_frame(b"", WSMsgType.CLOSE)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE

await asyncio.sleep(0)
msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED


async def test_auto_pong_with_closing_by_peer(loop: Any, aiohttp_client: Any) -> None:
closed = loop.create_future()

async def handler(request):
Expand Down
Loading