diff --git a/discord/app/commands.py b/discord/app/commands.py index 623995e4e8..183315c08b 100644 --- a/discord/app/commands.py +++ b/discord/app/commands.py @@ -30,7 +30,6 @@ from typing import Callable, Dict, List, Optional, Union from ..enums import SlashCommandOptionType -from ..interactions import Interaction from ..member import Member from ..user import User from ..message import Message @@ -138,12 +137,10 @@ def __eq__(self, other) -> bool: and other.description == self.description ) - async def invoke(self, interaction) -> None: + async def invoke(self, ctx: InteractionContext) -> None: # TODO: Parse the args better, apply custom converters etc. - ctx = InteractionContext(interaction) - kwargs = {} - for arg in interaction.data.get("options", []): + for arg in ctx.interaction.data.get("options", []): op = find(lambda x: x.name == arg["name"], self.options) arg = arg["value"] @@ -257,11 +254,11 @@ def command_group(self, name, description) -> SubCommandGroup: self.subcommands.append(sub_command_group) return sub_command_group - async def invoke(self, interaction: Interaction) -> None: - option = interaction.data["options"][0] + async def invoke(self, ctx: InteractionContext) -> None: + option = ctx.interaction.data["options"][0] command = find(lambda x: x.name == option["name"], self.subcommands) - interaction.data = option - await command.invoke(interaction) + ctx.interaction.data = option + await command.invoke(ctx) class UserCommand(ApplicationCommand): @@ -290,29 +287,28 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: def to_dict(self) -> Dict[str, Union[str, int]]: return {"name": self.name, "description": self.description, "type": self.type} - async def invoke(self, interaction: Interaction) -> None: - if "members" not in interaction.data["resolved"]: - _data = interaction.data["resolved"]["users"] + async def invoke(self, ctx: InteractionContext) -> None: + if "members" not in ctx.interaction.data["resolved"]: + _data = ctx.interaction.data["resolved"]["users"] for i, v in _data.items(): v["id"] = int(i) user = v - target = User(state=interaction._state, data=user) + target = User(state=ctx.interaction._state, data=user) else: - _data = interaction.data["resolved"]["members"] + _data = ctx.interaction.data["resolved"]["members"] for i, v in _data.items(): v["id"] = int(i) member = v - _data = interaction.data["resolved"]["users"] + _data = ctx.interaction.data["resolved"]["users"] for i, v in _data.items(): v["id"] = int(i) user = v member["user"] = user target = Member( data=member, - guild=interaction._state._get_guild(interaction.guild_id), - state=interaction._state, + guild=ctx.interaction._state._get_guild(ctx.interaction.guild_id), + state=ctx.interaction._state, ) - ctx = InteractionContext(interaction) await self.callback(ctx, target) @@ -340,18 +336,17 @@ def __init__(self, func, *args, **kwargs): def to_dict(self): return {"name": self.name, "description": self.description, "type": self.type} - async def invoke(self, interaction): - _data = interaction.data["resolved"]["messages"] + async def invoke(self, ctx: InteractionContext): + _data = ctx.interaction.data["resolved"]["messages"] for i, v in _data.items(): v["id"] = int(i) message = v - channel = interaction._state.get_channel(int(message["channel_id"])) + channel = ctx.interaction._state.get_channel(int(message["channel_id"])) if channel is None: - data = await interaction._state.http.start_private_message( + data = await ctx.interaction._state.http.start_private_message( int(message["author"]["id"]) ) - channel = interaction._state.add_dm_channel(data) + channel = ctx.interaction._state.add_dm_channel(data) - target = Message(state=interaction._state, channel=channel, data=message) - ctx = InteractionContext(interaction) + target = Message(state=ctx.interaction._state, channel=channel, data=message) await self.callback(ctx, target) diff --git a/discord/app/context.py b/discord/app/context.py index 96ee8706ad..ac65780c21 100644 --- a/discord/app/context.py +++ b/discord/app/context.py @@ -22,6 +22,11 @@ DEALINGS IN THE SOFTWARE. """ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import discord + from ..interactions import Interaction from ..utils import cached_property @@ -35,7 +40,8 @@ class InteractionContext: .. versionadded:: 2.0 """ - def __init__(self, interaction: Interaction): + def __init__(self, bot: "discord.Bot", interaction: Interaction): + self.bot = bot self.interaction = interaction @cached_property diff --git a/discord/bot.py b/discord/bot.py index 8b3d40aad2..1b5b0a0024 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -38,6 +38,7 @@ MessageCommand, UserCommand, ApplicationCommand, + InteractionContext, ) from .errors import Forbidden from .interactions import Interaction @@ -245,7 +246,8 @@ async def handle_interaction(self, interaction: Interaction) -> None: except KeyError: self.dispatch("unknown_command", interaction) else: - await command.invoke(interaction) + context = await self.get_application_context(interaction) + await command.invoke(context) def slash_command(self, **kwargs) -> SlashCommand: """A shortcut decorator that invokes :func:`.ApplicationCommandMixin.command` and adds it to @@ -336,6 +338,30 @@ def command_group(self, name: str, description: str, guild_ids=None) -> SubComma self.add_application_command(group) return group + async def get_application_context( + self, interaction: Interaction, cls=None + ) -> InteractionContext: + r"""|coro| + + Returns the invocation context from the interaction. + + This is a more low-level counter-part for :meth:`.handle_interaction` + to allow users more fine grained control over the processing. + + Parameters + ----------- + interaction: :class:`discord.Interaction` + The interaction to get the invocation context from. + + Returns + -------- + :class:`.InteractionContext` + The invocation context. + """ + if cls == None: + cls = InteractionContext + return cls(self, interaction) + class BotBase(ApplicationCommandMixin): # To Insert: CogMixin # TODO I think