|
7 | 7 |
|
8 | 8 | import discord |
9 | 9 | import lavalink |
| 10 | +from aiohttp import ClientError as AiohttpClientError, ContentTypeError |
10 | 11 | from discord import app_commands |
11 | 12 | from discord.ext import commands |
| 13 | +from lavalink.errors import ClientError as LavalinkClientError |
12 | 14 |
|
13 | 15 | from src.services.lavalink_service import LavalinkVoiceClient |
14 | 16 | from src.utils.embeds import EmbedFactory |
|
17 | 19 | VOICE_PERMISSIONS = ("connect", "speak", "view_channel") |
18 | 20 |
|
19 | 21 |
|
| 22 | +class TrackLookupError(Exception): |
| 23 | + """Raised when Lavalink fails to provide a usable load result.""" |
| 24 | + |
| 25 | + |
20 | 26 | def ms_to_clock(ms: int) -> str: |
21 | 27 | """Convert milliseconds into a human readable duration string.""" |
22 | 28 | seconds = max(0, int(ms // 1000)) |
@@ -140,19 +146,35 @@ def _tag_tracks( |
140 | 146 | track.requester = requester_id |
141 | 147 | return tracks |
142 | 148 |
|
| 149 | + async def _safe_get_tracks(self, identifier: str) -> lavalink.LoadResult: |
| 150 | + """Fetch tracks from Lavalink with user-friendly error handling.""" |
| 151 | + try: |
| 152 | + return await self.bot.lavalink.get_tracks(identifier) |
| 153 | + except ContentTypeError as exc: |
| 154 | + status = getattr(exc, "status", "unknown") |
| 155 | + raise TrackLookupError( |
| 156 | + f"Lavalink returned an unexpected response (HTTP {status}). Please verify the node is reachable." |
| 157 | + ) from exc |
| 158 | + except (AiohttpClientError, LavalinkClientError) as exc: |
| 159 | + raise TrackLookupError( |
| 160 | + "Unable to reach the Lavalink node. Please try again in a few moments." |
| 161 | + ) from exc |
| 162 | + |
143 | 163 | async def _resolve(self, query: str) -> lavalink.LoadResult: |
144 | 164 | query = query.strip() |
| 165 | + if not query: |
| 166 | + raise TrackLookupError("Please provide a search query or URL.") |
145 | 167 | if URL_REGEX.match(query): |
146 | | - result = await self.bot.lavalink.get_tracks(query) |
| 168 | + result = await self._safe_get_tracks(query) |
147 | 169 | if result.tracks: |
148 | 170 | return result |
149 | 171 | last: Optional[lavalink.LoadResult] = None |
150 | 172 | for prefix in ("ytsearch", "scsearch", "amsearch"): |
151 | | - result = await self.bot.lavalink.get_tracks(f"{prefix}:{query}") |
| 173 | + result = await self._safe_get_tracks(f"{prefix}:{query}") |
152 | 174 | if result.tracks: |
153 | 175 | return result |
154 | 176 | last = result |
155 | | - return last or await self.bot.lavalink.get_tracks(query) |
| 177 | + return last or await self._safe_get_tracks(query) |
156 | 178 |
|
157 | 179 | async def _player(self, inter: discord.Interaction) -> Optional[lavalink.DefaultPlayer]: |
158 | 180 | factory = EmbedFactory(inter.guild.id if inter.guild else None) |
@@ -220,7 +242,11 @@ async def play(self, inter: discord.Interaction, query: str): |
220 | 242 | if not player: |
221 | 243 | return |
222 | 244 |
|
223 | | - results = await self._resolve(query) |
| 245 | + try: |
| 246 | + results = await self._resolve(query) |
| 247 | + except TrackLookupError as exc: |
| 248 | + await inter.followup.send(embed=factory.error(str(exc)), ephemeral=True) |
| 249 | + return |
224 | 250 | if results.load_type == "LOAD_FAILED": |
225 | 251 | return await inter.followup.send(embed=factory.error("Loading the track failed."), ephemeral=True) |
226 | 252 | if not results.tracks: |
|
0 commit comments