Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #1805 #1806

Merged
merged 7 commits into from Jun 9, 2023
Merged

fix: #1805 #1806

merged 7 commits into from Jun 9, 2023

Conversation

Nzii3
Copy link
Contributor

@Nzii3 Nzii3 commented Nov 29, 2022

Summary

Fixes the commands.FlagConverter bug in the stable release.
Fixes #1805

Information

  • This PR fixes an issue.
  • This PR adds something new (e.g. new method or parameters).
  • This PR is a breaking change (e.g. methods or parameters removed/renamed).
  • This PR is not a code change (e.g. documentation, README, typehinting,
    examples, ...).

Checklist

  • I have searched the open pull requests for duplicates.
  • If code changes were made then they have been tested.
    • I have updated the documentation to reflect the changes.
  • If type: ignore comments were used, a comment is also left explaining why.

@Nzii3 Nzii3 requested a review from a team as a code owner November 29, 2022 01:37
@Nzii3
Copy link
Contributor Author

Nzii3 commented Nov 29, 2022

The updated flags.py file works in my bots that are using flags. Here is the commit that changed MISSING to MissingField which broke it: #1680 (specific file commit)

@BobDotCom
Copy link
Member

BobDotCom commented Nov 29, 2022

I'll wait on the CI to be sure, but I'm pretty sure this is a syntax error on python 3.11

Edit: The CI doesn't catch it, but a ValueError is raised on import, causing the library to completely fail on 3.11. I'll look into a better fix.

@BobDotCom BobDotCom marked this pull request as draft November 29, 2022 03:06
@Nzii3
Copy link
Contributor Author

Nzii3 commented Nov 29, 2022

👍

@Lulalaby
Copy link
Member

Lulalaby commented Dec 7, 2022

bob do something

@Lulalaby
Copy link
Member

@Pycord-Development/maintainers

@Lulalaby Lulalaby changed the title Fix issue #1805 fix: #1805 Jan 5, 2023
@JustaSqu1d JustaSqu1d added the status: in progress Work in Progess label Jan 9, 2023
@xFGhoul
Copy link
Member

xFGhoul commented Feb 25, 2023

are there any updates?

@BobDotCom
Copy link
Member

are there any updates?

Haven't had time to look at this yet. If anyone is able to work on it, that would be great. Otherwise it'll get done when I get around to it.

@NeloBlivion
Copy link
Member

We discussed this with Ghoul in a help thread and he managed to come up with this very cursed solution... I wouldn't dare myself but perhaps this could be explored if you're desperate.
image

@xFGhoul
Copy link
Member

xFGhoul commented Apr 9, 2023

We discussed this with Ghoul in a help thread and he managed to come up with this very cursed solution... I wouldn't dare myself but perhaps this could be explored if you're desperate. image

yeah...I would not exactly recommend this be implemented...not the best at typing in python so surely someone will come up with something better.

@VincentRPS
Copy link
Member

I'm not too sure about it but I think #2008 might fix it?

MissingField = field(default=MISSING)

# ...
    name: Maybe[str] = MissingField
    aliases: list[str] = field(default_factory=list)
    attribute: Maybe[str] = MissingField
    annotation: Maybe[Any] = MissingField
    default: Maybe[Any] = MissingField
    max_args: Maybe[int] = MissingField
    override: Maybe[bool] = MissingField
    cast_to_dict: Maybe[bool] = False

If the ValueError originates from type concatenation, this should fix it perfectly.

@xFGhoul
Copy link
Member

xFGhoul commented Apr 9, 2023

I'm not too sure about it but I think #2008 might fix it?

MissingField = field(default=MISSING)

# ...
    name: Maybe[str] = MissingField
    aliases: list[str] = field(default_factory=list)
    attribute: Maybe[str] = MissingField
    annotation: Maybe[Any] = MissingField
    default: Maybe[Any] = MissingField
    max_args: Maybe[int] = MissingField
    override: Maybe[bool] = MissingField
    cast_to_dict: Maybe[bool] = False

