Skip to content

Commit

Permalink
Fix voice disconnect+connect race condition
Browse files Browse the repository at this point in the history
Fixes a race condition when disconnecting and immediately connecting
again.  Also fixes disconnect() being called twice.

Let me be clear, I DO NOT LIKE THIS SOLUTION.  I think it's dumb but I
don't see any other reasonable alternative.  There isn't a way to
transfer state to a new connection state object and I can't think of a
nice way to do it either.  That said, waiting an arbitrary amount of
time for an arbitrary websocket event doesn't seem like the right
solution either, but it's the best I can do at this point.
  • Loading branch information
imayhaveborkedit committed Dec 15, 2023
1 parent 50190e0 commit 9db0dad
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions discord/voice_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] =
self._expecting_disconnect: bool = False
self._connected = threading.Event()
self._state_event = asyncio.Event()
self._disconnected = asyncio.Event()
self._runner: Optional[asyncio.Task] = None
self._connector: Optional[asyncio.Task] = None
self._socket_reader = SocketReader(self)
Expand Down Expand Up @@ -254,8 +255,10 @@ async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
channel_id = data['channel_id']

if channel_id is None:
self._disconnected.set()

# If we know we're going to get a voice_state_update where we have no channel due to
# being in the reconnect flow, we ignore it. Otherwise, it probably wasn't from us.
# being in the reconnect or disconnect flow, we ignore it. Otherwise, it probably wasn't from us.
if self._expecting_disconnect:
self._expecting_disconnect = False
else:
Expand Down Expand Up @@ -419,9 +422,9 @@ async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None:
return

try:
await self._voice_disconnect()
if self.ws:
await self.ws.close()
await self._voice_disconnect()
except Exception:
_log.debug('Ignoring exception disconnecting from voice', exc_info=True)
finally:
Expand All @@ -436,11 +439,25 @@ async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None:

if cleanup:
self._socket_reader.stop()
self.voice_client.cleanup()

if self.socket:
self.socket.close()

# Skip this part if disconnect was called from the poll loop task
if self._runner and asyncio.current_task() != self._runner:
# Wait for the voice_state_update event confirming the bot left the voice channel.
# This prevents a race condition caused by disconnecting and immediately connecting again.
# The new VoiceConnectionState object receives the voice_state_update event containing channel=None while still
# connecting leaving it in a bad state. Since there's no nice way to transfer state to the new one, we have to do this.
try:
async with atimeout(self.timeout):
await self._disconnected.wait()
except TimeoutError:
_log.debug('Timed out waiting for disconnect confirmation event')

if cleanup:
self.voice_client.cleanup()

async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None:
_log.debug('Soft disconnecting from voice')
# Stop the websocket reader because closing the websocket will trigger an unwanted reconnect
Expand Down Expand Up @@ -524,6 +541,7 @@ async def _voice_disconnect(self) -> None:
self.state = ConnectionFlowState.disconnected
await self.voice_client.channel.guild.change_voice_state(channel=None)
self._expecting_disconnect = True
self._disconnected.clear()

async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook)
Expand Down Expand Up @@ -557,8 +575,10 @@ async def _poll_voice_ws(self, reconnect: bool) -> None:
# 4014 - we were externally disconnected (voice channel deleted, we were moved, etc)
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
# Don't call disconnect a second time if the websocket closed from a disconnect call
if not self._expecting_disconnect:
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break

if exc.code == 4014:
Expand Down Expand Up @@ -602,13 +622,17 @@ async def _potential_reconnect(self) -> bool:
)
except asyncio.TimeoutError:
return False

previous_ws = self.ws
try:
self.ws = await self._connect_websocket(False)
await self._handshake_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
finally:
await previous_ws.close()

async def _move_to(self, channel: abc.Snowflake) -> None:
await self.voice_client.channel.guild.change_voice_state(channel=channel)
Expand Down

0 comments on commit 9db0dad

Please sign in to comment.