Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the hidden exception when STARTTLS is called #101

Merged
merged 16 commits into from
May 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions aiosmtpd/controller.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import os
import socket
import asyncio
import threading

from aiosmtpd.smtp import SMTP
from public import public

try:
from socket import socketpair
except ImportError: # pragma: nocover
from asyncio.windows_utils import socketpair


@public
class Controller:
Expand All @@ -26,17 +20,6 @@ def __init__(self, handler, loop=None, hostname=None, port=8025, *,
self._thread_exception = None
self.ready_timeout = os.getenv(
'AIOSMTPD_CONTROLLER_TIMEOUT', ready_timeout)
# For exiting the loop.
self._rsock, self._wsock = socketpair()
self.loop.add_reader(self._rsock, self._reader)

def _reader(self):
self.loop.remove_reader(self._rsock)
self.loop.stop()
for task in asyncio.Task.all_tasks(self.loop):
task.cancel()
self._rsock.close()
self._wsock.close()

def factory(self):
"""Allow subclasses to customize the handler/server creation."""
Expand All @@ -48,7 +31,7 @@ def _run(self, ready_event):
self.server = self.loop.run_until_complete(
self.loop.create_server(
self.factory, host=self.hostname, port=self.port))
except socket.error as error:
except Exception as error:
self._thread_exception = error
return
self.loop.call_soon(ready_event.set)
Expand All @@ -69,7 +52,13 @@ def start(self):
if self._thread_exception is not None:
raise self._thread_exception

def _stop(self):
self.loop.stop()
for task in asyncio.Task.all_tasks(self.loop):
task.cancel()

def stop(self):
assert self._thread is not None, 'SMTP daemon not running'
self._wsock.send(b'x')
self.loop.call_soon_threadsafe(self._stop)
self._thread.join()
self._thread = None
7 changes: 7 additions & 0 deletions aiosmtpd/docs/NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
NEWS for aiosmtpd
===================

1.1 (201X-XX-XX)
================
* When aiosmtpd handles a ``STARTTLS`` it must arrange for the original
transport to be closed when the wrapped transport is closed. This fixes a
hidden exception which occurs when an EOF is received on the original
tranport after the connection is lost. (Closes #83)

1.0 (2017-05-15)
================
* Release.
Expand Down
42 changes: 25 additions & 17 deletions aiosmtpd/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_has_ssl = sslproto and hasattr(ssl, 'MemoryBIO')


__version__ = '1.0'
__version__ = '1.0+'
__ident__ = 'Python SMTP {}'.format(__version__)
log = logging.getLogger('mail.log')

Expand Down Expand Up @@ -97,6 +97,7 @@ def __init__(self, handler,
self.require_starttls = tls_context and require_starttls
self._tls_handshake_okay = True
self._tls_protocol = None
self._original_transport = None
self.session = None
self.envelope = None
self.transport = None
Expand Down Expand Up @@ -128,15 +129,14 @@ def connection_made(self, transport):
self._set_rset_state()
self.session = self._create_session()
self.session.peer = transport.get_extra_info('peername')
is_instance = (_has_ssl and
isinstance(transport, sslproto._SSLProtocolTransport))
if self.transport is not None and is_instance: # pragma: nopy34
seen_starttls = (_has_ssl and self._original_transport is not None)
if self.transport is not None and seen_starttls: # pragma: nopy34
# It is STARTTLS connection over normal connection.
self._reader._transport = transport
self._writer._transport = transport
self.transport = transport
# Do SSL certificate checking as rfc3207 part 4.1 says.
# Why _extra is protected attribute?
# Why is _extra a protected attribute?
self.session.ssl = self._tls_protocol._extra
handler = getattr(self.event_handler, 'handle_STARTTLS', None)
if handler is None:
Expand All @@ -154,21 +154,28 @@ def connection_made(self, transport):

def connection_lost(self, error):
log.info('%r connection lost', self.session.peer)
# If STARTTLS was issued, then our transport is the SSL protocol
# transport, and we need to close the original transport explicitly,
# otherwise an unexpected eof_received() will be called *after* the
# connection_lost(). At that point the stream reader will already be
# destroyed and we'll get a traceback in super().eof_received() below.
if self._original_transport is not None: # pragma: nopy34
self._original_transport.close()
super().connection_lost(error)
self._writer.close()
self.transport = None

def eof_received(self):
log.info('%r EOF received', self.session.peer)
self._handler_coroutine.cancel()
return super().eof_received()

def _client_connected_cb(self, reader, writer):
# This is redundant since we subclass StreamReaderProtocol, but I like
# the shorter names.
self._reader = reader
self._writer = writer

def eof_received(self):
log.info('%r EOF received', self.session.peer)
self._handler_coroutine.cancel()
return super().eof_received()

def _set_post_data_state(self):
"""Reset state variables to their post-DATA state."""
self.envelope = self._create_envelope()
Expand Down Expand Up @@ -375,14 +382,15 @@ def smtp_STARTTLS(self, arg): # pragma: nopy34
self.tls_context,
None,
server_side=True)
# Reconfigure transport layer.
socket_transport = self.transport
socket_transport._protocol = self._tls_protocol
# Reconfigure protocol layer. Cant understand why app transport is
# protected property, if it MUST be used externally.
# Reconfigure transport layer. Keep a reference to the original
# transport so that we can close it explicitly when the connection is
# lost. XXX BaseTransport.set_protocol() was added in Python 3.5.3 :(
self._original_transport = self.transport
self._original_transport._protocol = self._tls_protocol
# Reconfigure the protocol layer. Why is the app transport a protected
# property, if it MUST be used externally?
self.transport = self._tls_protocol._app_transport
# Start handshake.
self._tls_protocol.connection_made(socket_transport)
self._tls_protocol.connection_made(self._original_transport)

def _strip_command_keyword(self, keyword, arg):
keylen = len(keyword)
Expand Down