If the ValueError originates from type concatenation, this should fix it perfectly.

potentially, yes, have you tested?

@Nzii3
Copy link
Contributor Author

Nzii3 commented Jun 9, 2023

I think it’s time for this to be fixed, I have to copy and paste a fix everything I install or update py-cord in new bots, which is becoming frustrating since I keep getting the original error if I don’t.

@Lulalaby
Copy link
Member

Lulalaby commented Jun 9, 2023

Yeah what now. Is this tested. Is that fixed.

@Nzii3
Copy link
Contributor Author

Nzii3 commented Jun 9, 2023

This is exactly what I used to fix it and it works fine. But idk what other problems it can cause. So far i've found none.
image

from __future__ import annotations

import inspect
import re
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Iterator, Literal, Pattern, TypeVar, Union

from discord.utils import MISSING, MissingField, maybe_coroutine, resolve_annotation

from .converter import run_converters
from .errors import (
    BadFlagArgument,
    CommandError,
    MissingFlagArgument,
    MissingRequiredFlag,
    TooManyFlags,
)
from .view import StringView

__all__ = (
    "Flag",
    "flag",
    "FlagConverter",
)


if TYPE_CHECKING:
    from .context import Context


@dataclass
class Flag:
    """Represents a flag parameter for :class:`FlagConverter`.

    The :func:`~discord.ext.commands.flag` function helps
    create these flag objects, but it is not necessary to
    do so. These cannot be constructed manually.

    Attributes
    ----------
    name: :class:`str`
        The name of the flag.
    aliases: List[:class:`str`]
        The aliases of the flag name.
    attribute: :class:`str`
        The attribute in the class that corresponds to this flag.
    default: Any
        The default value of the flag, if available.
    annotation: Any
        The underlying evaluated annotation of the flag.
    max_args: :class:`int`
        The maximum number of arguments the flag can accept.
        A negative value indicates an unlimited amount of arguments.
    override: :class:`bool`
        Whether multiple given values overrides the previous value.
    """

    name: str = MISSING
    aliases: list[str] = field(default_factory=list)
    attribute: str = MISSING
    annotation: Any = MISSING
    default: Any = MISSING
    max_args: int = MISSING
    override: bool = MISSING
    cast_to_dict: bool = False

    @property
    def required(self) -> bool:
        """Whether the flag is required.

        A required flag has no default value.
        """
        return self.default is MISSING


def flag(
    *,
    name: str = MISSING,
    aliases: list[str] = MISSING,
    default: Any = MISSING,
    max_args: int = MISSING,
    override: bool = MISSING,
) -> Any:
    """Override default functionality and parameters of the underlying :class:`FlagConverter`
    class attributes.

    Parameters
    ----------
    name: :class:`str`
        The flag name. If not given, defaults to the attribute name.
    aliases: List[:class:`str`]
        Aliases to the flag name. If not given, no aliases are set.
    default: Any
        The default parameter. This could be either a value or a callable that takes
        :class:`Context` as its sole parameter. If not given then it defaults to
        the default value given to the attribute.
    max_args: :class:`int`
        The maximum number of arguments the flag can accept.
        A negative value indicates an unlimited amount of arguments.
        The default value depends on the annotation given.
    override: :class:`bool`
        Whether multiple given values overrides the previous value. The default
        value depends on the annotation given.
    """
    return Flag(
        name=name,
        aliases=aliases,
        default=default,
        max_args=max_args,
        override=override,
    )


def validate_flag_name(name: str, forbidden: set[str]):
    if not name:
        raise ValueError("flag names should not be empty")

    for ch in name:
        if ch.isspace():
            raise ValueError(f"flag name {name!r} cannot have spaces")
        if ch == "\\":
            raise ValueError(f"flag name {name!r} cannot have backslashes")
        if ch in forbidden:
            raise ValueError(
                f"flag name {name!r} cannot have any of {forbidden!r} within them"
            )


