diff --git a/discord/app_commands/installs.py b/discord/app_commands/installs.py index 3907b65813da..7d9b2f049245 100644 --- a/discord/app_commands/installs.py +++ b/discord/app_commands/installs.py @@ -23,7 +23,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Optional, Sequence +from typing import TYPE_CHECKING, ClassVar, List, Optional, Sequence __all__ = ( 'AppInstallationType', @@ -32,6 +32,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from ..types.interactions import InteractionContextType, InteractionInstallationType class AppInstallationType: @@ -84,14 +85,14 @@ def merge(self, other: AppInstallationType) -> AppInstallationType: def _is_unset(self) -> bool: return all(x is None for x in (self._guild, self._user)) - def _merge_to_array(self, other: Optional[AppInstallationType]) -> Optional[list[int]]: + def _merge_to_array(self, other: Optional[AppInstallationType]) -> Optional[List[InteractionInstallationType]]: result = self.merge(other) if other is not None else self if result._is_unset(): return None return result.to_array() @classmethod - def _from_value(cls, value: Sequence[int]) -> Self: + def _from_value(cls, value: Sequence[InteractionInstallationType]) -> Self: self = cls() for x in value: if x == cls.GUILD: @@ -100,7 +101,7 @@ def _from_value(cls, value: Sequence[int]) -> Self: self._user = True return self - def to_array(self) -> list[int]: + def to_array(self) -> List[InteractionInstallationType]: values = [] if self._guild: values.append(self.GUILD) @@ -177,14 +178,14 @@ def merge(self, other: AppCommandContext) -> AppCommandContext: def _is_unset(self) -> bool: return all(x is None for x in (self._guild, self._dm_channel, self._private_channel)) - def _merge_to_array(self, other: Optional[AppCommandContext]) -> Optional[list[int]]: + def _merge_to_array(self, other: Optional[AppCommandContext]) -> Optional[List[InteractionContextType]]: result = self.merge(other) if other is not None else self if result._is_unset(): return None return result.to_array() @classmethod - def _from_value(cls, value: Sequence[int]) -> Self: + def _from_value(cls, value: Sequence[InteractionContextType]) -> Self: self = cls() for x in value: if x == cls.GUILD: @@ -195,7 +196,7 @@ def _from_value(cls, value: Sequence[int]) -> Self: self._private_channel = True return self - def to_array(self) -> list[int]: + def to_array(self) -> List[InteractionContextType]: values = [] if self._guild: values.append(self.GUILD) diff --git a/discord/interactions.py b/discord/interactions.py index f7df595b15f2..5638886b348f 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -45,6 +45,7 @@ from .permissions import Permissions from .http import handle_message_parameters from .webhook.async_ import async_context, Webhook, interaction_response_params, interaction_message_response_params +from .app_commands.installs import AppCommandContext from .app_commands.namespace import Namespace from .app_commands.translator import locale_str, TranslationContext, TranslationContextLocation from .channel import _threaded_channel_factory @@ -140,6 +141,10 @@ class Interaction(Generic[ClientT]): command_failed: :class:`bool` Whether the command associated with this interaction failed to execute. This includes checks and execution. + context: :class:`.AppCommandContext` + The context of the interaction. + + .. versionadded:: 2.4 """ __slots__: Tuple[str, ...] = ( @@ -158,6 +163,7 @@ class Interaction(Generic[ClientT]): 'command_failed', 'entitlement_sku_ids', 'entitlements', + "context", '_integration_owners', '_permissions', '_app_permissions', @@ -200,6 +206,10 @@ def _from_data(self, data: InteractionPayload): self._integration_owners: Dict[int, Snowflake] = { int(k): int(v) for k, v in data.get('authorizing_integration_owners', {}).items() } + try: + self.context = AppCommandContext._from_value([data['context']]) + except KeyError: + self.context = AppCommandContext() self.locale: Locale = try_enum(Locale, data.get('locale', 'en-US')) self.guild_locale: Optional[Locale] @@ -380,6 +390,22 @@ def is_expired(self) -> bool: """:class:`bool`: Returns ``True`` if the interaction is expired.""" return utils.utcnow() >= self.expires_at + def is_guild_integration(self) -> bool: + """:class:`bool`: Returns ``True`` if the interaction is a guild integration. + + .. versionadded:: 2.4 + """ + if self.guild_id: + return self.guild_id == self._integration_owners.get(0) + return False + + def is_user_integration(self) -> bool: + """:class:`bool`: Returns ``True`` if the interaction is a user integration. + + .. versionadded:: 2.4 + """ + return self.user.id == self._integration_owners.get(1) + async def original_response(self) -> InteractionMessage: """|coro| diff --git a/discord/types/interactions.py b/discord/types/interactions.py index ae63a126f5d1..d9446ee0eb6f 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -43,6 +43,7 @@ InteractionType = Literal[1, 2, 3, 4, 5] InteractionContextType = Literal[0, 1, 2] +InteractionInstallationType = Literal[0, 1] class _BasePartialChannel(TypedDict):