Skip to content
This repository has been archived by the owner on Mar 9, 2022. It is now read-only.

Implement chat-type command checks. #14

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions slash_util/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


from .context import Context
from .core import Command

BotT = TypeVar("BotT", bound='Bot')

if TYPE_CHECKING:
from .bot import Bot
from .modal import Modal
from .core import Command
from typing import Awaitable, Any, Callable
from typing_extensions import Self, ParamSpec, Concatenate

Expand All @@ -34,14 +34,14 @@ async def command_error_wrapper(func: Callable[WrapperPS, Awaitable[Any]], *args
class Cog(commands.Cog, Generic[BotT]):
"""
The cog that must be used for application commands.

Attributes:
- bot: [``slash_util.Bot``](#class-botcommand_prefix-help_commanddefault-help-command-descriptionnone-options)
- - The bot instance."""
def __init__(self, bot: BotT):
self.bot: BotT = bot
self._commands: dict[str, Command]

async def slash_command_error(self, ctx: Context[BotT, Self], error: Exception) -> None:
print("Error occured in command", ctx.command.name, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__)
Expand All @@ -60,16 +60,16 @@ async def _internal_interaction_handler(self, interaction: discord.Interaction):

if interaction.type is not discord.InteractionType.application_command:
return

name = interaction.data['name'] # type: ignore
command = self._commands.get(name)

if not command:
return

state = self.bot._connection
params: dict = command._build_arguments(interaction, state)

ctx = Context(self.bot, command, interaction)
try:
await command_error_wrapper(command.invoke, ctx, **params)
Expand Down
86 changes: 71 additions & 15 deletions slash_util/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from collections import defaultdict
from typing import TYPE_CHECKING, TypeVar, overload, Union, Generic, get_origin, get_args, Literal

import discord, discord.state
import discord
import discord.state
from discord.ext import commands
import discord.ext.commands._types
from .cog import Cog

NumT = Union[int, float]

Expand All @@ -17,7 +21,6 @@
from typing_extensions import Concatenate, ParamSpec

from .bot import Bot
from .cog import Cog
from .context import Context

CmdP = ParamSpec("CmdP")
Expand All @@ -29,11 +32,13 @@
RngT = TypeVar("RngT", bound="Range")

__all__ = ['describe', 'slash_command', 'message_command', 'user_command', 'Range', 'Command', 'SlashCommand', 'ContextMenuCommand', 'UserCommand', 'MessageCommand']


def _parse_resolved_data(interaction: discord.Interaction, data, state: discord.state.ConnectionState):
if not data:
return {}

assert interaction.guild
assert interaction.guild
resolved = {}

resolved_users = data.get('users')
Expand All @@ -44,7 +49,7 @@ def _parse_resolved_data(interaction: discord.Interaction, data, state: discord.
member_data['user'] = d
member = discord.Member(data=member_data, guild=interaction.guild, state=state)
resolved[int(id)] = member

resolved_channels = data.get('channels')
if resolved_channels:
for id, d in resolved_channels.items():
Expand All @@ -64,16 +69,17 @@ def _parse_resolved_data(interaction: discord.Interaction, data, state: discord.
for id, d in resolved_roles.items():
role = discord.Role(guild=interaction.guild, state=state, data=d)
resolved[int(id)] = role

resolved_attachments = data.get('attachments')
if resolved_attachments:
for id, d in resolved_attachments.items():
attachment = discord.Attachment(state=state, data=d)
resolved[int(id)] = attachment


return resolved


command_type_map: dict[type[Any], int] = {
str: 3,
int: 4,
Expand All @@ -94,6 +100,7 @@ def _parse_resolved_data(interaction: discord.Interaction, data, state: discord.
discord.CategoryChannel: 4
}


def describe(**kwargs):
"""
Sets the description for the specified parameters of the slash command. Sample usage:
Expand All @@ -114,10 +121,11 @@ def _inner(cmd):
return cmd
return _inner


def slash_command(**kwargs) -> Callable[[CmdT], SlashCommand]:
"""
Defines a function as a slash-type application command.

Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
Expand All @@ -129,11 +137,12 @@ def slash_command(**kwargs) -> Callable[[CmdT], SlashCommand]:
def _inner(func: CmdT) -> SlashCommand:
return SlashCommand(func, **kwargs)
return _inner



def message_command(**kwargs) -> Callable[[MsgCmdT], MessageCommand]:
"""
Defines a function as a message-type application command.

Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
Expand All @@ -144,10 +153,11 @@ def _inner(func: MsgCmdT) -> MessageCommand:
return MessageCommand(func, **kwargs)
return _inner


def user_command(**kwargs) -> Callable[[UsrCmdT], UserCommand]:
"""
Defines a function as a user-type application command.

Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
Expand All @@ -158,7 +168,9 @@ def _inner(func: UsrCmdT) -> UserCommand:
return UserCommand(func, **kwargs)
return _inner


class _RangeMeta(type):

@overload
def __getitem__(cls: type[RngT], max: int) -> type[int]: ...
@overload
Expand All @@ -173,6 +185,7 @@ def __getitem__(cls, max):
return cls(*max)
return cls(None, max)


class Range(metaclass=_RangeMeta):
"""
Defines a minimum and maximum value for float or int values. The minimum value is optional.
Expand All @@ -186,22 +199,45 @@ def __init__(self, min: NumT | None, max: NumT):
self.min = min
self.max = max


class Command(Generic[CogT]):
cog: CogT
func: Callable
name: str
guild_id: int | None
checks: list[commands._types.Check]

def _build_command_payload(self) -> dict[str, Any]:
raise NotImplementedError

def _build_arguments(self, interaction: discord.Interaction, state: discord.state.ConnectionState) -> dict[str, Any]:
raise NotImplementedError

async def can_run(self, ctx: Context[BotT, CogT]) -> bool:

if not await ctx.bot.can_run(ctx):
raise commands.CheckFailure(f"The global check functions for application command '{self.name}' failed.")

local_check = Cog._get_overridden_method(self.cog.cog_check)
if local_check is not None:
ret = await discord.utils.maybe_coroutine(local_check, ctx)
if not ret:
return False

predicates = self.checks
if not predicates:
return True

return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore

async def invoke(self, context: Context[BotT, CogT], **params) -> None:
await self.func(self.cog, context, **params)

if not await self.can_run(context):
raise commands.CheckFailure(f"The check functions for application command '{self.name}' failed.")


class SlashCommand(Command[CogT]):

def __init__(self, func: CmdT, **kwargs):
self.func = func
self.cog: CogT
Expand All @@ -215,6 +251,12 @@ def __init__(self, func: CmdT, **kwargs):
self.parameters = self._build_parameters()
self._parameter_descriptions: dict[str, str] = defaultdict(lambda: "No description provided")

try:
checks = function.__commands_checks__ # type: ignore
except AttributeError:
checks = kwargs.get("checks", [])
self.checks: list[commands._types.Check] = checks # type: ignore

def _build_arguments(self, interaction, state):
if 'options' not in interaction.data:
return {}
Expand All @@ -235,7 +277,7 @@ def _build_parameters(self) -> dict[str, inspect.Parameter]:
params.pop(0)
except IndexError:
raise ValueError("expected argument `self` is missing")

try:
params.pop(0)
except IndexError:
Expand All @@ -246,7 +288,7 @@ def _build_parameters(self) -> dict[str, inspect.Parameter]:
def _build_descriptions(self):
if not hasattr(self.func, '_param_desc_'):
return

for k, v in self.func._param_desc_.items():
if k not in self.parameters:
raise TypeError(f"@describe used to describe a non-existant parameter `{k}`")
Expand Down Expand Up @@ -292,7 +334,7 @@ def _build_command_payload(self):
}
if param.default is param.empty:
option['required'] = True

if isinstance(ann, Range):
option['max_value'] = ann.max
option['min_value'] = ann.min
Expand All @@ -313,12 +355,17 @@ def _build_command_payload(self):

elif issubclass(ann, discord.abc.GuildChannel):
option['channel_types'] = [channel_filter[ann]]

options.append(option)
options.sort(key=lambda f: not f.get('required'))
payload['options'] = options
return payload

async def invoke(self, context: Context[BotT, CogT], **params) -> None:
await super().invoke(context, **params)
await self.func(self.cog, context, **params)


class ContextMenuCommand(Command[CogT]):
_type: ClassVar[int]

Expand All @@ -327,6 +374,12 @@ def __init__(self, func: CtxMnT, **kwargs):
self.guild_id: int | None = kwargs.get('guild_id', None)
self.name: str = kwargs.get('name', func.__name__)

try:
checks = function.__commands_checks__ # type: ignore
except AttributeError:
checks = kwargs.get("checks", [])
self.checks: list[commands._types.Check] = checks # type: ignore

def _build_command_payload(self):
payload = {
'name': self.name,
Expand All @@ -342,10 +395,13 @@ def _build_arguments(self, interaction: discord.Interaction, state: discord.stat
return {'target': value}

async def invoke(self, context: Context[BotT, CogT], **params) -> None:
await super().invoke(context, **params)
await self.func(self.cog, context, *params.values())


class MessageCommand(ContextMenuCommand[CogT]):
_type = 3


class UserCommand(ContextMenuCommand[CogT]):
_type = 2