def get_flags(
    namespace: dict[str, Any], globals: dict[str, Any], locals: dict[str, Any]
) -> dict[str, Flag]:
    annotations = namespace.get("__annotations__", {})
    case_insensitive = namespace["__commands_flag_case_insensitive__"]
    flags: dict[str, Flag] = {}
    cache: dict[str, Any] = {}
    names: set[str] = set()
    for name, annotation in annotations.items():
        flag = namespace.pop(name, MISSING)
        if isinstance(flag, Flag):
            flag.annotation = annotation
        else:
            flag = Flag(name=name, annotation=annotation, default=flag)

        flag.attribute = name
        if flag.name is MISSING:
            flag.name = name

        annotation = flag.annotation = resolve_annotation(
            flag.annotation, globals, locals, cache
        )

        if (
            flag.default is MISSING
            and hasattr(annotation, "__commands_is_flag__")
            and annotation._can_be_constructible()
        ):
            flag.default = annotation._construct_default

        if flag.aliases is MISSING:
            flag.aliases = []

        # Add sensible defaults based off of the type annotation
        # <type> -> (max_args=1)
        # List[str] -> (max_args=-1)
        # Tuple[int, ...] -> (max_args=1)
        # Dict[K, V] -> (max_args=-1, override=True)
        # Union[str, int] -> (max_args=1)
        # Optional[str] -> (default=None, max_args=1)

        try:
            origin = annotation.__origin__
        except AttributeError:
            # A regular type hint
            if flag.max_args is MISSING:
                flag.max_args = 1
        else:
            if origin is Union:
                # typing.Union
                if flag.max_args is MISSING:
                    flag.max_args = 1
                if annotation.__args__[-1] is type(None) and flag.default is MISSING:
                    # typing.Optional
                    flag.default = None
            elif origin is tuple:
                # typing.Tuple
                # tuple parsing is e.g. `flag: peter 20`
                # for Tuple[str, int] would give you flag: ('peter', 20)
                if flag.max_args is MISSING:
                    flag.max_args = 1
            elif origin is list:
                # typing.List
                if flag.max_args is MISSING:
                    flag.max_args = -1
            elif origin is dict:
                # typing.Dict[K, V]
                # Equivalent to:
                # typing.List[typing.Tuple[K, V]]
                flag.cast_to_dict = True
                if flag.max_args is MISSING:
                    flag.max_args = -1
                if flag.override is MISSING:
                    flag.override = True
            elif origin is Literal:
                if flag.max_args is MISSING:
                    flag.max_args = 1
            else:
                raise TypeError(
                    f"Unsupported typing annotation {annotation!r} for"
                    f" {flag.name!r} flag"
                )

        if flag.override is MISSING:
            flag.override = False

        # Validate flag names are unique
        name = flag.name.casefold() if case_insensitive else flag.name
        if name in names:
            raise TypeError(
                f"{flag.name!r} flag conflicts with previous flag or alias."
            )
        else:
            names.add(name)

        for alias in flag.aliases:
            # Validate alias is unique
            alias = alias.casefold() if case_insensitive else alias
            if alias in names:
                raise TypeError(
                    f"{flag.name!r} flag alias {alias!r} conflicts with previous flag"
                    " or alias."
                )
            else:
                names.add(alias)

        flags[flag.name] = flag

    return flags


