NetSpec: type-check Function inputs (they must be Top instances) #3153

Merged
merged 1 commit into from Apr 14, 2017
Jump to file or symbol
Failed to load files and symbols.
+12 −0
Split
View
@@ -103,6 +103,10 @@ class Function(object):
def __init__(self, type_name, inputs, params):
self.type_name = type_name
+ for index, input in enumerate(inputs):
+ if not isinstance(input, Top):
+ raise TypeError('%s input %d is not a Top (type is %s)' %
+ (type_name, index, type(input)))
self.inputs = inputs
self.params = params
self.ntop = self.params.get('ntop', 1)
@@ -79,3 +79,11 @@ def test_zero_tops(self):
net_proto = silent_net()
net = self.load_net(net_proto)
self.assertEqual(len(net.forward()), 0)
+
+ def test_type_error(self):
+ """Test that a TypeError is raised when a Function input isn't a Top."""
+ data = L.DummyData(ntop=2) # data is a 2-tuple of Tops
+ r = r"^Silence input 0 is not a Top \(type is <(type|class) 'tuple'>\)$"
+ with self.assertRaisesRegexp(TypeError, r):
+ L.Silence(data, ntop=0) # should raise: data is a tuple, not a Top
+ L.Silence(*data, ntop=0) # shouldn't raise: each elt of data is a Top