Permalink
Browse files

make assert_raises usable as a context manager.

  • Loading branch information...
1 parent e40cf8c commit 1dc45da073e6c7d92aad6242de3b4be639484f46 @sumeet sumeet committed with Sumeet Agarwal Apr 8, 2012
Showing with 113 additions and 7 deletions.
  1. +77 −0 test/assertions_test.py
  2. +36 −7 testify/assertions.py
View
@@ -1,7 +1,10 @@
+from __future__ import with_statement
+
from testify import TestCase
from testify import assertions
from testify import run
from testify import assert_equal
+from testify import assert_not_reached
class DiffMessageTestCase(TestCase):
@@ -43,5 +46,79 @@ def test_shows_pretty_diff_output(self):
assert False, 'Expected `AssertionError`.'
+class MyException(Exception):
+ pass
+
+class AssertRaisesAsContextManagerTestCase(TestCase):
+
+ def test_fails_when_exception_is_not_raised(self):
+ def exception_should_be_raised():
+ with assertions.assert_raises(MyException):
+ pass
+
+ try:
+ exception_should_be_raised()
+ except AssertionError:
+ pass
+ else:
+ assert_not_reached('AssertionError should have been raised')
+
+ def test_passes_when_exception_is_raised(self):
+ def exception_should_be_raised():
+ with assertions.assert_raises(MyException):
+ raise MyException
+
+ exception_should_be_raised()
+
+ def test_crashes_when_another_exception_class_is_raised(self):
+ def assert_raises_an_exception_and_raise_another():
+ with assertions.assert_raises(MyException):
+ raise ValueError
+
+ try:
+ assert_raises_an_exception_and_raise_another()
+ except ValueError:
+ pass
+ else:
+ raise AssertionError('ValueError should have been raised')
+
+
+class AssertRaisesAsCallableTestCase(TestCase):
+
+ def test_fails_when_exception_is_not_raised(self):
+ raises_nothing = lambda: None
+ try:
+ assertions.assert_raises(ValueError, raises_nothing)
+ except AssertionError:
+ pass
+ else:
+ assert_not_reached('AssertionError should have been raised')
+
+ def test_passes_when_exception_is_raised(self):
+ def raises_value_error():
+ raise ValueError
+ assertions.assert_raises(ValueError, raises_value_error)
+
+ def test_fails_when_wrong_exception_is_raised(self):
+ def raises_value_error():
+ raise ValueError
+ try:
+ assertions.assert_raises(MyException, raises_value_error)
+ except ValueError:
+ pass
+ else:
+ assert_not_reached('ValueError should have been raised')
+
+ def test_callable_is_called_with_all_arguments(self):
+ class GoodArguments(Exception): pass
+ arg1, arg2, kwarg = object(), object(), object()
+ def check_arguments(*args, **kwargs):
+ assert_equal((arg1, arg2), args)
+ assert_equal({'kwarg': kwarg}, kwargs)
+ raise GoodArguments
+ assertions.assert_raises(GoodArguments, check_arguments, arg1, arg2,
+ kwarg=kwarg)
+
+
if __name__ == '__main__':
run()
View
@@ -11,16 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import with_statement
import datetime
import functools
import re
+import contextlib
+
from .utils import stringdiffer
__testify = 1
+
def _val_subtract(val1, val2, dict_subtractor, list_subtractor):
"""
Find the difference between two container types
@@ -122,14 +126,39 @@ def _list_subtract(list1, list2):
return res_list
-def assert_raises(expected_exception_class, callable_obj, *args, **kwargs):
- """Returns true only if the callable raises expected_exception_class"""
+def assert_raises(*args, **kwargs):
+ """Assert an exception is raised as a context manager or by passing in a
+ callable and its arguments.
+
+ As a context manager:
+ >>> with assert_raises(Exception):
+ ... raise Exception
+
+ Pass in a callable:
+ >>> def raise_exception(arg, kwarg=None):
+ ... raise Exception
+ >>> assert_raises(Exception, raise_exception, 1, kwarg=234)
+ """
+ if (len(args) == 1) and not kwargs:
+ return _assert_raises_context_manager(args[0])
+ else:
+ return _assert_raises(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def _assert_raises_context_manager(exception_class):
try:
- callable_obj(*args, **kwargs)
- except expected_exception_class:
- # we got the expected exception
- return True
- assert_not_reached("No exception was raised (expected %s)" % expected_exception_class)
+ yield
+ except exception_class:
+ return
+ else:
+ assert_not_reached("No exception was raised (expected %r)" %
+ exception_class)
+
+
+def _assert_raises(exception_class, callable, *args, **kwargs):
+ with _assert_raises_context_manager(exception_class):
+ callable(*args, **kwargs)
def _diff_message(lhs, rhs):

0 comments on commit 1dc45da

Please sign in to comment.