Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

A proof-of-concept for smart parameter checking in Python.

  • Loading branch information...
commit 83d44dca20180eea9f7b3a657f4e9aef66cbac33 0 parents
@airportyh authored
Showing with 202 additions and 0 deletions.
  1. +73 −0 smart_param_check.py
  2. +129 −0 tests.py
73 smart_param_check.py
@@ -0,0 +1,73 @@
+import inspect
+
+class Args(object):
+ def __init__(self, kwargs, vargs):
+ self.kwargs = kwargs
+ self.vargs = vargs
+
+ def keys(self):
+ return self.kwargs.keys()
+
+ def __eq__(one, other):
+ return one.kwargs == other.kwargs and \
+ one.vargs == other.vargs
+
+ def __ne__(one, other):
+ return not one.__eq__(other)
+
+ def __repr__(self):
+ return "Args(%s, %s)" % (self.kwargs, self.vargs)
+
+ @classmethod
+ def combine(clz, sig, args, kwds):
+ args_by_key = dict()
+ extra = []
+ for i in xrange(len(args)):
+ try:
+ argname = sig.args[i]
+ args_by_key[argname] = args[i]
+ except IndexError:
+ extra.append(args[i])
+ for key, val in kwds.items():
+ if key in args_by_key:
+ raise Exception('TODO')
+ args_by_key[key] = val
+ return Args(args_by_key, extra)
+
+class Mock(object):
+ def __init__(self, func):
+ self.func = func
+ self.sig = inspect.getargspec(self.func)
+ self.required_args = self.get_required_args()
+ self.named_args = self.get_named_args()
+ self.call_args = []
+
+ def __call__(self, *args, **kwds):
+ args = self.combine_args(args, kwds)
+ self.call_args.append(args)
+ if not set(self.required_args).issubset(set(args.keys())):
+ raise Exception('Not all required args are given')
+ if self.sig.keywords is None and \
+ not set(args.keys()).issubset(set(self.named_args)):
+ raise Exception('Unexpected args were given')
+ if len(args.vargs) > 0 and self.sig.varargs is None:
+ raise Exception('Unexpected positional args were given')
+
+ def get_required_args(self):
+ ret = list(self.sig.args)
+ if self.sig.defaults:
+ for df in self.sig.defaults:
+ ret.pop()
+ return ret
+
+ def get_named_args(self):
+ return list(self.sig.args)
+
+ def combine_args(self, args, kwds):
+ return Args.combine(self.sig, args, kwds)
+
+ def assert_called_with(self, *args, **kwds):
+ args = self.combine_args(args, kwds)
+ if self.call_args[0] != args:
+ raise Exception('Expected call with %s but called with %s' %
+ (args, self.call_args[0]))
129 tests.py
@@ -0,0 +1,129 @@
+import unittest
+from smart_param_check import Mock
+class TestSmartParamCheckBasic(unittest.TestCase):
+
+ def setUp(self):
+ def f(a, b):
+ pass
+ self.f = Mock(f)
+
+ def test_positional(self):
+ self.f(1, 2)
+ self.f.assert_called_with(1, 2)
+
+ def test_positional_should_fail(self):
+ self.f(2, 1)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1, 2))
+
+ def test_keyword_args(self):
+ self.f(a=1, b=2)
+ self.f.assert_called_with(a=1, b=2)
+
+ def test_keyword_args_should_fail(self):
+ self.f(b=2, a=1)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(a=2, b=1))
+
+ def test_positional_should_work_w_keyword(self):
+ self.f(1, 2)
+ self.f.assert_called_with(1, b=2)
+ self.f.assert_called_with(b=2, a=1)
+ self.f.assert_called_with(a=1, b=2)
+
+ def test_position_w_keyword_fail(self):
+ self.f(2, 1)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(a=1, b=2))
+
+ def test_keyword_should_work_w_positional(self):
+ self.f(a=1, b=2)
+ self.f.assert_called_with(1, 2)
+
+ def test_keyword_should_work_w_positional_fail(self):
+ self.f(b=1, a=2)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1, 2))
+
+ def test_should_check_required(self):
+ self.assertRaises(Exception,
+ lambda: self.f(1))
+ self.assertRaises(Exception,
+ lambda: self.f(b=2))
+ self.assertRaises(Exception,
+ lambda: self.f(a=1, c=3))
+
+ def test_should_disallow_unexpected(self):
+ self.assertRaises(Exception,
+ lambda: self.f(a=1, b=2, c=3))
+
+ def test_should_disallow_unexpected_positional(self):
+ self.assertRaises(Exception,
+ lambda: self.f(1, 2, 3))
+
+class TestSmartParamCheckOptional(unittest.TestCase):
+ def setUp(self):
+ def f(a, b=None):
+ pass
+ self.f = Mock(f)
+
+ def test_should_allow_omitting_optional(self):
+ self.f(1)
+
+ def test_should_allow_providing_optional(self):
+ self.f(1, 2)
+
+ def test_keywords_allow_omitting(self):
+ self.f(a=1)
+
+ def test_keywords_allow_providing(self):
+ self.f(1, b=2)
+ self.f(a=1, b=2)
+
+ def test_still_check_required(self):
+ self.assertRaises(Exception,
+ lambda: self.f(b=2))
+
+class TestSmartParamCheckVarargs(unittest.TestCase):
+ def setUp(self):
+ def f(a, b=None, *vargs):
+ pass
+ self.f = Mock(f)
+
+ def test_allow_vargs(self):
+ self.f(1,2,3,4)
+ self.f.assert_called_with(1,2,3,4)
+
+ def test_vargs_fail(self):
+ self.f(1,2,3,4)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1,2,3))
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1,2,3,4,5))
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1,2,3,5))
+
+class TestSmartParamCheckKwargs(unittest.TestCase):
+ def setUp(self):
+ def f(a, b=None, *vargs, **kwargs):
+ pass
+ self.f = Mock(f)
+
+ def test_allow_kwargs(self):
+ self.f(a=1, b=2, c=3)
+ self.f.assert_called_with(a=1, b=2, c=3)
+ self.f.assert_called_with(1, 2, c=3)
+
+ def test_kwargs_fail(self):
+ self.f(a=1, b=2, c=3)
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1, 2, d=3))
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1, 2, c=4))
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(1, 2, c=3, d=4))
+ self.assertRaises(Exception,
+ lambda: self.f.assert_called_with(a=1, b=2, c=3, d=4))
+
+if __name__ == '__main__':
+ unittest.main()

0 comments on commit 83d44dc

Please sign in to comment.
Something went wrong with that request. Please try again.