Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions discord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ class VersionInfo(NamedTuple):
serial: int


version_info: VersionInfo = VersionInfo(
major=2, minor=0, micro=0, releaselevel="beta", serial=4
)
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="beta", serial=4)

logging.getLogger(__name__).addHandler(logging.NullHandler())
34 changes: 8 additions & 26 deletions discord/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,10 @@


def show_version() -> None:
entries = [
"- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(
sys.version_info
)
]
entries = ["- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info)]

version_info = discord.version_info
entries.append(
"- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info)
)
entries.append("- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info))
if version_info.releaselevel != "final":
pkg = pkg_resources.get_distribution("py-cord")
if pkg:
Expand Down Expand Up @@ -299,9 +293,7 @@ def newcog(parser, args) -> None:


def add_newbot_args(subparser: argparse._SubParsersAction) -> None:
parser = subparser.add_parser(
"newbot", help="creates a command bot project quickly"
)
parser = subparser.add_parser("newbot", help="creates a command bot project quickly")
parser.set_defaults(func=newbot)

parser.add_argument("name", help="the bot project name")
Expand All @@ -311,12 +303,8 @@ def add_newbot_args(subparser: argparse._SubParsersAction) -> None:
nargs="?",
default=Path.cwd(),
)
parser.add_argument(
"--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>"
)
parser.add_argument(
"--sharded", help="whether to use AutoShardedBot", action="store_true"
)
parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>")
parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true")
parser.add_argument(
"--no-git",
help="do not create a .gitignore file",
Expand Down Expand Up @@ -347,18 +335,12 @@ def add_newcog_args(subparser: argparse._SubParsersAction) -> None:
help="whether to hide all commands in the cog",
action="store_true",
)
parser.add_argument(
"--full", help="add all special methods as well", action="store_true"
)
parser.add_argument("--full", help="add all special methods as well", action="store_true")


def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
parser = argparse.ArgumentParser(
prog="discord", description="Tools for helping with Pycord"
)
parser.add_argument(
"-v", "--version", action="store_true", help="shows the library version"
)
parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with Pycord")
parser.add_argument("-v", "--version", action="store_true", help="shows the library version")
parser.set_defaults(func=core)

subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
Expand Down
114 changes: 28 additions & 86 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@
from .ui.view import View
from .user import ClientUser

PartialMessageableChannel = Union[
TextChannel, Thread, DMChannel, PartialMessageable
]
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]

Expand Down Expand Up @@ -262,9 +260,7 @@ class GuildChannel:

if TYPE_CHECKING:

def __init__(
self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]
):
def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]):
...

def __str__(self) -> str:
Expand All @@ -290,9 +286,7 @@ async def _move(

http = self._state.http
bucket = self._sorting_bucket
channels: List[GuildChannel] = [
c for c in self.guild.channels if c._sorting_bucket == bucket
]
channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]

channels.sort(key=lambda c: c.position)

