diff --git a/twitchio/ext/commands/cooldowns.py b/twitchio/ext/commands/cooldowns.py index ded01efc..96c6cea9 100644 --- a/twitchio/ext/commands/cooldowns.py +++ b/twitchio/ext/commands/cooldowns.py @@ -47,22 +47,31 @@ class Bucket(enum.Enum): The default bucket. channel: :class:`enum.Enum` Cooldown is shared amongst all chatters per channel. + user: :class:`enum.Enum` + Cooldown operates on a per user basis across all channels. member: :class:`enum.Enum` Cooldown operates on a per channel basis per user. - user: :class:`enum.Enum` - Cooldown operates on a user basis across all channels. + turbo: :class:`enum.Enum` + Cooldown for turbo users. subscriber: :class:`enum.Enum` Cooldown for subscribers. + vip: :class:`enum.Enum` + Cooldown for VIPs. mod: :class:`enum.Enum` Cooldown for mods. + broadcaster: :class:`enum.Enum` + Cooldown for the broadcaster. """ default = 0 channel = 1 - member = 2 - user = 3 - subscriber = 4 - mod = 5 + user = 2 + member = 3 + turbo = 4 + subscriber = 5 + vip = 6 + mod = 7 + broadcaster = 8 class Cooldown: @@ -100,80 +109,90 @@ async def my_command(self, ctx: commands.Context): @commands.command() async def my_command(self, ctx: commands.Context): pass + + # Restrict a command to 5 times every 60 seconds globally for a user, + # 5 times every 30 seconds if the user is turbo, + # and 1 time every 1 second if they're the channel broadcaster + @commands.cooldown(rate=5, per=60, bucket=commands.Bucket.user) + @commands.cooldown(rate=5, per=30, bucket=commands.Bucket.turbo) + @commands.cooldown(rate=1, per=1, bucket=commands.Bucket.broadcaster) + @commands.command() + async def my_command(self, ctx: commands.Context): + pass """ - __slots__ = ("_rate", "_per", "bucket", "_window", "_tokens", "_cache") + __slots__ = ("_rate", "_per", "bucket", "_cache") - def __init__(self, rate: int, per: float, bucket: Bucket): + def __init__(self, rate: int, per: float, bucket: Bucket) -> None: self._rate = rate self._per = per self.bucket = bucket self._cache = {} - def update_bucket(self, ctx): - now = time.time() - - bucket_keys = self._bucket_keys(ctx) - buckets = [] - - for bucket in bucket_keys: - (tokens, window) = self._cache[bucket] - - if tokens == self._rate: - retry = self._per - (now - window) - raise CommandOnCooldown(command=ctx.command, retry_after=retry) + def _update_cooldown(self, bucket_key, now) -> int | None: + tokens = self._cache[bucket_key] - tokens += 1 + if len(tokens) == self._rate: + retry = self._per - (now - tokens[0]) + return retry - if tokens == self._rate: - window = now + tokens.append(now) - self._cache[bucket] = (tokens, window) - - def reset(self): + def reset(self) -> None: self._cache = {} - def _bucket_keys(self, ctx): - buckets = [] - - for bucket in ctx.command._cooldowns: - if bucket.bucket == Bucket.default: - buckets.append("default") - - if bucket.bucket == Bucket.channel: - buckets.append(ctx.channel.name) - - if bucket.bucket == Bucket.member: - buckets.append((ctx.channel.name, ctx.author.id)) - if bucket.bucket == Bucket.user: - buckets.append(ctx.author.id) - - if bucket.bucket == Bucket.subscriber: - buckets.append((ctx.channel.name, ctx.author.id, 0)) - if bucket.bucket == Bucket.mod: - buckets.append((ctx.channel.name, ctx.author.id, 1)) - - return buckets - - def _update_cache(self, now=None): - now = now or time.time() - dead = [key for key, cooldown in self._cache.items() if now > cooldown[1] + self._per] - - for bucket in dead: - del self._cache[bucket] - - def get_buckets(self, ctx): + def _bucket_key(self, ctx): + key = None + + if self.bucket == Bucket.default: + key = "default" + elif self.bucket == Bucket.channel: + key = ctx.channel.name + elif self.bucket == Bucket.user: + key = ctx.author.id + elif self.bucket == Bucket.member: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.turbo and ctx.author.is_turbo: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.subscriber and ctx.author.is_subscriber: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.vip and ctx.author.is_vip: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.mod and ctx.author.is_mod: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.broadcaster and ctx.author.is_broadcaster: + key = (ctx.channel.name, ctx.author.id) + + return key + + def _update_cache(self, now) -> None: + expired_bucket_keys = [] + + for bucket_key, tokens in self._cache.items(): + expired_tokens = [] + + for token in tokens: + if now - token > self._per: + expired_tokens.append(token) + + for expired_token in expired_tokens: + tokens.remove(expired_token) + + if not tokens: + expired_bucket_keys.append(bucket_key) + + for expired_bucket_key in expired_bucket_keys: + del self._cache[expired_bucket_key] + + def on_cooldown(self, ctx) -> int | None: now = time.time() self._update_cache(now) - bucket_keys = self._bucket_keys(ctx) - buckets = [] - - for index, bucket in enumerate(bucket_keys): - buckets.append(ctx.command._cooldowns[index]) - if bucket not in self._cache: - self._cache[bucket] = (0, now) + bucket_key = self._bucket_key(ctx) + if bucket_key: + if not bucket_key in self._cache: + self._cache[bucket_key] = [] - return buckets + return self._update_cooldown(bucket_key, now) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 2b40bdfb..9337eb31 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -356,7 +356,8 @@ async def try_run(func, *, to_command=False): limited = self._run_cooldowns(context) if limited: - context.bot.run_event("command_error", context, limited[0]) + e = CommandOnCooldown(command=context.command, retry_after=limited) + context.bot.run_event("command_error", context, e) return instance = self._instance args = [instance, context] if instance else [context] @@ -377,19 +378,16 @@ async def try_run(func, *, to_command=False): await try_run(self._after_invoke(*args), to_command=True) await try_run(context.bot.global_after_invoke(context)) - def _run_cooldowns(self, context: Context) -> Optional[List[CommandOnCooldown]]: - try: - buckets = self._cooldowns[0].get_buckets(context) - except IndexError: + def _run_cooldowns(self, context: Context) -> Optional[int]: + if not self._cooldowns: return None - expired = [] - try: - for bucket in buckets: - bucket.update_bucket(context) - except CommandOnCooldown as e: - expired.append(e) - return expired + retries = [] + for c in self._cooldowns: + retry = c.on_cooldown(context) + retries.append(retry) + if all(retries): + return min(retries) async def handle_checks(self, context: Context) -> Union[Literal[True], Exception]: # TODO Docs