Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ model GuildConfig {
}

model Case {
case_id BigInt @id @default(autoincrement())
case_status Boolean? @default(true)
case_type CaseType
case_reason String
case_moderator_id BigInt
case_user_id BigInt
case_user_roles BigInt[] @default([])
case_number BigInt?
case_created_at DateTime? @default(now())
case_expires_at DateTime?
guild_id BigInt
guild Guild @relation(fields: [guild_id], references: [guild_id])
case_id BigInt @id @default(autoincrement())
case_status Boolean? @default(true)
case_type CaseType
case_reason String
case_moderator_id BigInt
case_user_id BigInt
case_user_roles BigInt[] @default([])
case_number BigInt?
case_created_at DateTime? @default(now())
case_expires_at DateTime?
case_tempban_expired Boolean? @default(false)
guild_id BigInt
guild Guild @relation(fields: [guild_id], references: [guild_id])

@@unique([case_number, guild_id])
@@index([case_number, guild_id])
Expand Down Expand Up @@ -167,4 +168,5 @@ enum CaseType {
JAIL
UNJAIL
SNIPPETUNBAN
UNTEMPBAN
}
113 changes: 113 additions & 0 deletions tux/cogs/moderation/tempban.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from datetime import UTC, datetime

import discord
from discord.ext import commands, tasks
from loguru import logger

from prisma.enums import CaseType
from tux.bot import Tux
from tux.utils import checks
from tux.utils.flags import TempBanFlags, generate_usage
from tux.utils.functions import parse_time_string

from . import ModerationCogBase


class TempBan(ModerationCogBase):
def __init__(self, bot: Tux) -> None:
super().__init__(bot)
self.tempban.usage = generate_usage(self.tempban, TempBanFlags)
self.tempban_check.start()

@commands.hybrid_command(name="tempban", aliases=["tb"])
@commands.guild_only()
@checks.has_pl(3)
async def tempban(
self,
ctx: commands.Context[Tux],
member: discord.Member,
*,
flags: TempBanFlags,
) -> None:
"""
Temporarily ban a member from the server.

Parameters
----------
ctx : commands.Context[Tux]
The context in which the command is being invoked.
member : discord.Member
The member to ban.
flags : TempBanFlags
The flags for the command.

Raises
------
discord.Forbidden
If the bot is unable to ban the user.
discord.HTTPException
If an error occurs while banning the user.
"""
if not ctx.guild:
logger.warning("Ban command used outside of a guild context.")
return

if not await self.check_conditions(ctx, member, ctx.author, "temporarily ban"):
return

duration = parse_time_string(f"{flags.expires_at}d")
expires_at = datetime.now(UTC) + duration

try:
await self.send_dm(ctx, flags.silent, member, flags.reason, action="temporarily banned")
await ctx.guild.ban(member, reason=flags.reason, delete_message_days=flags.purge_days)
except (discord.Forbidden, discord.HTTPException) as e:
logger.error(f"Failed to temporarily ban {member}. {e}")
await ctx.send(f"Failed to temporarily ban {member}. {e}", delete_after=30, ephemeral=True)
return

case = await self.db.case.insert_case(
case_user_id=member.id,
case_moderator_id=ctx.author.id,
case_type=CaseType.TEMPBAN,
case_reason=flags.reason,
guild_id=ctx.guild.id,
case_expires_at=expires_at,
case_tempban_expired=False,
)

await self.handle_case_response(ctx, CaseType.TEMPBAN, case.case_number, flags.reason, member)

@tasks.loop(hours=1)
async def tempban_check(self) -> None:
expired_temp_bans = await self.db.case.get_expired_tempbans()

for temp_ban in expired_temp_bans:
guild = self.bot.get_guild(temp_ban.guild_id) or await self.bot.fetch_guild(temp_ban.guild_id)
if not guild:
logger.error(f"Failed to get guild with ID {temp_ban.guild_id} for tempban check.")
continue

try:
ban_entry = await guild.fetch_ban(discord.Object(id=temp_ban.case_user_id))
await guild.unban(ban_entry.user, reason=f"Tempban expired | Case number: {temp_ban.case_number}")

await self.db.case.set_tempban_expired(temp_ban.case_number, temp_ban.guild_id)
await self.db.case.insert_case(
guild_id=temp_ban.guild_id,
case_user_id=temp_ban.case_user_id,
case_moderator_id=temp_ban.case_moderator_id,
case_type=CaseType.UNTEMPBAN,
case_reason="Expired tempban",
case_tempban_expired=True,
)
logger.debug(f"Unbanned user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}")

except (discord.Forbidden, discord.HTTPException, discord.NotFound) as e:
logger.error(
f"Failed to unban user with ID {temp_ban.case_user_id} | Case number {temp_ban.case_number}. Error: {e}",
)


async def setup(bot: Tux) -> None:
await bot.add_cog(TempBan(bot))
56 changes: 55 additions & 1 deletion tux/database/controllers/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import UTC, datetime

from prisma.enums import CaseType
from prisma.models import Case, Guild
Expand Down Expand Up @@ -82,6 +82,7 @@ async def insert_case(
case_reason: str,
case_user_roles: list[int] | None = None,
case_expires_at: datetime | None = None,
case_tempban_expired: bool = False,
) -> Case:
"""
Insert a case into the database.
Expand All @@ -102,6 +103,8 @@ async def insert_case(
The roles of the target of the case.
case_expires_at : datetime | None
The expiration date of the case.
case_tempban_expired : bool
Whether the tempban has expired (Use only for tempbans).

Returns
-------
Expand All @@ -122,6 +125,7 @@ async def insert_case(
"case_reason": case_reason,
"case_expires_at": case_expires_at,
"case_user_roles": case_user_roles if case_user_roles is not None else [],
"case_tempban_expired": case_tempban_expired,
},
)

Expand Down Expand Up @@ -308,3 +312,53 @@ async def delete_case_by_number(self, guild_id: int, case_number: int) -> Case |
if case is not None:
return await self.table.delete(where={"case_id": case.case_id})
return None

async def get_expired_tempbans(self) -> list[Case]:
"""
Get all cases that have expired tempbans.

Returns
-------
list[Case]
A list of cases of the type in the guild.
"""
return await self.table.find_many(
where={
"case_type": CaseType.TEMPBAN,
"case_expires_at": {"lt": datetime.now(UTC)},
"case_tempban_expired": False,
},
)

async def set_tempban_expired(self, case_number: int | None, guild_id: int) -> int | None:
"""
Set a tempban case as expired.

Parameters
----------
case_number : int
The number of the case to update.
guild_id : int
The ID of the guild the case belongs to.

Returns
-------
Optional[int]
The number of Case records updated (1) if successful, None if no records were found,
or raises an exception if multiple records were affected.
"""
if case_number is None:
msg = "Case number not found"
raise ValueError(msg)

result = await self.table.update_many(
where={"case_number": case_number, "guild_id": guild_id},
data={"case_tempban_expired": True},
)
if result == 1:
return result
if result == 0:
return None

msg = f"Multiple records ({result}) were affected when updating case {case_number} in guild {guild_id}"
raise ValueError(msg)
3 changes: 3 additions & 0 deletions tux/database/controllers/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ async def insert_guild_by_id(self, guild_id: int) -> Guild:

async def delete_guild_by_id(self, guild_id: int) -> None:
await self.table.delete(where={"guild_id": guild_id})

async def get_all_guilds(self) -> list[Guild]:
return await self.table.find_many()
4 changes: 2 additions & 2 deletions tux/utils/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class TempBanFlags(commands.FlagConverter, case_insensitive=True, delimiter=" ",
default=MISSING,
)
expires_at: int = commands.flag(
name="expires_at",
name="duration",
description="Number of days the ban will last for.",
aliases=["t", "d", "e", "duration", "expires", "time"],
aliases=["t", "d", "e", "expires", "time"],
)
purge_days: int = commands.flag(
name="purge_days",
Expand Down