diff --git a/disputils/abc.py b/disputils/abc.py index 6d3e530..f5660fa 100644 --- a/disputils/abc.py +++ b/disputils/abc.py @@ -1,5 +1,5 @@ from abc import ABC -from discord import Message, Embed +from discord import Message, Embed, TextChannel, errors from typing import Optional @@ -11,6 +11,23 @@ def __init__(self, *args, **kwargs): self.message: Optional[Message] = None self.color: hex = kwargs.get("color") or kwargs.get("colour") or 0x000000 + async def _publish(self, channel: Optional[TextChannel], **kwargs) -> TextChannel: + if channel is None and self.message is None: + raise TypeError( + "Missing argument. You need to specify a target channel or message." + ) + + if channel is None: + try: + await self.message.edit(**kwargs) + except errors.NotFound: + self.message = None + + if self.message is None: + self.message = await channel.send(**kwargs) + + return self.message.channel + async def quit(self, text: str = None): """ Quit the dialog. @@ -23,9 +40,13 @@ async def quit(self, text: str = None): if text is None: await self.message.delete() + self.message = None else: - await self.message.edit(content=text, embed=None) - await self.message.clear_reactions() + await self.display(text) + try: + await self.message.clear_reactions() + except errors.Forbidden: + pass async def update(self, text: str, color: hex = None, hide_author: bool = False): """ diff --git a/disputils/confirmation.py b/disputils/confirmation.py index cddfbea..84a9303 100644 --- a/disputils/confirmation.py +++ b/disputils/confirmation.py @@ -35,7 +35,7 @@ async def confirm( user: discord.User, channel: discord.TextChannel = None, hide_author: bool = False, - timeout: int = 20 + timeout: int = 20, ) -> bool or None: """ Run the confirmation. @@ -68,13 +68,8 @@ async def confirm( self._embed = emb - if channel is None and self.message is not None: - channel = self.message.channel - elif channel is None: - raise TypeError("Missing argument. You need to specify a target channel.") - - msg = await channel.send(embed=emb) - self.message = msg + await self._publish(channel, embed=emb) + msg = self.message for emoji in self.emojis: await msg.add_reaction(emoji) @@ -90,17 +85,15 @@ async def confirm( except asyncio.TimeoutError: self._confirmed = None return + else: + self._confirmed = self.emojis[reaction.emoji] + return self._confirmed finally: try: await msg.clear_reactions() except discord.Forbidden: pass - confirmed = self.emojis[reaction.emoji] - - self._confirmed = confirmed - return confirmed - class BotConfirmation(Confirmation): def __init__( @@ -119,7 +112,7 @@ async def confirm( user: discord.User = None, channel: discord.TextChannel = None, hide_author: bool = False, - timeout: int = 20 + timeout: int = 20, ) -> bool or None: if user is None: diff --git a/disputils/multiple_choice.py b/disputils/multiple_choice.py index b6f887f..c9ad348 100644 --- a/disputils/multiple_choice.py +++ b/disputils/multiple_choice.py @@ -122,6 +122,9 @@ async def run( - message :class:`discord.Message` - timeout :class:`int` (seconds, default: ``60``), - closable :class:`bool` (default: ``True``) + - text :class:`str`: Text to appear in the message. + - timeout_msg :class:`str`: Text to appear when dialog times out. + - quit_msg :class:`str`: Text to appear when user quits the dialog. :return: selected option and used :class:`discord.Message` :rtype: tuple[:class:`str`, :class:`discord.Message`] @@ -134,18 +137,11 @@ async def run( timeout = kwargs.get("timeout", 60) closable: bool = kwargs.get("closable", True) - config_embed = self.embed + publish_kwargs = {"embed": self.embed} + if "text" in kwargs: + publish_kwargs["content"] = kwargs["text"] - if channel is not None: - self.message = await channel.send(embed=config_embed) - elif self.message is not None: - await self.message.clear_reactions() - await self.message.edit(content=self.message.content, embed=config_embed) - else: - raise TypeError( - "Missing argument. " - + "You need to specify either 'channel' or 'message' as a target." - ) + await self._publish(channel, **publish_kwargs) for emoji in self._emojis: await self.message.add_reaction(emoji) @@ -173,10 +169,14 @@ def check(r, u): ) except asyncio.TimeoutError: self._choice = None + if "timeout_msg" in kwargs: + await self.quit(kwargs["timeout_msg"]) return None, self.message if reaction.emoji == self.close_emoji: self._choice = None + if "quit_msg" in kwargs: + await self.quit(kwargs["quit_msg"]) return None, self.message index = self._emojis.index(reaction.emoji) diff --git a/disputils/pagination.py b/disputils/pagination.py index 2ed2eba..6c29218 100644 --- a/disputils/pagination.py +++ b/disputils/pagination.py @@ -61,6 +61,7 @@ async def run( users: List[discord.User], channel: discord.TextChannel = None, timeout: int = 100, + **kwargs, ): """ Runs the paginator. @@ -79,21 +80,24 @@ async def run( :param timeout: Seconds to wait until stopping to listen for user interaction. + :param kwargs: + - text :class:`str`: Text to appear in the pagination message. + - timeout_msg :class:`str`: Text to appear when pagination times out. + - quit_msg :class:`str`: Text to appear when user quits the dialog. + :return: None """ - if channel is None and self.message is not None: - channel = self.message.channel - elif channel is None: - raise TypeError("Missing argument. You need to specify a target channel.") - self._embed = self.pages[0] + text = kwargs.get("text") if len(self.pages) == 1: # no pagination needed in this case - self.message = await channel.send(embed=self._embed) + await self._publish(channel, content=text, embed=self._embed) return - self.message = await channel.send(embed=self.formatted_pages[0]) + channel = await self._publish( + channel, content=text, embed=self.formatted_pages[0] + ) current_page_index = 0 for emoji in self.control_emojis: @@ -120,6 +124,8 @@ def check(r: discord.Reaction, u: discord.User): await self.message.clear_reactions() except discord.Forbidden: pass + if "timeout_msg" in kwargs: + await self.display(kwargs["timeout_msg"]) return emoji = reaction.emoji @@ -146,10 +152,10 @@ def check(r: discord.Reaction, u: discord.User): load_page_index = max_index else: - await self.message.delete() + await self.quit(kwargs.get("quit_msg")) return - await self.message.edit(embed=self.formatted_pages[load_page_index]) + await self.display(text, self.formatted_pages[load_page_index]) if not isinstance(channel, discord.channel.DMChannel) and not isinstance( channel, discord.channel.GroupChannel ): @@ -228,6 +234,7 @@ async def run( channel: discord.TextChannel = None, users: List[discord.User] = None, timeout: int = 100, + **kwargs, ): """ Runs the paginator. @@ -247,6 +254,11 @@ async def run( :param timeout: Seconds to wait until stopping to listen for user interaction. + :param kwargs: + - text :class:`str`: Text to appear in the pagination message. + - timeout_msg :class:`str`: Text to appear when pagination times out. + - quit_msg :class:`str`: Text to appear when user quits the dialog. + :return: None """ @@ -256,4 +268,4 @@ async def run( if self.message is None and channel is None: channel = self._ctx.channel - await super().run(users, channel, timeout) + await super().run(users, channel, timeout, **kwargs)