diff --git a/CHANGES.rst b/CHANGES.rst index c2af7ae..ee46dfd 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,11 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Backported upstream fix for gh-99553 (custom subclasses of ``BaseExceptionGroup`` that + also inherit from ``Exception`` should not be able to wrap base exceptions) + **1.0.4** - Fixed regression introduced in v1.0.3 where the code computing the suggestions would diff --git a/src/exceptiongroup/_exceptions.py b/src/exceptiongroup/_exceptions.py index ebdd172..a2ca092 100644 --- a/src/exceptiongroup/_exceptions.py +++ b/src/exceptiongroup/_exceptions.py @@ -67,6 +67,18 @@ def __new__( if all(isinstance(exc, Exception) for exc in __exceptions): cls = ExceptionGroup + if issubclass(cls, Exception): + for exc in __exceptions: + if not isinstance(exc, Exception): + if cls is ExceptionGroup: + raise TypeError( + "Cannot nest BaseExceptions in an ExceptionGroup" + ) + else: + raise TypeError( + f"Cannot nest BaseExceptions in {cls.__name__!r}" + ) + return super().__new__(cls, __message, __exceptions) def __init__( @@ -219,15 +231,7 @@ def __repr__(self) -> str: class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception): def __new__(cls, __message: str, __exceptions: Sequence[_ExceptionT_co]) -> Self: - instance: ExceptionGroup[_ExceptionT_co] = super().__new__( - cls, __message, __exceptions - ) - if cls is ExceptionGroup: - for exc in __exceptions: - if not isinstance(exc, Exception): - raise TypeError("Cannot nest BaseExceptions in an ExceptionGroup") - - return instance + return super().__new__(cls, __message, __exceptions) if TYPE_CHECKING: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index d0d33cd..9a0e39b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,6 +3,8 @@ import sys import unittest +import pytest + from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -90,19 +92,35 @@ def test_BEG_wraps_BaseException__creates_BEG(self): beg = BaseExceptionGroup("beg", [ValueError(1), KeyboardInterrupt(2)]) self.assertIs(type(beg), BaseExceptionGroup) - def test_EG_subclass_wraps_anything(self): + def test_EG_subclass_wraps_non_base_exceptions(self): class MyEG(ExceptionGroup): pass self.assertIs(type(MyEG("eg", [ValueError(12), TypeError(42)])), MyEG) - self.assertIs(type(MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])), MyEG) - def test_BEG_subclass_wraps_anything(self): - class MyBEG(BaseExceptionGroup): + @pytest.mark.skipif( + sys.version_info[:3] == (3, 11, 0), + reason="Behavior was made stricter in 3.11.1", + ) + def test_EG_subclass_does_not_wrap_base_exceptions(self): + class MyEG(ExceptionGroup): + pass + + msg = "Cannot nest BaseExceptions in 'MyEG'" + with self.assertRaisesRegex(TypeError, msg): + MyEG("eg", [ValueError(12), KeyboardInterrupt(42)]) + + @pytest.mark.skipif( + sys.version_info[:3] == (3, 11, 0), + reason="Behavior was made stricter in 3.11.1", + ) + def test_BEG_and_E_subclass_does_not_wrap_base_exceptions(self): + class MyEG(BaseExceptionGroup, ValueError): pass - self.assertIs(type(MyBEG("eg", [ValueError(12), TypeError(42)])), MyBEG) - self.assertIs(type(MyBEG("eg", [ValueError(12), KeyboardInterrupt(42)])), MyBEG) + msg = "Cannot nest BaseExceptions in 'MyEG'" + with self.assertRaisesRegex(TypeError, msg): + MyEG("eg", [ValueError(12), KeyboardInterrupt(42)]) def create_simple_eg():