Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import inspect
import sys
import traceback
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -75,7 +76,7 @@
)


class ApplicationCommandMixin:
class ApplicationCommandMixin(ABC):
"""A mixin that implements common functionality for classes that need
application command compatibility.

Expand Down Expand Up @@ -104,8 +105,8 @@ def pending_application_commands(self):
@property
def commands(self) -> List[Union[ApplicationCommand, Any]]:
commands = self.application_commands
if self._supports_prefixed_commands:
commands += self.prefixed_commands
if self._bot._supports_prefixed_commands and hasattr(self._bot, "prefixed_commands"):
commands += self._bot.prefixed_commands
return commands

@property
Expand All @@ -128,8 +129,8 @@ def add_application_command(self, command: ApplicationCommand) -> None:
if isinstance(command, SlashCommand) and command.is_subcommand:
raise TypeError("The provided command is a sub-command of group")

if self.debug_guilds and command.guild_ids is None:
command.guild_ids = self.debug_guilds
if self._bot.debug_guilds and command.guild_ids is None:
command.guild_ids = self._bot.debug_guilds

for cmd in self.pending_application_commands:
if cmd == command:
Expand Down Expand Up @@ -239,10 +240,10 @@ async def get_desynced_commands(self, guild_id: Optional[int] = None) -> List[Di
cmds = self.pending_application_commands.copy()

if guild_id is None:
registered_commands = await self.http.get_global_commands(self.user.id)
registered_commands = await self._bot.http.get_global_commands(self._bot.user.id)
pending = [cmd for cmd in cmds if cmd.guild_ids is None]
else:
registered_commands = await self.http.get_guild_commands(self.user.id, guild_id)
registered_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids]

registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands}
Expand Down Expand Up @@ -385,14 +386,14 @@ async def register_commands(
if is_global:
pending = list(filter(lambda c: c.guild_ids is None, commands))
registration_methods = {
"bulk": self.http.bulk_upsert_global_commands,
"upsert": self.http.upsert_global_command,
"delete": self.http.delete_global_command,
"edit": self.http.edit_global_command,
"bulk": self._bot.http.bulk_upsert_global_commands,
"upsert": self._bot.http.upsert_global_command,
"delete": self._bot.http.delete_global_command,
"edit": self._bot.http.edit_global_command,
}

def register(method: str, *args, **kwargs):
return registration_methods[method](self.user.id, *args, **kwargs)
return registration_methods[method](self._bot.user.id, *args, **kwargs)

else:
pending = list(
Expand All @@ -402,14 +403,14 @@ def register(method: str, *args, **kwargs):
)
)
registration_methods = {
"bulk": self.http.bulk_upsert_guild_commands,
"upsert": self.http.upsert_guild_command,
"delete": self.http.delete_guild_command,
"edit": self.http.edit_guild_command,
"bulk": self._bot.http.bulk_upsert_guild_commands,
"upsert": self._bot.http.upsert_guild_command,
"delete": self._bot.http.delete_guild_command,
"edit": self._bot.http.edit_guild_command,
}

def register(method: str, *args, **kwargs):
return registration_methods[method](self.user.id, guild_id, *args, **kwargs)
return registration_methods[method](self._bot.user.id, guild_id, *args, **kwargs)

pending_actions = []

Expand Down Expand Up @@ -472,9 +473,9 @@ def register(method: str, *args, **kwargs):

# TODO: Our lists dont work sometimes, see if that can be fixed so we can avoid this second API call
if guild_id is None:
registered = await self.http.get_global_commands(self.user.id)
registered = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered = await self.http.get_guild_commands(self.user.id, guild_id)
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)

for i in registered:
cmd = get(
Expand Down Expand Up @@ -626,7 +627,7 @@ async def sync_commands(
# Replace Role Names
if permission["type"] == 1:
role = get(
self.get_guild(guild_id).roles,
self._bot.get_guild(guild_id).roles,
name=permission["id"],
)

Expand Down Expand Up @@ -682,7 +683,7 @@ async def sync_commands(

# Upsert
try:
await self.http.bulk_upsert_command_permissions(self.user.id, guild_id, guild_cmd_perms)
await self._bot.http.bulk_upsert_command_permissions(self._bot.user.id, guild_id, guild_cmd_perms)
except Forbidden:
raise RuntimeError(
f"Failed to add command permissions to guild {guild_id}",
Expand Down Expand Up @@ -716,7 +717,7 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
based on the type of the command, respectively. Defaults to :attr:`.Bot.auto_sync_commands`.
"""
if auto_sync is None:
auto_sync = self.auto_sync_commands
auto_sync = self._bot.auto_sync_commands
if interaction.type not in (
InteractionType.application_command,
InteractionType.auto_complete,
Expand All @@ -740,7 +741,7 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
await self.sync_commands()
else:
await self.sync_commands(unregister_guilds=[guild_id])
return self.dispatch("unknown_application_command", interaction)
return self._bot.dispatch("unknown_application_command", interaction)

if interaction.type is InteractionType.auto_complete:
ctx = await self.get_autocomplete_context(interaction)
Expand Down Expand Up @@ -988,19 +989,24 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None:
ctx: :class:`.ApplicationCommand`
The invocation context to invoke.
"""
self.dispatch("application_command", ctx)
self._bot.dispatch("application_command", ctx)
try:
if await self.can_run(ctx, call_once=True):
if await self._bot.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise CheckFailure("The global check once functions failed.")
except DiscordException as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch("application_command_completion", ctx)
self._bot.dispatch("application_command_completion", ctx)

@property
@abstractmethod
def _bot(self) -> Union["Bot", "AutoShardedBot"]:
...


class BotBase(ApplicationCommandMixin, CogMixin):
class BotBase(ApplicationCommandMixin, CogMixin, ABC):
_supports_prefixed_commands = False

# TODO I think
Expand Down Expand Up @@ -1404,7 +1410,9 @@ class Bot(BotBase, Client):
.. versionadded:: 2.0
"""

pass
@property
def _bot(self) -> "Bot":
return self


class AutoShardedBot(BotBase, AutoShardedClient):
Expand All @@ -1414,4 +1422,6 @@ class AutoShardedBot(BotBase, AutoShardedClient):
.. versionadded:: 2.0
"""

pass
@property
def _bot(self) -> "AutoShardedBot":
return self
41 changes: 16 additions & 25 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]):
def __init__(self, func: Callable, **kwargs) -> None:
from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency

try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get("cooldown")
cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown"))

if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
Expand All @@ -160,16 +157,26 @@ def __init__(self, func: Callable, **kwargs) -> None:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self._buckets: CooldownMapping = buckets

try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get("max_concurrency")
max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency"))

self._max_concurrency: Optional[MaxConcurrency] = max_concurrency

self._callback = None
self.module = None

self.name: str = kwargs.get("name", func.__name__)

try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get("checks", [])

self.checks = checks
self.id: Optional[int] = kwargs.get("id")
self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None)
self.parent = kwargs.get("parent")

def __repr__(self) -> str:
return f"<discord.commands.{self.__class__.__name__} name={self.name}>"

Expand Down Expand Up @@ -578,21 +585,15 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
raise TypeError("Callback must be a coroutine.")
self.callback = func

self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None)

name = kwargs.get("name") or func.__name__
validate_chat_input_name(name)
self.name: str = name
validate_chat_input_name(self.name)
self.name_localizations: Optional[Dict[str, str]] = kwargs.get("name_localizations", None)
self.id = None

description = kwargs.get("description") or (
inspect.cleandoc(func.__doc__).splitlines()[0] if func.__doc__ is not None else "No description provided"
)
validate_chat_input_description(description)
self.description: str = description
self.description_localizations: Optional[Dict[str, str]] = kwargs.get("description_localizations", None)
self.parent = kwargs.get("parent")
self.attached_to_group: bool = False

self.cog = None
Expand Down Expand Up @@ -1152,25 +1153,15 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
raise TypeError("Callback must be a coroutine.")
self.callback = func

self.guild_ids: Optional[List[int]] = kwargs.get("guild_ids", None)

# Discord API doesn't support setting descriptions for context menu commands
# so it must be empty
self.description = ""
self.name: str = kwargs.pop("name", func.__name__)
if not isinstance(self.name, str):
raise TypeError("Name of a command must be a string.")

self.cog = None
self.id = None

try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get("checks", [])

self.checks = checks
self._before_invoke = None
self._after_invoke = None

Expand Down