diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 53d812976..777a9cc3e 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -56,18 +56,19 @@ model GuildConfig { } model Case { - case_id BigInt @id @default(autoincrement()) - case_status Boolean? @default(true) - case_type CaseType - case_reason String - case_moderator_id BigInt - case_user_id BigInt - case_user_roles BigInt[] @default([]) - case_number BigInt? - case_created_at DateTime? @default(now()) - case_expires_at DateTime? - guild_id BigInt - guild Guild @relation(fields: [guild_id], references: [guild_id]) + case_id BigInt @id @default(autoincrement()) + case_status Boolean? @default(true) + case_type CaseType + case_reason String + case_moderator_id BigInt + case_user_id BigInt + case_user_roles BigInt[] @default([]) + case_number BigInt? + case_created_at DateTime? @default(now()) + case_expires_at DateTime? + case_tempban_expired Boolean? @default(false) + guild_id BigInt + guild Guild @relation(fields: [guild_id], references: [guild_id]) @@unique([case_number, guild_id]) @@index([case_number, guild_id]) @@ -167,4 +168,5 @@ enum CaseType { JAIL UNJAIL SNIPPETUNBAN + UNTEMPBAN } diff --git a/tux/cogs/moderation/tempban.py b/tux/cogs/moderation/tempban.py new file mode 100644 index 000000000..d726fc90f --- /dev/null +++ b/tux/cogs/moderation/tempban.py @@ -0,0 +1,113 @@ +from datetime import UTC, datetime + +import discord +from discord.ext import commands, tasks +from loguru import logger + +from prisma.enums import CaseType +from tux.bot import Tux +from tux.utils import checks +from tux.utils.flags import TempBanFlags, generate_usage +from tux.utils.functions import parse_time_string + +from . import ModerationCogBase + + +class TempBan(ModerationCogBase): + def __init__(self, bot: Tux) -> None: + super().__init__(bot) + self.tempban.usage = generate_usage(self.tempban, TempBanFlags) + self.tempban_check.start() + + @commands.hybrid_command(name="tempban", aliases=["tb"]) + @commands.guild_only() + @checks.has_pl(3) + async def tempban( + self, + ctx: commands.Context[Tux], + member: discord.Member, + *, + flags: TempBanFlags, + ) -> None: + """ + Temporarily ban a member from the server. + + Parameters + ---------- + ctx : commands.Context[Tux] + The context in which the command is being invoked. + member : discord.Member + The member to ban. + flags : TempBanFlags + The flags for the command. + + Raises + ------ + discord.Forbidden + If the bot is unable to ban the user. + discord.HTTPException + If an error occurs while banning the user. + """ + if not ctx.guild: + logger.warning("Ban command used outside of a guild context.") + return + + if not await self.check_conditions(ctx, member, ctx.author, "temporarily ban"): + return + + duration = parse_time_string(f"{flags.expires_at}d") + expires_at = datetime.now(UTC) + duration + + try: + await self.send_dm(ctx, flags.silent, member, flags.reason, action="temporarily banned") + await ctx.guild.ban(member, reason=flags.reason, delete_message_days=flags.purge_days) + except (discord.Forbidden, discord.HTTPException) as e: + logger.error(f"Failed to temporarily ban {member}. {e}") + await ctx.send(f"Failed to temporarily ban {member}. {e}", delete_after=30, ephemeral=True) + return + + case = await self.db.case.insert_case( + case_user_id=member.id, + case_moderator_id=ctx.author.id, + case_type=CaseType.TEMPBAN, + case_reason=flags.reason, + guild_id=ctx.guild.id, + case_expires_at=expires_at, + case_tempban_expired=False, + ) + + await self.handle_case_response(ctx, CaseType.TEMPBAN, case.case_number, flags.reason, member) + + @tasks.loop(hours=1) + async def tempban_check(self) -> None: + expired_temp_bans = await self.db.case.get_expired_tempbans() + + for temp_ban in expired_temp_bans: + guild = self.bot.get_guild(temp_ban.guild_id) or await self.bot.fetch_guild(temp_ban.guild_id) + if not guild: + logger.error(f"Failed to get guild with ID {temp_ban.guild_id} for tempban check.") + continue + + try: + ban_entry = await guild.fetch_ban(discord.Object(id=temp_ban.case_user_id)) + await guild.unban(ban_entry.user, reason=f"Tempban expired | Case number: {temp_ban.case_number}") + + await self.db.case.set_tempban_expired(temp_ban.case_number, temp_ban.guild_id) + await self.db.case.insert_case( + guild_id=temp_ban.guild_id, + case_user_id=temp_ban.case_user_id, + case_moderator_id=temp_ban.case_moderator_id, + case_type=CaseType.UNTEMPBAN, + case_reason="Expired tempban", + case_tempban_expired=True, + ) + logger.debug(f"Unbanned user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}") + + except (discord.Forbidden, discord.HTTPException, discord.NotFound) as e: + logger.error( + f"Failed to unban user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}. Error: {e}", + ) + + +async def setup(bot: Tux) -> None: + await bot.add_cog(TempBan(bot)) diff --git a/tux/database/controllers/case.py b/tux/database/controllers/case.py index 2e0f8e120..eee226df7 100644 --- a/tux/database/controllers/case.py +++ b/tux/database/controllers/case.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import UTC, datetime from prisma.enums import CaseType from prisma.models import Case, Guild @@ -82,6 +82,7 @@ async def insert_case( case_reason: str, case_user_roles: list[int] | None = None, case_expires_at: datetime | None = None, + case_tempban_expired: bool = False, ) -> Case: """ Insert a case into the database. @@ -102,6 +103,8 @@ async def insert_case( The roles of the target of the case. case_expires_at : datetime | None The expiration date of the case. + case_tempban_expired : bool + Whether the tempban has expired (Use only for tempbans). Returns ------- @@ -122,6 +125,7 @@ async def insert_case( "case_reason": case_reason, "case_expires_at": case_expires_at, "case_user_roles": case_user_roles if case_user_roles is not None else [], + "case_tempban_expired": case_tempban_expired, }, ) @@ -308,3 +312,53 @@ async def delete_case_by_number(self, guild_id: int, case_number: int) -> Case | if case is not None: return await self.table.delete(where={"case_id": case.case_id}) return None + + async def get_expired_tempbans(self) -> list[Case]: + """ + Get all cases that have expired tempbans. + + Returns + ------- + list[Case] + A list of cases of the type in the guild. + """ + return await self.table.find_many( + where={ + "case_type": CaseType.TEMPBAN, + "case_expires_at": {"lt": datetime.now(UTC)}, + "case_tempban_expired": False, + }, + ) + + async def set_tempban_expired(self, case_number: int | None, guild_id: int) -> int | None: + """ + Set a tempban case as expired. + + Parameters + ---------- + case_number : int + The number of the case to update. + guild_id : int + The ID of the guild the case belongs to. + + Returns + ------- + Optional[int] + The number of Case records updated (1) if successful, None if no records were found, + or raises an exception if multiple records were affected. + """ + if case_number is None: + msg = "Case number not found" + raise ValueError(msg) + + result = await self.table.update_many( + where={"case_number": case_number, "guild_id": guild_id}, + data={"case_tempban_expired": True}, + ) + if result == 1: + return result + if result == 0: + return None + + msg = f"Multiple records ({result}) were affected when updating case {case_number} in guild {guild_id}" + raise ValueError(msg) diff --git a/tux/database/controllers/guild.py b/tux/database/controllers/guild.py index 411a4b5c6..9fdbf9541 100644 --- a/tux/database/controllers/guild.py +++ b/tux/database/controllers/guild.py @@ -14,3 +14,6 @@ async def insert_guild_by_id(self, guild_id: int) -> Guild: async def delete_guild_by_id(self, guild_id: int) -> None: await self.table.delete(where={"guild_id": guild_id}) + + async def get_all_guilds(self) -> list[Guild]: + return await self.table.find_many() diff --git a/tux/utils/flags.py b/tux/utils/flags.py index 6869ea098..7fbfb3480 100644 --- a/tux/utils/flags.py +++ b/tux/utils/flags.py @@ -91,9 +91,9 @@ class TempBanFlags(commands.FlagConverter, case_insensitive=True, delimiter=" ", default=MISSING, ) expires_at: int = commands.flag( - name="expires_at", + name="duration", description="Number of days the ban will last for.", - aliases=["t", "d", "e", "duration", "expires", "time"], + aliases=["t", "d", "e", "expires", "time"], ) purge_days: int = commands.flag( name="purge_days",