From c23dc441a681ffcd841ed04a5819ed106f8df731 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Thu, 28 Mar 2024 14:09:09 -0400 Subject: [PATCH] Allow setting AppCommandContext and AppInstallationType on the tree --- discord/app_commands/__init__.py | 1 + discord/app_commands/commands.py | 66 ++++++++++++++++---------------- discord/app_commands/tree.py | 28 ++++++++++++-- discord/ext/commands/bot.py | 22 ++++++++++- docs/interactions/api.rst | 8 ++-- 5 files changed, 84 insertions(+), 41 deletions(-) diff --git a/discord/app_commands/__init__.py b/discord/app_commands/__init__.py index 971461713449..a338cab75dc5 100644 --- a/discord/app_commands/__init__.py +++ b/discord/app_commands/__init__.py @@ -16,5 +16,6 @@ from .namespace import * from .transformers import * from .translator import * +from .installs import * from . import checks as checks from .checks import Cooldown as Cooldown diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 9f09db7bafde..b1f3c3ee7cee 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -49,7 +49,7 @@ from copy import copy as shallow_copy from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale -from ..flags import AppCommandContext, AppInstallationType +from .installs import AppCommandContext, AppInstallationType from .models import Choice from .transformers import annotation_to_parameter, CommandParameter, NoneType from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered @@ -66,6 +66,8 @@ from ..abc import Snowflake from .namespace import Namespace from .models import ChoiceT + from .tree import CommandTree + from .._types import ClientT # Generally, these two libraries are supposed to be separate from each other. # However, for type hinting purposes it's unfortunately necessary for one to @@ -744,8 +746,8 @@ def _copy_with( return copy - async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]: - base = self.to_dict() + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) name_localizations: Dict[str, str] = {} description_localizations: Dict[str, str] = {} @@ -771,7 +773,7 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any] ] return base - def to_dict(self) -> Dict[str, Any]: + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: # If we have a parent then our type is a subcommand # Otherwise, the type falls back to the specific command type (e.g. slash command or context menu) option_type = AppCommandType.chat_input.value if self.parent is None else AppCommandOptionType.subcommand.value @@ -786,8 +788,8 @@ def to_dict(self) -> Dict[str, Any]: base['nsfw'] = self.nsfw base['dm_permission'] = not self.guild_only base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value - base['contexts'] = self.allowed_contexts.to_array() if self.allowed_contexts is not None else None - base['integration_types'] = self.allowed_installs.to_array() if self.allowed_installs is not None else None + base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts) + base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs) return base @@ -1277,8 +1279,8 @@ def qualified_name(self) -> str: """:class:`str`: Returns the fully qualified command name.""" return self.name - async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]: - base = self.to_dict() + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) context = TranslationContext(location=TranslationContextLocation.command_name, data=self) if self._locale_name: name_localizations: Dict[str, str] = {} @@ -1290,13 +1292,13 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any] base['name_localizations'] = name_localizations return base - def to_dict(self) -> Dict[str, Any]: + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: return { 'name': self.name, 'type': self.type.value, 'dm_permission': not self.guild_only, - 'contexts': self.allowed_contexts.to_array() if self.allowed_contexts is not None else None, - 'integration_types': self.allowed_installs.to_array() if self.allowed_installs is not None else None, + 'contexts': tree.allowed_contexts._merge_to_array(self.allowed_contexts), + 'integration_types': tree.allowed_installs._merge_to_array(self.allowed_installs), 'default_member_permissions': None if self.default_permissions is None else self.default_permissions.value, 'nsfw': self.nsfw, } @@ -1711,8 +1713,8 @@ def _copy_with( return copy - async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]: - base = self.to_dict() + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) name_localizations: Dict[str, str] = {} description_localizations: Dict[str, str] = {} @@ -1732,10 +1734,10 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any] base['name_localizations'] = name_localizations base['description_localizations'] = description_localizations - base['options'] = [await child.get_translated_payload(translator) for child in self._children.values()] + base['options'] = [await child.get_translated_payload(tree, translator) for child in self._children.values()] return base - def to_dict(self) -> Dict[str, Any]: + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: # If this has a parent command then it's part of a subcommand group # Otherwise, it's just a regular command option_type = 1 if self.parent is None else AppCommandOptionType.subcommand_group.value @@ -1743,15 +1745,15 @@ def to_dict(self) -> Dict[str, Any]: 'name': self.name, 'description': self.description, 'type': option_type, - 'options': [child.to_dict() for child in self._children.values()], + 'options': [child.to_dict(tree) for child in self._children.values()], } if self.parent is None: base['nsfw'] = self.nsfw base['dm_permission'] = not self.guild_only base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value - base['contexts'] = self.allowed_contexts.to_array() if self.allowed_contexts is not None else None - base['integration_types'] = self.allowed_installs.to_array() if self.allowed_installs is not None else None + base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts) + base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs) return base @@ -2501,12 +2503,12 @@ async def my_guild_only_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): f.guild_only = True - allowed_contexts = f.allowed_contexts or AppCommandContext.none() + allowed_contexts = f.allowed_contexts or AppCommandContext() f.allowed_contexts = allowed_contexts else: f.__discord_app_commands_guild_only__ = True # type: ignore # Runtime attribute assignment - allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none() + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment allowed_contexts.guild = True @@ -2545,10 +2547,10 @@ async def my_private_channel_only_command(interaction: discord.Interaction) -> N def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): f.guild_only = False - allowed_contexts = f.allowed_contexts or AppCommandContext.none() + allowed_contexts = f.allowed_contexts or AppCommandContext() f.allowed_contexts = allowed_contexts else: - allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none() + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment allowed_contexts.private_channel = True @@ -2587,10 +2589,10 @@ async def my_dm_only_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): f.guild_only = False - allowed_contexts = f.allowed_contexts or AppCommandContext.none() + allowed_contexts = f.allowed_contexts or AppCommandContext() f.allowed_contexts = allowed_contexts else: - allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none() + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment allowed_contexts.dm_channel = True @@ -2628,10 +2630,10 @@ async def my_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): f.guild_only = False - allowed_contexts = f.allowed_contexts or AppCommandContext.none() + allowed_contexts = f.allowed_contexts or AppCommandContext() f.allowed_contexts = allowed_contexts else: - allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none() + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment if guilds is not MISSING: @@ -2668,10 +2670,10 @@ async def my_guild_install_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): - allowed_installs = f.allowed_installs or AppInstallationType.none() + allowed_installs = f.allowed_installs or AppInstallationType() f.allowed_installs = allowed_installs else: - allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none() + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment allowed_installs.guild = True @@ -2706,10 +2708,10 @@ async def my_user_install_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): - allowed_installs = f.allowed_installs or AppInstallationType.none() + allowed_installs = f.allowed_installs or AppInstallationType() f.allowed_installs = allowed_installs else: - allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none() + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment allowed_installs.user = True @@ -2748,10 +2750,10 @@ async def my_command(interaction: discord.Interaction) -> None: def inner(f: T) -> T: if isinstance(f, (Command, Group, ContextMenu)): - allowed_installs = f.allowed_installs or AppInstallationType.none() + allowed_installs = f.allowed_installs or AppInstallationType() f.allowed_installs = allowed_installs else: - allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none() + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment if guilds is not MISSING: diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index c75682e0ea2c..abd8924806fd 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -58,6 +58,7 @@ CommandSyncFailure, MissingApplicationID, ) +from .installs import AppCommandContext, AppInstallationType from .translator import Translator, locale_str from ..errors import ClientException, HTTPException from ..enums import AppCommandType, InteractionType @@ -121,9 +122,26 @@ class CommandTree(Generic[ClientT]): to find the guild-specific ``/ping`` command it will fall back to the global ``/ping`` command. This has the potential to raise more :exc:`~discord.app_commands.CommandSignatureMismatch` errors than usual. Defaults to ``True``. + allowed_contexts: :class:`~discord.app_commands.AppCommandContext` + The default allowed contexts that applies to all commands in this tree. + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 + allowed_installs: :class:`~discord.app_commands.AppInstallationType` + The default allowed install locations that apply to all commands in this tree. + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 """ - def __init__(self, client: ClientT, *, fallback_to_global: bool = True): + def __init__( + self, + client: ClientT, + *, + fallback_to_global: bool = True, + allowed_contexts: AppCommandContext = MISSING, + allowed_installs: AppInstallationType = MISSING, + ): self.client: ClientT = client self._http = client.http self._state = client._connection @@ -133,6 +151,8 @@ def __init__(self, client: ClientT, *, fallback_to_global: bool = True): self._state._command_tree = self self.fallback_to_global: bool = fallback_to_global + self.allowed_contexts = AppCommandContext() if allowed_contexts is MISSING else allowed_contexts + self.allowed_installs = AppInstallationType() if allowed_installs is MISSING else allowed_installs self._guild_commands: Dict[int, Dict[str, Union[Command, Group]]] = {} self._global_commands: Dict[str, Union[Command, Group]] = {} # (name, guild_id, command_type): Command @@ -722,7 +742,7 @@ def walk_commands( else: guild_id = None if guild is None else guild.id value = type.value - for ((_, g, t), command) in self._context_menus.items(): + for (_, g, t), command in self._context_menus.items(): if g == guild_id and t == value: yield command @@ -1058,9 +1078,9 @@ async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: translator = self.translator if translator: - payload = [await command.get_translated_payload(translator) for command in commands] + payload = [await command.get_translated_payload(self, translator) for command in commands] else: - payload = [command.to_dict() for command in commands] + payload = [command.to_dict(self) for command in commands] try: if guild is None: diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index b691c5af29a9..208948335568 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -166,6 +166,8 @@ def __init__( help_command: Optional[HelpCommand] = _default, tree_cls: Type[app_commands.CommandTree[Any]] = app_commands.CommandTree, description: Optional[str] = None, + allowed_contexts: app_commands.AppCommandContext = MISSING, + allowed_installs: app_commands.AppInstallationType = MISSING, intents: discord.Intents, **options: Any, ) -> None: @@ -174,6 +176,11 @@ def __init__( self.extra_events: Dict[str, List[CoroFunc]] = {} # Self doesn't have the ClientT bound, but since this is a mixin it technically does self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore + if allowed_contexts is not MISSING: + self.__tree.allowed_contexts = allowed_contexts + if allowed_installs is not MISSING: + self.__tree.allowed_installs = allowed_installs + self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} self._checks: List[UserCheck] = [] @@ -521,7 +528,6 @@ async def is_owner(self, user: User, /) -> bool: elif self.owner_ids: return user.id in self.owner_ids else: - app: discord.AppInfo = await self.application_info() # type: ignore if app.team: self.owner_ids = ids = { @@ -1489,6 +1495,20 @@ class Bot(BotBase, discord.Client): The type of application command tree to use. Defaults to :class:`~discord.app_commands.CommandTree`. .. versionadded:: 2.0 + allowed_contexts: :class:`~discord.app_commands.AppCommandContext` + The default allowed contexts that applies to all application commands + in the application command tree. + + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 + allowed_installs: :class:`~discord.app_commands.AppInstallationType` + The default allowed install locations that apply to all application commands + in the application command tree. + + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 """ pass diff --git a/docs/interactions/api.rst b/docs/interactions/api.rst index 02088d829ebf..6aa234257797 100644 --- a/docs/interactions/api.rst +++ b/docs/interactions/api.rst @@ -132,17 +132,17 @@ AppCommandPermissions AppCommandContext ~~~~~~~~~~~~~~~~~ -.. attributetable:: AppCommandContext +.. attributetable:: discord.app_commands.AppCommandContext -.. autoclass:: AppCommandContext +.. autoclass:: discord.app_commands.AppCommandContext :members: AppInstallationType ~~~~~~~~~~~~~~~~~~~~ -.. attributetable:: AppInstallationType +.. attributetable:: discord.app_commands.AppInstallationType -.. autoclass:: AppInstallationType +.. autoclass:: discord.app_commands.AppInstallationType :members: GuildAppCommandPermissions