class FlagsMeta(type):
    if TYPE_CHECKING:
        __commands_is_flag__: bool
        __commands_flags__: dict[str, Flag]
        __commands_flag_aliases__: dict[str, str]
        __commands_flag_regex__: Pattern[str]
        __commands_flag_case_insensitive__: bool
        __commands_flag_delimiter__: str
        __commands_flag_prefix__: str

    def __new__(
        cls: type[type],
        name: str,
        bases: tuple[type, ...],
        attrs: dict[str, Any],
        *,
        case_insensitive: bool = MISSING,
        delimiter: str = MISSING,
        prefix: str = MISSING,
    ):
        attrs["__commands_is_flag__"] = True

        try:
            global_ns = sys.modules[attrs["__module__"]].__dict__
        except KeyError:
            global_ns = {}

        frame = inspect.currentframe()
        try:
            if frame is None:
                local_ns = {}
            else:
                if frame.f_back is None:
                    local_ns = frame.f_locals
                else:
                    local_ns = frame.f_back.f_locals
        finally:
            del frame

        flags: dict[str, Flag] = {}
        aliases: dict[str, str] = {}
        for base in reversed(bases):
            if base.__dict__.get("__commands_is_flag__", False):
                flags.update(base.__dict__["__commands_flags__"])
                aliases.update(base.__dict__["__commands_flag_aliases__"])
                if case_insensitive is MISSING:
                    attrs["__commands_flag_case_insensitive__"] = base.__dict__[
                        "__commands_flag_case_insensitive__"
                    ]
                if delimiter is MISSING:
                    attrs["__commands_flag_delimiter__"] = base.__dict__[
                        "__commands_flag_delimiter__"
                    ]
                if prefix is MISSING:
                    attrs["__commands_flag_prefix__"] = base.__dict__[
                        "__commands_flag_prefix__"
                    ]

        if case_insensitive is not MISSING:
            attrs["__commands_flag_case_insensitive__"] = case_insensitive
        if delimiter is not MISSING:
            attrs["__commands_flag_delimiter__"] = delimiter
        if prefix is not MISSING:
            attrs["__commands_flag_prefix__"] = prefix

        case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
        delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
        prefix = attrs.setdefault("__commands_flag_prefix__", "")

        for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
            flags[flag_name] = flag
            aliases.update({alias_name: flag_name for alias_name in flag.aliases})

        forbidden = set(delimiter).union(prefix)
        for flag_name in flags:
            validate_flag_name(flag_name, forbidden)
        for alias_name in aliases:
            validate_flag_name(alias_name, forbidden)

        regex_flags = 0
        if case_insensitive:
            flags = {key.casefold(): value for key, value in flags.items()}
            aliases = {
                key.casefold(): value.casefold() for key, value in aliases.items()
            }
            regex_flags = re.IGNORECASE

        keys = [re.escape(k) for k in flags]
        keys.extend(re.escape(a) for a in aliases)
        keys = sorted(keys, key=len, reverse=True)

        joined = "|".join(keys)
        pattern = re.compile(
            f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})",
            regex_flags,
        )
        attrs["__commands_flag_regex__"] = pattern
        attrs["__commands_flags__"] = flags
        attrs["__commands_flag_aliases__"] = aliases

        return type.__new__(cls, name, bases, attrs)


async def tuple_convert_all(
    ctx: Context, argument: str, flag: Flag, converter: Any
) -> tuple[Any, ...]:
    view = StringView(argument)
    results = []
    param: inspect.Parameter = ctx.current_parameter  # type: ignore
    while not view.eof:
        view.skip_ws()
        if view.eof:
            break

        word = view.get_quoted_word()
        if word is None:
            break

        try:
            converted = await run_converters(ctx, converter, word, param)
        except CommandError:
            raise
        except Exception as e:
            raise BadFlagArgument(flag) from e
        else:
            results.append(converted)

    return tuple(results)


async def tuple_convert_flag(
    ctx: Context, argument: str, flag: Flag, converters: Any
) -> tuple[Any, ...]:
    view = StringView(argument)
    results = []
    param: inspect.Parameter = ctx.current_parameter  # type: ignore
    for converter in converters:
        view.skip_ws()
        if view.eof:
            break

        word = view.get_quoted_word()
        if word is None:
            break

        try:
            converted = await run_converters(ctx, converter, word, param)
        except CommandError:
            raise
        except Exception as e:
            raise BadFlagArgument(flag) from e
        else:
            results.append(converted)

    if len(results) != len(converters):
        raise BadFlagArgument(flag)

    return tuple(results)


