From 2c828c48299112140f371acf60615dd1d76bcf04 Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Sun, 14 May 2023 13:58:17 -0400 Subject: [PATCH] fix: allow creating forum threads with files --- CHANGELOG.md | 2 ++ discord/channel.py | 49 ++++++++++++---------------------------------- discord/http.py | 29 +++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b981cd659e..ca70b98dd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,6 +113,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2048](https://github.com/Pycord-Development/pycord/pull/2048)) - Fixed the Slash command syncronization method `indiviual`. ([#1925](https://github.com/Pycord-Development/pycord/pull/1925)) +- Fixed `HttpException` when trying to create a Forum thread with files. + ([#2075](https://github.com/Pycord-Development/pycord/pull/2075)) ## [2.4.1] - 2023-03-20 diff --git a/discord/channel.py b/discord/channel.py index ce2e87aad9..9a2e931eba 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -1274,26 +1274,7 @@ async def create_thread( if file is not None and files is not None: raise InvalidArgument("cannot pass both file and files parameter to send()") - if file is not None: - if not isinstance(file, File): - raise InvalidArgument("file parameter must be File") - - try: - data = await state.http.send_files( - self.id, - files=[file], - allowed_mentions=allowed_mentions, - content=message_content, - embed=embed, - embeds=embeds, - nonce=nonce, - stickers=stickers, - components=components, - ) - finally: - file.close() - - elif files is not None: + if files is not None: if len(files) > 10: raise InvalidArgument( "files parameter must be a list of up to 10 elements" @@ -1301,26 +1282,17 @@ async def create_thread( elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") - try: - data = await state.http.send_files( - self.id, - files=files, - content=message_content, - embed=embed, - embeds=embeds, - nonce=nonce, - allowed_mentions=allowed_mentions, - stickers=stickers, - components=components, - ) - finally: - for f in files: - f.close() - else: + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + files = [file] + + try: data = await state.http.start_forum_thread( self.id, content=message_content, name=name, + files=files, embed=embed, embeds=embeds, nonce=nonce, @@ -1333,6 +1305,11 @@ async def create_thread( applied_tags=applied_tags, reason=reason, ) + finally: + if files is not None: + for f in files: + f.close() + ret = Thread(guild=self.guild, state=self._state, data=data) msg = ret.get_partial_message(data["last_message_id"]) if view: diff --git a/discord/http.py b/discord/http.py index d70380cd4f..7808a16c79 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1170,6 +1170,7 @@ def start_forum_thread( invitable: bool = True, applied_tags: SnowflakeList | None = None, reason: str | None = None, + files: Sequence[File] | None = None, embed: embed.Embed | None = None, embeds: list[embed.Embed] | None = None, nonce: str | None = None, @@ -1177,7 +1178,7 @@ def start_forum_thread( stickers: list[sticker.StickerItem] | None = None, components: list[components.Component] | None = None, ) -> Response[threads.Thread]: - payload = { + payload: dict[str, Any] = { "name": name, "auto_archive_duration": auto_archive_duration, "invitable": invitable, @@ -1208,13 +1209,37 @@ def start_forum_thread( if rate_limit_per_user: payload["rate_limit_per_user"] = rate_limit_per_user + + form = [{"name": "payload_json"}] + if files: + attachments = [] + for index, file in enumerate(files): + attachments.append( + { + "id": index, + "filename": file.filename, + "description": file.description, + } + ) + form.append( + { + "name": f"files[{index}]", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", + } + ) + payload["attachments"] = attachments + + form[0]["value"] = utils._to_json(payload) + # TODO: Once supported by API, remove has_message=true query parameter route = Route( "POST", "/channels/{channel_id}/threads?has_message=true", channel_id=channel_id, ) - return self.request(route, json=payload, reason=reason) + return self.request(route, form=form, reason=reason) def join_thread(self, channel_id: Snowflake) -> Response[None]: return self.request(