diff --git a/gns3server/utils/asyncio/telnet_server.py b/gns3server/utils/asyncio/telnet_server.py index 608ce6846..96dd60eaf 100644 --- a/gns3server/utils/asyncio/telnet_server.py +++ b/gns3server/utils/asyncio/telnet_server.py @@ -53,20 +53,23 @@ NAWS = 31 # Negotiate About Window Size LINEMO = 34 # Line Mode +READ_SIZE = 1024 + class AsyncioTelnetServer: def __init__(self, reader=None, writer=None): self._reader = reader self._writer = writer - self._clients = {} + self._clients = set() + self._lock = asyncio.Lock() + self._reader_process = None + self._current_read = None @asyncio.coroutine def run(self, network_reader, network_writer): - READ_SIZE = 1024 - - # Keep track of - self._clients[network_reader] = network_writer + # Keep track of connected clients + self._clients.add(network_writer) try: # Send initial telnet session opening @@ -76,42 +79,66 @@ def run(self, network_reader, network_writer): IAC, DO, BINARY])) yield from network_writer.drain() - network_read = asyncio.async(network_reader.read(READ_SIZE)) - reader_read = asyncio.async(self._reader.read(READ_SIZE)) + yield from self._process(network_reader, network_writer) + except ConnectionResetError: + with (yield from self._lock): + if self._reader_process == network_reader: + self._reader_process = None + # Cancel current read from this reader + self._current_read.cancel() + self._clients.remove(network_writer) + + @asyncio.coroutine + def _get_reader(self, network_reader): + """ + Get a reader or None if another reader is already reading. + """ + with (yield from self._lock): + if self._reader_process is None: + self._reader_process = network_reader + if self._reader_process == network_reader: + self._current_read = asyncio.async(self._reader.read(READ_SIZE)) + return self._current_read + print(network_reader) + return None + + @asyncio.coroutine + def _process(self, network_reader, network_writer): + network_read = asyncio.async(network_reader.read(READ_SIZE)) + reader_read = yield from self._get_reader(network_reader) - while True: + while True: + if reader_read is None: + reader_read = yield from self._get_reader(network_reader) + if reader_read is None: + done, pending = yield from asyncio.wait( + [ + network_read, + ], + timeout=1, + return_when=asyncio.FIRST_COMPLETED) + else: done, pending = yield from asyncio.wait( [ network_read, reader_read ], return_when=asyncio.FIRST_COMPLETED) - for coro in done: - try: - data = coro.result() - # Raise if another process is reading the same - # datas - except RuntimeError: - continue - if coro == network_read: - network_read = asyncio.async(network_reader.read(READ_SIZE)) - if IAC in data: - data = yield from self._IAC_parser(data, network_reader, network_writer) - if self._writer: - self._writer.write(data) - yield from self._writer.drain() - elif coro == reader_read: - reader_read = asyncio.async(self._reader.read(READ_SIZE)) - network_writer.write(data) - yield from network_writer.drain() - # Replicate the output on other clients - for writer in self._clients.values(): - if writer != network_writer: - writer.write(data) - yield from writer.drain() - except ConnectionResetError: - del self._clients[network_reader] - return + for coro in done: + data = coro.result() + if coro == network_read: + network_read = asyncio.async(network_reader.read(READ_SIZE)) + if IAC in data: + data = yield from self._IAC_parser(data, network_reader, network_writer) + if self._writer: + self._writer.write(data) + yield from self._writer.drain() + elif coro == reader_read: + reader_read = yield from self._get_reader(network_reader) + # Replicate the output on all clients + for writer in self._clients: + writer.write(data) + yield from writer.drain() def _IAC_parser(self, buf, network_reader, network_writer): """