async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
    param: inspect.Parameter = ctx.current_parameter  # type: ignore
    annotation = annotation or flag.annotation
    try:
        origin = annotation.__origin__
    except AttributeError:
        pass
    else:
        if origin is tuple:
            if annotation.__args__[-1] is Ellipsis:
                return await tuple_convert_all(
                    ctx, argument, flag, annotation.__args__[0]
                )
            else:
                return await tuple_convert_flag(
                    ctx, argument, flag, annotation.__args__
                )
        elif origin is list:
            # typing.List[x]
            annotation = annotation.__args__[0]
            return await convert_flag(ctx, argument, flag, annotation)
        elif origin is Union and annotation.__args__[-1] is type(None):
            # typing.Optional[x]
            annotation = Union[annotation.__args__[:-1]]
            return await run_converters(ctx, annotation, argument, param)
        elif origin is dict:
            # typing.Dict[K, V] -> typing.Tuple[K, V]
            return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)

    try:
        return await run_converters(ctx, annotation, argument, param)
    except CommandError:
        raise
    except Exception as e:
        raise BadFlagArgument(flag) from e


F = TypeVar("F", bound="FlagConverter")


class FlagConverter(metaclass=FlagsMeta):
    """A converter that allows for a user-friendly flag syntax.

    The flags are defined using :pep:`526` type annotations similar
    to the :mod:`dataclasses` Python module. For more information on
    how this converter works, check the appropriate
    :ref:`documentation <ext_commands_flag_converter>`.

    .. container:: operations

        .. describe:: iter(x)

            Returns an iterator of ``(flag_name, flag_value)`` pairs. This allows it
            to be, for example, constructed as a dict or a list of pairs.
            Note that aliases are not shown.

    .. versionadded:: 2.0

    Parameters
    ----------
    case_insensitive: :class:`bool`
        A class parameter to toggle case insensitivity of the flag parsing.
        If ``True`` then flags are parsed in a case-insensitive manner.
        Defaults to ``False``.
    prefix: :class:`str`
        The prefix that all flags must be prefixed with. By default,
        there is no prefix.
    delimiter: :class:`str`
        The delimiter that separates a flag's argument from the flag's name.
        By default, this is ``:``.
    """

    @classmethod
    def get_flags(cls) -> dict[str, Flag]:
        """A mapping of flag name to flag object this converter has."""
        return cls.__commands_flags__.copy()

    @classmethod
    def _can_be_constructible(cls) -> bool:
        return all(not flag.required for flag in cls.__commands_flags__.values())

    def __iter__(self) -> Iterator[tuple[str, Any]]:
        for flag in self.__class__.__commands_flags__.values():
            yield flag.name, getattr(self, flag.attribute)

    @classmethod
    async def _construct_default(cls: type[F], ctx: Context) -> F:
        self: F = cls.__new__(cls)
        flags = cls.__commands_flags__
        for flag in flags.values():
            if callable(flag.default):
                default = await maybe_coroutine(flag.default, ctx)
                setattr(self, flag.attribute, default)
            else:
                setattr(self, flag.attribute, flag.default)
        return self

    def __repr__(self) -> str:
        pairs = " ".join(
            [
                f"{flag.attribute}={getattr(self, flag.attribute)!r}"
                for flag in self.get_flags().values()
            ]
        )
        return f"<{self.__class__.__name__} {pairs}>"

    @classmethod
    def parse_flags(cls, argument: str) -> dict[str, list[str]]:
        result: dict[str, list[str]] = {}
        flags = cls.__commands_flags__
        aliases = cls.__commands_flag_aliases__
        last_position = 0
        last_flag: Flag | None = None

        case_insensitive = cls.__commands_flag_case_insensitive__
        for match in cls.__commands_flag_regex__.finditer(argument):
            begin, end = match.span(0)
            key = match.group("flag")
            if case_insensitive:
                key = key.casefold()

            if key in aliases:
                key = aliases[key]

            flag = flags.get(key)
            if last_position and last_flag is not None:
                value = argument[last_position : begin - 1].lstrip()
                if not value:
                    raise MissingFlagArgument(last_flag)

                try:
                    values = result[last_flag.name]
                except KeyError:
                    result[last_flag.name] = [value]
                else:
                    values.append(value)

            last_position = end
            last_flag = flag

        # Add the remaining string to the last available flag
        if last_position and last_flag is not None:
            value = argument[last_position:].strip()
            if not value:
                raise MissingFlagArgument(last_flag)

            try:
                values = result[last_flag.name]
            except KeyError:
                result[last_flag.name] = [value]
            else:
                values.append(value)

        # Verification of values will come at a later stage
        return result

    @classmethod
    async def convert(cls: type[F], ctx: Context, argument: str) -> F:
        """|coro|

        The method that actually converters an argument to the flag mapping.

        Parameters
        ----------
        cls: Type[:class:`FlagConverter`]
            The flag converter class.
        ctx: :class:`Context`
            The invocation context.
        argument: :class:`str`
            The argument to convert from.

        Returns
        -------
        :class:`FlagConverter`
            The flag converter instance with all flags parsed.

        Raises
        ------
        FlagError
            A flag related parsing error.
        CommandError
            A command related error.
        """
        arguments = cls.parse_flags(argument)
        flags = cls.__commands_flags__

        self: F = cls.__new__(cls)
        for name, flag in flags.items():
            try:
                values = arguments[name]
            except KeyError:
                if flag.required:
                    raise MissingRequiredFlag(flag)
                else:
                    if callable(flag.default):
                        default = await maybe_coroutine(flag.default, ctx)
                        setattr(self, flag.attribute, default)
                    else:
                        setattr(self, flag.attribute, flag.default)
                    continue

            if 0 < flag.max_args < len(values):
                if flag.override:
                    values = values[-flag.max_args :]
                else:
                    raise TooManyFlags(flag, values)

            # Special case:
            if flag.max_args == 1:
                value = await convert_flag(ctx, values[0], flag)
                setattr(self, flag.attribute, value)
                continue

            # Another special case, tuple parsing.
            # Tuple parsing is basically converting arguments within the flag
            # So, given flag: hello 20 as the input and Tuple[str, int] as the type hint
            # We would receive ('hello', 20) as the resulting value
            # This uses the same whitespace and quoting rules as regular parameters.
            values = [await convert_flag(ctx, value, flag) for value in values]

            if flag.cast_to_dict:
                values = dict(values)  # type: ignore

            setattr(self, flag.attribute, values)

        return self

