Skip to content

Commit

Permalink
Avoid error if origin has a buggy __eq__ (python#422)
Browse files Browse the repository at this point in the history
Fixes python#419

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
JelleZijlstra and AlexWaygood committed Jun 3, 2024
1 parent 7269638 commit 53bcdde
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Unreleased

- Fix regression in v4.12.0 where specialization of certain
generics with an overridden `__eq__` method would raise errors.
Patch by Jelle Zijlstra.

# Release 4.12.1 (June 1, 2024)

- Preliminary changes for compatibility with the draft implementation
Expand Down
16 changes: 16 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6617,6 +6617,22 @@ def test_allow_default_after_non_default_in_alias(self):
a4 = Callable[[Unpack[Ts]], T]
self.assertEqual(a4.__args__, (Unpack[Ts], T))

@skip_if_py313_beta_1
def test_generic_with_broken_eq(self):
# See https://github.com/python/typing_extensions/pull/422 for context
class BrokenEq(type):
def __eq__(self, other):
if other is typing_extensions.Protocol:
raise TypeError("I'm broken")
return False

class G(Generic[T], metaclass=BrokenEq):
pass

alias = G[int]
self.assertIs(get_origin(alias), G)
self.assertEqual(get_args(alias), (int,))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
Expand Down
17 changes: 12 additions & 5 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2954,13 +2954,20 @@ def _check_generic(cls, parameters, elen):
def _has_generic_or_protocol_as_origin() -> bool:
try:
frame = sys._getframe(2)
# not all platforms have sys._getframe()
except AttributeError:
# - Catch AttributeError: not all Python implementations have sys._getframe()
# - Catch ValueError: maybe we're called from an unexpected module
# and the call stack isn't deep enough
except (AttributeError, ValueError):
return False # err on the side of leniency
else:
return frame.f_locals.get("origin") in (
typing.Generic, Protocol, typing.Protocol
)
# If we somehow get invoked from outside typing.py,
# also err on the side of leniency
if frame.f_globals.get("__name__") != "typing":
return False
origin = frame.f_locals.get("origin")
# Cannot use "in" because origin may be an object with a buggy __eq__ that
# throws an error.
return origin is typing.Generic or origin is Protocol or origin is typing.Protocol


_TYPEVARTUPLE_TYPES = {TypeVarTuple, getattr(typing, "TypeVarTuple", None)}
Expand Down

0 comments on commit 53bcdde

Please sign in to comment.