diff --git a/discord/__main__.py b/discord/__main__.py index 6e34be54cbc2..843274b53a76 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -157,7 +157,7 @@ async def bot_check_once(self, ctx): async def cog_command_error(self, ctx, error): # error handling to every command in here pass - + async def cog_app_command_error(self, interaction, error): # error handling to every application command in here pass diff --git a/discord/abc.py b/discord/abc.py index 71eaff6ab62c..4c1f24618349 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1842,7 +1842,7 @@ def _get_voice_state_pair(self) -> Tuple[int, int]: async def connect( self, *, - timeout: float = 60.0, + timeout: float = 30.0, reconnect: bool = True, cls: Callable[[Client, Connectable], T] = VoiceClient, self_deaf: bool = False, @@ -1858,7 +1858,7 @@ async def connect( Parameters ----------- timeout: :class:`float` - The timeout in seconds to wait for the voice endpoint. + The timeout in seconds to wait the connection to complete. reconnect: :class:`bool` Whether the bot should automatically attempt a reconnect if a part of the handshake fails diff --git a/discord/activity.py b/discord/activity.py index 534d12a2b85d..82b979b2417b 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -732,7 +732,9 @@ class CustomActivity(BaseActivity): __slots__ = ('name', 'emoji', 'state') - def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None: + def __init__( + self, name: Optional[str], *, emoji: Optional[Union[PartialEmoji, Dict[str, Any], str]] = None, **extra: Any + ) -> None: super().__init__(**extra) self.name: Optional[str] = name self.state: Optional[str] = extra.pop('state', name) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 0766475ecfb4..6c2aae7b8df6 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -1548,6 +1548,9 @@ def __init__( if not self.description: raise TypeError('groups must have a description') + if not self.name: + raise TypeError('groups must have a name') + self.parent: Optional[Group] = parent self.module: Optional[str] if cls.__discord_app_commands_has_module__: diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index 3cc12c72d8cb..dc63f10e8c88 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -28,6 +28,7 @@ from ..enums import AppCommandOptionType, AppCommandType, Locale from ..errors import DiscordException, HTTPException, _flatten_error_dict +from ..utils import _human_join __all__ = ( 'AppCommandError', @@ -242,13 +243,7 @@ class MissingAnyRole(CheckFailure): def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles - missing = [f"'{role}'" for role in missing_roles] - - if len(missing) > 2: - fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1]) - else: - fmt = ' or '.join(missing) - + fmt = _human_join([f"'{role}'" for role in missing_roles]) message = f'You are missing at least one of the required roles: {fmt}' super().__init__(message) @@ -271,11 +266,7 @@ def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] - - if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) - else: - fmt = ' and '.join(missing) + fmt = _human_join(missing, final='and') message = f'You are missing {fmt} permission(s) to run this command.' super().__init__(message, *args) @@ -298,11 +289,7 @@ def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] - - if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) - else: - fmt = ' and '.join(missing) + fmt = _human_join(missing, final='and') message = f'Bot requires {fmt} permission(s) to run this command.' super().__init__(message, *args) @@ -530,8 +517,18 @@ def __init__(self, child: HTTPException, commands: List[CommandTypes]) -> None: messages = [f'Failed to upload commands to Discord (HTTP status {self.status}, error code {self.code})'] if self._errors: - for index, inner in self._errors.items(): - _get_command_error(index, inner, commands, messages) + # Handle case where the errors dict has no actual chain such as APPLICATION_COMMAND_TOO_LARGE + if len(self._errors) == 1 and '_errors' in self._errors: + errors = self._errors['_errors'] + if len(errors) == 1: + extra = errors[0].get('message') + if extra: + messages[0] += f': {extra}' + else: + messages.extend(f'Error {e.get("code", "")}: {e.get("message", "")}' for e in errors) + else: + for index, inner in self._errors.items(): + _get_command_error(index, inner, commands, messages) # Equivalent to super().__init__(...) but skips other constructors self.args = ('\n'.join(messages),) diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index dfdbefa23afc..59b3af758310 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -525,7 +525,7 @@ class Transform: .. versionadded:: 2.0 """ - def __class_getitem__(cls, items) -> _TransformMetadata: + def __class_getitem__(cls, items) -> Transformer: if not isinstance(items, tuple): raise TypeError(f'expected tuple for arguments, received {items.__class__.__name__} instead') @@ -570,7 +570,7 @@ async def range(interaction: discord.Interaction, value: app_commands.Range[int, await interaction.response.send_message(f'Your value is {value}', ephemeral=True) """ - def __class_getitem__(cls, obj) -> _TransformMetadata: + def __class_getitem__(cls, obj) -> RangeTransformer: if not isinstance(obj, tuple): raise TypeError(f'expected tuple for arguments, received {obj.__class__.__name__} instead') diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 5bdfbec58ddd..c75682e0ea2c 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -1240,7 +1240,7 @@ async def _call(self, interaction: Interaction[ClientT]) -> None: await command._invoke_autocomplete(interaction, focused, namespace) except Exception: # Suppress exception since it can't be handled anyway. - pass + _log.exception('Ignoring exception in autocomplete for %r', command.qualified_name) return diff --git a/discord/channel.py b/discord/channel.py index c31a6af0d73d..02f2de075cc9 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -1600,6 +1600,7 @@ async def create_instance( topic: str, privacy_level: PrivacyLevel = MISSING, send_start_notification: bool = False, + scheduled_event: Snowflake = MISSING, reason: Optional[str] = None, ) -> StageInstance: """|coro| @@ -1621,6 +1622,10 @@ async def create_instance( You must have :attr:`~Permissions.mention_everyone` to do this. .. versionadded:: 2.3 + scheduled_event: :class:`~discord.abc.Snowflake` + The guild scheduled event associated with the stage instance. + + .. versionadded:: 2.4 reason: :class:`str` The reason the stage instance was created. Shows up on the audit log. @@ -1647,6 +1652,9 @@ async def create_instance( payload['privacy_level'] = privacy_level.value + if scheduled_event is not MISSING: + payload['guild_scheduled_event_id'] = scheduled_event.id + payload['send_start_notification'] = send_start_notification data = await self._state.http.create_stage_instance(**payload, reason=reason) diff --git a/discord/components.py b/discord/components.py index 6a834580134c..297f815fe304 100644 --- a/discord/components.py +++ b/discord/components.py @@ -25,7 +25,7 @@ from __future__ import annotations from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload -from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType +from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType, SelectDefaultValueType from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag @@ -40,8 +40,10 @@ ActionRow as ActionRowPayload, TextInput as TextInputPayload, ActionRowChildComponent as ActionRowChildComponentPayload, + SelectDefaultValues as SelectDefaultValuesPayload, ) from .emoji import Emoji + from .abc import Snowflake ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput'] @@ -53,6 +55,7 @@ 'SelectMenu', 'SelectOption', 'TextInput', + 'SelectDefaultValue', ) @@ -263,6 +266,7 @@ class SelectMenu(Component): 'options', 'disabled', 'channel_types', + 'default_values', ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ @@ -276,6 +280,9 @@ def __init__(self, data: SelectMenuPayload, /) -> None: self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] self.disabled: bool = data.get('disabled', False) self.channel_types: List[ChannelType] = [try_enum(ChannelType, t) for t in data.get('channel_types', [])] + self.default_values: List[SelectDefaultValue] = [ + SelectDefaultValue.from_dict(d) for d in data.get('default_values', []) + ] def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { @@ -291,6 +298,8 @@ def to_dict(self) -> SelectMenuPayload: payload['options'] = [op.to_dict() for op in self.options] if self.channel_types: payload['channel_types'] = [t.value for t in self.channel_types] + if self.default_values: + payload["default_values"] = [v.to_dict() for v in self.default_values] return payload @@ -512,6 +521,79 @@ def default(self) -> Optional[str]: return self.value +class SelectDefaultValue: + """Represents a select menu's default value. + + These can be created by users. + + .. versionadded:: 2.4 + + Parameters + ----------- + id: :class:`int` + The id of a role, user, or channel. + type: :class:`SelectDefaultValueType` + The type of value that ``id`` represents. + """ + + def __init__( + self, + *, + id: int, + type: SelectDefaultValueType, + ) -> None: + self.id: int = id + self._type: SelectDefaultValueType = type + + @property + def type(self) -> SelectDefaultValueType: + return self._type + + @type.setter + def type(self, value: SelectDefaultValueType) -> None: + if not isinstance(value, SelectDefaultValueType): + raise TypeError(f'expected SelectDefaultValueType, received {value.__class__.__name__} instead') + + self._type = value + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_dict(cls, data: SelectDefaultValuesPayload) -> SelectDefaultValue: + return cls( + id=data['id'], + type=try_enum(SelectDefaultValueType, data['type']), + ) + + def to_dict(self) -> SelectDefaultValuesPayload: + return { + 'id': self.id, + 'type': self._type.value, + } + + @classmethod + def from_channel(cls, channel: Snowflake, /) -> Self: + return cls( + id=channel.id, + type=SelectDefaultValueType.channel, + ) + + @classmethod + def from_role(cls, role: Snowflake, /) -> Self: + return cls( + id=role.id, + type=SelectDefaultValueType.role, + ) + + @classmethod + def from_user(cls, user: Snowflake, /) -> Self: + return cls( + id=user.id, + type=SelectDefaultValueType.user, + ) + + @overload def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]: ... diff --git a/discord/enums.py b/discord/enums.py index c0a2c3f43572..254f86bc789d 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -69,6 +69,7 @@ 'AutoModRuleActionType', 'ForumLayoutType', 'ForumOrderType', + 'SelectDefaultValueType', ) if TYPE_CHECKING: @@ -772,6 +773,12 @@ class ForumOrderType(Enum): creation_date = 1 +class SelectDefaultValueType(Enum): + user = 'user' + role = 'role' + channel = 'channel' + + def create_unknown_value(cls: Type[E], val: Any) -> E: value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below name = f'unknown_{val}' diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index f2004c93e2fe..1ffe25c702bb 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -25,6 +25,7 @@ import inspect import discord +import logging from discord import app_commands from discord.utils import maybe_coroutine, _to_kebab_case @@ -65,6 +66,7 @@ FuncT = TypeVar('FuncT', bound=Callable[..., Any]) MISSING: Any = discord.utils.MISSING +_log = logging.getLogger(__name__) class CogMeta(type): @@ -360,6 +362,8 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self: if isinstance(app_command, app_commands.Group): for child in app_command.walk_commands(): app_command_refs[child.qualified_name] = child + if hasattr(child, '__commands_is_hybrid_app_command__') and child.qualified_name in lookup: + child.wrapped = lookup[child.qualified_name] # type: ignore if self.__cog_app_commands_group__: children.append(app_command) # type: ignore # Somehow it thinks it can be None here @@ -769,7 +773,7 @@ async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None try: await maybe_coroutine(self.cog_unload) except Exception: - pass + _log.exception('Ignoring exception in cog unload for Cog %r (%r)', cls, self.qualified_name) class GroupCog(Cog): diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 736d0b5af3c0..0c1e0f2d0db4 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union from discord.errors import ClientException, DiscordException +from discord.utils import _human_join if TYPE_CHECKING: from discord.abc import GuildChannel @@ -758,12 +759,7 @@ def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] - - if len(missing) > 2: - fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1]) - else: - fmt = ' or '.join(missing) - + fmt = _human_join(missing) message = f'You are missing at least one of the required roles: {fmt}' super().__init__(message) @@ -788,12 +784,7 @@ def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] - - if len(missing) > 2: - fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1]) - else: - fmt = ' or '.join(missing) - + fmt = _human_join(missing) message = f'Bot is missing at least one of the required roles: {fmt}' super().__init__(message) @@ -832,11 +823,7 @@ def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] - - if len(missing) > 2: - fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1]) - else: - fmt = ' and '.join(missing) + fmt = _human_join(missing, final='and') message = f'You are missing {fmt} permission(s) to run this command.' super().__init__(message, *args) @@ -857,11 +844,7 @@ def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] - - if len(missing) > 2: - fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1]) - else: - fmt = ' and '.join(missing) + fmt = _human_join(missing, final='and') message = f'Bot requires {fmt} permission(s) to run this command.' super().__init__(message, *args) @@ -896,11 +879,7 @@ def _get_name(x): return x.__class__.__name__ to_string = [_get_name(x) for x in converters] - if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) - else: - fmt = ' or '.join(to_string) - + fmt = _human_join(to_string) super().__init__(f'Could not convert "{param.displayed_name or param.name}" into {fmt}.') @@ -933,11 +912,7 @@ def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[Com self.argument: str = argument to_string = [repr(l) for l in literals] - if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) - else: - fmt = ' or '.join(to_string) - + fmt = _human_join(to_string) super().__init__(f'Could not convert "{param.displayed_name or param.name}" into the literal {fmt}.') diff --git a/discord/ext/commands/hybrid.py b/discord/ext/commands/hybrid.py index 3b588ef64d95..a8775474dc9a 100644 --- a/discord/ext/commands/hybrid.py +++ b/discord/ext/commands/hybrid.py @@ -297,6 +297,8 @@ def replace_parameters( class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): + __commands_is_hybrid_app_command__: ClassVar[bool] = True + def __init__( self, wrapped: Union[HybridCommand[CogT, ..., T], HybridGroup[CogT, ..., T]], diff --git a/discord/gateway.py b/discord/gateway.py index 551e36a55f0e..4f98bc2c1bbb 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -34,7 +34,7 @@ import traceback import zlib -from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar +from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple import aiohttp import yarl @@ -59,7 +59,7 @@ from .client import Client from .state import ConnectionState - from .voice_client import VoiceClient + from .voice_state import VoiceConnectionState class ReconnectWebSocket(Exception): @@ -797,7 +797,7 @@ class DiscordVoiceWebSocket: if TYPE_CHECKING: thread_id: int - _connection: VoiceClient + _connection: VoiceConnectionState gateway: str _max_heartbeat_timeout: float @@ -866,16 +866,21 @@ async def identify(self) -> None: await self.send_as_json(payload) @classmethod - async def from_client( - cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None + async def from_connection_state( + cls, + state: VoiceConnectionState, + *, + resume: bool = False, + hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, ) -> Self: """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = 'wss://' + client.endpoint + '/?v=4' + gateway = f'wss://{state.endpoint}/?v=4' + client = state.voice_client http = client._state.http socket = await http.ws_connect(gateway, compress=15) ws = cls(socket, loop=client.loop, hook=hook) ws.gateway = gateway - ws._connection = client + ws._connection = state ws._max_heartbeat_timeout = 60.0 ws.thread_id = threading.get_ident() @@ -951,29 +956,49 @@ async def initial_connection(self, data: Dict[str, Any]) -> None: state.voice_port = data['port'] state.endpoint_ip = data['ip'] + _log.debug('Connecting to voice socket') + await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port)) + + state.ip, state.port = await self.discover_ip() + # there *should* always be at least one supported mode (xsalsa20_poly1305) + modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] + _log.debug('received supported encryption modes: %s', ', '.join(modes)) + + mode = modes[0] + await self.select_protocol(state.ip, state.port, mode) + _log.debug('selected the voice protocol for use (%s)', mode) + + async def discover_ip(self) -> Tuple[str, int]: + state = self._connection packet = bytearray(74) struct.pack_into('>H', packet, 0, 1) # 1 = Send struct.pack_into('>H', packet, 2, 70) # 70 = Length struct.pack_into('>I', packet, 4, state.ssrc) - state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) - recv = await self.loop.sock_recv(state.socket, 74) - _log.debug('received packet in initial_connection: %s', recv) + + _log.debug('Sending ip discovery packet') + await self.loop.sock_sendall(state.socket, packet) + + fut: asyncio.Future[bytes] = self.loop.create_future() + + def get_ip_packet(data: bytes): + if data[1] == 0x02 and len(data) == 74: + self.loop.call_soon_threadsafe(fut.set_result, data) + + fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) + state.add_socket_listener(get_ip_packet) + recv = await fut + + _log.debug('Received ip discovery packet: %s', recv) # the ip is ascii starting at the 8th byte and ending at the first null ip_start = 8 ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode('ascii') + ip = recv[ip_start:ip_end].decode('ascii') - state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('detected ip: %s port: %s', state.ip, state.port) + port = struct.unpack_from('>H', recv, len(recv) - 2)[0] + _log.debug('detected ip: %s port: %s', ip, port) - # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] - _log.debug('received supported encryption modes: %s', ", ".join(modes)) - - mode = modes[0] - await self.select_protocol(state.ip, state.port, mode) - _log.debug('selected the voice protocol for use (%s)', mode) + return ip, port @property def latency(self) -> float: @@ -995,9 +1020,8 @@ async def load_secret_key(self, data: Dict[str, Any]) -> None: self.secret_key = self._connection.secret_key = data['secret_key'] # Send a speak command with the "not speaking" state. - # This also tells Discord our SSRC value, which Discord requires - # before sending any voice data (and is the real reason why we - # call this here). + # This also tells Discord our SSRC value, which Discord requires before + # sending any voice data (and is the real reason why we call this here). await self.speak(SpeakingState.none) async def poll_event(self) -> None: @@ -1006,10 +1030,10 @@ async def poll_event(self) -> None: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug('Received voice %s', msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Received %s', msg) + _log.debug('Received voice %s', msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code: int = 1000) -> None: diff --git a/discord/http.py b/discord/http.py index 84365cec700f..aa6dc9f3ea74 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1918,6 +1918,7 @@ def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Res 'topic', 'privacy_level', 'send_start_notification', + 'guild_scheduled_event_id', ) payload = {k: v for k, v in payload.items() if k in valid_keys} diff --git a/discord/opus.py b/discord/opus.py index ab3916138e7f..971675f73ef6 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -72,6 +72,8 @@ class SignalCtl(TypedDict): _log = logging.getLogger(__name__) +OPUS_SILENCE = b'\xF8\xFF\xFE' + c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) diff --git a/discord/player.py b/discord/player.py index 5cc8b133a6ea..147c0628a533 100644 --- a/discord/player.py +++ b/discord/player.py @@ -40,7 +40,7 @@ from .enums import SpeakingState from .errors import ClientException -from .opus import Encoder as OpusEncoder +from .opus import Encoder as OpusEncoder, OPUS_SILENCE from .oggparse import OggStream from .utils import MISSING @@ -212,7 +212,8 @@ def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Pope return process def _kill_process(self) -> None: - proc = self._process + # this function gets called in __del__ so instance attributes might not even exist + proc = getattr(self, '_process', MISSING) if proc is MISSING: return @@ -702,7 +703,6 @@ def __init__( self._resumed: threading.Event = threading.Event() self._resumed.set() # we are not paused self._current_error: Optional[Exception] = None - self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() if after is not None and not callable(after): @@ -713,36 +713,46 @@ def _do_run(self) -> None: self._start = time.perf_counter() # getattr lookup speed ups - play_audio = self.client.send_audio_packet + client = self.client + play_audio = client.send_audio_packet self._speak(SpeakingState.voice) while not self._end.is_set(): # are we paused? if not self._resumed.is_set(): + self.send_silence() # wait until we aren't self._resumed.wait() continue - # are we disconnected from voice? - if not self._connected.is_set(): - # wait until we are connected - self._connected.wait() - # reset our internal data - self.loops = 0 - self._start = time.perf_counter() - - self.loops += 1 data = self.source.read() if not data: self.stop() break + # are we disconnected from voice? + if not client.is_connected(): + _log.debug('Not connected, waiting for %ss...', client.timeout) + # wait until we are connected, but not forever + connected = client.wait_until_connected(client.timeout) + if self._end.is_set() or not connected: + _log.debug('Aborting playback') + return + _log.debug('Reconnected, resuming playback') + self._speak(SpeakingState.voice) + # reset our internal data + self.loops = 0 + self._start = time.perf_counter() + play_audio(data, encode=not self.source.is_opus()) + self.loops += 1 next_time = self._start + self.DELAY * self.loops delay = max(0, self.DELAY + (next_time - time.perf_counter())) time.sleep(delay) + self.send_silence() + def run(self) -> None: try: self._do_run() @@ -788,7 +798,7 @@ def is_playing(self) -> bool: def is_paused(self) -> bool: return not self._end.is_set() and not self._resumed.is_set() - def _set_source(self, source: AudioSource) -> None: + def set_source(self, source: AudioSource) -> None: with self._lock: self.pause(update_speaking=False) self.source = source @@ -799,3 +809,11 @@ def _speak(self, speaking: SpeakingState) -> None: asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop) except Exception: _log.exception("Speaking call in player failed") + + def send_silence(self, count: int = 5) -> None: + try: + for n in range(count): + self.client.send_audio_packet(OPUS_SILENCE, encode=False) + except Exception: + # Any possible error (probably a socket error) is so inconsequential it's not even worth logging + pass diff --git a/discord/raw_models.py b/discord/raw_models.py index 874edfcc83d6..556df52451ab 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -30,6 +30,7 @@ from .enums import ChannelType, try_enum from .utils import _get_as_snowflake from .app_commands import AppCommandPermissions +from .colour import Colour if TYPE_CHECKING: from .types.gateway import ( @@ -207,9 +208,29 @@ class RawReactionActionEvent(_RawReprMixin): ``REACTION_REMOVE`` for reaction removal. .. versionadded:: 1.3 + burst: :class:`bool` + Whether the reaction was a burst reaction, also known as a "super reaction". + + .. versionadded:: 2.4 + burst_colours: List[:class:`Colour`] + A list of colours used for burst reaction animation. Only available if ``burst`` is ``True`` + and if ``event_type`` is ``REACTION_ADD``. + + .. versionadded:: 2.0 """ - __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', 'event_type', 'member', 'message_author_id') + __slots__ = ( + 'message_id', + 'user_id', + 'channel_id', + 'guild_id', + 'emoji', + 'event_type', + 'member', + 'message_author_id', + 'burst', + 'burst_colours', + ) def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: ReactionActionType) -> None: self.message_id: int = int(data['message_id']) @@ -219,12 +240,22 @@ def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: R self.event_type: ReactionActionType = event_type self.member: Optional[Member] = None self.message_author_id: Optional[int] = _get_as_snowflake(data, 'message_author_id') + self.burst: bool = data.get('burst', False) + self.burst_colours: List[Colour] = [Colour.from_str(c) for c in data.get('burst_colours', [])] try: self.guild_id: Optional[int] = int(data['guild_id']) except KeyError: self.guild_id: Optional[int] = None + @property + def burst_colors(self) -> List[Colour]: + """An alias of :attr:`burst_colours`. + + .. versionadded:: 2.4 + """ + return self.burst_colours + class RawReactionClearEvent(_RawReprMixin): """Represents the payload for a :func:`on_raw_reaction_clear` event. diff --git a/discord/reaction.py b/discord/reaction.py index c0cbb8ee5262..cd0fbef10268 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -74,20 +74,40 @@ class Reaction: emoji: Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`] The reaction emoji. May be a custom emoji, or a unicode emoji. count: :class:`int` - Number of times this reaction was made + Number of times this reaction was made. This is a sum of :attr:`normal_count` and :attr:`burst_count`. me: :class:`bool` If the user sent this reaction. message: :class:`Message` Message this reaction is for. + me_burst: :class:`bool` + If the user sent this super reaction. + + .. versionadded:: 2.4 + normal_count: :class:`int` + The number of times this reaction was made using normal reactions. + This is not available in the gateway events such as :func:`on_reaction_add` + or :func:`on_reaction_remove`. + + .. versionadded:: 2.4 + burst_count: :class:`int` + The number of times this reaction was made using super reactions. + This is not available in the gateway events such as :func:`on_reaction_add` + or :func:`on_reaction_remove`. + + .. versionadded:: 2.4 """ - __slots__ = ('message', 'count', 'emoji', 'me') + __slots__ = ('message', 'count', 'emoji', 'me', 'me_burst', 'normal_count', 'burst_count') def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): self.message: Message = message self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) self.count: int = data.get('count', 1) self.me: bool = data['me'] + details = data.get('count_details', {}) + self.normal_count: int = details.get('normal', 0) + self.burst_count: int = details.get('burst', 0) + self.me_burst: bool = data.get('me_burst', False) def is_custom_emoji(self) -> bool: """:class:`bool`: If this is a custom emoji.""" diff --git a/discord/types/components.py b/discord/types/components.py index f1790ff35985..218f5cef07bf 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -33,6 +33,7 @@ ComponentType = Literal[1, 2, 3, 4] ButtonStyle = Literal[1, 2, 3, 4, 5] TextStyle = Literal[1, 2] +DefaultValueType = Literal['user', 'role', 'channel'] class ActionRow(TypedDict): @@ -66,6 +67,11 @@ class SelectComponent(TypedDict): disabled: NotRequired[bool] +class SelectDefaultValues(TypedDict): + id: int + type: DefaultValueType + + class StringSelectComponent(SelectComponent): type: Literal[3] options: NotRequired[List[SelectOption]] @@ -73,19 +79,23 @@ class StringSelectComponent(SelectComponent): class UserSelectComponent(SelectComponent): type: Literal[5] + default_values: NotRequired[List[SelectDefaultValues]] class RoleSelectComponent(SelectComponent): type: Literal[6] + default_values: NotRequired[List[SelectDefaultValues]] class MentionableSelectComponent(SelectComponent): type: Literal[7] + default_values: NotRequired[List[SelectDefaultValues]] class ChannelSelectComponent(SelectComponent): type: Literal[8] channel_types: NotRequired[List[ChannelType]] + default_values: NotRequired[List[SelectDefaultValues]] class TextInput(TypedDict): @@ -104,6 +114,7 @@ class SelectMenu(SelectComponent): type: Literal[3, 5, 6, 7, 8] options: NotRequired[List[SelectOption]] channel_types: NotRequired[List[ChannelType]] + default_values: NotRequired[List[SelectDefaultValues]] ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput] diff --git a/discord/types/gateway.py b/discord/types/gateway.py index 3175fd9f0744..0c50671e1094 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -101,6 +101,8 @@ class MessageReactionAddEvent(TypedDict): member: NotRequired[MemberWithUser] guild_id: NotRequired[Snowflake] message_author_id: NotRequired[Snowflake] + burst: bool + burst_colors: NotRequired[List[str]] class MessageReactionRemoveEvent(TypedDict): @@ -109,6 +111,7 @@ class MessageReactionRemoveEvent(TypedDict): message_id: Snowflake emoji: PartialEmoji guild_id: NotRequired[Snowflake] + burst: bool class MessageReactionRemoveAllEvent(TypedDict): diff --git a/discord/types/message.py b/discord/types/message.py index 48b301ca2d8e..e1046c82ab32 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -50,10 +50,18 @@ class ChannelMention(TypedDict): name: str +class ReactionCountDetails(TypedDict): + burst: int + normal: int + + class Reaction(TypedDict): count: int me: bool emoji: PartialEmoji + me_burst: bool + count_details: ReactionCountDetails + burst_colors: List[str] class Attachment(TypedDict): diff --git a/discord/ui/dynamic.py b/discord/ui/dynamic.py index 6c20bd7b4463..799fe48fc468 100644 --- a/discord/ui/dynamic.py +++ b/discord/ui/dynamic.py @@ -110,7 +110,7 @@ def __init__( raise TypeError('item must be dispatchable, e.g. not a URL button') if not self.template.match(self.custom_id): - raise ValueError(f'item custom_id must match the template {self.template.pattern!r}') + raise ValueError(f'item custom_id {self.custom_id!r} must match the template {self.template.pattern!r}') @property def template(self) -> re.Pattern[str]: diff --git a/discord/ui/select.py b/discord/ui/select.py index 222596075c97..c20e4be6e849 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -22,21 +22,42 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Callable, Union, Dict, overload +from typing import ( + Any, + List, + Literal, + Optional, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Callable, + Union, + Dict, + overload, + Sequence, +) from contextvars import ContextVar import inspect import os from .item import Item, ItemCallbackType -from ..enums import ChannelType, ComponentType +from ..enums import ChannelType, ComponentType, SelectDefaultValueType from ..partial_emoji import PartialEmoji from ..emoji import Emoji -from ..utils import MISSING +from ..utils import MISSING, _human_join from ..components import ( SelectOption, SelectMenu, + SelectDefaultValue, ) from ..app_commands.namespace import Namespace +from ..member import Member +from ..object import Object +from ..role import Role +from ..user import User, ClientUser +from ..abc import GuildChannel +from ..threads import Thread __all__ = ( 'Select', @@ -48,15 +69,12 @@ ) if TYPE_CHECKING: - from typing_extensions import TypeAlias, Self + from typing_extensions import TypeAlias, Self, TypeGuard from .view import View from ..types.components import SelectMenu as SelectMenuPayload from ..types.interactions import SelectMessageComponentInteractionData from ..app_commands import AppCommandChannel, AppCommandThread - from ..member import Member - from ..role import Role - from ..user import User from ..interactions import Interaction ValidSelectType: TypeAlias = Literal[ @@ -69,6 +87,18 @@ PossibleValue: TypeAlias = Union[ str, User, Member, Role, AppCommandChannel, AppCommandThread, Union[Role, Member], Union[Role, User] ] + ValidDefaultValues: TypeAlias = Union[ + SelectDefaultValue, + Object, + Role, + Member, + ClientUser, + User, + GuildChannel, + AppCommandChannel, + AppCommandThread, + Thread, + ] V = TypeVar('V', bound='View', covariant=True) BaseSelectT = TypeVar('BaseSelectT', bound='BaseSelect[Any]') @@ -78,10 +108,81 @@ ChannelSelectT = TypeVar('ChannelSelectT', bound='ChannelSelect[Any]') MentionableSelectT = TypeVar('MentionableSelectT', bound='MentionableSelect[Any]') SelectCallbackDecorator: TypeAlias = Callable[[ItemCallbackType[V, BaseSelectT]], BaseSelectT] +DefaultSelectComponentTypes = Literal[ + ComponentType.user_select, + ComponentType.role_select, + ComponentType.channel_select, + ComponentType.mentionable_select, +] selected_values: ContextVar[Dict[str, List[PossibleValue]]] = ContextVar('selected_values') +def _is_valid_object_type( + obj: Any, + component_type: DefaultSelectComponentTypes, + type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]], +) -> TypeGuard[Type[ValidDefaultValues]]: + return issubclass(obj, type_to_supported_classes[component_type]) + + +def _handle_select_defaults( + defaults: Sequence[ValidDefaultValues], component_type: DefaultSelectComponentTypes +) -> List[SelectDefaultValue]: + if not defaults or defaults is MISSING: + return [] + + from ..app_commands import AppCommandChannel, AppCommandThread + + cls_to_type: Dict[Type[ValidDefaultValues], SelectDefaultValueType] = { + User: SelectDefaultValueType.user, + Member: SelectDefaultValueType.user, + ClientUser: SelectDefaultValueType.user, + Role: SelectDefaultValueType.role, + GuildChannel: SelectDefaultValueType.channel, + AppCommandChannel: SelectDefaultValueType.channel, + AppCommandThread: SelectDefaultValueType.channel, + Thread: SelectDefaultValueType.channel, + } + type_to_supported_classes: Dict[ValidSelectType, Tuple[Type[ValidDefaultValues], ...]] = { + ComponentType.user_select: (User, ClientUser, Member, Object), + ComponentType.role_select: (Role, Object), + ComponentType.channel_select: (GuildChannel, AppCommandChannel, AppCommandThread, Thread, Object), + ComponentType.mentionable_select: (User, ClientUser, Member, Role, Object), + } + + values: List[SelectDefaultValue] = [] + for obj in defaults: + if isinstance(obj, SelectDefaultValue): + values.append(obj) + continue + + object_type = obj.__class__ if not isinstance(obj, Object) else obj.type + + if not _is_valid_object_type(object_type, component_type, type_to_supported_classes): + supported_classes = _human_join([c.__name__ for c in type_to_supported_classes[component_type]]) + raise TypeError(f'Expected an instance of {supported_classes} not {object_type.__name__}') + + if object_type is Object: + if component_type is ComponentType.mentionable_select: + raise ValueError( + 'Object must have a type specified for the chosen select type. Please pass one using the `type`` kwarg.' + ) + elif component_type is ComponentType.user_select: + object_type = User + elif component_type is ComponentType.role_select: + object_type = Role + elif component_type is ComponentType.channel_select: + object_type = GuildChannel + + if issubclass(object_type, GuildChannel): + object_type = GuildChannel + + values.append(SelectDefaultValue(id=obj.id, type=cls_to_type[object_type])) + + return values + + class BaseSelect(Item[V]): """The base Select model that all other Select models inherit from. @@ -115,6 +216,13 @@ class BaseSelect(Item[V]): 'max_values', 'disabled', ) + __component_attributes__: Tuple[str, ...] = ( + 'custom_id', + 'placeholder', + 'min_values', + 'max_values', + 'disabled', + ) def __init__( self, @@ -128,6 +236,7 @@ def __init__( disabled: bool = False, options: List[SelectOption] = MISSING, channel_types: List[ChannelType] = MISSING, + default_values: Sequence[SelectDefaultValue] = MISSING, ) -> None: super().__init__() self._provided_custom_id = custom_id is not MISSING @@ -144,6 +253,7 @@ def __init__( disabled=disabled, channel_types=[] if channel_types is MISSING else channel_types, options=[] if options is MISSING else options, + default_values=[] if default_values is MISSING else default_values, ) self.row = row @@ -233,10 +343,16 @@ def is_dispatchable(self) -> bool: @classmethod def from_component(cls, component: SelectMenu) -> Self: - return cls( - **{k: getattr(component, k) for k in cls.__item_repr_attributes__}, - row=None, - ) + type_to_cls: Dict[ComponentType, Type[BaseSelect[Any]]] = { + ComponentType.string_select: Select, + ComponentType.user_select: UserSelect, + ComponentType.role_select: RoleSelect, + ComponentType.channel_select: ChannelSelect, + ComponentType.mentionable_select: MentionableSelect, + } + constructor = type_to_cls.get(component.type, Select) + kwrgs = {key: getattr(component, key) for key in constructor.__component_attributes__} + return constructor(**kwrgs) class Select(BaseSelect[V]): @@ -270,7 +386,7 @@ class Select(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ - __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('options',) + __component_attributes__ = BaseSelect.__component_attributes__ + ('options',) def __init__( self, @@ -409,6 +525,10 @@ class UserSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -417,6 +537,8 @@ class UserSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',) + def __init__( self, *, @@ -426,6 +548,7 @@ def __init__( max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -435,6 +558,7 @@ def __init__( max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -455,6 +579,18 @@ def values(self) -> List[Union[Member, User]]: """ return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class RoleSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current roles of the guild. @@ -478,6 +614,10 @@ class RoleSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -486,6 +626,8 @@ class RoleSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',) + def __init__( self, *, @@ -495,6 +637,7 @@ def __init__( max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -504,6 +647,7 @@ def __init__( max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -516,6 +660,18 @@ def values(self) -> List[Role]: """List[:class:`discord.Role`]: A list of roles that have been selected by the user.""" return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class MentionableSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current members and roles in the guild. @@ -542,6 +698,11 @@ class MentionableSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the users/roles that should be selected by default. + if :class:`.Object` is passed, then the type must be specified in the constructor. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -550,6 +711,8 @@ class MentionableSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ + __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',) + def __init__( self, *, @@ -559,6 +722,7 @@ def __init__( max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -568,6 +732,7 @@ def __init__( max_values=max_values, disabled=disabled, row=row, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -588,6 +753,18 @@ def values(self) -> List[Union[Member, User, Role]]: """ return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + class ChannelSelect(BaseSelect[V]): """Represents a UI select menu with a list of predefined options with the current channels in the guild. @@ -613,6 +790,10 @@ class ChannelSelect(BaseSelect[V]): Defaults to 1 and must be between 1 and 25. disabled: :class:`bool` Whether the select is disabled or not. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the channels that should be selected by default. + + .. versionadded:: 2.4 row: Optional[:class:`int`] The relative row this select menu belongs to. A Discord component can only have 5 rows. By default, items are arranged automatically into those 5 rows. If you'd @@ -621,7 +802,10 @@ class ChannelSelect(BaseSelect[V]): ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ - __item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('channel_types',) + __component_attributes__ = BaseSelect.__component_attributes__ + ( + 'channel_types', + 'default_values', + ) def __init__( self, @@ -633,6 +817,7 @@ def __init__( max_values: int = 1, disabled: bool = False, row: Optional[int] = None, + default_values: Sequence[ValidDefaultValues] = MISSING, ) -> None: super().__init__( self.type, @@ -643,6 +828,7 @@ def __init__( disabled=disabled, row=row, channel_types=channel_types, + default_values=_handle_select_defaults(default_values, self.type), ) @property @@ -669,6 +855,18 @@ def values(self) -> List[Union[AppCommandChannel, AppCommandThread]]: """List[Union[:class:`~discord.app_commands.AppCommandChannel`, :class:`~discord.app_commands.AppCommandThread`]]: A list of channels selected by the user.""" return super().values # type: ignore + @property + def default_values(self) -> List[SelectDefaultValue]: + """List[:class:`discord.SelectDefaultValue`]: A list of default values for the select menu. + + .. versionadded:: 2.4 + """ + return self._underlying.default_values + + @default_values.setter + def default_values(self, value: Sequence[ValidDefaultValues]) -> None: + self._underlying.default_values = _handle_select_defaults(value, self.type) + @overload def select( @@ -697,6 +895,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, UserSelectT]: ... @@ -713,6 +912,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, RoleSelectT]: ... @@ -729,6 +929,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, ChannelSelectT]: ... @@ -745,6 +946,7 @@ def select( min_values: int = ..., max_values: int = ..., disabled: bool = ..., + default_values: Sequence[ValidDefaultValues] = ..., row: Optional[int] = ..., ) -> SelectCallbackDecorator[V, MentionableSelectT]: ... @@ -760,6 +962,7 @@ def select( min_values: int = 1, max_values: int = 1, disabled: bool = False, + default_values: Sequence[ValidDefaultValues] = MISSING, row: Optional[int] = None, ) -> SelectCallbackDecorator[V, BaseSelectT]: """A decorator that attaches a select menu to a component. @@ -831,6 +1034,11 @@ async def select_channels(self, interaction: discord.Interaction, select: Channe with :class:`ChannelSelect` instances. disabled: :class:`bool` Whether the select is disabled or not. Defaults to ``False``. + default_values: Sequence[:class:`~discord.abc.Snowflake`] + A list of objects representing the default values for the select menu. This cannot be used with regular :class:`Select` instances. + If ``cls`` is :class:`MentionableSelect` and :class:`.Object` is passed, then the type must be specified in the constructor. + + .. versionadded:: 2.4 """ def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, BaseSelectT]: @@ -838,8 +1046,8 @@ def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, Bas raise TypeError('select function must be a coroutine function') callback_cls = getattr(cls, '__origin__', cls) if not issubclass(callback_cls, BaseSelect): - supported_classes = ", ".join(["ChannelSelect", "MentionableSelect", "RoleSelect", "Select", "UserSelect"]) - raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls!r}.') + supported_classes = ', '.join(['ChannelSelect', 'MentionableSelect', 'RoleSelect', 'Select', 'UserSelect']) + raise TypeError(f'cls must be one of {supported_classes} or a subclass of one of them, not {cls.__name__}.') func.__discord_ui_model_type__ = callback_cls func.__discord_ui_model_kwargs__ = { @@ -854,6 +1062,24 @@ def decorator(func: ItemCallbackType[V, BaseSelectT]) -> ItemCallbackType[V, Bas func.__discord_ui_model_kwargs__['options'] = options if issubclass(callback_cls, ChannelSelect): func.__discord_ui_model_kwargs__['channel_types'] = channel_types + if not issubclass(callback_cls, Select): + cls_to_type: Dict[ + Type[BaseSelect], + Literal[ + ComponentType.user_select, + ComponentType.channel_select, + ComponentType.role_select, + ComponentType.mentionable_select, + ], + ] = { + UserSelect: ComponentType.user_select, + RoleSelect: ComponentType.role_select, + MentionableSelect: ComponentType.mentionable_select, + ChannelSelect: ComponentType.channel_select, + } + func.__discord_ui_model_kwargs__['default_values'] = ( + MISSING if default_values is MISSING else _handle_select_defaults(default_values, cls_to_type[callback_cls]) + ) return func diff --git a/discord/ui/view.py b/discord/ui/view.py index 883b87c84435..7c3cc3b8cb4c 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -78,9 +78,10 @@ def _component_to_item(component: Component) -> Item: return Button.from_component(component) if isinstance(component, SelectComponent): - from .select import Select + from .select import BaseSelect + + return BaseSelect.from_component(component) - return Select.from_component(component) return Item.from_component(component) diff --git a/discord/utils.py b/discord/utils.py index a3f830019114..33a4020a2504 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1380,3 +1380,17 @@ def _shorten( def _to_kebab_case(text: str) -> str: return CAMEL_CASE_REGEX.sub('-', text).lower() + + +def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'or') -> str: + size = len(seq) + if size == 0: + return '' + + if size == 1: + return seq[0] + + if size == 2: + return f'{seq[0]} {final} {seq[1]}' + + return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' diff --git a/discord/voice_client.py b/discord/voice_client.py index 8309218a1a22..3c5699acddc6 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -20,40 +20,24 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -Some documentation to refer to: - -- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. -- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. -- We pull the session_id from VOICE_STATE_UPDATE. -- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. -- Then we initiate the voice web socket (vWS) pointing to the endpoint. -- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. -- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. -- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. -- Then we send our IP and port via vWS with opcode 1. -- When that's all done, we receive opcode 4 from the vWS. -- Finally we can transmit data to endpoint:port. """ from __future__ import annotations import asyncio -import socket import logging import struct -import threading from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union -from . import opus, utils -from .backoff import ExponentialBackoff +from . import opus from .gateway import * -from .errors import ClientException, ConnectionClosed +from .errors import ClientException from .player import AudioPlayer, AudioSource from .utils import MISSING +from .voice_state import VoiceConnectionState if TYPE_CHECKING: + from .gateway import DiscordVoiceWebSocket from .client import Client from .guild import Guild from .state import ConnectionState @@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol): """ channel: VocalGuildChannel - endpoint_ip: str - voice_port: int - ip: str - port: int - secret_key: List[int] - ssrc: int def __init__(self, client: Client, channel: abc.Connectable) -> None: if not has_nacl: @@ -239,29 +217,18 @@ def __init__(self, client: Client, channel: abc.Connectable) -> None: super().__init__(client, channel) state = client._connection - self.token: str = MISSING self.server_id: int = MISSING self.socket = MISSING self.loop: asyncio.AbstractEventLoop = state.loop self._state: ConnectionState = state - # this will be used in the AudioPlayer thread - self._connected: threading.Event = threading.Event() - self._handshaking: bool = False - self._potentially_reconnecting: bool = False - self._voice_state_complete: asyncio.Event = asyncio.Event() - self._voice_server_complete: asyncio.Event = asyncio.Event() - - self.mode: str = MISSING - self._connections: int = 0 self.sequence: int = 0 self.timestamp: int = 0 - self.timeout: float = 0 - self._runner: asyncio.Task = MISSING self._player: Optional[AudioPlayer] = None self.encoder: Encoder = MISSING self._lite_nonce: int = 0 - self.ws: DiscordVoiceWebSocket = MISSING + + self._connection: VoiceConnectionState = self.create_connection_state() warn_nacl: bool = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( @@ -280,6 +247,38 @@ def user(self) -> ClientUser: """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" return self._state.user # type: ignore + @property + def session_id(self) -> Optional[str]: + return self._connection.session_id + + @property + def token(self) -> Optional[str]: + return self._connection.token + + @property + def endpoint(self) -> Optional[str]: + return self._connection.endpoint + + @property + def ssrc(self) -> int: + return self._connection.ssrc + + @property + def mode(self) -> SupportedModes: + return self._connection.mode + + @property + def secret_key(self) -> List[int]: + return self._connection.secret_key + + @property + def ws(self) -> DiscordVoiceWebSocket: + return self._connection.ws + + @property + def timeout(self) -> float: + return self._connection.timeout + def checked_add(self, attr: str, value: int, limit: int) -> None: val = getattr(self, attr) if val + value > limit: @@ -289,149 +288,23 @@ def checked_add(self, attr: str, value: int, limit: int) -> None: # connection related + def create_connection_state(self) -> VoiceConnectionState: + return VoiceConnectionState(self) + async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id: str = data['session_id'] - channel_id = data['channel_id'] - - if not self._handshaking or self._potentially_reconnecting: - # If we're done handshaking then we just need to update ourselves - # If we're potentially reconnecting due to a 4014, then we need to differentiate - # a channel move and an actual force disconnect - if channel_id is None: - # We're being disconnected so cleanup - await self.disconnect() - else: - self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore - else: - self._voice_state_complete.set() + await self._connection.voice_state_update(data) async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - if self._voice_server_complete.is_set(): - _log.warning('Ignoring extraneous voice server update.') - return - - self.token = data['token'] - self.server_id = int(data['guild_id']) - endpoint = data.get('endpoint') - - if endpoint is None or self.token is None: - _log.warning( - 'Awaiting endpoint... This requires waiting. ' - 'If timeout occurred considering raising the timeout and reconnecting.' - ) - return - - self.endpoint, _, _ = endpoint.rpartition(':') - if self.endpoint.startswith('wss://'): - # Just in case, strip it off since we're going to add it later - self.endpoint: str = self.endpoint[6:] - - # This gets set later - self.endpoint_ip = MISSING - - self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setblocking(False) - - if not self._handshaking: - # If we're not handshaking then we need to terminate our previous connection in the websocket - await self.ws.close(4000) - return - - self._voice_server_complete.set() - - async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None: - await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute) - - async def voice_disconnect(self) -> None: - _log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) - await self.channel.guild.change_voice_state(channel=None) - - def prepare_handshake(self) -> None: - self._voice_state_complete.clear() - self._voice_server_complete.clear() - self._handshaking = True - _log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) - self._connections += 1 - - def finish_handshake(self) -> None: - _log.info('Voice handshake complete. Endpoint found %s', self.endpoint) - self._handshaking = False - self._voice_server_complete.clear() - self._voice_state_complete.clear() - - async def connect_websocket(self) -> DiscordVoiceWebSocket: - ws = await DiscordVoiceWebSocket.from_client(self) - self._connected.clear() - while ws.secret_key is None: - await ws.poll_event() - self._connected.set() - return ws + await self._connection.voice_server_update(data) async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None: - _log.info('Connecting to voice...') - self.timeout = timeout - - for i in range(5): - self.prepare_handshake() - - # This has to be created before we start the flow. - futures = [ - self._voice_state_complete.wait(), - self._voice_server_complete.wait(), - ] - - # Start the connection flow - await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute) - - try: - await utils.sane_wait_for(futures, timeout=timeout) - except asyncio.TimeoutError: - await self.disconnect(force=True) - raise - - self.finish_handshake() - - try: - self.ws = await self.connect_websocket() - break - except (ConnectionClosed, asyncio.TimeoutError): - if reconnect: - _log.exception('Failed to connect to voice... Retrying...') - await asyncio.sleep(1 + i * 2.0) - await self.voice_disconnect() - continue - else: - raise - - if self._runner is MISSING: - self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect)) - - async def potential_reconnect(self) -> bool: - # Attempt to stop the player thread from playing early - self._connected.clear() - self.prepare_handshake() - self._potentially_reconnecting = True - try: - # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) - except asyncio.TimeoutError: - self._potentially_reconnecting = False - await self.disconnect(force=True) - return False - - self.finish_handshake() - self._potentially_reconnecting = False - - if self.ws: - _log.debug("Closing existing voice websocket") - await self.ws.close() + await self._connection.connect( + reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False + ) - try: - self.ws = await self.connect_websocket() - except (ConnectionClosed, asyncio.TimeoutError): - return False - else: - return True + def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool: + self._connection.wait(timeout) + return self._connection.is_connected() @property def latency(self) -> float: @@ -442,7 +315,7 @@ def latency(self) -> float: .. versionadded:: 1.4 """ - ws = self.ws + ws = self._connection.ws return float("inf") if not ws else ws.latency @property @@ -451,72 +324,19 @@ def average_latency(self) -> float: .. versionadded:: 1.4 """ - ws = self.ws + ws = self._connection.ws return float("inf") if not ws else ws.average_latency - async def poll_voice_ws(self, reconnect: bool) -> None: - backoff = ExponentialBackoff() - while True: - try: - await self.ws.poll_event() - except (ConnectionClosed, asyncio.TimeoutError) as exc: - if isinstance(exc, ConnectionClosed): - # The following close codes are undocumented so I will document them here. - # 1000 - normal closure (obviously) - # 4014 - voice channel has been deleted. - # 4015 - voice server has crashed - if exc.code in (1000, 4015): - _log.info('Disconnecting from voice normally, close code %d.', exc.code) - await self.disconnect() - break - if exc.code == 4014: - _log.info('Disconnected from voice by force... potentially reconnecting.') - successful = await self.potential_reconnect() - if not successful: - _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') - await self.disconnect() - break - else: - continue - - if not reconnect: - await self.disconnect() - raise - - retry = backoff.delay() - _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) - self._connected.clear() - await asyncio.sleep(retry) - await self.voice_disconnect() - try: - await self.connect(reconnect=True, timeout=self.timeout) - except asyncio.TimeoutError: - # at this point we've retried 5 times... let's continue the loop. - _log.warning('Could not connect to voice... Retrying...') - continue - async def disconnect(self, *, force: bool = False) -> None: """|coro| Disconnects this voice client from voice. """ - if not force and not self.is_connected(): - return - self.stop() - self._connected.clear() - - try: - if self.ws: - await self.ws.close() - - await self.voice_disconnect() - finally: - self.cleanup() - if self.socket: - self.socket.close() + await self._connection.disconnect(force=force) + self.cleanup() - async def move_to(self, channel: Optional[abc.Snowflake]) -> None: + async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None: """|coro| Moves you to a different voice channel. @@ -525,12 +345,21 @@ async def move_to(self, channel: Optional[abc.Snowflake]) -> None: ----------- channel: Optional[:class:`abc.Snowflake`] The channel to move to. Must be a voice channel. + timeout: Optional[:class:`float`] + How long to wait for the move to complete. + + .. versionadded:: 2.4 + + Raises + ------- + asyncio.TimeoutError + The move did not complete in time, but may still be ongoing. """ - await self.channel.guild.change_voice_state(channel=channel) + await self._connection.move_to(channel, timeout) def is_connected(self) -> bool: """Indicates if the voice client is connected to voice.""" - return self._connected.is_set() + return self._connection.is_connected() # audio related @@ -703,7 +532,7 @@ def source(self, value: AudioSource) -> None: if self._player is None: raise ValueError('Not playing anything.') - self._player._set_source(value) + self._player.set_source(value) def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: """Sends an audio packet composed of the data. @@ -732,8 +561,8 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: encoded_data = data packet = self._get_voice_packet(encoded_data) try: - self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) - except BlockingIOError: - _log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) + self._connection.send_packet(packet) + except OSError: + _log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) diff --git a/discord/voice_state.py b/discord/voice_state.py new file mode 100644 index 000000000000..1dda1e5d9797 --- /dev/null +++ b/discord/voice_state.py @@ -0,0 +1,612 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + + +Some documentation to refer to: + +- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. +- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. +- We pull the session_id from VOICE_STATE_UPDATE. +- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. +- Then we initiate the voice web socket (vWS) pointing to the endpoint. +- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. +- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. +- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. +- Then we send our IP and port via vWS with opcode 1. +- When that's all done, we receive opcode 4 from the vWS. +- Finally we can transmit data to endpoint:port. +""" + +from __future__ import annotations + +import select +import socket +import asyncio +import logging +import threading + +import async_timeout + +from typing import TYPE_CHECKING, Optional, Dict, List, Callable, Coroutine, Any, Tuple + +from .enums import Enum +from .utils import MISSING, sane_wait_for +from .errors import ConnectionClosed +from .backoff import ExponentialBackoff +from .gateway import DiscordVoiceWebSocket + +if TYPE_CHECKING: + from . import abc + from .guild import Guild + from .user import ClientUser + from .member import VoiceState + from .voice_client import VoiceClient + + from .types.voice import ( + GuildVoiceState as GuildVoiceStatePayload, + VoiceServerUpdate as VoiceServerUpdatePayload, + SupportedModes, + ) + + WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]] + SocketReaderCallback = Callable[[bytes], Any] + + +__all__ = ('VoiceConnectionState',) + +_log = logging.getLogger(__name__) + + +class SocketReader(threading.Thread): + def __init__(self, state: VoiceConnectionState) -> None: + super().__init__(daemon=True, name=f'voice-socket-reader:{id(self):#x}') + self.state: VoiceConnectionState = state + self._callbacks: List[SocketReaderCallback] = [] + self._running = threading.Event() + self._end = threading.Event() + # If we have paused reading due to having no callbacks + self._idle_paused: bool = True + + def register(self, callback: SocketReaderCallback) -> None: + self._callbacks.append(callback) + if self._idle_paused: + self._idle_paused = False + self._running.set() + + def unregister(self, callback: SocketReaderCallback) -> None: + try: + self._callbacks.remove(callback) + except ValueError: + pass + else: + if not self._callbacks and self._running.is_set(): + # If running is not set, we are either explicitly paused and + # should be explicitly resumed, or we are already idle paused + self._idle_paused = True + self._running.clear() + + def pause(self) -> None: + self._idle_paused = False + self._running.clear() + + def resume(self, *, force: bool = False) -> None: + if self._running.is_set(): + return + # Don't resume if there are no callbacks registered + if not force and not self._callbacks: + # We tried to resume but there was nothing to do, so resume when ready + self._idle_paused = True + return + self._idle_paused = False + self._running.set() + + def stop(self) -> None: + self._end.set() + self._running.set() + + def run(self) -> None: + self._end.clear() + self._running.set() + try: + self._do_run() + except Exception: + _log.exception('Error in %s', self) + finally: + self.stop() + self._running.clear() + self._callbacks.clear() + + def _do_run(self) -> None: + while not self._end.is_set(): + if not self._running.is_set(): + self._running.wait() + continue + + # Since this socket is a non blocking socket, select has to be used to wait on it for reading. + try: + readable, _, _ = select.select([self.state.socket], [], [], 30) + except (ValueError, TypeError): + # The socket is either closed or doesn't exist at the moment + continue + + if not readable: + continue + + try: + data = self.state.socket.recv(2048) + except OSError: + _log.debug('Error reading from socket in %s, this should be safe to ignore', self, exc_info=True) + else: + for cb in self._callbacks: + try: + cb(data) + except Exception: + _log.exception('Error calling %s in %s', cb, self) + + +class ConnectionFlowState(Enum): + """Enum representing voice connection flow state.""" + + # fmt: off + disconnected = 0 + set_guild_voice_state = 1 + got_voice_state_update = 2 + got_voice_server_update = 3 + got_both_voice_updates = 4 + websocket_connected = 5 + got_websocket_ready = 6 + got_ip_discovery = 7 + connected = 8 + # fmt: on + + +class VoiceConnectionState: + """Represents the internal state of a voice connection.""" + + def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] = None) -> None: + self.voice_client = voice_client + self.hook = hook + + self.timeout: float = 30.0 + self.reconnect: bool = True + self.self_deaf: bool = False + self.self_mute: bool = False + self.token: Optional[str] = None + self.session_id: Optional[str] = None + self.endpoint: Optional[str] = None + self.endpoint_ip: Optional[str] = None + self.server_id: Optional[int] = None + self.ip: Optional[str] = None + self.port: Optional[int] = None + self.voice_port: Optional[int] = None + self.secret_key: List[int] = MISSING + self.ssrc: int = MISSING + self.mode: SupportedModes = MISSING + self.socket: socket.socket = MISSING + self.ws: DiscordVoiceWebSocket = MISSING + + self._state: ConnectionFlowState = ConnectionFlowState.disconnected + self._expecting_disconnect: bool = False + self._connected = threading.Event() + self._state_event = asyncio.Event() + self._runner: Optional[asyncio.Task] = None + self._connector: Optional[asyncio.Task] = None + self._socket_reader = SocketReader(self) + self._socket_reader.start() + + @property + def state(self) -> ConnectionFlowState: + return self._state + + @state.setter + def state(self, state: ConnectionFlowState) -> None: + if state is not self._state: + _log.debug('Connection state changed to %s', state.name) + self._state = state + self._state_event.set() + self._state_event.clear() + + if state is ConnectionFlowState.connected: + self._connected.set() + else: + self._connected.clear() + + @property + def guild(self) -> Guild: + return self.voice_client.guild + + @property + def user(self) -> ClientUser: + return self.voice_client.user + + @property + def supported_modes(self) -> Tuple[SupportedModes, ...]: + return self.voice_client.supported_modes + + @property + def self_voice_state(self) -> Optional[VoiceState]: + return self.guild.me.voice + + async def voice_state_update(self, data: GuildVoiceStatePayload) -> None: + channel_id = data['channel_id'] + + if channel_id is None: + # If we know we're going to get a voice_state_update where we have no channel due to + # being in the reconnect flow, we ignore it. Otherwise, it probably wasn't from us. + if self._expecting_disconnect: + self._expecting_disconnect = False + else: + _log.debug('We were externally disconnected from voice.') + await self.disconnect() + + return + + self.session_id = data['session_id'] + + # we got the event while connecting + if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_server_update): + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_state_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + return + + if self.state is ConnectionFlowState.connected: + self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore + + elif self.state is not ConnectionFlowState.disconnected: + if channel_id != self.voice_client.channel.id: + # For some unfortunate reason we were moved during the connection flow + _log.info('Handling channel move while connecting...') + + self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore + + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + else: + _log.debug('Ignoring unexpected voice_state_update event') + + async def voice_server_update(self, data: VoiceServerUpdatePayload) -> None: + self.token = data['token'] + self.server_id = int(data['guild_id']) + endpoint = data.get('endpoint') + + if self.token is None or endpoint is None: + _log.warning( + 'Awaiting endpoint... This requires waiting. ' + 'If timeout occurred considering raising the timeout and reconnecting.' + ) + return + + self.endpoint, _, _ = endpoint.rpartition(':') + if self.endpoint.startswith('wss://'): + # Just in case, strip it off since we're going to add it later + self.endpoint = self.endpoint[6:] + + # we got the event while connecting + if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update): + # This gets set after READY is received + self.endpoint_ip = MISSING + self._create_socket() + + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_server_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + + elif self.state is ConnectionFlowState.connected: + _log.debug('Voice server update, closing old voice websocket') + await self.ws.close(4014) + self.state = ConnectionFlowState.got_voice_server_update + + elif self.state is not ConnectionFlowState.disconnected: + _log.debug('Unexpected server update event, attempting to handle') + + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + self._create_socket() + + async def connect( + self, *, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool, wait: bool = True + ) -> None: + if self._connector: + self._connector.cancel() + self._connector = None + + if self._runner: + self._runner.cancel() + self._runner = None + + self.timeout = timeout + self.reconnect = reconnect + self._connector = self.voice_client.loop.create_task( + self._wrap_connect(reconnect, timeout, self_deaf, self_mute, resume), name='Voice connector' + ) + if wait: + await self._connector + + async def _wrap_connect(self, *args: Any) -> None: + try: + await self._connect(*args) + except asyncio.CancelledError: + _log.debug('Cancelling voice connection') + await self.soft_disconnect() + raise + except asyncio.TimeoutError: + _log.info('Timed out connecting to voice') + await self.disconnect() + raise + except Exception: + _log.exception('Error connecting to voice... disconnecting') + await self.disconnect() + raise + + async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None: + _log.info('Connecting to voice...') + + async with async_timeout.timeout(timeout): + for i in range(5): + _log.info('Starting voice handshake... (connection attempt %d)', i + 1) + + await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) + # Setting this unnecessarily will break reconnecting + if self.state is ConnectionFlowState.disconnected: + self.state = ConnectionFlowState.set_guild_voice_state + + await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) + + _log.info('Voice handshake complete. Endpoint found: %s', self.endpoint) + + try: + self.ws = await self._connect_websocket(resume) + await self._handshake_websocket() + break + except ConnectionClosed: + if reconnect: + wait = 1 + i * 2.0 + _log.exception('Failed to connect to voice... Retrying in %ss...', wait) + await self.disconnect(cleanup=False) + await asyncio.sleep(wait) + continue + else: + await self.disconnect() + raise + + _log.info('Voice connection complete.') + + if not self._runner: + self._runner = self.voice_client.loop.create_task(self._poll_voice_ws(reconnect), name='Voice websocket poller') + + async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None: + if not force and not self.is_connected(): + return + + try: + if self.ws: + await self.ws.close() + await self._voice_disconnect() + except Exception: + _log.debug('Ignoring exception disconnecting from voice', exc_info=True) + finally: + self.ip = MISSING + self.port = MISSING + self.state = ConnectionFlowState.disconnected + self._socket_reader.pause() + + # Flip the connected event to unlock any waiters + self._connected.set() + self._connected.clear() + + if cleanup: + self._socket_reader.stop() + self.voice_client.cleanup() + + if self.socket: + self.socket.close() + + async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None: + _log.debug('Soft disconnecting from voice') + # Stop the websocket reader because closing the websocket will trigger an unwanted reconnect + if self._runner: + self._runner.cancel() + self._runner = None + + try: + if self.ws: + await self.ws.close() + except Exception: + _log.debug('Ignoring exception soft disconnecting from voice', exc_info=True) + finally: + self.ip = MISSING + self.port = MISSING + self.state = with_state + self._socket_reader.pause() + + if self.socket: + self.socket.close() + + async def move_to(self, channel: Optional[abc.Snowflake], timeout: Optional[float]) -> None: + if channel is None: + await self.disconnect() + return + + previous_state = self.state + # this is only an outgoing ws request + # if it fails, nothing happens and nothing changes (besides self.state) + await self._move_to(channel) + + last_state = self.state + try: + await self.wait_async(timeout) + except asyncio.TimeoutError: + _log.warning('Timed out trying to move to channel %s in guild %s', channel.id, self.guild.id) + if self.state is last_state: + _log.debug('Reverting to previous state %s', previous_state.name) + + self.state = previous_state + + def wait(self, timeout: Optional[float] = None) -> bool: + return self._connected.wait(timeout) + + async def wait_async(self, timeout: Optional[float] = None) -> None: + await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout) + + def is_connected(self) -> bool: + return self.state is ConnectionFlowState.connected + + def send_packet(self, packet: bytes) -> None: + self.socket.sendall(packet) + + def add_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Registering socket listener callback %s', callback) + self._socket_reader.register(callback) + + def remove_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Unregistering socket listener callback %s', callback) + self._socket_reader.unregister(callback) + + async def _wait_for_state( + self, state: ConnectionFlowState, *other_states: ConnectionFlowState, timeout: Optional[float] = None + ) -> None: + states = (state, *other_states) + while True: + if self.state in states: + return + await sane_wait_for([self._state_event.wait()], timeout=timeout) + + async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None: + channel = self.voice_client.channel + await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute) + + async def _voice_disconnect(self) -> None: + _log.info( + 'The voice handshake is being terminated for Channel ID %s (Guild ID %s)', + self.voice_client.channel.id, + self.voice_client.guild.id, + ) + self.state = ConnectionFlowState.disconnected + await self.voice_client.channel.guild.change_voice_state(channel=None) + self._expecting_disconnect = True + + async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket: + ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook) + self.state = ConnectionFlowState.websocket_connected + return ws + + async def _handshake_websocket(self) -> None: + while not self.ip: + await self.ws.poll_event() + self.state = ConnectionFlowState.got_ip_discovery + while self.ws.secret_key is None: + await self.ws.poll_event() + self.state = ConnectionFlowState.connected + + def _create_socket(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + self._socket_reader.resume() + + async def _poll_voice_ws(self, reconnect: bool) -> None: + backoff = ExponentialBackoff() + while True: + try: + await self.ws.poll_event() + except asyncio.CancelledError: + return + except (ConnectionClosed, asyncio.TimeoutError) as exc: + if isinstance(exc, ConnectionClosed): + # The following close codes are undocumented so I will document them here. + # 1000 - normal closure (obviously) + # 4014 - we were externally disconnected (voice channel deleted, we were moved, etc) + # 4015 - voice server has crashed + if exc.code in (1000, 4015): + _log.info('Disconnecting from voice normally, close code %d.', exc.code) + await self.disconnect() + break + + if exc.code == 4014: + _log.info('Disconnected from voice by force... potentially reconnecting.') + successful = await self._potential_reconnect() + if not successful: + _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') + await self.disconnect() + break + else: + continue + + _log.debug('Not handling close code %s (%s)', exc.code, exc.reason or 'no reason') + + if not reconnect: + await self.disconnect() + raise + + retry = backoff.delay() + _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) + await asyncio.sleep(retry) + await self.disconnect(cleanup=False) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + ) + except asyncio.TimeoutError: + # at this point we've retried 5 times... let's continue the loop. + _log.warning('Could not connect to voice... Retrying...') + continue + + async def _potential_reconnect(self) -> bool: + try: + await self._wait_for_state( + ConnectionFlowState.got_voice_server_update, ConnectionFlowState.got_both_voice_updates, timeout=self.timeout + ) + except asyncio.TimeoutError: + return False + try: + self.ws = await self._connect_websocket(False) + await self._handshake_websocket() + except (ConnectionClosed, asyncio.TimeoutError): + return False + else: + return True + + async def _move_to(self, channel: abc.Snowflake) -> None: + await self.voice_client.channel.guild.change_voice_state(channel=channel) + self.state = ConnectionFlowState.set_guild_voice_state diff --git a/docs/_static/style.css b/docs/_static/style.css index 01017fbc7aaa..4354344ec42b 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -113,6 +113,12 @@ section { --attribute-table-entry-hover-text: var(--blue-2); --attribute-table-badge: var(--grey-7); --highlighted-text: rgb(252, 233, 103); + --tabs--label-text: var(--main-text); + --tabs--label-text--hover: var(--main-text); + --tabs--label-text--active: var(--blue-1); + --tabs--label-text--active--hover: var(--blue-1); + --tabs--label-border--active: var(--blue-1); + --tabs--label-border--active--hover: var(--blue-1); } :root[data-font="serif"] { diff --git a/docs/api.rst b/docs/api.rst index 93029f65e5e9..89b05a8c3678 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1029,6 +1029,12 @@ Reactions Consider using :func:`on_raw_reaction_add` if you need this and do not otherwise want to enable the members intent. + .. warning:: + + This event does not have a way of differentiating whether a reaction is a + burst reaction (also known as "super reaction") or not. If you need this, + consider using :func:`on_raw_reaction_add` instead. + :param reaction: The current state of the reaction. :type reaction: :class:`Reaction` :param user: The user who added the reaction. @@ -1051,6 +1057,12 @@ Reactions Consider using :func:`on_raw_reaction_remove` if you need this and do not want to enable the members intent. + .. warning:: + + This event does not have a way of differentiating whether a reaction is a + burst reaction (also known as "super reaction") or not. If you need this, + consider using :func:`on_raw_reaction_remove` instead. + :param reaction: The current state of the reaction. :type reaction: :class:`Reaction` :param user: The user whose reaction was removed. @@ -3380,6 +3392,24 @@ of :class:`enum.Enum`. Sort forum posts by creation time (from most recent to oldest). +.. class:: SelectDefaultValueType + + Represents the default value of a select menu. + + .. versionadded:: 2.4 + + .. attribute:: user + + The underlying type of the ID is a user. + + .. attribute:: role + + The underlying type of the ID is a role. + + .. attribute:: channel + + The underlying type of the ID is a channel or thread. + .. _discord-api-audit-logs: diff --git a/docs/conf.py b/docs/conf.py index a5fcc17730a4..28b39452cfba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,6 +37,7 @@ 'sphinx.ext.intersphinx', 'sphinx.ext.napoleon', 'sphinxcontrib_trio', + 'sphinx_inline_tabs', 'details', 'exception_hierarchy', 'attributetable', diff --git a/docs/interactions/api.rst b/docs/interactions/api.rst index 8e930c6ef12f..95c1922d181d 100644 --- a/docs/interactions/api.rst +++ b/docs/interactions/api.rst @@ -166,6 +166,14 @@ SelectOption .. autoclass:: SelectOption :members: +SelectDefaultValue +~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: SelectDefaultValue + +.. autoclass:: SelectDefaultValue + :members: + Choice ~~~~~~~ diff --git a/examples/advanced_startup.py b/examples/advanced_startup.py index c521841df9ec..4a452188df7b 100644 --- a/examples/advanced_startup.py +++ b/examples/advanced_startup.py @@ -1,3 +1,5 @@ +# This example requires the 'message_content' privileged intent to function, however your own bot might not. + # This example covers advanced startup options and uses some real world examples for why you may need them. import asyncio @@ -88,9 +90,16 @@ async def main(): # 2. We become responsible for starting the bot. exts = ['general', 'mod', 'dice'] - async with CustomBot(commands.when_mentioned, db_pool=pool, web_client=our_client, initial_extensions=exts) as bot: - - await bot.start(os.getenv('TOKEN', '')) + intents = discord.Intents.default() + intents.message_content = True + async with CustomBot( + commands.when_mentioned, + db_pool=pool, + web_client=our_client, + initial_extensions=exts, + intents=intents, + ) as bot: + await bot.start('token') # For most use cases, after defining what needs to run, we can just tell asyncio to run it: diff --git a/setup.py b/setup.py index ba9a075bce1d..1b0fbdcfea1d 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'sphinxcontrib_trio==1.1.2', 'sphinxcontrib-websupport', 'typing-extensions>=4.3,<5', + 'sphinx-inline-tabs', ], 'speed': [ 'orjson>=3.5.4', @@ -94,6 +95,7 @@ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Topic :: Internet', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', diff --git a/tests/test_app_commands_group.py b/tests/test_app_commands_group.py index 228debde6f3c..4445cfa98b11 100644 --- a/tests/test_app_commands_group.py +++ b/tests/test_app_commands_group.py @@ -479,3 +479,31 @@ async def third(self, interaction: discord.Interaction) -> None: assert isinstance(third, app_commands.Command) assert third.parent is second assert third.binding is cog + + +def test_cog_hybrid_group_wrapped_instance(): + class MyCog(commands.Cog): + @commands.hybrid_group(fallback='fallback') + async def first(self, ctx: commands.Context) -> None: + pass + + @first.command() + async def second(self, ctx: commands.Context) -> None: + pass + + @first.group() + async def nested(self, ctx: commands.Context) -> None: + pass + + @nested.app_command.command() + async def child(self, interaction: discord.Interaction) -> None: + pass + + cog = MyCog() + + fallback = cog.first.app_command.get_command('fallback') + assert fallback is not None + assert getattr(fallback, 'wrapped', None) is cog.first + assert fallback.parent is cog.first.app_command + assert cog.second.app_command is not None + assert cog.second.app_command.wrapped is cog.second