From 2bd0ae7c18bda4c7f0b949fbd05093280337d37b Mon Sep 17 00:00:00 2001 From: Ethan Coward Date: Fri, 20 Oct 2023 21:24:53 +0100 Subject: [PATCH] Add handling for new link criteria --- bot.py | 5 ++- cogs/voicestate.py | 81 +++++++++++++++++++++++++++++++++++++------- prisma/schema.prisma | 1 + requirements.txt | 3 +- utils/checks.py | 2 +- utils/database.py | 3 +- 6 files changed, 75 insertions(+), 20 deletions(-) diff --git a/bot.py b/bot.py index 1e2819f..869f3f7 100644 --- a/bot.py +++ b/bot.py @@ -119,7 +119,6 @@ async def on_command_error( async def main(): - # Removing TTS Files for filename in os.listdir("tts"): @@ -139,10 +138,11 @@ async def main(): f.write("") async with client: - # Setting up topgg integration if using_topgg: + import topgg + client.topgg_webhook = ( topgg.WebhookManager().set_data(client).endpoint(dbl_endpoint) ) @@ -162,5 +162,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) diff --git a/cogs/voicestate.py b/cogs/voicestate.py index 65e7d82..ad12b13 100644 --- a/cogs/voicestate.py +++ b/cogs/voicestate.py @@ -1,6 +1,9 @@ +from typing import Any import discord from discord.ext import commands from prisma.enums import LinkType +from prisma.models import Link +from pydantic import Json from utils import VCRolesClient from utils.types import ( @@ -31,7 +34,7 @@ async def on_voice_state_update( # Joining if not before.channel and after.channel: - roles_changed, failed_roles = await self.join(member, after) + roles_changed, failed_roles = await self.handle_join(member, after) if failed_roles: self.client.log( @@ -48,7 +51,7 @@ async def on_voice_state_update( # Leaving elif before.channel and not after.channel: - roles_changed, failed_roles = await self.leave(member, before) + roles_changed, failed_roles = await self.handle_leave(member, before) if failed_roles: self.client.log( @@ -65,10 +68,11 @@ async def on_voice_state_update( # Changing elif before.channel and after.channel and before.channel != after.channel: - - leave_roles_changed, join_roles_changed, failed_roles = await self.change( - member, before, after - ) + ( + leave_roles_changed, + join_roles_changed, + failed_roles, + ) = await self.handle_change(member, before, after) if failed_roles: self.client.log( @@ -126,7 +130,42 @@ async def on_voice_state_update( except discord.errors.HTTPException: pass - async def join( + def check_link_criteria(self, link: Link, member: discord.Member) -> bool: + # Check if the criteria is empty, return True if it is + if not link.linkCriteria or bool(link.linkCriteria) == False: + return True + + # Define a recursive function to evaluate criteria + def evaluate_criteria(criteria: Json[Any]) -> bool: + # Check for "and," "or," and "not" keys in the criteria + if "and" in criteria: + return all(evaluate_criteria(c) for c in criteria["and"]) + if "or" in criteria: + return any(evaluate_criteria(c) for c in criteria["or"]) + if "not" in criteria: + return not evaluate_criteria(criteria["not"]) + + # If none of the special keys are present, you can now check the criteria + if "hasRole" in criteria: + role_id = criteria["hasRole"] + return any(role.id == role_id for role in member.roles) + if "hasPermission" in criteria: + permission = criteria["hasPermission"] + return any( + p[0] == permission.lower() and p[1] + for p in member.guild_permissions + ) + if "isUser" in criteria: + user_id = criteria["isUser"] + return str(member.id) == user_id + + # If no recognizable criteria are found, return False + return False + + # Start the evaluation with the top-level criteria + return evaluate_criteria(link.linkCriteria) + + async def handle_join( self, member: discord.Member, after: discord.VoiceState, @@ -135,7 +174,7 @@ async def join( # Unreachable. return [], [] - links = await self.client.db.get_all_linked_channel( + links = await self.client.db.get_all_links_for_channel( member.guild.id, after.channel.id, after.channel.category.id if after.channel.category else None, @@ -150,6 +189,10 @@ async def join( for link in links: if str(after.channel.id) in link.excludeChannels: continue + + if not self.check_link_criteria(link, member): + continue + addable_roles.extend(link.linkedRoles) removeable_roles.extend(link.reverseLinkedRoles) return_data.append( @@ -302,7 +345,7 @@ async def join( return return_data, list(set(failed_roles)) - async def leave( + async def handle_leave( self, member: discord.Member, before: discord.VoiceState, @@ -311,7 +354,7 @@ async def leave( # Unreachable. return [], [] - links = await self.client.db.get_all_linked_channel( + links = await self.client.db.get_all_links_for_channel( member.guild.id, before.channel.id, before.channel.category.id if before.channel.category else None, @@ -328,6 +371,10 @@ async def leave( continue if link.type == LinkType.PERMANENT: continue + + if not self.check_link_criteria(link, member): + continue + addable_roles.extend(link.reverseLinkedRoles) removeable_roles.extend(link.linkedRoles) return_data.append( @@ -473,7 +520,7 @@ async def leave( return return_data, list(set(failed_roles)) - async def change( + async def handle_change( self, member: discord.Member, before: discord.VoiceState, @@ -485,13 +532,13 @@ async def change( # Unreachable. return [], [], [] - before_links = await self.client.db.get_all_linked_channel( + before_links = await self.client.db.get_all_links_for_channel( member.guild.id, before.channel.id, before.channel.category.id if before.channel.category else None, ) - after_links = await self.client.db.get_all_linked_channel( + after_links = await self.client.db.get_all_links_for_channel( member.guild.id, after.channel.id, after.channel.category.id if after.channel.category else None, @@ -511,6 +558,10 @@ async def change( continue if link.type == LinkType.PERMANENT: continue + + if not self.check_link_criteria(link, member): + continue + addable_roles.extend(link.reverseLinkedRoles) removeable_roles.extend(link.linkedRoles) leave_return_data.append( @@ -526,6 +577,10 @@ async def change( for link in after_links: if str(after.channel.id) in link.excludeChannels: continue + + if not self.check_link_criteria(link, member): + continue + addable_roles.extend(link.linkedRoles) removeable_roles.extend(link.reverseLinkedRoles) join_return_data.append( diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 11d8aa5..89cce3e 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -52,6 +52,7 @@ model Link { suffix String? speakerRoles String[] excludeChannels String[] + linkCriteria Json @default("{}") @@id([id, type]) } diff --git a/requirements.txt b/requirements.txt index 0242fd3..dd45af0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,10 @@ mutagen git+https://github.com/top-gg/python-sdk/ # topggpy pynacl redis +types-redis pre-commit requests -prisma==0.7.1 +prisma cachetools types-cachetools asyncache diff --git a/utils/checks.py b/utils/checks.py index 3d75d09..e342059 100644 --- a/utils/checks.py +++ b/utils/checks.py @@ -39,7 +39,7 @@ async def command_available(interaction: Interaction) -> bool: premium = await client.ar.hget("premium", str(interaction.user.id)) if premium and str(premium) == "1": return True - except redis.exceptions.RedisError: + except redis.RedisError: pass cmds_count = await client.ar.hget("commands", str(interaction.user.id)) diff --git a/utils/database.py b/utils/database.py index 2e5db83..03be60e 100644 --- a/utils/database.py +++ b/utils/database.py @@ -434,13 +434,12 @@ async def delete_generator( except KeyError: pass - async def get_all_linked_channel( + async def get_all_links_for_channel( self, guild_id: DiscordID, channel_id: DiscordID, category_id: Optional[DiscordID] = None, ) -> List[Link]: - if category_id: s = [str(channel_id), str(category_id), str(guild_id)] else: