Skip to content
Browse files

Merge pull request #5 from allenap/tsize-support

tsize support
  • Loading branch information...
2 parents d2fd0fd + 26c0c79 commit df484eb8c52024f4ddc54d10c9b768decad8c174 @shylent shylent committed Sep 24, 2012
Showing with 297 additions and 58 deletions.
  1. +2 −3 README.markdown
  2. +22 −6 tftp/backend.py
  3. +39 −1 tftp/bootstrap.py
  4. +28 −22 tftp/protocol.py
  5. +4 −0 tftp/session.py
  6. +43 −11 tftp/test/test_backend.py
  7. +65 −7 tftp/test/test_bootstrap.py
  8. +77 −7 tftp/test/test_protocol.py
  9. +2 −0 tftp/test/test_sessions.py
  10. +15 −1 tftp/util.py
View
5 README.markdown
@@ -12,9 +12,8 @@ A Twisted-based TFTP implementation
- netascii transfer mode.
- [RFC2347](http://tools.ietf.org/html/rfc2347) (TFTP Option
Extension) support. *blksize*
-([RFC2348](http://tools.ietf.org/html/rfc2348)) and *timeout* (partial
-support for [RFC2349](http://tools.ietf.org/html/rfc2349)) options are
-supported.
+([RFC2348](http://tools.ietf.org/html/rfc2348)), *timeout* and *tsize*
+([RFC2349](http://tools.ietf.org/html/rfc2349)) options are supported.
- An actual TFTP server.
- Plugin for twistd.
- Tests
View
28 tftp/backend.py
@@ -1,7 +1,9 @@
'''
@author: shylent
'''
+from os import fstat
from tftp.errors import Unsupported, FileExists, AccessViolation, FileNotFound
+from tftp.util import deferred
from twisted.python.filepath import FilePath, InsecurePath
import shutil
import tempfile
@@ -32,7 +34,7 @@ def get_reader(file_name):
@raise BackendError: for any other errors, that were encountered while
attempting to construct a reader
- @return: an object, that provides L{IReader}
+ @return: a L{Deferred} that will fire with an L{IReader}
"""
@@ -55,13 +57,16 @@ def get_writer(file_name):
@raise BackendError: for any other errors, that were encountered while
attempting to construct a writer
- @return: an object, that provides L{IWriter}
+ @return: a L{Deferred} that will fire with an L{IWriter}
"""
class IReader(interface.Interface):
"""An object, that performs reads on request of the TFTP protocol"""
+ size = interface.Attribute(
+ "The size of the file to be read, or C{None} if it's not known.")
+
def read(size):
"""Attempt to read C{size} number of bytes.
@@ -130,6 +135,17 @@ def __init__(self, file_path):
raise FileNotFound(self.file_path)
self.state = 'active'
+ @property
+ def size(self):
+ """
+ @see: L{IReader.size}
+
+ """
+ if self.file_obj.closed:
+ return None
+ else:
+ return fstat(self.file_obj.fileno()).st_size
+
def read(self, size):
"""
@see: L{IReader.read}
@@ -240,12 +256,12 @@ def __init__(self, base_path, can_read=True, can_write=True):
self.base = FilePath(base_path)
self.can_read, self.can_write = can_read, can_write
+ @deferred
def get_reader(self, file_name):
"""
@see: L{IBackend.get_reader}
- @return: an object, providing L{IReader}
- @rtype: L{FilesystemReader}
+ @rtype: L{Deferred}, yielding a L{FilesystemReader}
"""
if not self.can_read:
@@ -256,12 +272,12 @@ def get_reader(self, file_name):
raise AccessViolation("Insecure path: %s" % e)
return FilesystemReader(target_path)
+ @deferred
def get_writer(self, file_name):
"""
@see: L{IBackend.get_writer}
- @return: an object, providing L{IWriter}
- @rtype: L{FilesystemWriter}
+ @rtype: L{Deferred}, yielding a L{FilesystemWriter}
"""
if not self.can_write:
View
40 tftp/bootstrap.py
@@ -46,7 +46,7 @@ class TFTPBootstrap(DatagramProtocol):
@type backend: L{IReader} or L{IWriter} provider
"""
- supported_options = ('blksize', 'timeout')
+ supported_options = ('blksize', 'timeout', 'tsize')
def __init__(self, remote, backend, options=None, _clock=None):
if options is None:
@@ -124,6 +124,25 @@ def option_timeout(self, val):
return None
return str(int_timeout)
+ def option_tsize(self, val):
+ """Process tsize interval option
+ (U{RFC2349<http://tools.ietf.org/html/rfc2349>}). Valid range is 0 and up.
+
+ @param val: value of the option
+ @type val: C{str}
+
+ @return: accepted option value or C{None}, if it is invalid
+ @rtype: C{str} or C{None}
+
+ """
+ try:
+ int_tsize = int(val)
+ except ValueError:
+ return None
+ if int_tsize < 0:
+ return None
+ return str(int_tsize)
+
def applyOptions(self, session, options):
"""Apply given options mapping to the given L{WriteSession} or
L{ReadSession} object.
@@ -141,6 +160,9 @@ def applyOptions(self, session, options):
elif opt_name == 'timeout':
timeout = int(opt_val)
session.timeout = (timeout,) * 3
+ elif opt_name == 'tsize':
+ tsize = int(opt_val)
+ session.tsize = tsize
def datagramReceived(self, datagram, addr):
if self.remote[1] != addr[1]:
@@ -328,6 +350,22 @@ def __init__(self, remote, reader, options=None, _clock=None):
TFTPBootstrap.__init__(self, remote, reader, options, _clock)
self.session = ReadSession(reader, self._clock)
+ def option_tsize(self, val):
+ """Process tsize option.
+
+ If tsize is zero, get the size of the file to be read so that it can
+ be returned in the OACK datagram.
+
+ @see: L{TFTPBootstrap.option_tsize}
+
+ """
+ val = TFTPBootstrap.option_tsize(self, val)
+ if val == str(0):
+ val = self.session.reader.size
+ if val is not None:
+ val = str(val)
+ return val
+
def startProtocol(self):
"""Start sending an OACK datagram if we were initialized with options
or start the L{ReadSession} immediately.
View
50 tftp/protocol.py
@@ -9,6 +9,7 @@
FileNotFound)
from tftp.netascii import NetasciiReceiverProxy, NetasciiSenderProxy
from twisted.internet import reactor
+from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.protocol import DatagramProtocol
from twisted.python import log
@@ -42,34 +43,39 @@ def datagramReceived(self, datagram, addr):
return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
"Unknown transfer mode %s, - expected "
"'netascii' or 'octet' (case-insensitive)" % mode).to_wire(), addr)
+
+ self._clock.callLater(0, self._startSession, datagram, addr, mode)
+
+ @inlineCallbacks
+ def _startSession(self, datagram, addr, mode):
try:
if datagram.opcode == OP_WRQ:
- fs_interface = self.backend.get_writer(datagram.filename)
+ fs_interface = yield self.backend.get_writer(datagram.filename)
elif datagram.opcode == OP_RRQ:
- fs_interface = self.backend.get_reader(datagram.filename)
+ fs_interface = yield self.backend.get_reader(datagram.filename)
except Unsupported, e:
- return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
+ self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
str(e)).to_wire(), addr)
except AccessViolation:
- return self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
except FileExists:
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
except FileNotFound:
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
except BackendError, e:
- return self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr)
-
- if datagram.opcode == OP_WRQ:
- if mode == 'netascii':
- fs_interface = NetasciiReceiverProxy(fs_interface)
- session = RemoteOriginWriteSession(addr, fs_interface,
- datagram.options, _clock=self._clock)
- reactor.listenUDP(0, session)
- return session
- elif datagram.opcode == OP_RRQ:
- if mode == 'netascii':
- fs_interface = NetasciiSenderProxy(fs_interface)
- session = RemoteOriginReadSession(addr, fs_interface,
- datagram.options, _clock=self._clock)
- reactor.listenUDP(0, session)
- return session
+ self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr)
+ else:
+ if datagram.opcode == OP_WRQ:
+ if mode == 'netascii':
+ fs_interface = NetasciiReceiverProxy(fs_interface)
+ session = RemoteOriginWriteSession(addr, fs_interface,
+ datagram.options, _clock=self._clock)
+ reactor.listenUDP(0, session)
+ returnValue(session)
+ elif datagram.opcode == OP_RRQ:
+ if mode == 'netascii':
+ fs_interface = NetasciiSenderProxy(fs_interface)
+ session = RemoteOriginReadSession(addr, fs_interface,
+ datagram.options, _clock=self._clock)
+ reactor.listenUDP(0, session)
+ returnValue(session)
View
4 tftp/session.py
@@ -34,6 +34,7 @@ class WriteSession(DatagramProtocol):
block_size = 512
timeout = (1, 3, 7)
+ tsize = None
def __init__(self, writer, _clock=None):
self.writer = writer
@@ -124,6 +125,9 @@ def blockWriteSuccess(self, ign, datagram):
if len(datagram.data) < self.block_size:
self.completed = True
self.writer.finish()
+ # TODO: If self.tsize is not None, compare it with the actual
+ # count of bytes written. Log if there's a mismatch. Should it
+ # also emit an error datagram?
def blockWriteFailure(self, failure):
"""Write failed"""
View
54 tftp/test/test_backend.py
@@ -2,9 +2,10 @@
@author: shylent
'''
from tftp.backend import (FilesystemSynchronousBackend, FilesystemReader,
- FilesystemWriter)
+ FilesystemWriter, IReader, IWriter)
from tftp.errors import Unsupported, AccessViolation, FileNotFound, FileExists
from twisted.python.filepath import FilePath
+from twisted.internet.defer import inlineCallbacks
from twisted.trial import unittest
import os.path
import shutil
@@ -24,21 +25,35 @@ def setUp(self):
with open(self.existing_file_name, 'w') as f:
f.write(self.test_data)
- def test_unsupported(self):
+ @inlineCallbacks
+ def test_read_supported_by_default(self):
+ b = FilesystemSynchronousBackend(self.temp_dir)
+ reader = yield b.get_reader('foo')
+ self.assertTrue(IReader.providedBy(reader))
+
+ @inlineCallbacks
+ def test_write_supported_by_default(self):
+ b = FilesystemSynchronousBackend(self.temp_dir)
+ writer = yield b.get_writer('bar')
+ self.assertTrue(IWriter.providedBy(writer))
+
+ def test_read_unsupported(self):
b = FilesystemSynchronousBackend(self.temp_dir, can_read=False)
- self.assertRaises(Unsupported, b.get_reader, 'foo')
- self.assert_(b.get_writer('bar'),
- "A writer should be dispatched")
+ return self.assertFailure(b.get_reader('foo'), Unsupported)
+
+ def test_write_unsupported(self):
b = FilesystemSynchronousBackend(self.temp_dir, can_write=False)
- self.assertRaises(Unsupported, b.get_writer, 'bar')
- self.assert_(b.get_reader('foo'),
- "A reader should be dispatched")
+ return self.assertFailure(b.get_writer('bar'), Unsupported)
- def test_insecure(self):
+ def test_insecure_reader(self):
b = FilesystemSynchronousBackend(self.temp_dir)
- self.assertRaises(AccessViolation, b.get_reader, '../foo')
+ return self.assertFailure(
+ b.get_reader('../foo'), AccessViolation)
+
+ def test_insecure_writer(self):
b = FilesystemSynchronousBackend(self.temp_dir)
- self.assertRaises(AccessViolation, b.get_writer, '../foo')
+ return self.assertFailure(
+ b.get_writer('../foo'), AccessViolation)
def tearDown(self):
shutil.rmtree(self.temp_dir)
@@ -73,11 +88,28 @@ def test_read_existing_file(self):
"The file has been exhausted and should be in the closed state")
self.assertEqual(ostring, self.test_data)
+ def test_size(self):
+ r = FilesystemReader(self.temp_dir.child('foo'))
+ self.assertEqual(len(self.test_data), r.size)
+
+ def test_size_when_reader_finished(self):
+ r = FilesystemReader(self.temp_dir.child('foo'))
+ r.finish()
+ self.assertIsNone(r.size)
+
+ def test_size_when_file_removed(self):
+ # FilesystemReader.size uses fstat() to discover the file's size, so
+ # the absence of the file does not matter.
+ r = FilesystemReader(self.temp_dir.child('foo'))
+ self.existing_file_name.remove()
+ self.assertEqual(len(self.test_data), r.size)
+
def test_cancel(self):
r = FilesystemReader(self.temp_dir.child('foo'))
r.read(3)
r.finish()
self.failUnless(r.file_obj.closed,
+
"The session has been finished, so the file object should be in the closed state")
r.finish()
View
72 tftp/test/test_bootstrap.py
@@ -47,6 +47,7 @@ def active(self):
class MockSession(object):
block_size = 512
timeout = (1, 3, 5)
+ tsize = None
# Testing implementation here, but if I don't, I'll have a TON of duplicate code
class TestOptionProcessing(unittest.TestCase):
@@ -123,6 +124,27 @@ def test_timeout(self):
self.assertEqual(self.s.timeout, (1, 3, 5))
self.assertEqual(opts, OrderedDict())
+ def test_tsize(self):
+ self.s = MockSession()
+ opts = self.proto.processOptions(OrderedDict({'tsize':'1'}))
+ self.proto.applyOptions(self.s, opts)
+ self.assertEqual(self.s.tsize, 1)
+ self.assertEqual(opts, OrderedDict({'tsize':'1'}))
+
+ def test_tsize_ignored_when_not_a_number(self):
+ self.s = MockSession()
+ opts = self.proto.processOptions(OrderedDict({'tsize':'foo'}))
+ self.proto.applyOptions(self.s, opts)
+ self.assertIsNone(self.s.tsize)
+ self.assertEqual(opts, OrderedDict({}))
+
+ def test_tsize_ignored_when_less_than_zero(self):
+ self.s = MockSession()
+ opts = self.proto.processOptions(OrderedDict({'tsize':'-1'}))
+ self.proto.applyOptions(self.s, opts)
+ self.assertIsNone(self.s.tsize)
+ self.assertEqual(opts, OrderedDict({}))
+
def test_multiple_options(self):
got_options = OrderedDict()
got_options['timeout'] = '123'
@@ -321,14 +343,18 @@ def setUp(self):
self.target = FilePath(self.tmp_dir_path).child('foo')
self.writer = DelayedWriter(self.target, _clock=self.clock, delay=2)
self.transport = FakeTransport(hostAddress=('127.0.0.1', self.port))
- self.ws = RemoteOriginWriteSession(('127.0.0.1', 65465), self.writer,
- options={'blksize':'9'}, _clock=self.clock)
+ self.options = OrderedDict()
+ self.options['blksize'] = '9'
+ self.options['tsize'] = '45'
+ self.ws = RemoteOriginWriteSession(
+ ('127.0.0.1', 65465), self.writer, options=self.options,
+ _clock=self.clock)
self.ws.transport = self.transport
def test_option_normal(self):
self.ws.startProtocol()
self.clock.advance(0.1)
- oack_datagram = OACKDatagram({'blksize':'9'}).to_wire()
+ oack_datagram = OACKDatagram(self.options).to_wire()
self.assertEqual(self.transport.value(), oack_datagram)
self.clock.advance(3)
self.assertEqual(self.transport.value(), oack_datagram * 2)
@@ -351,7 +377,7 @@ def test_option_normal(self):
def test_option_timeout(self):
self.ws.startProtocol()
self.clock.advance(0.1)
- oack_datagram = OACKDatagram({'blksize':'9'}).to_wire()
+ oack_datagram = OACKDatagram(self.options).to_wire()
self.assertEqual(self.transport.value(), oack_datagram)
self.failIf(self.transport.disconnecting)
@@ -367,6 +393,22 @@ def test_option_timeout(self):
self.assertEqual(self.transport.value(), oack_datagram * 3)
self.failUnless(self.transport.disconnecting)
+ def test_option_tsize(self):
+ # A tsize option sent as part of a write session is recorded.
+ self.ws.startProtocol()
+ self.clock.advance(0.1)
+ oack_datagram = OACKDatagram(self.options).to_wire()
+ self.assertEqual(self.transport.value(), oack_datagram)
+ self.failIf(self.transport.disconnecting)
+ self.assertIsInstance(self.ws.session, WriteSession)
+ # Options are not applied to the WriteSession until the first DATA
+ # datagram is received,
+ self.assertIsNone(self.ws.session.tsize)
+ self.ws.datagramReceived(
+ DATADatagram(1, 'foobarbaz').to_wire(), ('127.0.0.1', 65465))
+ # The tsize option has been applied to the WriteSession.
+ self.assertEqual(45, self.ws.session.tsize)
+
def tearDown(self):
shutil.rmtree(self.tmp_dir_path)
@@ -529,14 +571,17 @@ def setUp(self):
temp_fd.write(self.test_data)
self.reader = DelayedReader(self.target, _clock=self.clock, delay=2)
self.transport = FakeTransport(hostAddress=('127.0.0.1', self.port))
+ self.options = OrderedDict()
+ self.options['blksize'] = '9'
+ self.options['tsize'] = '34'
self.rs = RemoteOriginReadSession(('127.0.0.1', 65465), self.reader,
- options={'blksize':'9'}, _clock=self.clock)
+ options=self.options, _clock=self.clock)
self.rs.transport = self.transport
def test_option_normal(self):
self.rs.startProtocol()
self.clock.advance(0.1)
- oack_datagram = OACKDatagram({'blksize':'9'}).to_wire()
+ oack_datagram = OACKDatagram(self.options).to_wire()
self.assertEqual(self.transport.value(), oack_datagram)
self.clock.advance(3)
self.assertEqual(self.transport.value(), oack_datagram * 2)
@@ -551,7 +596,7 @@ def test_option_normal(self):
def test_option_timeout(self):
self.rs.startProtocol()
self.clock.advance(0.1)
- oack_datagram = OACKDatagram({'blksize':'9'}).to_wire()
+ oack_datagram = OACKDatagram(self.options).to_wire()
self.assertEqual(self.transport.value(), oack_datagram)
self.failIf(self.transport.disconnecting)
@@ -567,5 +612,18 @@ def test_option_timeout(self):
self.assertEqual(self.transport.value(), oack_datagram * 3)
self.failUnless(self.transport.disconnecting)
+ def test_option_tsize(self):
+ # A tsize option of 0 sent as part of a read session prompts a tsize
+ # response with the actual size of the file.
+ self.options['tsize'] = '0'
+ self.rs.startProtocol()
+ self.clock.advance(0.1)
+ self.transport.clear()
+ self.clock.advance(3)
+ # The response contains the size of the test data.
+ self.options['tsize'] = str(len(self.test_data))
+ oack_datagram = OACKDatagram(self.options).to_wire()
+ self.assertEqual(self.transport.value(), oack_datagram)
+
def tearDown(self):
shutil.rmtree(self.tmp_dir_path)
View
84 tftp/test/test_protocol.py
@@ -1,7 +1,7 @@
'''
@author: shylent
'''
-from tftp.backend import FilesystemSynchronousBackend
+from tftp.backend import FilesystemSynchronousBackend, IReader, IWriter
from tftp.bootstrap import RemoteOriginWriteSession, RemoteOriginReadSession
from tftp.datagram import (WRQDatagram, TFTPDatagramFactory, split_opcode,
ERR_ILLEGAL_OP, RRQDatagram, ERR_ACCESS_VIOLATION, ERR_FILE_EXISTS,
@@ -72,12 +72,14 @@ def test_unsupported(self):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP)
@@ -86,12 +88,14 @@ def test_access_violation(self):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION)
@@ -100,6 +104,7 @@ def test_file_exists(self):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_FILE_EXISTS)
@@ -108,6 +113,7 @@ def test_file_not_found(self):
tftp.transport = self.transport
rrq_datagram = RRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_FILE_NOT_FOUND)
@@ -116,12 +122,14 @@ def test_generic_backend_error(self):
tftp.transport = self.transport
rrq_datagram = RRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED)
@@ -135,8 +143,15 @@ def startProtocol(self):
class TFTPWrapper(TFTP):
- def datagramReceived(self, *args, **kwargs):
- self.session = TFTP.datagramReceived(self, *args, **kwargs)
+ def _startSession(self, *args, **kwargs):
+ d = TFTP._startSession(self, *args, **kwargs)
+
+ def save_session(session):
+ self.session = session
+ return session
+
+ d.addCallback(save_session)
+ return d
class SuccessfulDispatch(unittest.TestCase):
@@ -156,8 +171,8 @@ def test_WRQ(self):
self.client.transport.write(WRQDatagram('foobar', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069))
d = Deferred()
def cb(ign):
- self.failUnless(isinstance(self.tftp.session, RemoteOriginWriteSession))
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiReceiverProxy))
+ self.assertIsInstance(self.tftp.session, RemoteOriginWriteSession)
+ self.assertIsInstance(self.tftp.session.backend, NetasciiReceiverProxy)
self.tftp.session.cancel()
d.addCallback(cb)
reactor.callLater(0.5, d.callback, None)
@@ -167,8 +182,8 @@ def test_RRQ(self):
self.client.transport.write(RRQDatagram('nonempty', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069))
d = Deferred()
def cb(ign):
- self.failUnless(isinstance(self.tftp.session, RemoteOriginReadSession))
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiSenderProxy))
+ self.assertIsInstance(self.tftp.session, RemoteOriginReadSession)
+ self.assertIsInstance(self.tftp.session.backend, NetasciiSenderProxy)
self.tftp.session.cancel()
d.addCallback(cb)
reactor.callLater(0.5, d.callback, None)
@@ -177,3 +192,58 @@ def cb(ign):
def tearDown(self):
self.tftp.transport.stopListening()
self.client.transport.stopListening()
+
+
+class FilesystemAsyncBackend(FilesystemSynchronousBackend):
+
+ def __init__(self, base_path, clock):
+ super(FilesystemAsyncBackend, self).__init__(
+ base_path, can_read=True, can_write=True)
+ self.clock = clock
+
+ def get_reader(self, file_name):
+ d_get = super(FilesystemAsyncBackend, self).get_reader(file_name)
+ d = Deferred()
+ # d_get has already fired, so don't chain d_get to d until later,
+ # otherwise d will be fired too early.
+ self.clock.callLater(0, d_get.chainDeferred, d)
+ return d
+
+ def get_writer(self, file_name):
+ d_get = super(FilesystemAsyncBackend, self).get_writer(file_name)
+ d = Deferred()
+ # d_get has already fired, so don't chain d_get to d until later,
+ # otherwise d will be fired too early.
+ self.clock.callLater(0, d_get.chainDeferred, d)
+ return d
+
+
+class SuccessfulAsyncDispatch(unittest.TestCase):
+
+ def setUp(self):
+ self.clock = Clock()
+ self.tmp_dir_path = tempfile.mkdtemp()
+ with FilePath(self.tmp_dir_path).child('nonempty').open('w') as fd:
+ fd.write('Something uninteresting')
+ self.backend = FilesystemAsyncBackend(self.tmp_dir_path, self.clock)
+ self.tftp = TFTP(self.backend, self.clock)
+
+ def test_get_reader_defers(self):
+ rrq_datagram = RRQDatagram('nonempty', 'NetASCiI', {})
+ rrq_addr = ('127.0.0.1', 1069)
+ rrq_mode = "octet"
+ d = self.tftp._startSession(rrq_datagram, rrq_addr, rrq_mode)
+ self.assertFalse(d.called)
+ self.clock.advance(1)
+ self.assertTrue(d.called)
+ self.assertTrue(IReader.providedBy(d.result.backend))
+
+ def test_get_writer_defers(self):
+ wrq_datagram = WRQDatagram('foobar', 'NetASCiI', {})
+ wrq_addr = ('127.0.0.1', 1069)
+ wrq_mode = "octet"
+ d = self.tftp._startSession(wrq_datagram, wrq_addr, wrq_mode)
+ self.assertFalse(d.called)
+ self.clock.advance(1)
+ self.assertTrue(d.called)
+ self.assertTrue(IWriter.providedBy(d.result.backend))
View
2 tftp/test/test_sessions.py
@@ -54,6 +54,8 @@ def c(ign):
class FailingReader(object):
interface.implements(IReader)
+ size = None
+
def read(self, size):
raise IOError('A failure')
View
16 tftp/util.py
@@ -1,10 +1,12 @@
'''
@author: shylent
'''
+from functools import wraps
from twisted.internet import reactor
+from twisted.internet.defer import maybeDeferred
-__all__ = ['SequentialCall', 'Spent', 'Cancelled']
+__all__ = ['SequentialCall', 'Spent', 'Cancelled', 'deferred']
class Spent(Exception):
@@ -120,3 +122,15 @@ def cancel(self):
def active(self):
"""Whether or not this L{SequentialCall} object is considered active"""
return not (self._spent or self._cancelled)
+
+
+def deferred(func):
+ """Decorates a function to ensure that it always returns a `Deferred`.
+
+ This also serves a secondary documentation purpose; functions decorated
+ with this are readily identifiable as asynchronous.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ return maybeDeferred(func, *args, **kwargs)
+ return wrapper

0 comments on commit df484eb

Please sign in to comment.
Something went wrong with that request. Please try again.