From 8b8791b662c0f62a574a09f305cd204dfb0a6a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 9 Aug 2023 17:55:46 +0300 Subject: [PATCH] Fixed bare `raise` and exception chaining when a handler raises an exception (#71) --- CHANGES.rst | 4 ++++ src/exceptiongroup/_catch.py | 21 ++++++++++++++++----- tests/test_catch.py | 36 ++++++++++++++++++++++++++++++++++-- tests/test_catch_py311.py | 34 ++++++++++++++++++++++++++++++++-- 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index eccb5b3..366fca9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,10 @@ This library adheres to `Semantic Versioning 2.0 `_. - `catch()` now raises a `TypeError` if passed an async exception handler instead of just giving a `RuntimeWarning` about the coroutine never being awaited. (#66, PR by John Litborn) +- Fixed plain ``raise`` statement in an exception handler callback to work like a + ``raise`` in an ``except*`` block +- Fixed new exception group not being chained to the original exception when raising an + exception group from exceptions raised in handler callbacks **1.1.2** diff --git a/src/exceptiongroup/_catch.py b/src/exceptiongroup/_catch.py index 2d82be1..b76f51f 100644 --- a/src/exceptiongroup/_catch.py +++ b/src/exceptiongroup/_catch.py @@ -34,7 +34,16 @@ def __exit__( elif unhandled is None: return True else: - raise unhandled from None + if isinstance(exc, BaseExceptionGroup): + try: + raise unhandled from exc.__cause__ + except BaseExceptionGroup: + # Change __context__ to __cause__ because Python 3.11 does this + # too + unhandled.__context__ = exc.__cause__ + raise + + raise unhandled from exc return False @@ -50,7 +59,12 @@ def handle_exception(self, exc: BaseException) -> BaseException | None: matched, excgroup = excgroup.split(exc_types) if matched: try: - result = handler(matched) + try: + raise matched + except BaseExceptionGroup: + result = handler(matched) + except BaseExceptionGroup as new_exc: + new_exceptions.extend(new_exc.exceptions) except BaseException as new_exc: new_exceptions.append(new_exc) else: @@ -67,9 +81,6 @@ def handle_exception(self, exc: BaseException) -> BaseException | None: if len(new_exceptions) == 1: return new_exceptions[0] - if excgroup: - new_exceptions.append(excgroup) - return BaseExceptionGroup("", new_exceptions) elif ( excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc diff --git a/tests/test_catch.py b/tests/test_catch.py index 586fd43..7fb93c1 100644 --- a/tests/test_catch.py +++ b/tests/test_catch.py @@ -148,9 +148,41 @@ def test_catch_handler_raises(): def handler(exc): raise RuntimeError("new") - with pytest.raises(RuntimeError, match="new"): + with pytest.raises(RuntimeError, match="new") as exc: with catch({(ValueError, ValueError): handler}): - raise ExceptionGroup("booboo", [ValueError("bar")]) + excgrp = ExceptionGroup("booboo", [ValueError("bar")]) + raise excgrp + + context = exc.value.__context__ + assert isinstance(context, ExceptionGroup) + assert str(context) == "booboo (1 sub-exception)" + assert len(context.exceptions) == 1 + assert isinstance(context.exceptions[0], ValueError) + assert exc.value.__cause__ is None + + +def test_bare_raise_in_handler(): + """Test that a bare "raise" "middle" ecxeption group gets discarded.""" + + def handler(exc): + raise + + with pytest.raises(ExceptionGroup) as excgrp: + with catch({(ValueError,): handler, (RuntimeError,): lambda eg: None}): + try: + first_exc = RuntimeError("first") + raise first_exc + except RuntimeError as exc: + middle_exc = ExceptionGroup( + "bad", [ValueError(), ValueError(), TypeError()] + ) + raise middle_exc from exc + + assert len(excgrp.value.exceptions) == 2 + assert all(isinstance(exc, ValueError) for exc in excgrp.value.exceptions) + assert excgrp.value is not middle_exc + assert excgrp.value.__cause__ is first_exc + assert excgrp.value.__context__ is first_exc def test_catch_subclass(): diff --git a/tests/test_catch_py311.py b/tests/test_catch_py311.py index 5880f0a..29f4dd5 100644 --- a/tests/test_catch_py311.py +++ b/tests/test_catch_py311.py @@ -128,12 +128,20 @@ def test_catch_full_match(): reason="Behavior was changed in 3.11.4", ) def test_catch_handler_raises(): - with pytest.raises(RuntimeError, match="new"): + with pytest.raises(RuntimeError, match="new") as exc: try: - raise ExceptionGroup("booboo", [ValueError("bar")]) + excgrp = ExceptionGroup("booboo", [ValueError("bar")]) + raise excgrp except* ValueError: raise RuntimeError("new") + context = exc.value.__context__ + assert isinstance(context, ExceptionGroup) + assert str(context) == "booboo (1 sub-exception)" + assert len(context.exceptions) == 1 + assert isinstance(context.exceptions[0], ValueError) + assert exc.value.__cause__ is None + def test_catch_subclass(): lookup_errors = [] @@ -146,3 +154,25 @@ def test_catch_subclass(): assert isinstance(lookup_errors[0], ExceptionGroup) exceptions = lookup_errors[0].exceptions assert isinstance(exceptions[0], KeyError) + + +def test_bare_raise_in_handler(): + """Test that the "middle" ecxeption group gets discarded.""" + with pytest.raises(ExceptionGroup) as excgrp: + try: + try: + first_exc = RuntimeError("first") + raise first_exc + except RuntimeError as exc: + middle_exc = ExceptionGroup( + "bad", [ValueError(), ValueError(), TypeError()] + ) + raise middle_exc from exc + except* ValueError: + raise + except* TypeError: + pass + + assert excgrp.value is not middle_exc + assert excgrp.value.__cause__ is first_exc + assert excgrp.value.__context__ is first_exc