diff --git a/tux/cogs/services/bookmarks.py b/tux/cogs/services/bookmarks.py index fdedf2c7c..3e453b761 100644 --- a/tux/cogs/services/bookmarks.py +++ b/tux/cogs/services/bookmarks.py @@ -1,6 +1,10 @@ -from typing import cast +from __future__ import annotations +import io + +import aiohttp import discord +from discord.abc import Messageable from discord.ext import commands from loguru import logger @@ -12,111 +16,262 @@ class Bookmarks(commands.Cog): def __init__(self, bot: Tux) -> None: self.bot = bot + self.add_bookmark_emojis = CONST.ADD_BOOKMARK + self.remove_bookmark_emojis = CONST.REMOVE_BOOKMARK + self.valid_emojis = self.add_bookmark_emojis + self.remove_bookmark_emojis + self.session = aiohttp.ClientSession() + + async def cog_unload(self) -> None: + """Cleans up the cog, closing the aiohttp session.""" + await self.session.close() @commands.Cog.listener() async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None: """ - Handle the addition of a reaction to a message. + Handles bookmarking messages via reactions. + + This listener checks for specific reaction emojis on messages and triggers + the bookmarking or unbookmarking process accordingly. Parameters ---------- payload : discord.RawReactionActionEvent - The payload of the reaction event. - - Returns - ------- - None + The event payload containing information about the reaction. """ - if str(payload.emoji) != "🔖": + # If the bot reacted to the message, or the user is the bot, or the emoji is not valid, return + if not self.bot.user or payload.user_id == self.bot.user.id or not payload.emoji.name: return - # Fetch the channel where the reaction was added - channel = self.bot.get_channel(payload.channel_id) - if channel is None: - logger.error(f"Channel not found for ID: {payload.channel_id}") + # If the emoji is not valid, return + if payload.emoji.name not in self.valid_emojis: return - channel = cast(discord.TextChannel | discord.Thread, channel) - # Fetch the message that was reacted to try: + # Get the user who reacted to the message + user = self.bot.get_user(payload.user_id) or await self.bot.fetch_user(payload.user_id) + + # Get the channel where the reaction was added + channel = self.bot.get_channel(payload.channel_id) + if channel is None: + channel = await self.bot.fetch_channel(payload.channel_id) + + # If the channel is not messageable, return + if not isinstance(channel, Messageable): + logger.warning(f"Bookmark reaction in non-messageable channel {payload.channel_id}.") + return + + # Get the message that was reacted to message = await channel.fetch_message(payload.message_id) - except discord.NotFound: - logger.error(f"Message not found for ID: {payload.message_id}") - return - except (discord.Forbidden, discord.HTTPException) as fetch_error: - logger.error(f"Failed to fetch message: {fetch_error}") + + # If the message is not found, return + except (discord.NotFound, discord.Forbidden, discord.HTTPException) as e: + logger.error(f"Failed to fetch data for bookmark event: {e}") return - # Create an embed for the bookmarked message + # If the emoji is the add bookmark emoji, add the bookmark + if payload.emoji.name in self.add_bookmark_emojis: + await self.add_bookmark(user, message) + + # If the emoji is the remove bookmark emoji, remove the bookmark + elif payload.emoji.name in self.remove_bookmark_emojis: + await self.remove_bookmark(message) + + async def add_bookmark(self, user: discord.User, message: discord.Message) -> None: + """ + Sends a bookmarked message to the user's DMs. + + Parameters + ---------- + user : discord.User + The user who bookmarked the message. + message : discord.Message + The message to be bookmarked. + """ embed = self._create_bookmark_embed(message) + files = await self._get_files_from_message(message) - # Get the user who reacted to the message - user = self.bot.get_user(payload.user_id) - if user is None: - logger.error(f"User not found for ID: {payload.user_id}") + try: + dm_message = await user.send(embed=embed, files=files) + await dm_message.add_reaction(self.remove_bookmark_emojis) + + except (discord.Forbidden, discord.HTTPException) as e: + logger.warning(f"Could not send DM to {user.name} ({user.id}): {e}") + + try: + await message.channel.send( + f"{user.mention}, I couldn't send you a DM. Please check your privacy settings.", + delete_after=30, + ) + + except (discord.Forbidden, discord.HTTPException) as e2: + logger.error(f"Could not send notification in channel {message.channel.id}: {e2}") + + @staticmethod + async def remove_bookmark(message: discord.Message) -> None: + """ + Deletes a bookmark DM when the user reacts with the remove emoji. + + Parameters + ---------- + message : discord.Message + The bookmark message in the user's DMs to be deleted. + """ + + try: + await message.delete() + + except (discord.Forbidden, discord.HTTPException) as e: + logger.error(f"Failed to delete bookmark message {message.id}: {e}") + + async def _get_files_from_attachments(self, message: discord.Message, files: list[discord.File]) -> None: + for attachment in message.attachments: + if len(files) >= 10: + break + + if attachment.content_type and "image" in attachment.content_type: + try: + files.append(await attachment.to_file()) + except (discord.HTTPException, discord.NotFound) as e: + logger.error(f"Failed to get attachment {attachment.filename}: {e}") + + async def _get_files_from_stickers(self, message: discord.Message, files: list[discord.File]) -> None: + if len(files) >= 10: return - # Send the bookmarked message to the user - await self._send_bookmark(user, message, embed, payload.emoji) + for sticker in message.stickers: + if len(files) >= 10: + break - def _create_bookmark_embed( - self, - message: discord.Message, - ) -> discord.Embed: - if len(message.content) > CONST.EMBED_MAX_DESC_LENGTH: - message.content = f"{message.content[: CONST.EMBED_MAX_DESC_LENGTH - 3]}..." + if sticker.format in {discord.StickerFormatType.png, discord.StickerFormatType.apng}: + try: + sticker_bytes = await sticker.read() + files.append(discord.File(io.BytesIO(sticker_bytes), filename=f"{sticker.name}.png")) + except (discord.HTTPException, discord.NotFound) as e: + logger.error(f"Failed to read sticker {sticker.name}: {e}") - embed = EmbedCreator.create_embed( - bot=self.bot, - embed_type=EmbedCreator.INFO, - title="Message Bookmarked", - description=f"> {message.content}", - ) + async def _get_files_from_embeds(self, message: discord.Message, files: list[discord.File]) -> None: + if len(files) >= 10: + return - embed.add_field(name="Author", value=message.author.name, inline=False) + for embed in message.embeds: + if len(files) >= 10: + break - embed.add_field(name="Jump to Message", value=f"[Click Here]({message.jump_url})", inline=False) + if embed.image and embed.image.url: + try: + async with self.session.get(embed.image.url) as resp: + if resp.status == 200: + data = await resp.read() + filename = embed.image.url.split("/")[-1].split("?")[0] + files.append(discord.File(io.BytesIO(data), filename=filename)) + except aiohttp.ClientError as e: + logger.error(f"Failed to fetch embed image {embed.image.url}: {e}") - if message.attachments: - attachments_info = "\n".join([attachment.url for attachment in message.attachments]) - embed.add_field(name="Attachments", value=attachments_info, inline=False) + async def _get_files_from_message(self, message: discord.Message) -> list[discord.File]: + """ + Gathers images from a message to be sent as attachments. - return embed + This function collects images from attachments, stickers, and embeds, + respecting Discord's 10-file limit. - @staticmethod - async def _send_bookmark( - user: discord.User, - message: discord.Message, - embed: discord.Embed, - emoji: discord.PartialEmoji, - ) -> None: + Parameters + ---------- + message : discord.Message + The message to extract files from. + + Returns + ------- + list[discord.File] + A list of files to be attached to the bookmark message. + """ + files: list[discord.File] = [] + + await self._get_files_from_attachments(message, files) + await self._get_files_from_stickers(message, files) + await self._get_files_from_embeds(message, files) + + return files + + def _create_bookmark_embed(self, message: discord.Message) -> discord.Embed: """ - Send a bookmarked message to the user. + Creates an embed for a bookmarked message. + + This function constructs a detailed embed that includes the message content, + author, attachments, and other contextual information. Parameters ---------- - user : discord.User - The user to send the bookmarked message to. message : discord.Message - The message that was bookmarked. - embed : discord.Embed - The embed to send to the user. - emoji : str - The emoji that was reacted to the message. + The message to create an embed from. + + Returns + ------- + discord.Embed + The generated bookmark embed. """ - try: - await user.send(embed=embed) + # Get the content of the message + content = message.content or "" + + # Truncate the content if it's too long + if len(content) > CONST.EMBED_MAX_DESC_LENGTH: + content = f"{content[: CONST.EMBED_MAX_DESC_LENGTH - 4]}..." + + embed = EmbedCreator.create_embed( + bot=self.bot, + embed_type=EmbedCreator.INFO, + title="Message Bookmarked", + description=f"{content}" if content else "> No content available to display", + ) - except (discord.Forbidden, discord.HTTPException) as dm_error: - logger.error(f"Cannot send a DM to {user.name}: {dm_error}") + # Add author to the embed + embed.set_author( + name=message.author.display_name, + icon_url=message.author.display_avatar.url, + ) + + # Add reference to the embed if it exists + if message.reference and message.reference.resolved: + ref_msg = message.reference.resolved + if isinstance(ref_msg, discord.Message): + embed.add_field( + name="Replying to", + value=f"[Click Here]({ref_msg.jump_url})", + ) + + # Add jump to message to the embed + embed.add_field( + name="Jump to Message", + value=f"[Click Here]({message.jump_url})", + ) + + # Add attachments to the embed + if message.attachments: + attachments = "\n".join(f"[{a.filename}]({a.url})" for a in message.attachments) + embed.add_field(name="Attachments", value=attachments, inline=False) + + # Add stickers to the embed + if message.stickers: + stickers = "\n".join(f"[{s.name}]({s.url})" for s in message.stickers) + embed.add_field(name="Stickers", value=stickers, inline=False) - notify_message = await message.channel.send( - f"{user.mention}, I couldn't send you a DM. Please make sure your DMs are open for bookmarks to work.", + # Handle embeds + if message.embeds: + embed.add_field( + name="Contains Embeds", + value="Original message contains embeds which are not shown here.", + inline=False, ) - await notify_message.delete(delay=30) + # Add footer to the embed + if message.guild and isinstance(message.channel, discord.TextChannel | discord.Thread): + embed.set_footer(text=f"In #{message.channel.name} on {message.guild.name}") + + # Add timestamp to the embed + embed.timestamp = message.created_at + + return embed async def setup(bot: Tux) -> None: diff --git a/tux/cogs/utility/poll.py b/tux/cogs/utility/poll.py index fb75a2982..0affb1565 100644 --- a/tux/cogs/utility/poll.py +++ b/tux/cogs/utility/poll.py @@ -1,3 +1,5 @@ +from typing import cast + import discord from discord import app_commands from discord.ext import commands @@ -74,13 +76,15 @@ async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> # get reaction from payload.message_id, payload.channel_id, payload.guild_id, payload.emoji channel = self.bot.get_channel(payload.channel_id) if channel is None: - logger.error(f"Channel with ID {payload.channel_id} not found.") - return - if isinstance(channel, discord.ForumChannel | discord.CategoryChannel | discord.abc.PrivateChannel): - logger.error( - f"Channel with ID {payload.channel_id} is not a compatible channel type. How the fuck did you get here?", - ) - return + try: + channel = await self.bot.fetch_channel(payload.channel_id) + except discord.NotFound: + logger.error(f"Channel not found for ID: {payload.channel_id}") + return + except (discord.Forbidden, discord.HTTPException) as fetch_error: + logger.error(f"Failed to fetch channel: {fetch_error}") + return + channel = cast(discord.TextChannel | discord.Thread, channel) message = await channel.fetch_message(payload.message_id) # Lookup the reaction object for this event diff --git a/tux/utils/constants.py b/tux/utils/constants.py index 166673c2b..f1f2a94b8 100644 --- a/tux/utils/constants.py +++ b/tux/utils/constants.py @@ -73,5 +73,9 @@ class Constants: EIGHT_BALL_QUESTION_LENGTH_LIMIT = 120 EIGHT_BALL_RESPONSE_WRAP_WIDTH = 30 + # Bookmark constants + ADD_BOOKMARK = "🔖" + REMOVE_BOOKMARK = "🗑️" + CONST = Constants()