From 8c8cb152e47a68706c28a017e4b82118326938d5 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 27 Aug 2023 22:13:32 -0700 Subject: [PATCH 1/8] Add ability to parse Unions into ext.commands --- twitchio/ext/commands/core.py | 114 +++++++++++++++++++++++++++++--- twitchio/ext/commands/errors.py | 2 + 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 57556545..253d78b2 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -27,6 +27,7 @@ import itertools import copy +import types from typing import Any, Union, Optional, Callable, Awaitable, Tuple, TYPE_CHECKING, List, Type, Set, TypeVar from typing_extensions import Literal @@ -36,13 +37,33 @@ from . import builtin_converter if TYPE_CHECKING: + import sys + from twitchio import Message, Chatter, PartialChatter, Channel, User, PartialUser from . import Cog, Bot from .stringparser import StringParser + + if sys.version_info >= (3, 8): + UnionT = Union[types.UnionType, Union] + else: + UnionT = Union + + __all__ = ("Command", "command", "Group", "Context", "cooldown") -def _boolconverter(param: str): +class EmptyArgumentSentinel: + def __repr__(self) -> str: + return "" + + def __eq__(self, __value: object) -> bool: + return False + + +EMPTY = EmptyArgumentSentinel() + + +def _boolconverter(ctx: Context, param: str): param = param.lower() if param in {"yes", "y", "1", "true", "on"}: return True @@ -113,32 +134,94 @@ def full_name(self) -> str: if not self.parent: return self._name return f"{self.parent.full_name} {self._name}" + + def _is_optional_argument(self, converter: Any): + return (getattr(converter, '__origin__', None) is Union or isinstance(converter, types.UnionType)) and type(None) in converter.__args__ - def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Union[Callable[..., Any]]: + def resolve_union_callback(self, converter: UnionT) -> Callable[[Context, str], Any]: + #print(type(converter), converter.__args__) + + args = converter.__args__ # type: ignore # pyright doesnt like this + + async def _resolve(context: Context, arg: str) -> Any: + t = EMPTY + last = None + + for original in args: + underlying = self._resolve_converter(original) + + try: + t: Any = underlying(context, arg) + if inspect.iscoroutine(t): + t = await t + + break + except Exception as l: + last = l + t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back + continue + + if t is EMPTY: + fmt = f"Failed to convert argument '{arg}' to any of {', '.join(str(x) for x in args)}" + raise UnionArgumentParsingFailed(fmt, last) # type: ignore # if t is EMPTY, there has to be a last error + + return t + + return _resolve + + def resolve_optional_callback(self, converter: Any) -> Callable[[Context, str], Any]: + underlying = self._resolve_converter(converter.__args__[0]) + + async def _resolve(context: Context, arg: str) -> Any: + try: + t: Any = underlying(context, arg) + if inspect.iscoroutine(t): + t = await t + + except Exception: + return EMPTY # instruct the parser to roll back and ignore this argument + + return t + + return _resolve + + def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: + #print(dir(converter)) if ( isinstance(converter, type) and converter.__module__.startswith("twitchio") and converter in builtin_converter._mapping ): return builtin_converter._mapping[converter] - return converter + + elif self._is_optional_argument(converter): + return self.resolve_optional_callback(converter) + + elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: + return self.resolve_union_callback(converter) # type: ignore + + if converter is bool: + converter = _boolconverter + + elif converter in (str, int): + _original = converter + converter = lambda _, param: _original(param) # type: ignore # the types dont take a ctx argument, so strip that out here + + return converter # type: ignore async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any: converter = param.annotation + if converter is param.empty: if param.default in (param.empty, None): converter = str else: converter = type(param.default) + true_converter = self._resolve_converter(converter) try: - if true_converter in (int, str): - argument = true_converter(parsed) - elif true_converter is bool: - argument = _boolconverter(parsed) - else: - argument = true_converter(context, parsed) + argument = true_converter(context, parsed) if inspect.iscoroutine(argument): argument = await argument except BadArgument: @@ -174,8 +257,17 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di raise MissingRequiredArgument(param) args.append(param.default) else: - argument = await self._convert_types(context, param, argument) - args.append(argument) + _parsed_arg = await self._convert_types(context, param, argument) + + if _parsed_arg is EMPTY: + parsed[index] = argument + index -= 1 + args.append(None) + + continue + else: + args.append(_parsed_arg) + elif param.kind == param.KEYWORD_ONLY: rest = " ".join(parsed.values()) if rest.startswith(" "): diff --git a/twitchio/ext/commands/errors.py b/twitchio/ext/commands/errors.py index 04ba0ab9..e6148851 100644 --- a/twitchio/ext/commands/errors.py +++ b/twitchio/ext/commands/errors.py @@ -52,6 +52,8 @@ def __init__(self, message: str, original: Exception): self.original = original super().__init__(message) +class UnionArgumentParsingFailed(ArgumentParsingFailed): + pass class CommandNotFound(TwitchCommandError): pass From fccb1667b3019db78ccf0b91c57c2e7ae8dd2d85 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 27 Aug 2023 22:20:19 -0700 Subject: [PATCH 2/8] pipe was 3.10, not 3.8 --- twitchio/ext/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 393b6746..ecd6b692 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -43,7 +43,7 @@ from . import Cog, Bot from .stringparser import StringParser - if sys.version_info >= (3, 8): + if sys.version_info >= (3, 10): UnionT = Union[types.UnionType, Union] else: UnionT = Union From 80ca63470fed678accdd1325701e5425e3462eed Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 3 Sep 2023 19:00:19 -0700 Subject: [PATCH 3/8] Fix optional parsing when optional is the last argument --- twitchio/ext/commands/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index ecd6b692..8e430f2d 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -253,6 +253,10 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di try: argument = parsed.pop(index) except (KeyError, IndexError): + if self._is_optional_argument(param.annotation): # parameter is optional and at the end. + args.append(param.default if param.default is not param.empty else None) + continue + if param.default is param.empty: raise MissingRequiredArgument(param) args.append(param.default) @@ -262,7 +266,7 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di if _parsed_arg is EMPTY: parsed[index] = argument index -= 1 - args.append(None) + args.append(param.default if param.default is not param.empty else None) continue else: From 8eb5a40bc7ef068a39232728d77744633eb4e06d Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 3 Sep 2023 20:13:54 -0700 Subject: [PATCH 4/8] Add Annotated support --- twitchio/ext/commands/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 8e430f2d..f97839ee 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -200,6 +200,10 @@ def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Cal elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: return self.resolve_union_callback(converter) # type: ignore + elif hasattr(converter, "__metadata__"): # Annotated + annotated = converter.__metadata__ # type: ignore + return self._resolve_converter(annotated[0]) + if converter is bool: converter = _boolconverter From d4f54eb5a7c086a7479bfdbcc7c36485bf212ae5 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 3 Sep 2023 20:15:23 -0700 Subject: [PATCH 5/8] Run black --- twitchio/ext/commands/core.py | 64 +++++++++++++++++---------------- twitchio/ext/commands/errors.py | 2 ++ 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index f97839ee..b15b8916 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -134,14 +134,16 @@ def full_name(self) -> str: if not self.parent: return self._name return f"{self.parent.full_name} {self._name}" - + def _is_optional_argument(self, converter: Any): - return (getattr(converter, '__origin__', None) is Union or isinstance(converter, types.UnionType)) and type(None) in converter.__args__ + return (getattr(converter, "__origin__", None) is Union or isinstance(converter, types.UnionType)) and type( + None + ) in converter.__args__ def resolve_union_callback(self, converter: UnionT) -> Callable[[Context, str], Any]: - #print(type(converter), converter.__args__) + # print(type(converter), converter.__args__) - args = converter.__args__ # type: ignore # pyright doesnt like this + args = converter.__args__ # type: ignore # pyright doesnt like this async def _resolve(context: Context, arg: str) -> Any: t = EMPTY @@ -154,39 +156,39 @@ async def _resolve(context: Context, arg: str) -> Any: t: Any = underlying(context, arg) if inspect.iscoroutine(t): t = await t - + break except Exception as l: last = l - t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back + t = EMPTY # thisll get changed when t is a coroutine, but is still invalid, so roll it back continue - + if t is EMPTY: fmt = f"Failed to convert argument '{arg}' to any of {', '.join(str(x) for x in args)}" - raise UnionArgumentParsingFailed(fmt, last) # type: ignore # if t is EMPTY, there has to be a last error - + raise UnionArgumentParsingFailed(fmt, last) # type: ignore # if t is EMPTY, there has to be a last error + return t - + return _resolve def resolve_optional_callback(self, converter: Any) -> Callable[[Context, str], Any]: underlying = self._resolve_converter(converter.__args__[0]) - + async def _resolve(context: Context, arg: str) -> Any: try: t: Any = underlying(context, arg) if inspect.iscoroutine(t): t = await t - + except Exception: - return EMPTY # instruct the parser to roll back and ignore this argument - + return EMPTY # instruct the parser to roll back and ignore this argument + return t - + return _resolve def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: - #print(dir(converter)) + # print(dir(converter)) if ( isinstance(converter, type) and converter.__module__.startswith("twitchio") @@ -196,32 +198,32 @@ def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Cal elif self._is_optional_argument(converter): return self.resolve_optional_callback(converter) - + elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: - return self.resolve_union_callback(converter) # type: ignore - - elif hasattr(converter, "__metadata__"): # Annotated - annotated = converter.__metadata__ # type: ignore + return self.resolve_union_callback(converter) # type: ignore + + elif hasattr(converter, "__metadata__"): # Annotated + annotated = converter.__metadata__ # type: ignore return self._resolve_converter(annotated[0]) - + if converter is bool: converter = _boolconverter - + elif converter in (str, int): _original = converter - converter = lambda _, param: _original(param) # type: ignore # the types dont take a ctx argument, so strip that out here - - return converter # type: ignore + converter = lambda _, param: _original(param) # type: ignore # the types dont take a ctx argument, so strip that out here + + return converter # type: ignore async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any: converter = param.annotation - + if converter is param.empty: if param.default in (param.empty, None): converter = str else: converter = type(param.default) - + true_converter = self._resolve_converter(converter) try: @@ -257,7 +259,7 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di try: argument = parsed.pop(index) except (KeyError, IndexError): - if self._is_optional_argument(param.annotation): # parameter is optional and at the end. + if self._is_optional_argument(param.annotation): # parameter is optional and at the end. args.append(param.default if param.default is not param.empty else None) continue @@ -275,7 +277,7 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di continue else: args.append(_parsed_arg) - + elif param.kind == param.KEYWORD_ONLY: rest = " ".join(parsed.values()) if rest.startswith(" "): @@ -313,7 +315,7 @@ async def try_run(func, *, to_command=False): try: args, kwargs = await self.parse_args(context, self._instance, context.view.words, index=index) - except (MissingRequiredArgument, BadArgument) as e: + except (MissingRequiredArgument, BadArgument) as e: if self.event_error: args_ = [self._instance, context] if self._instance else [context] await try_run(self.event_error(*args_, e)) diff --git a/twitchio/ext/commands/errors.py b/twitchio/ext/commands/errors.py index e6148851..d754f1b1 100644 --- a/twitchio/ext/commands/errors.py +++ b/twitchio/ext/commands/errors.py @@ -52,9 +52,11 @@ def __init__(self, message: str, original: Exception): self.original = original super().__init__(message) + class UnionArgumentParsingFailed(ArgumentParsingFailed): pass + class CommandNotFound(TwitchCommandError): pass From 80cc23e35fb15636cd1a8017059b67c96ed470c2 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Mon, 4 Sep 2023 15:50:31 -0700 Subject: [PATCH 6/8] revamp errors with proper useful messages and details on objects --- twitchio/ext/commands/core.py | 72 +++++++++++++++++++-------------- twitchio/ext/commands/errors.py | 42 ++++++++++++++----- 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index b15b8916..599c827f 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -63,7 +63,7 @@ def __eq__(self, __value: object) -> bool: EMPTY = EmptyArgumentSentinel() -def _boolconverter(ctx: Context, param: str): +def _boolconverter(param: str): param = param.lower() if param in {"yes", "y", "1", "true", "on"}: return True @@ -140,7 +140,7 @@ def _is_optional_argument(self, converter: Any): None ) in converter.__args__ - def resolve_union_callback(self, converter: UnionT) -> Callable[[Context, str], Any]: + def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Context, str], Any]: # print(type(converter), converter.__args__) args = converter.__args__ # type: ignore # pyright doesnt like this @@ -150,7 +150,7 @@ async def _resolve(context: Context, arg: str) -> Any: last = None for original in args: - underlying = self._resolve_converter(original) + underlying = self._resolve_converter(name, original) try: t: Any = underlying(context, arg) @@ -164,15 +164,14 @@ async def _resolve(context: Context, arg: str) -> Any: continue if t is EMPTY: - fmt = f"Failed to convert argument '{arg}' to any of {', '.join(str(x) for x in args)}" - raise UnionArgumentParsingFailed(fmt, last) # type: ignore # if t is EMPTY, there has to be a last error + raise UnionArgumentParsingFailed(name, args) return t return _resolve - def resolve_optional_callback(self, converter: Any) -> Callable[[Context, str], Any]: - underlying = self._resolve_converter(converter.__args__[0]) + def resolve_optional_callback(self, name: str, converter: Any) -> Callable[[Context, str], Any]: + underlying = self._resolve_converter(name, converter.__args__[0]) async def _resolve(context: Context, arg: str) -> Any: try: @@ -187,33 +186,46 @@ async def _resolve(context: Context, arg: str) -> Any: return _resolve - def _resolve_converter(self, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: - # print(dir(converter)) + def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: + if ( isinstance(converter, type) and converter.__module__.startswith("twitchio") and converter in builtin_converter._mapping ): - return builtin_converter._mapping[converter] + return self._convert_builtin_type(name, converter, builtin_converter._mapping[converter]) + + elif converter is bool: + converter = self._convert_builtin_type(name, bool, _boolconverter) + + elif converter in (str, int): + converter = self._convert_builtin_type(name, converter, converter) # type: ignore elif self._is_optional_argument(converter): - return self.resolve_optional_callback(converter) + return self.resolve_optional_callback(name, converter) elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: - return self.resolve_union_callback(converter) # type: ignore + return self.resolve_union_callback(name, converter) # type: ignore elif hasattr(converter, "__metadata__"): # Annotated annotated = converter.__metadata__ # type: ignore - return self._resolve_converter(annotated[0]) - - if converter is bool: - converter = _boolconverter - - elif converter in (str, int): - _original = converter - converter = lambda _, param: _original(param) # type: ignore # the types dont take a ctx argument, so strip that out here + return self._resolve_converter(name, annotated[0]) return converter # type: ignore + + def _convert_builtin_type(self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]]) -> Callable[[Context, str], Awaitable[Any]]: + async def resolve(_, arg: str) -> Any: + try: + t = converter(arg) + + if inspect.iscoroutine(t): + t = await t + + return t + except Exception as e: + raise ArgumentParsingFailed(f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`", original=e, argname=arg_name, expected=original) from e + + return resolve async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any: converter = param.annotation @@ -224,20 +236,19 @@ async def _convert_types(self, context: Context, param: inspect.Parameter, parse else: converter = type(param.default) - true_converter = self._resolve_converter(converter) + true_converter = self._resolve_converter(param.name, converter) try: argument = true_converter(context, parsed) if inspect.iscoroutine(argument): argument = await argument - except BadArgument: + except BadArgument as e: + if e.name is None: + e.name = param.name + raise except Exception as e: - raise ArgumentParsingFailed( - f"Invalid argument parsed at `{param.name}` in command `{self.name}`." - f" Expected type {converter} got {type(parsed)}.", - e, - ) from e + raise ArgumentParsingFailed(f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None) from e return argument async def parse_args(self, context: Context, instance: Optional[Cog], parsed: dict, index=0) -> Tuple[list, dict]: @@ -264,7 +275,8 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di continue if param.default is param.empty: - raise MissingRequiredArgument(param) + raise MissingRequiredArgument(argname=param.name) + args.append(param.default) else: _parsed_arg = await self._convert_types(context, param, argument) @@ -285,13 +297,13 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di if rest: rest = await self._convert_types(context, param, rest) elif param.default is param.empty: - raise MissingRequiredArgument(param) + raise MissingRequiredArgument(argname=param.name) else: rest = param.default kwargs[param.name] = rest parsed.clear() break - elif param.VAR_POSITIONAL: + elif param.kind == param.VAR_POSITIONAL: args.extend([await self._convert_types(context, param, argument) for argument in parsed.values()]) parsed.clear() break diff --git a/twitchio/ext/commands/errors.py b/twitchio/ext/commands/errors.py index d754f1b1..59d0330e 100644 --- a/twitchio/ext/commands/errors.py +++ b/twitchio/ext/commands/errors.py @@ -21,7 +21,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .core import Command class TwitchCommandError(Exception): """Base TwitchIO Command Error. All command errors derive from this error.""" @@ -38,33 +43,50 @@ class InvalidCog(TwitchCommandError): class MissingRequiredArgument(TwitchCommandError): - pass + def __init__(self, *args, argname: Optional[str] = None) -> None: + self.name: str = (argname or "unknown") + + if args: + super().__init__(*args) + else: + super().__init__(f"Missing required argument `{self.name}`") class BadArgument(TwitchCommandError): - def __init__(self, message: str): + def __init__(self, message: str, argname: Optional[str] = None): + self.name: str = argname # type: ignore # this'll get fixed in the parser handler self.message = message super().__init__(message) class ArgumentParsingFailed(BadArgument): - def __init__(self, message: str, original: Exception): - self.original = original - super().__init__(message) + def __init__(self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None): + self.original: Exception = original + self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none. + self.expected_type: Optional[type] = expected + + Exception.__init__(self, message) # bypass badArgument class UnionArgumentParsingFailed(ArgumentParsingFailed): - pass + def __init__(self, argname: str, expected: tuple[type, ...]): + self.name: str = argname + self.expected_type: tuple[type, ...] = expected + + self.message = f"Failed to convert argument `{self.name}` to any of the valid options" + Exception.__init__(self, self.message) class CommandNotFound(TwitchCommandError): - pass + def __init__(self, message: str, name: str) -> None: + self.name: str = name + super().__init__(message) class CommandOnCooldown(TwitchCommandError): - def __init__(self, command, retry_after): - self.command = command - self.retry_after = retry_after + def __init__(self, command: Command, retry_after: float): + self.command: Command = command + self.retry_after: float = retry_after super().__init__(f"Command <{command.name}> is on cooldown. Try again in ({retry_after:.2f})s") From 8c409ea216d19505b59078a50992e9314d8accfa Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sat, 9 Sep 2023 10:00:40 -0700 Subject: [PATCH 7/8] update changelog with changes --- docs/changelog.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index bf8ddd76..3f51367c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -21,6 +21,13 @@ Master - Bug fixes - Fix websocket reconnection event. +- ext.commands + - Additions + - Added support for the following typing constructs in command signatures: + - ``Union[A, B]`` / ``A | B`` + - ``Optional[T]`` / ``T | None`` + - ``Annotated[T, converter]`` (accessible through the ``typing_extensions`` module on older python versions) + 2.7.0 ====== From 22d124a0953823a4f7d00a2970728d1c44f1bd42 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sat, 9 Sep 2023 10:03:57 -0700 Subject: [PATCH 8/8] run black --- twitchio/ext/commands/core.py | 28 ++++++++++++++++++---------- twitchio/ext/commands/errors.py | 13 ++++++++----- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 599c827f..b0602d93 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -187,7 +187,6 @@ async def _resolve(context: Context, arg: str) -> Any: return _resolve def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, type]) -> Callable[..., Any]: - if ( isinstance(converter, type) and converter.__module__.startswith("twitchio") @@ -199,32 +198,39 @@ def _resolve_converter(self, name: str, converter: Union[Callable, Awaitable, ty converter = self._convert_builtin_type(name, bool, _boolconverter) elif converter in (str, int): - converter = self._convert_builtin_type(name, converter, converter) # type: ignore + converter = self._convert_builtin_type(name, converter, converter) # type: ignore elif self._is_optional_argument(converter): return self.resolve_optional_callback(name, converter) elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union: - return self.resolve_union_callback(name, converter) # type: ignore + return self.resolve_union_callback(name, converter) # type: ignore elif hasattr(converter, "__metadata__"): # Annotated annotated = converter.__metadata__ # type: ignore return self._resolve_converter(name, annotated[0]) return converter # type: ignore - - def _convert_builtin_type(self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]]) -> Callable[[Context, str], Awaitable[Any]]: + + def _convert_builtin_type( + self, arg_name: str, original: type, converter: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]] + ) -> Callable[[Context, str], Awaitable[Any]]: async def resolve(_, arg: str) -> Any: try: t = converter(arg) if inspect.iscoroutine(t): t = await t - + return t except Exception as e: - raise ArgumentParsingFailed(f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`", original=e, argname=arg_name, expected=original) from e - + raise ArgumentParsingFailed( + f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`", + original=e, + argname=arg_name, + expected=original, + ) from e + return resolve async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any: @@ -248,7 +254,9 @@ async def _convert_types(self, context: Context, param: inspect.Parameter, parse raise except Exception as e: - raise ArgumentParsingFailed(f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None) from e + raise ArgumentParsingFailed( + f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None + ) from e return argument async def parse_args(self, context: Context, instance: Optional[Cog], parsed: dict, index=0) -> Tuple[list, dict]: @@ -276,7 +284,7 @@ async def parse_args(self, context: Context, instance: Optional[Cog], parsed: di if param.default is param.empty: raise MissingRequiredArgument(argname=param.name) - + args.append(param.default) else: _parsed_arg = await self._convert_types(context, param, argument) diff --git a/twitchio/ext/commands/errors.py b/twitchio/ext/commands/errors.py index 59d0330e..eeaa2229 100644 --- a/twitchio/ext/commands/errors.py +++ b/twitchio/ext/commands/errors.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from .core import Command + class TwitchCommandError(Exception): """Base TwitchIO Command Error. All command errors derive from this error.""" @@ -44,7 +45,7 @@ class InvalidCog(TwitchCommandError): class MissingRequiredArgument(TwitchCommandError): def __init__(self, *args, argname: Optional[str] = None) -> None: - self.name: str = (argname or "unknown") + self.name: str = argname or "unknown" if args: super().__init__(*args) @@ -54,18 +55,20 @@ def __init__(self, *args, argname: Optional[str] = None) -> None: class BadArgument(TwitchCommandError): def __init__(self, message: str, argname: Optional[str] = None): - self.name: str = argname # type: ignore # this'll get fixed in the parser handler + self.name: str = argname # type: ignore # this'll get fixed in the parser handler self.message = message super().__init__(message) class ArgumentParsingFailed(BadArgument): - def __init__(self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None): + def __init__( + self, message: str, original: Exception, argname: Optional[str] = None, expected: Optional[type] = None + ): self.original: Exception = original - self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none. + self.name: str = argname # type: ignore # in theory this'll never be None but if someone is creating this themselves itll be none. self.expected_type: Optional[type] = expected - Exception.__init__(self, message) # bypass badArgument + Exception.__init__(self, message) # bypass badArgument class UnionArgumentParsingFailed(ArgumentParsingFailed):