Skip to content

Commit

Permalink
Rewrite voice connection internals
Browse files Browse the repository at this point in the history
  • Loading branch information
imayhaveborkedit committed Sep 28, 2023
1 parent 5559403 commit 44284ae
Show file tree
Hide file tree
Showing 5 changed files with 735 additions and 279 deletions.
4 changes: 2 additions & 2 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ def _get_voice_state_pair(self) -> Tuple[int, int]:
async def connect(
self,
*,
timeout: float = 60.0,
timeout: float = 30.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = VoiceClient,
self_deaf: bool = False,
Expand All @@ -1858,7 +1858,7 @@ async def connect(
Parameters
-----------
timeout: :class:`float`
The timeout in seconds to wait for the voice endpoint.
The timeout in seconds to wait the connection to complete.
reconnect: :class:`bool`
Whether the bot should automatically attempt
a reconnect if a part of the handshake fails
Expand Down
74 changes: 49 additions & 25 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import traceback
import zlib

from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple

import aiohttp
import yarl
Expand All @@ -59,7 +59,7 @@

from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
from .voice_state import VoiceConnectionState


class ReconnectWebSocket(Exception):
Expand Down Expand Up @@ -797,7 +797,7 @@ class DiscordVoiceWebSocket:

if TYPE_CHECKING:
thread_id: int
_connection: VoiceClient
_connection: VoiceConnectionState
gateway: str
_max_heartbeat_timeout: float

Expand Down Expand Up @@ -866,16 +866,21 @@ async def identify(self) -> None:
await self.send_as_json(payload)

@classmethod
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
async def from_connection_state(
cls,
state: VoiceConnectionState,
*,
resume: bool = False,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = f'wss://{state.endpoint}/?v=4'
client = state.voice_client
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
ws.gateway = gateway
ws._connection = client
ws._connection = state
ws._max_heartbeat_timeout = 60.0
ws.thread_id = threading.get_ident()

Expand Down Expand Up @@ -951,29 +956,49 @@ async def initial_connection(self, data: Dict[str, Any]) -> None:
state.voice_port = data['port']
state.endpoint_ip = data['ip']

_log.debug('Connecting to voice socket')
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))

state.ip, state.port = await self.discover_ip()
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ', '.join(modes))

mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)

async def discover_ip(self) -> Tuple[str, int]:
state = self._connection
packet = bytearray(74)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 74)
_log.debug('received packet in initial_connection: %s', recv)

_log.debug('Sending ip discovery packet')
await self.loop.sock_sendall(state.socket, packet)

fut: asyncio.Future[bytes] = self.loop.create_future()

def get_ip_packet(data: bytes):
if data[1] == 0x02 and len(data) == 74:
self.loop.call_soon_threadsafe(fut.set_result, data)

fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
state.add_socket_listener(get_ip_packet)
recv = await fut

_log.debug('Received ip discovery packet: %s', recv)

# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
ip = recv[ip_start:ip_end].decode('ascii')

state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', ip, port)

# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))

mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)
return ip, port

@property
def latency(self) -> float:
Expand All @@ -995,9 +1020,8 @@ async def load_secret_key(self, data: Dict[str, Any]) -> None:
self.secret_key = self._connection.secret_key = data['secret_key']

# Send a speak command with the "not speaking" state.
# This also tells Discord our SSRC value, which Discord requires
# before sending any voice data (and is the real reason why we
# call this here).
# This also tells Discord our SSRC value, which Discord requires before
# sending any voice data (and is the real reason why we call this here).
await self.speak(SpeakingState.none)

async def poll_event(self) -> None:
Expand All @@ -1006,10 +1030,10 @@ async def poll_event(self) -> None:
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)

async def close(self, code: int = 1000) -> None:
Expand Down
30 changes: 18 additions & 12 deletions discord/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,6 @@ def __init__(
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock()

if after is not None and not callable(after):
Expand All @@ -714,7 +713,8 @@ def _do_run(self) -> None:
self._start = time.perf_counter()

# getattr lookup speed ups
play_audio = self.client.send_audio_packet
client = self.client
play_audio = client.send_audio_packet
self._speak(SpeakingState.voice)

while not self._end.is_set():
Expand All @@ -725,22 +725,28 @@ def _do_run(self) -> None:
self._resumed.wait()
continue

# are we disconnected from voice?
if not self._connected.is_set():
# wait until we are connected
self._connected.wait()
# reset our internal data
self.loops = 0
self._start = time.perf_counter()

self.loops += 1
data = self.source.read()

if not data:
self.stop()
break

# are we disconnected from voice?

This comment has been minimized.

Copy link
@issamansur

issamansur Sep 28, 2023

interesting

if not client.is_connected():
_log.debug('Not connected, waiting for %ss...', client.timeout)
# wait until we are connected, but not forever
connected = client.wait_until_connected(client.timeout)
if self._end.is_set() or not connected:
_log.debug('Aborting playback')
return
_log.debug('Reconnected, resuming playback')
self._speak(SpeakingState.voice)
# reset our internal data
self.loops = 0
self._start = time.perf_counter()

play_audio(data, encode=not self.source.is_opus())
self.loops += 1
next_time = self._start + self.DELAY * self.loops
delay = max(0, self.DELAY + (next_time - time.perf_counter()))
time.sleep(delay)
Expand Down Expand Up @@ -792,7 +798,7 @@ def is_playing(self) -> bool:
def is_paused(self) -> bool:
return not self._end.is_set() and not self._resumed.is_set()

def _set_source(self, source: AudioSource) -> None:
def set_source(self, source: AudioSource) -> None:
with self._lock:
self.pause(update_speaking=False)
self.source = source
Expand Down

1 comment on commit 44284ae

@TruncatedDinoSour
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

damnn

Please sign in to comment.