Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG-V3.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ release.

### Changed

- Removed the custom `enums.Enum` implementation in favor of a stdlib `enum.Enum` subclass.

### Deprecated

### Removed
Expand Down
155 changes: 31 additions & 124 deletions discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from __future__ import annotations

import types
from collections import namedtuple
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union
from enum import Enum as EnumBase
from typing import Any, Self, TypeVar, Union

E = TypeVar("E", bound="Enum")

__all__ = (
"Enum",
Expand Down Expand Up @@ -83,118 +85,36 @@
)


def _create_value_cls(name, comparable):
cls = namedtuple(f"_EnumValue_{name}", "name value")
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls


def _is_descriptor(obj):
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")


class EnumMeta(type):
if TYPE_CHECKING:
__name__: ClassVar[str]
_enum_member_names_: ClassVar[list[str]]
_enum_member_map_: ClassVar[dict[str, Any]]
_enum_value_map_: ClassVar[dict[Any, Any]]

def __new__(cls, name, bases, attrs, *, comparable: bool = False):
value_mapping = {}
member_mapping = {}
member_names = []

value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == "_" and not is_descriptor:
continue

# Special case classmethod to just pass through
if isinstance(value, classmethod):
continue

if is_descriptor:
setattr(value_cls, key, value)
del attrs[key]
continue

try:
new_value = value_mapping[value]
except KeyError:
new_value = value_cls(name=key, value=value)
value_mapping[value] = new_value
member_names.append(key)

member_mapping[key] = new_value
attrs[key] = new_value

attrs["_enum_value_map_"] = value_mapping
attrs["_enum_member_map_"] = member_mapping
attrs["_enum_member_names_"] = member_names
attrs["_enum_value_cls_"] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore
return actual_cls

def __iter__(cls):
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)

def __reversed__(cls):
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))

def __len__(cls):
return len(cls._enum_member_names_)

def __repr__(cls):
return f"<enum {cls.__name__}>"

@property
def __members__(cls):
return types.MappingProxyType(cls._enum_member_map_)

def __call__(cls, value):
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError) as e:
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e

def __getitem__(cls, key):
return cls._enum_member_map_[key]

def __setattr__(cls, name, value):
raise TypeError("Enums are immutable.")
class Enum(EnumBase):
"""An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is
not present in any of the members of it.
"""

def __delattr__(cls, attr):
raise TypeError("Enums are immutable")
def __init_subclass__(cls, *, comparable: bool = False) -> None:
super().__init_subclass__()

def __instancecheck__(self, instance):
# isinstance(x, Y)
# -> __instancecheck__(Y, x)
try:
return instance._actual_enum_cls_ is self
except AttributeError:
return False
if comparable is True:
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value

@classmethod
def _missing_(cls, value: Any) -> Self:
name = f"unknown_{value}"
if name in cls.__members__:
return cls.__members__[name]

if TYPE_CHECKING:
from enum import Enum
else:
# this creates the new unknown value member
obj = object.__new__(cls)
obj._name_ = name
obj._value_ = value

class Enum(metaclass=EnumMeta):
@classmethod
def try_value(cls, value):
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError):
return value
# and adds it to the member mapping of this enum so we don't
# create a different enum member value each time
cls._member_map_[name] = obj
cls._value2member_map_[value] = obj
return obj


class ChannelType(Enum):
Expand Down Expand Up @@ -1078,22 +998,9 @@ def __int__(self):
return self.value


T = TypeVar("T")


def create_unknown_value(cls: type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore
name = f"unknown_{val}"
return value_cls(name=name, value=val)


def try_enum(cls: type[T], val: Any) -> T:
def try_enum(cls: type[E], val: Any) -> E:
"""A function that tries to turn the value into enum ``cls``.
If it fails it returns a proxy invalid value instead.
"""

try:
return cls._enum_value_map_[val] # type: ignore
except (KeyError, TypeError, AttributeError):
return create_unknown_value(cls, val)
return cls(val)