diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index 828d069a01e..8135f8e7e81 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -31,6 +31,7 @@ from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate from urllib.parse import urlparse from .manager import ServerManager +from .errors import LavalinkDownloadFailed _ = Translator("Audio", __file__) @@ -91,6 +92,7 @@ def __init__(self, bot): self._connect_task = None self._disconnect_task = None self._cleaned_up = False + self._connection_aborted = False self.spotify_token = None self.play_lock = {} @@ -121,7 +123,10 @@ def _restart_connect(self): self._connect_task = self.bot.loop.create_task(self.attempt_connect()) async def attempt_connect(self, timeout: int = 30): - while True: # run until success + self._connection_aborted = False + max_retries = 5 + retry_count = 0 + while retry_count < max_retries: external = await self.config.use_external_lavalink() if external is False: settings = self._default_lavalink_settings @@ -134,21 +139,52 @@ async def attempt_connect(self, timeout: int = 30): self._manager = ServerManager() try: await self._manager.start() - except RuntimeError as exc: - log.exception( - "Exception whilst starting internal Lavalink server, retrying...", - exc_info=exc, - ) + except LavalinkDownloadFailed as exc: await asyncio.sleep(1) - continue + if exc.should_retry: + log.exception( + "Exception whilst starting internal Lavalink server, retrying...", + exc_info=exc, + ) + retry_count += 1 + continue + else: + log.exception( + "Fatal exception whilst starting internal Lavalink server, " + "aborting...", + exc_info=exc, + ) + self._connection_aborted = True + raise except asyncio.CancelledError: log.exception("Invalid machine architecture, cannot run Lavalink.") raise + except Exception as exc: + log.exception( + "Unhandled exception whilst starting internal Lavalink server, " + "aborting...", + exc_info=exc, + ) + self._connection_aborted = True + raise + else: + break else: host = await self.config.host() password = await self.config.password() rest_port = await self.config.rest_port() ws_port = await self.config.ws_port() + break + else: + log.critical( + "Setting up the Lavalink server failed after multiple attempts. See above " + "tracebacks for details." + ) + self._connection_aborted = True + return + + retry_count = 0 + while retry_count < max_retries: try: await lavalink.initialize( bot=self.bot, @@ -158,12 +194,26 @@ async def attempt_connect(self, timeout: int = 30): ws_port=ws_port, timeout=timeout, ) - return # break infinite loop except asyncio.TimeoutError: log.error("Connecting to Lavalink server timed out, retrying...") if external is False and self._manager is not None: await self._manager.shutdown() + retry_count += 1 await asyncio.sleep(1) # prevent busylooping + except Exception as exc: + log.exception( + "Unhandled exception whilst connecting to Lavalink, aborting...", exc_info=exc + ) + self._connection_aborted = True + raise + else: + break + else: + self._connection_aborted = True + log.critical( + "Connecting to the Lavalink server failed after multiple attempts. See above " + "tracebacks for details." + ) async def event_handler(self, player, event_type, extra): disconnect = await self.config.guild(player.channel.guild).disconnect() @@ -1160,6 +1210,11 @@ async def play(self, ctx, *, query): if not url_check: return await self._embed_msg(ctx, _("That URL is not allowed.")) if not self._player_check(ctx): + if self._connection_aborted: + msg = _("Connection to Lavalink has failed.") + if await ctx.bot.is_owner(ctx.author): + msg += " " + _("Please check your console or logs for details.") + return await self._embed_msg(ctx, msg) try: if ( not ctx.author.voice.channel.permissions_for(ctx.me).connect @@ -2096,15 +2151,22 @@ async def _playlist_check(self, ctx): await self._embed_msg(ctx, _("You need the DJ role to use playlists.")) return False if not self._player_check(ctx): + if self._connection_aborted: + msg = _("Connection to Lavalink has failed.") + if await ctx.bot.is_owner(ctx.author): + msg += " " + _("Please check your console or logs for details.") + await self._embed_msg(ctx, msg) + return False try: if ( not ctx.author.voice.channel.permissions_for(ctx.me).connect or not ctx.author.voice.channel.permissions_for(ctx.me).move_members and self._userlimit(ctx.author.voice.channel) ): - return await self._embed_msg( + await self._embed_msg( ctx, _("I don't have permission to connect to your channel.") ) + return False await lavalink.connect(ctx.author.voice.channel) player = lavalink.get_player(ctx.guild.id) player.store("connect", datetime.datetime.utcnow()) @@ -2560,6 +2622,11 @@ async def _search_menu( } if not self._player_check(ctx): + if self._connection_aborted: + msg = _("Connection to Lavalink has failed.") + if await ctx.bot.is_owner(ctx.author): + msg += " " + _("Please check your console or logs for details.") + return await self._embed_msg(ctx, msg) try: if ( not ctx.author.voice.channel.permissions_for(ctx.me).connect @@ -2673,6 +2740,11 @@ async def _search_menu( async def _search_button_action(self, ctx, tracks, emoji, page): if not self._player_check(ctx): + if self._connection_aborted: + msg = _("Connection to Lavalink has failed.") + if await ctx.bot.is_owner(ctx.author): + msg += " " + _("Please check your console or logs for details.") + return await self._embed_msg(ctx, msg) try: await lavalink.connect(ctx.author.voice.channel) player = lavalink.get_player(ctx.guild.id) @@ -3493,8 +3565,9 @@ def _play_lock(self, ctx, tf): else: self.play_lock[ctx.message.guild.id] = False - @staticmethod - def _player_check(ctx): + def _player_check(self, ctx: commands.Context): + if self._connection_aborted: + return False try: lavalink.get_player(ctx.guild.id) return True diff --git a/redbot/cogs/audio/errors.py b/redbot/cogs/audio/errors.py new file mode 100644 index 00000000000..9785a9b82d3 --- /dev/null +++ b/redbot/cogs/audio/errors.py @@ -0,0 +1,33 @@ +import aiohttp + + +class AudioError(Exception): + """Base exception for errors in the Audio cog.""" + + +class LavalinkDownloadFailed(AudioError, RuntimeError): + """Downloading the Lavalink jar failed. + + Attributes + ---------- + response : aiohttp.ClientResponse + The response from the server to the failed GET request. + should_retry : bool + Whether or not the Audio cog should retry downloading the jar. + + """ + + def __init__(self, *args, response: aiohttp.ClientResponse, should_retry: bool = False): + super().__init__(*args) + self.response = response + self.should_retry = should_retry + + def __repr__(self) -> str: + str_args = [*map(str, self.args), self._response_repr()] + return f"LavalinkDownloadFailed({', '.join(str_args)}" + + def __str__(self) -> str: + return f"{super().__str__()} {self._response_repr()}" + + def _response_repr(self) -> str: + return f"[{self.response.status} {self.response.reason}]" diff --git a/redbot/cogs/audio/manager.py b/redbot/cogs/audio/manager.py index 3c9b7a501e1..db4c140b23e 100644 --- a/redbot/cogs/audio/manager.py +++ b/redbot/cogs/audio/manager.py @@ -6,12 +6,15 @@ import asyncio.subprocess # disables for # https://github.com/PyCQA/pylint/issues/1469 import logging import re +import sys import tempfile from typing import Optional, Tuple, ClassVar, List import aiohttp +from tqdm import tqdm from redbot.core import data_manager +from .errors import LavalinkDownloadFailed JAR_VERSION = "3.2.0.3" JAR_BUILD = 796 @@ -200,22 +203,45 @@ async def _download_jar() -> None: async with aiohttp.ClientSession() as session: async with session.get(LAVALINK_DOWNLOAD_URL) as response: if response.status == 404: - raise RuntimeError( - f"Lavalink jar version {JAR_VERSION}_{JAR_BUILD} hasn't been published" + # A 404 means our LAVALINK_DOWNLOAD_URL is invalid, so likely the jar version + # hasn't been published yet + raise LavalinkDownloadFailed( + f"Lavalink jar version {JAR_VERSION}_{JAR_BUILD} hasn't been published " + f"yet", + response=response, + should_retry=False, ) + elif 400 <= response.status < 600: + # Other bad responses should be raised but we should retry just incase + raise LavalinkDownloadFailed(response=response, should_retry=True) fd, path = tempfile.mkstemp() file = open(fd, "wb") - try: - chunk = await response.content.read(1024) - while chunk: - file.write(chunk) + nbytes = 0 + with tqdm( + desc="Lavalink.jar", + total=response.content_length, + file=sys.stdout, + unit="B", + unit_scale=True, + miniters=1, + dynamic_ncols=True, + leave=False, + ) as progress_bar: + try: chunk = await response.content.read(1024) - file.flush() - finally: - file.close() + while chunk: + chunk_size = file.write(chunk) + nbytes += chunk_size + progress_bar.update(chunk_size) + chunk = await response.content.read(1024) + file.flush() + finally: + file.close() shutil.move(path, str(LAVALINK_JAR_FILE), copy_function=shutil.copyfile) + log.info("Successfully downloaded Lavalink.jar (%s bytes written)", format(nbytes, ",")) + @classmethod async def _is_up_to_date(cls): if cls._up_to_date is True: diff --git a/setup.cfg b/setup.cfg index 3f6ceff495c..ed3044aa3db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = pyyaml==3.13 red-lavalink>=0.3.0,<0.4 schema==0.6.8 + tqdm==4.32.1 yarl==1.3.0 discord.py==1.0.1 websockets<7