Skip to content

Commit

Permalink
Return a Future from IOStream methods.
Browse files Browse the repository at this point in the history
This makes it easier to use IOStreams directly from coroutines.

Closes tornadoweb#953.
  • Loading branch information
bdarnell committed Jan 20, 2014
1 parent c51970d commit 66a647d
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
51 changes: 49 additions & 2 deletions tornado/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@

import collections
import errno
import functools
import numbers
import os
import socket
import ssl
import sys
import re

from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado import stack_context
from tornado.util import bytes_type
from tornado.util import bytes_type, ArgReplacer

try:
from tornado.platform.posix import _set_nonblocking
Expand All @@ -66,6 +68,37 @@ class StreamClosedError(IOError):
pass


def _iostream_return_future(f):
"""Similar to tornado.concurrent.return_future, but the Future will
also raise a StreamClosedError if the stream is closed before
it resolves.
Unlike return_future (and _auth_return_future), no Future will be
returned if a callback is given.
"""
replacer = ArgReplacer(f, 'callback')

@functools.wraps(f)
def wrapper(*args, **kwargs):
if replacer.get_old_value(args, kwargs) is not None:
# If a callaback is present, just call in to the decorated
# function. This is a slight optimization (by not creating a
# Future that is unlikely to be used), but mainly avoids the
# complexity of running the callback in the expected way.
return f(*args, **kwargs)
future = TracebackFuture()
callback, args, kwargs = replacer.replace(
lambda value=None: future.set_result(value),
args, kwargs)
f(*args, **kwargs)
stream = args[0]
stream._pending_futures.add(future)
future.add_done_callback(
lambda fut: stream._pending_futures.discard(fut))
return future
return wrapper


class BaseIOStream(object):
"""A utility class to write to and read from a non-blocking file or socket.
Expand Down Expand Up @@ -102,6 +135,7 @@ def __init__(self, io_loop=None, max_buffer_size=None,
self._state = None
self._pending_callbacks = 0
self._closed = False
self._pending_futures = set()

def fileno(self):
"""Returns the file descriptor for this stream."""
Expand Down Expand Up @@ -142,6 +176,7 @@ def get_fd_error(self):
"""
return None

@_iostream_return_future
def read_until_regex(self, regex, callback):
"""Run ``callback`` when we read the given regex pattern.
Expand All @@ -152,6 +187,7 @@ def read_until_regex(self, regex, callback):
self._read_regex = re.compile(regex)
self._try_inline_read()

@_iostream_return_future
def read_until(self, delimiter, callback):
"""Run ``callback`` when we read the given delimiter.
Expand All @@ -162,6 +198,7 @@ def read_until(self, delimiter, callback):
self._read_delimiter = delimiter
self._try_inline_read()

@_iostream_return_future
def read_bytes(self, num_bytes, callback, streaming_callback=None):
"""Run callback when we read the given number of bytes.
Expand All @@ -176,6 +213,7 @@ def read_bytes(self, num_bytes, callback, streaming_callback=None):
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()

@_iostream_return_future
def read_until_close(self, callback, streaming_callback=None):
"""Reads all data from the socket until it is closed.
Expand All @@ -202,6 +240,7 @@ def read_until_close(self, callback, streaming_callback=None):
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()

@_iostream_return_future
def write(self, data, callback=None):
"""Write the given data to this stream.
Expand Down Expand Up @@ -266,6 +305,10 @@ def _maybe_run_close_callback(self):
# If there are pending callbacks, don't run the close callback
# until they're done (see _maybe_add_error_handler)
if self.closed() and self._pending_callbacks == 0:
# Copy the _pending_futures set because each will remove itself
# from the set as it is closed.
for fut in list(self._pending_futures):
fut.set_exception(StreamClosedError())
if self._close_callback is not None:
cb = self._close_callback
self._close_callback = None
Expand Down Expand Up @@ -704,6 +747,7 @@ def read_from_fd(self):
def write_to_fd(self, data):
return self.socket.send(data)

@_iostream_return_future
def connect(self, address, callback=None, server_hostname=None):
"""Connects the socket to a remote address without blocking.
Expand Down Expand Up @@ -904,7 +948,10 @@ def connect(self, address, callback=None, server_hostname=None):
# has completed.
self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
super(SSLIOStream, self).connect(address, callback=None)
# Note: Since we don't pass our callback argument along to
# super.connect(), this will always return a Future.
# This is harmless, but a bit less efficient than it could be.
return super(SSLIOStream, self).connect(address, callback=None)

def _handle_connect(self):
# When the connection is complete, wrap the socket for SSL
Expand Down
64 changes: 62 additions & 2 deletions tornado/test/iostream_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import netutil
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import errno
Expand Down Expand Up @@ -106,6 +107,46 @@ def write_callback():

stream.close()

@gen_test
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
content_length = int(headers['Content-Length'])
body = yield stream.read_bytes(content_length)
self.assertEqual(body, b'Hello')
stream.close()

@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
stream.close()

@gen_test
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
self.assertEqual(body, b"Hello")

# Nothing else to read; the error comes immediately without waiting
# for yield.
with self.assertRaises(StreamClosedError):
stream.read_bytes(1)


class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
Expand Down Expand Up @@ -298,6 +339,25 @@ def callback2(data):
server.close()
client.close()

def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
def f(self):
server.write(b"12")
chunks = []
chunks.append((yield client.read_bytes(1)))
server.close()
chunks.append((yield client.read_bytes(1)))
self.assertEqual(chunks, [b"1", b"2"])
try:
f(self)
finally:
server.close()
client.close()

def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
Expand Down
16 changes: 12 additions & 4 deletions tornado/test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ def function(x, y, callback=None, z=None):
self.replacer = ArgReplacer(function, 'callback')

def test_omitted(self):
self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
args = (1, 2)
kwargs = dict()
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(self.replacer.replace('new', args, kwargs),
(None, (1, 2), dict(callback='new')))

def test_position(self):
self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
args = (1, 2, 'old', 3)
kwargs = dict()
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', [1, 2, 'new', 3], dict()))

def test_keyword(self):
self.assertEqual(self.replacer.replace('new', (1,),
dict(y=2, callback='old', z=3)),
args = (1,)
kwargs = dict(y=2, callback='old', z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))
10 changes: 10 additions & 0 deletions tornado/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ def __init__(self, func, name):
# Not a positional parameter
self.arg_pos = None

def get_old_value(self, args, kwargs, default=None):
"""Returns the old value of the named argument without replacing it.
Returns ``default`` if the argument is not present.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
return args[self.arg_pos]
else:
return kwargs.get(self.name, default)

def replace(self, new_value, args, kwargs):
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
Expand Down

0 comments on commit 66a647d

Please sign in to comment.