Expand All @@ -319,9 +313,7 @@ async def _move(

await http.bulk_channel_update(self.guild.id, payload, reason=reason)

async def _edit(
self, options: Dict[str, Any], reason: Optional[str]
) -> Optional[ChannelPayload]:
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
try:
parent = options.pop("category")
except KeyError:
Expand Down Expand Up @@ -357,18 +349,14 @@ async def _edit(
if lock_permissions:
category = self.guild.get_channel(parent_id)
if category:
options["permission_overwrites"] = [
c._asdict() for c in category._overwrites
]
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
options["parent_id"] = parent_id
elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id)
if category:
options["permission_overwrites"] = [
c._asdict() for c in category._overwrites
]
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
else:
await self._move(
position,
Expand All @@ -382,18 +370,14 @@ async def _edit(
perms = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(
f"Expected PermissionOverwrite received {perm.__class__.__name__}"
)
raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")

allow, deny = perm.pair()
payload = {
"allow": allow.value,
"deny": deny.value,
"id": target.id,
"type": _Overwrites.ROLE
if isinstance(target, Role)
else _Overwrites.MEMBER,
"type": _Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER,
}

perms.append(payload)
Expand All @@ -409,9 +393,7 @@ async def _edit(
options["type"] = ch_type.value

if options:
return await self._state.http.edit_channel(
self.id, reason=reason, **options
)
return await self._state.http.edit_channel(self.id, reason=reason, **options)

def _fill_overwrites(self, data: GuildChannelPayload) -> None:
self._overwrites = []
Expand Down Expand Up @@ -617,9 +599,7 @@ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
try:
maybe_everyone = self._overwrites[0]
if maybe_everyone.id == self.guild.id:
base.handle_overwrite(
allow=maybe_everyone.allow, deny=maybe_everyone.deny
)
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
except IndexError:
pass

Expand Down Expand Up @@ -650,9 +630,7 @@ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
try:
maybe_everyone = self._overwrites[0]
if maybe_everyone.id == self.guild.id:
base.handle_overwrite(
allow=maybe_everyone.allow, deny=maybe_everyone.deny
)
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
remaining_overwrites = self._overwrites[1:]
else:
remaining_overwrites = self._overwrites
Expand Down Expand Up @@ -735,9 +713,7 @@ async def set_permissions(
) -> None:
...

async def set_permissions(
self, target, *, overwrite=_undefined, reason=None, **permissions
):
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
r"""|coro|

Sets the channel specific permission overwrites for a target in the
Expand Down Expand Up @@ -831,9 +807,7 @@ async def set_permissions(
await http.delete_channel_permissions(self.id, target.id, reason=reason)
elif isinstance(overwrite, PermissionOverwrite):
(allow, deny) = overwrite.pair()
await http.edit_channel_permissions(
self.id, target.id, allow.value, deny.value, perm_type, reason=reason
)
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
else:
raise InvalidArgument("Invalid overwrite type provided.")

Expand All @@ -849,18 +823,14 @@ async def _clone_impl(
base_attrs["name"] = name or self.name
guild_id = self.guild.id
cls = self.__class__
data = await self._state.http.create_channel(
guild_id, self.type.value, reason=reason, **base_attrs
)
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
obj = cls(state=self._state, guild=self.guild, data=data)

# temporarily add it to the cache
self.guild._channels[obj.id] = obj # type: ignore
return obj

async def clone(
self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None
) -> GCH:
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
"""|coro|

Clones this channel. This creates a channel with the same properties
Expand Down Expand Up @@ -1007,25 +977,19 @@ async def move(self, **kwargs) -> None:
before, after = kwargs.get("before"), kwargs.get("after")
offset = kwargs.get("offset", 0)
if sum(bool(a) for a in (beginning, end, before, after)) > 1:
raise InvalidArgument(
"Only one of [before, after, end, beginning] can be used."
)
raise InvalidArgument("Only one of [before, after, end, beginning] can be used.")

bucket = self._sorting_bucket
parent_id = kwargs.get("category", MISSING)
channels: List[GuildChannel]
if parent_id not in (MISSING, None):
parent_id = parent_id.id
channels = [
ch
for ch in self.guild.channels
if ch._sorting_bucket == bucket and ch.category_id == parent_id
ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id
]
else:
channels = [
ch
for ch in self.guild.channels
if ch._sorting_bucket == bucket and ch.category_id == self.category_id
ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id
]

channels.sort(key=lambda c: (c.position, c.id))
Expand All @@ -1045,9 +1009,7 @@ async def move(self, **kwargs) -> None:
elif before:
index = next((i for i, c in enumerate(channels) if c.id == before.id), None)
elif after:
index = next(
(i + 1 for i, c in enumerate(channels) if c.id == after.id), None
)
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)

if index is None:
raise InvalidArgument("Could not resolve appropriate move position")
Expand All @@ -1062,9 +1024,7 @@ async def move(self, **kwargs) -> None:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)

await self._state.http.bulk_channel_update(
self.guild.id, payload, reason=reason
)
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)

async def create_invite(
self,
Expand Down Expand Up @@ -1180,10 +1140,7 @@ async def invites(self) -> List[Invite]:
state = self._state
data = await state.http.invites_from_channel(self.id)
guild = self.guild
return [
Invite(state=state, data=invite, channel=self, guild=guild)
for invite in data
]
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]


class Messageable:
Expand Down Expand Up @@ -1390,27 +1347,21 @@ async def send(
content = str(content) if content is not None else None

if embed is not None and embeds is not None:
raise InvalidArgument(
"cannot pass both embed and embeds parameter to send()"
)
raise InvalidArgument("cannot pass both embed and embeds parameter to send()")

if embed is not None:
embed = embed.to_dict()

elif embeds is not None:
if len(embeds) > 10:
raise InvalidArgument(
"embeds parameter must be a list of up to 10 elements"
)
raise InvalidArgument("embeds parameter must be a list of up to 10 elements")
embeds = [embed.to_dict() for embed in embeds]

if stickers is not None:
stickers = [sticker.id for sticker in stickers]

if allowed_mentions is None:
allowed_mentions = (
state.allowed_mentions and state.allowed_mentions.to_dict()
)
allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict()

elif state.allowed_mentions is not None:
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
Expand All @@ -1430,9 +1381,7 @@ async def send(

if view:
if not hasattr(view, "__discord_ui_view__"):
raise InvalidArgument(
f"view parameter must be View not {view.__class__!r}"
)
raise InvalidArgument(f"view parameter must be View not {view.__class__!r}")

components = view.to_components()
else:
Expand Down Expand Up @@ -1464,9 +1413,7 @@ async def send(

elif files is not None:
if len(files) > 10:
raise InvalidArgument(
"files parameter must be a list of up to 10 elements"
)
raise InvalidArgument("files parameter must be a list of up to 10 elements")
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument("files parameter must be a list of File")

Expand Down Expand Up @@ -1635,15 +1582,10 @@ def can_send(self, *objects) -> bool:
if obj is None:
permission = mapping["Message"]
else:
permission = (
mapping.get(type(obj).__name__) or mapping[obj.__name__]
)
permission = mapping.get(type(obj).__name__) or mapping[obj.__name__]

if type(obj).__name__ == "Emoji":
if (
obj._to_partial().is_unicode_emoji
or obj.guild_id == channel.guild.id
):
if obj._to_partial().is_unicode_emoji or obj.guild_id == channel.guild.id:
continue
elif type(obj).__name__ == "GuildSticker":
if obj.guild_id == channel.guild.id:
Expand Down
Loading