Skip to content

Commit

Permalink
Fix ref issue when protocol is in Cython
Browse files Browse the repository at this point in the history
Because `context.run()` doesn't hold reference to the callable, when
e.g. the protocol is written in Cython, the callbacks were not
guaranteed to hold the protocol reference. This PR fixes the issue by
explicitly add a reference before `context.run()` calls.

Refs edgedb/edgedb#2222
  • Loading branch information
fantix committed Feb 15, 2021
1 parent 33c1d6a commit f0b9e65
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 13 deletions.
37 changes: 37 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,43 @@ async def runner():
self.assertIsNone(
self.loop.run_until_complete(connection_lost_called))

def test_context_run_segfault(self):
is_new = False
done = self.loop.create_future()

def server(sock):
sock.sendall(b'hello')

class Protocol(asyncio.Protocol):
def __init__(self):
self.transport = None

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
try:
self = weakref.ref(self)
nonlocal is_new
if is_new:
done.set_result(data)
else:
is_new = True
new_proto = Protocol()
self().transport.set_protocol(new_proto)
new_proto.connection_made(self().transport)
new_proto.data_received(data)
except Exception as e:
done.set_exception(e)

async def test(addr):
await self.loop.create_connection(Protocol, *addr)
data = await done
self.assertEqual(data, b'hello')

with self.tcp_server(server) as srv:
self.loop.run_until_complete(test(srv.addr))


class Test_UV_TCP(_TestTCP, tb.UVTestCase):

Expand Down
8 changes: 6 additions & 2 deletions uvloop/handles/basetransport.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ cdef class UVBaseTransport(UVSocketHandle):
try:
# _maybe_pause_protocol() is always triggered from user-calls,
# so we must copy the context to avoid entering context twice
self.context.copy().run(self._protocol.pause_writing)
run_in_context(
self.context.copy(), self._protocol.pause_writing,
)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand All @@ -91,7 +93,9 @@ cdef class UVBaseTransport(UVSocketHandle):
# We're copying the context to avoid entering context twice,
# even though it's not always necessary to copy - it's easier
# to copy here than passing down a copied context.
self.context.copy().run(self._protocol.resume_writing)
run_in_context(
self.context.copy(), self._protocol.resume_writing,
)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand Down
16 changes: 12 additions & 4 deletions uvloop/handles/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport):
except AttributeError:
keep_open = False
else:
keep_open = self.context.run(meth)
keep_open = run_in_context(self.context, meth)

if keep_open:
# We're keeping the connection open so the
Expand Down Expand Up @@ -826,7 +826,11 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream,
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_total += 1

sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread])
run_in_context1(
sc.context,
sc._protocol_data_received,
loop._recv_buffer[:nread],
)
except BaseException as exc:
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_errors_total += 1
Expand Down Expand Up @@ -911,7 +915,11 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream,

sc._read_pybuf_acquired = 0
try:
buf = sc.context.run(sc._protocol_get_buffer, suggested_size)
buf = run_in_context1(
sc.context,
sc._protocol_get_buffer,
suggested_size,
)
PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE)
got_buf = 1
except BaseException as exc:
Expand Down Expand Up @@ -976,7 +984,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream,
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_total += 1

sc.context.run(sc._protocol_buffer_updated, nread)
run_in_context1(sc.context, sc._protocol_buffer_updated, nread)
except BaseException as exc:
if UVLOOP_DEBUG:
loop._debug_stream_read_cb_errors_total += 1
Expand Down
2 changes: 1 addition & 1 deletion uvloop/handles/streamserver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ cdef class UVStreamServer(UVSocketHandle):
cdef inline _on_listen(self):
cdef UVStream client

protocol = self.context.run(self.protocol_factory)
protocol = run_in_context(self.context, self.protocol_factory)

if self.ssl is None:
client = self._make_new_transport(protocol, None, self.context)
Expand Down
8 changes: 5 additions & 3 deletions uvloop/handles/udp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,18 @@ cdef class UDPTransport(UVBaseTransport):

cdef _on_receive(self, bytes data, object exc, object addr):
if exc is None:
self.context.run(self._protocol.datagram_received, data, addr)
run_in_context2(
self.context, self._protocol.datagram_received, data, addr,
)
else:
self.context.run(self._protocol.error_received, exc)
run_in_context1(self.context, self._protocol.error_received, exc)

cdef _on_sent(self, object exc, object context=None):
if exc is not None:
if isinstance(exc, OSError):
if context is None:
context = self.context
context.run(self._protocol.error_received, exc)
run_in_context1(context, self._protocol.error_received, exc)
else:
self._fatal_error(
exc, False, 'Fatal write error on datagram transport')
Expand Down
28 changes: 28 additions & 0 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ cdef inline socket_dec_io_ref(sock):
sock._decref_socketios()


cdef inline run_in_context(context, method):
# This method is internally used to workaround a reference issue that in
# certain circumstances, inlined context.run() will not hold a reference to
# the given method instance, which - if deallocated - will cause segault.
# See also: edgedb/edgedb#2222
Py_INCREF(method)
try:
return context.run(method)
finally:
Py_DECREF(method)


cdef inline run_in_context1(context, method, arg):
Py_INCREF(method)
try:
return context.run(method, arg)
finally:
Py_DECREF(method)


cdef inline run_in_context2(context, method, arg1, arg2):
Py_INCREF(method)
try:
return context.run(method, arg1, arg2)
finally:
Py_DECREF(method)


# Used for deprecation and removal of `loop.create_datagram_endpoint()`'s
# *reuse_address* parameter
_unset = object()
Expand Down
8 changes: 5 additions & 3 deletions uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ cdef class SSLProtocol:
# inside the upstream callbacks like buffer_updated()
keep_open = self._app_protocol.eof_received()
else:
keep_open = context.run(self._app_protocol.eof_received)
keep_open = run_in_context(
context, self._app_protocol.eof_received,
)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as ex:
Expand All @@ -817,7 +819,7 @@ cdef class SSLProtocol:
# inside the upstream callbacks like buffer_updated()
self._app_protocol.pause_writing()
else:
context.run(self._app_protocol.pause_writing)
run_in_context(context, self._app_protocol.pause_writing)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand All @@ -836,7 +838,7 @@ cdef class SSLProtocol:
# inside the upstream callbacks like resume_writing()
self._app_protocol.resume_writing()
else:
context.run(self._app_protocol.resume_writing)
run_in_context(context, self._app_protocol.resume_writing)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
Expand Down

0 comments on commit f0b9e65

Please sign in to comment.