@Nzii3
Copy link
Contributor Author

Nzii3 commented Jun 9, 2023

Basically changing MissingField to MISSING fixes it.

@Lulalaby Lulalaby enabled auto-merge (squash) June 9, 2023 15:20
@Lulalaby Lulalaby disabled auto-merge June 9, 2023 15:20
@Lulalaby Lulalaby enabled auto-merge (squash) June 9, 2023 15:20
@Lulalaby Lulalaby merged commit a5c2b52 into Pycord-Development:master Jun 9, 2023
25 checks passed
@NeloBlivion NeloBlivion mentioned this pull request Jun 11, 2023
9 tasks
@Nzii3
Copy link
Contributor Author

Nzii3 commented Feb 1, 2024

@Lulalaby was this ever merged/fixed?

@Nzii3 Nzii3 deleted the patch-1 branch February 1, 2024 04:08
@NeloBlivion
Copy link
Member

It was merged... 8 months ago (and subsequently further patched in #2111), and the last release was 10 months ago 😅
If you're not on master, you'll see it on 2.5

@Nzii3
Copy link
Contributor Author

Nzii3 commented Mar 2, 2024

Was this merged?

@Lulalaby
Copy link
Member

Lulalaby commented Mar 2, 2024

try and see lol

@Pycord-Development Pycord-Development deleted a comment from pullapprove4 bot Mar 2, 2024
@Dorukyum
Copy link
Member

Dorukyum commented Mar 2, 2024

Yes, this was merged, just without a changelog entry. We can still add one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

commands.FlagConverter 'MissingField' bug
8 participants