diff --git a/discord/commands/core.py b/discord/commands/core.py index 46766024f8..a14c5cb334 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -50,7 +50,7 @@ ) from ..enums import ChannelType, SlashCommandOptionType -from ..errors import ClientException, ValidationError +from ..errors import ClientException, ValidationError, NotFound from ..member import Member from ..message import Attachment, Message from ..user import User @@ -744,9 +744,10 @@ async def _invoke(self, ctx: ApplicationContext) -> None: elif op.input_type == SlashCommandOptionType.mentionable: arg_id = int(arg) - arg = await get_or_fetch(ctx.guild, "member", arg_id) - if arg is None: - arg = ctx.guild.get_role(arg_id) or arg_id + try: + arg = await get_or_fetch(ctx.guild, "member", arg_id) + except NotFound: + arg = await get_or_fetch(ctx.guild, "role", arg_id) elif op.input_type == SlashCommandOptionType.string and (converter := op.converter) is not None: arg = await converter.convert(converter, ctx, arg) diff --git a/discord/guild.py b/discord/guild.py index 4263379436..d9ed58cb3c 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -2474,6 +2474,38 @@ async def fetch_roles(self) -> List[Role]: data = await self._state.http.get_roles(self.id) return [Role(guild=self, state=self._state, data=d) for d in data] + async def _fetch_role(self, role_id: int) -> Role: + """|coro| + + Retrieves a :class:`Role` that the guild has. + + .. note:: + + This method is an API call. For general usage, consider using :attr:`get_role` instead. + + .. versionadded:: 2.0 + + Parameters + ----------- + role_id: :class:`int` + The role ID to fetch from the guild. + + Raises + ------- + HTTPException + Retrieving the role failed. + + Returns + ------- + Optional[:class:`Role`] + The role in the guild with the specified ID. + Returns ``None`` if not found. + """ + roles = await self.fetch_roles() + for role in roles: + if role.id == role_id: + return role + @overload async def create_role( self, diff --git a/discord/utils.py b/discord/utils.py index 0878c78712..98b54e1be4 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -377,7 +377,7 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: The snowflake representing the time given. """ discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2**22 - 1 if high else 0) + return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]: @@ -477,6 +477,8 @@ async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING): if getter is None: try: getter = await getattr(obj, f"fetch_{attr}")(id) + except AttributeError: + getter = await getattr(obj, f"_fetch_{attr}")(id) except HTTPException: if default is not MISSING: return default