From 582c6ac0037b51f78b6e11e0c32676d44fd99d70 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Sun, 19 Nov 2023 17:16:27 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- telethon/_updates/messagebox.py | 3 +- telethon/client/chats.py | 5 +- telethon/client/downloads.py | 59 +++---- telethon/client/messages.py | 42 ++--- telethon/client/updates.py | 63 ++------ telethon/client/uploads.py | 55 +++---- telethon/client/users.py | 32 ++-- telethon/events/callbackquery.py | 5 +- telethon/extensions/html.py | 48 +++--- telethon/extensions/markdown.py | 34 ++-- telethon/helpers.py | 55 +++---- telethon/network/mtprotosender.py | 41 ++--- telethon/sessions/memory.py | 50 +++--- telethon/tl/custom/message.py | 21 +-- telethon/utils.py | 166 +++++++++----------- telethon_generator/generators/tlobject.py | 180 +++++++++++----------- 16 files changed, 355 insertions(+), 504 deletions(-) diff --git a/telethon/_updates/messagebox.py b/telethon/_updates/messagebox.py index 0c0d008be..ff027451e 100644 --- a/telethon/_updates/messagebox.py +++ b/telethon/_updates/messagebox.py @@ -109,8 +109,7 @@ def from_update(cls, update): entry = getattr(update, 'channel_id', None) or ENTRY_ACCOUNT return cls(pts=pts, pts_count=pts_count, entry=entry) - qts = getattr(update, 'qts', None) - if qts: + if qts := getattr(update, 'qts', None): pts_count = 1 if isinstance(update, tl.UpdateNewEncryptedMessage) else 0 return cls(pts=qts, pts_count=pts_count, entry=ENTRY_SECRET) diff --git a/telethon/client/chats.py b/telethon/client/chats.py index 64f60d5de..ccadf0c68 100644 --- a/telethon/client/chats.py +++ b/telethon/client/chats.py @@ -803,14 +803,13 @@ def action( try: action = _ChatAction._str_mapping[action.lower()] except KeyError: - raise ValueError( - 'No such action "{}"'.format(action)) from None + raise ValueError(f'No such action "{action}"') from None elif not isinstance(action, types.TLObject) or action.SUBCLASS_OF_ID != 0x20b2cc21: # 0x20b2cc21 = crc32(b'SendMessageAction') if isinstance(action, type): raise ValueError('You must pass an instance, not the class') else: - raise ValueError('Cannot use {} as action'.format(action)) + raise ValueError(f'Cannot use {action} as action') if isinstance(action, types.SendMessageCancelAction): # ``SetTypingRequest.resolve`` will get input peer of ``entity``. diff --git a/telethon/client/downloads.py b/telethon/client/downloads.py index 3c9fa2d16..0db8af7e4 100644 --- a/telethon/client/downloads.py +++ b/telethon/client/downloads.py @@ -265,27 +265,27 @@ async def download_profile_photo( thumb=thumb, progress_callback=None ) - for attr in ('username', 'first_name', 'title'): - possible_names.append(getattr(entity, attr, None)) - + possible_names.extend( + getattr(entity, attr, None) + for attr in ('username', 'first_name', 'title') + ) photo = entity.photo - if isinstance(photo, (types.UserProfilePhoto, types.ChatPhoto)): - dc_id = photo.dc_id - loc = types.InputPeerPhotoFileLocation( - # min users can be used to download profile photos - # self.get_input_entity would otherwise not accept those - peer=utils.get_input_peer(entity, check_hash=False), - photo_id=photo.photo_id, - big=download_big - ) - else: + if not isinstance(photo, (types.UserProfilePhoto, types.ChatPhoto)): # It doesn't make any sense to check if `photo` can be used # as input location, because then this method would be able # to "download the profile photo of a message", i.e. its # media which should be done with `download_media` instead. return None + dc_id = photo.dc_id + loc = types.InputPeerPhotoFileLocation( + # min users can be used to download profile photos + # self.get_input_entity would otherwise not accept those + peer=utils.get_input_peer(entity, check_hash=False), + photo_id=photo.photo_id, + big=download_big + ) file = self._get_proper_filename( file, 'profile_photo', '.jpg', possible_names=possible_names @@ -299,16 +299,15 @@ async def download_profile_photo( # The fix seems to be using the full channel chat photo. ie = await self.get_input_entity(entity) ty = helpers._entity_type(ie) - if ty == helpers._EntityType.CHANNEL: - full = await self(functions.channels.GetFullChannelRequest(ie)) - return await self._download_photo( - full.full_chat.chat_photo, file, - date=None, progress_callback=None, - thumb=thumb - ) - else: + if ty != helpers._EntityType.CHANNEL: # Until there's a report for chats, no need to. return None + full = await self(functions.channels.GetFullChannelRequest(ie)) + return await self._download_photo( + full.full_chat.chat_photo, file, + date=None, progress_callback=None, + thumb=thumb + ) async def download_media( self: 'TelegramClient', @@ -518,11 +517,7 @@ async def _download_file( iv: bytes = None, msg_data: tuple = None) -> typing.Optional[bytes]: if not part_size_kb: - if not file_size: - part_size_kb = 64 # Reasonable default - else: - part_size_kb = utils.get_appropriated_part_size(file_size) - + part_size_kb = utils.get_appropriated_part_size(file_size) if file_size else 64 part_size = int(part_size_kb * 1024) if part_size % MIN_CHUNK_SIZE != 0: raise ValueError( @@ -756,11 +751,7 @@ def sort_thumbs(thumb): return 1, thumb.size if isinstance(thumb, types.PhotoSizeProgressive): return 1, max(thumb.sizes) - if isinstance(thumb, types.VideoSize): - return 2, thumb.size - - # Empty size or invalid should go last - return 0, 0 + return (2, thumb.size) if isinstance(thumb, types.VideoSize) else (0, 0) thumbs = list(sorted(thumbs, key=sort_thumbs)) @@ -856,9 +847,7 @@ def _get_kind_and_names(attributes): elif isinstance(attr, types.DocumentAttributeAudio): kind = 'audio' if attr.performer and attr.title: - possible_names.append('{} - {}'.format( - attr.performer, attr.title - )) + possible_names.append(f'{attr.performer} - {attr.title}') elif attr.performer: possible_names.append(attr.performer) elif attr.title: @@ -1044,7 +1033,7 @@ def _get_proper_filename(file, kind, extension, i = 1 while True: - result = os.path.join(directory, '{} ({}){}'.format(name, i, ext)) + result = os.path.join(directory, f'{name} ({i}){ext}') if not os.path.isfile(result): return result i += 1 diff --git a/telethon/client/messages.py b/telethon/client/messages.py index e0927eedb..56f7d2b15 100644 --- a/telethon/client/messages.py +++ b/telethon/client/messages.py @@ -242,14 +242,12 @@ def _message_in_range(self, message): Determine whether the given message is in the range or it should be ignored (and avoid loading more chunks). """ - # No entity means message IDs between chats may vary if self.entity: if self.reverse: if message.id <= self.last_id or message.id >= self.max_id: return False - else: - if message.id >= self.last_id or message.id <= self.min_id: - return False + elif message.id >= self.last_id or message.id <= self.min_id: + return False return True @@ -585,11 +583,7 @@ async def get_messages(self: 'TelegramClient', *args, **kwargs) -> 'hints.TotalL message_1337 = await client.get_messages(chat, ids=1337) """ if len(args) == 1 and 'limit' not in kwargs: - if 'min_id' in kwargs and 'max_id' in kwargs: - kwargs['limit'] = None - else: - kwargs['limit'] = 1 - + kwargs['limit'] = None if 'min_id' in kwargs and 'max_id' in kwargs else 1 it = self.iter_messages(*args, **kwargs) ids = kwargs.get('ids') @@ -1203,15 +1197,14 @@ async def edit_message( # Invoke `messages.editInlineBotMessage` from the right datacenter. # Otherwise, Telegram will error with `MESSAGE_ID_INVALID` and do nothing. exported = self.session.dc_id != entity.dc_id - if exported: - try: - sender = await self._borrow_exported_sender(entity.dc_id) - return await self._call(sender, request) - finally: - await self._return_exported_sender(sender) - else: + if not exported: return await self(request) + try: + sender = await self._borrow_exported_sender(entity.dc_id) + return await self._call(sender, request) + finally: + await self._return_exported_sender(sender) entity = await self.get_input_entity(entity) request = functions.messages.EditMessageRequest( peer=entity, @@ -1223,8 +1216,7 @@ async def edit_message( reply_markup=self.build_reply_markup(buttons), schedule_date=schedule ) - msg = self._get_response_message(request, await self(request), entity) - return msg + return self._get_response_message(request, await self(request), entity) async def delete_messages( self: 'TelegramClient', @@ -1362,14 +1354,14 @@ async def send_read_acknowledge( await client.send_read_acknowledge(chat, messages) """ if max_id is None: - if not message: - max_id = 0 + if message: + max_id = ( + max(msg.id for msg in message) + if utils.is_list_like(message) + else message.id + ) else: - if utils.is_list_like(message): - max_id = max(msg.id for msg in message) - else: - max_id = message.id - + max_id = 0 entity = await self.get_input_entity(entity) if clear_mentions: await self(functions.messages.ReadMentionsRequest(entity)) diff --git a/telethon/client/updates.py b/telethon/client/updates.py index f06b28e8a..e3532a154 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -301,8 +301,7 @@ async def _update_loop(self: 'TelegramClient'): ) - get_diff = self._message_box.get_difference() - if get_diff: + if get_diff := self._message_box.get_difference(): self._log[__name__].debug('Getting difference for account updates') try: diff = await self(get_diff) @@ -346,8 +345,9 @@ async def _update_loop(self: 'TelegramClient'): updates_to_dispatch.extend(self._preprocess_updates(updates, users, chats)) continue - get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) - if get_diff: + if get_diff := self._message_box.get_channel_difference( + self._mb_entity_cache + ): self._log[__name__].debug('Getting difference for channel %s updates', get_diff.channel.channel_id) try: diff = await self(get_diff) @@ -446,16 +446,15 @@ async def _update_loop(self: 'TelegramClient'): deadline = self._message_box.check_deadlines() deadline_delay = deadline - get_running_loop().time() - if deadline_delay > 0: - # Don't bother sleeping and timing out if the delay is already 0 (pollutes the logs). - try: - updates = await asyncio.wait_for(self._updates_queue.get(), deadline_delay) - except asyncio.TimeoutError: - self._log[__name__].debug('Timeout waiting for updates expired') - continue - else: + if deadline_delay <= 0: continue + # Don't bother sleeping and timing out if the delay is already 0 (pollutes the logs). + try: + updates = await asyncio.wait_for(self._updates_queue.get(), deadline_delay) + except asyncio.TimeoutError: + self._log[__name__].debug('Timeout waiting for updates expired') + continue processed = [] try: users, chats = self._message_box.process_updates(updates, self._mb_entity_cache, processed) @@ -537,16 +536,13 @@ async def _dispatch_update(self: 'TelegramClient', update): built = EventBuilderDict(self, update, others) for conv_set in self._conversations.values(): for conv in conv_set: - ev = built[events.NewMessage] - if ev: + if ev := built[events.NewMessage]: conv._on_new_message(ev) - ev = built[events.MessageEdited] - if ev: + if ev := built[events.MessageEdited]: conv._on_edit(ev) - ev = built[events.MessageRead] - if ev: + if ev := built[events.MessageRead]: conv._on_read(ev) if conv._custom: @@ -637,37 +633,6 @@ async def _handle_auto_reconnect(self: 'TelegramClient'): 'after reconnect: %s: %s', type(e), e) return - try: - self._log[__name__].info( - 'Asking for the current state after reconnect...') - - # TODO consider: - # If there aren't many updates while the client is disconnected - # (I tried with up to 20), Telegram seems to send them without - # asking for them (via updates.getDifference). - # - # On disconnection, the library should probably set a "need - # difference" or "catching up" flag so that any new updates are - # ignored, and then the library should call updates.getDifference - # itself to fetch them. - # - # In any case (either there are too many updates and Telegram - # didn't send them, or there isn't a lot and Telegram sent them - # but we dropped them), we fetch the new difference to get all - # missed updates. I feel like this would be the best solution. - - # If a disconnection occurs, the old known state will be - # the latest one we were aware of, so we can catch up since - # the most recent state we were aware of. - await self.catch_up() - - self._log[__name__].info('Successfully fetched missed updates') - except errors.RPCError as e: - self._log[__name__].warning('Failed to get missed updates after ' - 'reconnect: %r', e) - except Exception: - self._log[__name__].exception( - 'Unhandled exception while getting update difference after reconnect') # endregion diff --git a/telethon/client/uploads.py b/telethon/client/uploads.py index 7245c0985..aed010b85 100644 --- a/telethon/client/uploads.py +++ b/telethon/client/uploads.py @@ -379,11 +379,7 @@ def callback(current, total): lambda s, t: progress_callback(sent_count + s, len(file)) ) - if utils.is_list_like(caption): - captions = caption - else: - captions = [caption] - + captions = caption if utils.is_list_like(caption) else [caption] result = [] while file: result += await self._send_album( @@ -403,7 +399,7 @@ def callback(current, total): msg_entities = formatting_entities else: caption, msg_entities =\ - await self._parse_message_text(caption, parse_mode) + await self._parse_message_text(caption, parse_mode) file_handle, media, image = await self._file_to_media( file, force_document=force_document, @@ -484,10 +480,7 @@ async def _send_album(self: 'TelegramClient', entity, files, caption='', fm = utils.get_input_media( r.document, supports_streaming=supports_streaming) - if captions: - caption, msg_entities = captions.pop() - else: - caption, msg_entities = '', None + caption, msg_entities = captions.pop() if captions else ('', None) media.append(types.InputSingleMedia( fm, message=caption, @@ -644,15 +637,15 @@ async def upload_file( if not isinstance(part, bytes): raise TypeError( - 'file descriptor returned {}, not bytes (you must ' - 'open the file in bytes mode)'.format(type(part))) + f'file descriptor returned {type(part)}, not bytes (you must open the file in bytes mode)' + ) # `file_size` could be wrong in which case `part` may not be # `part_size` before reaching the end. if len(part) != part_size and part_index < part_count - 1: raise ValueError( - 'read less than {} before reaching the end; either ' - '`file_size` or `read` are wrong'.format(part_size)) + f'read less than {part_size} before reaching the end; either `file_size` or `read` are wrong' + ) pos += len(part) @@ -676,15 +669,13 @@ async def upload_file( file_id, part_index, part) result = await self(request) - if result: - self._log[__name__].debug('Uploaded %d/%d', - part_index + 1, part_count) - if progress_callback: - await helpers._maybe_await(progress_callback(pos, file_size)) - else: - raise RuntimeError( - 'Failed to upload file part {}.'.format(part_index)) + if not result: + raise RuntimeError(f'Failed to upload file part {part_index}.') + self._log[__name__].debug('Uploaded %d/%d', + part_index + 1, part_count) + if progress_callback: + await helpers._maybe_await(progress_callback(pos, file_size)) if is_big: return types.InputFileBig(file_id, part_count, file_name) else: @@ -713,7 +704,7 @@ async def _file_to_media( # `aiofiles` do not base `io.IOBase` but do have `read`, so we # just check for the read attribute to see if it's file-like. if not isinstance(file, (str, bytes, types.InputFile, types.InputFileBig))\ - and not hasattr(file, 'read'): + and not hasattr(file, 'read'): # The user may pass a Message containing media (or the media, # or anything similar) that should be treated as a file. Try # getting the input media for whatever they passed and send it. @@ -747,21 +738,19 @@ async def _file_to_media( progress_callback=progress_callback ) elif re.match('https?://', file): - if as_image: - media = types.InputMediaPhotoExternal(file, ttl_seconds=ttl) - else: - media = types.InputMediaDocumentExternal(file, ttl_seconds=ttl) - else: - bot_file = utils.resolve_bot_file_id(file) - if bot_file: - media = utils.get_input_media(bot_file, ttl=ttl) + media = ( + types.InputMediaPhotoExternal(file, ttl_seconds=ttl) + if as_image + else types.InputMediaDocumentExternal(file, ttl_seconds=ttl) + ) + elif bot_file := utils.resolve_bot_file_id(file): + media = utils.get_input_media(bot_file, ttl=ttl) if media: pass # Already have media, don't check the rest elif not file_handle: raise ValueError( - 'Failed to convert {} to media. Not an existing file, ' - 'an HTTP URL or a valid bot-API-like file ID'.format(file) + f'Failed to convert {file} to media. Not an existing file, an HTTP URL or a valid bot-API-like file ID' ) elif as_image: media = types.InputMediaUploadedPhoto(file_handle, ttl_seconds=ttl) diff --git a/telethon/client/users.py b/telethon/client/users.py index 71c72db3d..80ce996e7 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -94,8 +94,9 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl last_error = e self._log[__name__].warning( 'Telegram is having internal issues %s: %s', - e.__class__.__name__, e) - + last_error.__class__.__name__, + last_error, + ) await asyncio.sleep(2) except (errors.FloodWaitError, errors.SlowModeWaitError, errors.FloodTestPhoneWaitError) as e: last_error = e @@ -105,18 +106,17 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl # SLOW_MODE_WAIT is chat-specific, not request-specific if not isinstance(e, errors.SlowModeWaitError): self._flood_waited_requests\ - [request.CONSTRUCTOR_ID] = time.time() + e.seconds + [request.CONSTRUCTOR_ID] = time.time() + e.seconds # In test servers, FLOOD_WAIT_0 has been observed, and sleeping for # such a short amount will cause retries very fast leading to issues. if e.seconds == 0: e.seconds = 1 - if e.seconds <= self.flood_sleep_threshold: - self._log[__name__].info(*_fmt_flood(e.seconds, request)) - await asyncio.sleep(e.seconds) - else: + if e.seconds > self.flood_sleep_threshold: raise + self._log[__name__].info(*_fmt_flood(e.seconds, request)) + await asyncio.sleep(e.seconds) except (errors.PhoneMigrateError, errors.NetworkMigrateError, errors.UserMigrateError) as e: last_error = e @@ -130,8 +130,7 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl if self._raise_last_call_error and last_error is not None: raise last_error - raise ValueError('Request was unsuccessful {} time(s)' - .format(attempt)) + raise ValueError(f'Request was unsuccessful {attempt} time(s)') # region Public methods @@ -466,10 +465,7 @@ async def get_input_entity( pass raise ValueError( - 'Could not find the input entity for {} ({}). Please read https://' - 'docs.telethon.dev/en/stable/concepts/entities.html to' - ' find out more details.' - .format(peer, type(peer).__name__) + f'Could not find the input entity for {peer} ({type(peer).__name__}). Please read https://docs.telethon.dev/en/stable/concepts/entities.html to find out more details.' ) async def _get_peer(self: 'TelegramClient', peer: 'hints.EntityLike'): @@ -525,8 +521,7 @@ async def _get_entity_from_string(self: 'TelegramClient', string): Returns the found entity, or raises TypeError if not found. """ - phone = utils.parse_phone(string) - if phone: + if phone := utils.parse_phone(string): try: for user in (await self( functions.contacts.GetContactsRequest(0))).users: @@ -555,8 +550,7 @@ async def _get_entity_from_string(self: 'TelegramClient', string): result = await self( functions.contacts.ResolveUsernameRequest(username)) except errors.UsernameNotOccupiedError as e: - raise ValueError('No user has "{}" as username' - .format(username)) from e + raise ValueError(f'No user has "{username}" as username') from e try: pid = utils.get_peer_id(result.peer, add_mark=False) @@ -573,9 +567,7 @@ async def _get_entity_from_string(self: 'TelegramClient', string): except ValueError: pass - raise ValueError( - 'Cannot find any entity corresponding to "{}"'.format(string) - ) + raise ValueError(f'Cannot find any entity corresponding to "{string}"') async def _get_input_dialog(self: 'TelegramClient', dialog): """ diff --git a/telethon/events/callbackquery.py b/telethon/events/callbackquery.py index 29ffd0533..0d42615f5 100644 --- a/telethon/events/callbackquery.py +++ b/telethon/events/callbackquery.py @@ -118,10 +118,7 @@ def filter(self, event): elif event.query.data != self.match: return - if self.func: - # Return the result of func directly as it may need to be awaited - return self.func(event) - return True + return self.func(event) if self.func else True class Event(EventCommon, SenderGetter): """ diff --git a/telethon/extensions/html.py b/telethon/extensions/html.py index 201312ace..45a886e4e 100644 --- a/telethon/extensions/html.py +++ b/telethon/extensions/html.py @@ -34,13 +34,13 @@ def handle_starttag(self, tag, attrs): attrs = dict(attrs) EntityType = None args = {} - if tag == 'strong' or tag == 'b': + if tag in ['strong', 'b']: EntityType = MessageEntityBold - elif tag == 'em' or tag == 'i': + elif tag in ['em', 'i']: EntityType = MessageEntityItalic elif tag == 'u': EntityType = MessageEntityUnderline - elif tag == 'del' or tag == 's': + elif tag in ['del', 's']: EntityType = MessageEntityStrike elif tag == 'blockquote': EntityType = MessageEntityBlockquote @@ -70,13 +70,12 @@ def handle_starttag(self, tag, attrs): if url.startswith('mailto:'): url = url[len('mailto:'):] EntityType = MessageEntityEmail + elif self.get_starttag_text() == url: + EntityType = MessageEntityUrl else: - if self.get_starttag_text() == url: - EntityType = MessageEntityUrl - else: - EntityType = MessageEntityTextUrl - args['url'] = del_surrogate(url) - url = None + EntityType = MessageEntityTextUrl + args['url'] = del_surrogate(url) + url = None self._open_tags_meta.popleft() self._open_tags_meta.appendleft(url) @@ -90,8 +89,7 @@ def handle_starttag(self, tag, attrs): def handle_data(self, text): previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else '' if previous_tag == 'a': - url = self._open_tags_meta[0] - if url: + if url := self._open_tags_meta[0]: text = url for tag, entity in self._building_entities.items(): @@ -105,8 +103,7 @@ def handle_endtag(self, tag): self._open_tags_meta.popleft() except IndexError: pass - entity = self._building_entities.pop(tag, None) - if entity: + if entity := self._building_entities.pop(tag, None): self.entities.append(entity) @@ -135,16 +132,16 @@ def parse(html: str) -> Tuple[str, List[TypeMessageEntity]]: MessageEntityStrike: ('', ''), MessageEntityBlockquote: ('
', '
'), MessageEntityPre: lambda e, _: ( - "
\n"
-        "    \n"
-        "        ".format(e.language), "{}\n"
-        "    \n"
-        "
" + f"
\n    \n        ",
+        "{}\n" "    \n" "
", + ), + MessageEntityEmail: lambda _, t: (f'', ''), + MessageEntityUrl: lambda _, t: (f'', ''), + MessageEntityTextUrl: lambda e, _: (f'', ''), + MessageEntityMentionName: lambda e, _: ( + f'', + '', ), - MessageEntityEmail: lambda _, t: (''.format(t), ''), - MessageEntityUrl: lambda _, t: (''.format(t), ''), - MessageEntityTextUrl: lambda e, _: (''.format(escape(e.url)), ''), - MessageEntityMentionName: lambda e, _: (''.format(e.user_id), ''), } @@ -170,13 +167,10 @@ def unparse(text: str, entities: Iterable[TypeMessageEntity]) -> str: for i, entity in enumerate(entities): s = entity.offset e = entity.offset + entity.length - delimiter = ENTITY_TO_FORMATTER.get(type(entity), None) - if delimiter: + if delimiter := ENTITY_TO_FORMATTER.get(type(entity), None): if callable(delimiter): delimiter = delimiter(entity, text[s:e]) - insert_at.append((s, i, delimiter[0])) - insert_at.append((e, len(entities) - i, delimiter[1])) - + insert_at.extend(((s, i, delimiter[0]), (e, len(entities) - i, delimiter[1]))) insert_at.sort(key=lambda t: (t[0], t[1])) next_escape_bound = len(text) while insert_at: diff --git a/telethon/extensions/markdown.py b/telethon/extensions/markdown.py index 78f283856..703db1ddc 100644 --- a/telethon/extensions/markdown.py +++ b/telethon/extensions/markdown.py @@ -56,8 +56,12 @@ def parse(message, delimiters=None, url_re=None): # Build a regex to efficiently test all delimiters at once. # Note that the largest delimiter should go first, we don't # want ``` to be interpreted as a single back-tick in a code block. - delim_re = re.compile('|'.join('({})'.format(re.escape(k)) - for k in sorted(delimiters, key=len, reverse=True))) + delim_re = re.compile( + '|'.join( + f'({re.escape(k)})' + for k in sorted(delimiters, key=len, reverse=True) + ) + ) # Cannot use a for loop because we need to skip some indices i = 0 @@ -67,10 +71,7 @@ def parse(message, delimiters=None, url_re=None): # The offset will just be half the index we're at. message = add_surrogate(message) while i < len(message): - m = delim_re.match(message, pos=i) - - # Did we find some delimiter here at `i`? - if m: + if m := delim_re.match(message, pos=i): delim = next(filter(None, m.groups())) # +1 to avoid matching right after (e.g. "****") @@ -91,11 +92,7 @@ def parse(message, delimiters=None, url_re=None): # If the end is after our start, it is affected if ent.offset + ent.length > i: # If the old start is also before ours, it is fully enclosed - if ent.offset <= i: - ent.length -= len(delim) * 2 - else: - ent.length -= len(delim) - + ent.length -= len(delim) * 2 if ent.offset <= i else len(delim) # Append the found entity ent = delimiters[delim] if ent == MessageEntityPre: @@ -110,8 +107,7 @@ def parse(message, delimiters=None, url_re=None): continue elif url_re: - m = url_re.match(message, pos=i) - if m: + if m := url_re.match(message, pos=i): # Replace the whole match with only the inline URL text. message = ''.join(( message[:m.start()], @@ -167,20 +163,16 @@ def unparse(text, entities, delimiters=None, url_fmt=None): for i, entity in enumerate(entities): s = entity.offset e = entity.offset + entity.length - delimiter = delimiters.get(type(entity), None) - if delimiter: - insert_at.append((s, i, delimiter)) - insert_at.append((e, len(entities) - i, delimiter)) + if delimiter := delimiters.get(type(entity), None): + insert_at.extend(((s, i, delimiter), (e, len(entities) - i, delimiter))) else: url = None if isinstance(entity, MessageEntityTextUrl): url = entity.url elif isinstance(entity, MessageEntityMentionName): - url = 'tg://user?id={}'.format(entity.user_id) + url = f'tg://user?id={entity.user_id}' if url: - insert_at.append((s, i, '[')) - insert_at.append((e, len(entities) - i, ']({})'.format(url))) - + insert_at.extend(((s, i, '['), (e, len(entities) - i, f']({url})'))) insert_at.sort(key=lambda t: (t[0], t[1])) while insert_at: at, _, what = insert_at.pop() diff --git a/telethon/helpers.py b/telethon/helpers.py index 4fb9d58bf..f0360e263 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -31,8 +31,7 @@ def generate_random_long(signed=True): def ensure_parent_dir_exists(file_path): """Ensures that the parent directory exists""" - parent = os.path.dirname(file_path) - if parent: + if parent := os.path.dirname(file_path): os.makedirs(parent, exist_ok=True) @@ -146,7 +145,7 @@ def retry_range(retries, force_retry=True): # We need at least one iteration even if the retries are 0 # when force_retry is True. - if force_retry and not (retries is None or retries < 0): + if force_retry and retries is not None and retries >= 0: retries += 1 attempt = 0 @@ -157,10 +156,7 @@ def retry_range(retries, force_retry=True): async def _maybe_await(value): - if inspect.isawaitable(value): - return await value - else: - return value + return await value if inspect.isawaitable(value) else value async def _cancel(log, **tasks): @@ -205,11 +201,7 @@ def _sync_enter(self): Helps to cut boilerplate on async context managers that offer synchronous variants. """ - if hasattr(self, 'loop'): - loop = self.loop - else: - loop = self._client.loop - + loop = self.loop if hasattr(self, 'loop') else self._client.loop if loop.is_running(): raise RuntimeError( 'You must use "async with" if the event loop ' @@ -220,11 +212,7 @@ def _sync_enter(self): def _sync_exit(self, *args): - if hasattr(self, 'loop'): - loop = self.loop - else: - loop = self._client.loop - + loop = self.loop if hasattr(self, 'loop') else self._client.loop return loop.run_until_complete(self.__aexit__(*args)) @@ -246,22 +234,24 @@ def _entity_type(entity): 0x1f4661b9, # crc32(b'UserFull') 0xd49a2697, # crc32(b'ChatFull') ): - raise TypeError('{} does not have any entity type'.format(entity)) + raise TypeError(f'{entity} does not have any entity type') except AttributeError: - raise TypeError('{} is not a TLObject, cannot determine entity type'.format(entity)) + raise TypeError(f'{entity} is not a TLObject, cannot determine entity type') name = entity.__class__.__name__ - if 'User' in name: + if ( + 'User' in name + or 'Chat' not in name + and 'Channel' not in name + and 'Self' in name + ): return _EntityType.USER elif 'Chat' in name: return _EntityType.CHAT elif 'Channel' in name: return _EntityType.CHANNEL - elif 'Self' in name: - return _EntityType.USER - # 'Empty' in name or not found, we don't care, not a valid entity. - raise TypeError('{} does not have any entity type'.format(entity)) + raise TypeError(f'{entity} does not have any entity type') # endregion @@ -314,12 +304,10 @@ def __init__(self, *args, **kwargs): self.total = 0 def __str__(self): - return '[{}, total={}]'.format( - ', '.join(str(x) for x in self), self.total) + return f"[{', '.join(str(x) for x in self)}, total={self.total}]" def __repr__(self): - return '[{}, total={}]'.format( - ', '.join(repr(x) for x in self), self.total) + return f"[{', '.join(repr(x) for x in self)}, total={self.total}]" class _FileStream(io.IOBase): @@ -425,10 +413,9 @@ def close(self, *args, **kwargs): # endregion def get_running_loop(): - if sys.version_info >= (3, 7): - try: - return asyncio.get_running_loop() - except RuntimeError: - return asyncio.get_event_loop_policy().get_event_loop() - else: + if sys.version_info < (3, 7): return asyncio.get_event_loop() + try: + return asyncio.get_running_loop() + except RuntimeError: + return asyncio.get_event_loop_policy().get_event_loop() diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 6c3e30c12..a8d5aaee4 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -232,8 +232,8 @@ async def _connect(self): for attempt in retry_range(self._retries): if not connected: connected = await self._try_connect(attempt) - if not connected: - continue # skip auth key generation until we're connected + if not connected: + continue # skip auth key generation until we're connected if not self.auth_key: try: @@ -257,9 +257,9 @@ async def _connect(self): break # all steps done, break retry loop else: if not connected: - raise ConnectionError('Connection to Telegram failed {} time(s)'.format(self._retries)) + raise ConnectionError(f'Connection to Telegram failed {self._retries} time(s)') - e = ConnectionError('auth_key generation failed {} time(s)'.format(self._retries)) + e = ConnectionError(f'auth_key generation failed {self._retries} time(s)') await self._disconnect(error=e) raise e @@ -477,14 +477,13 @@ async def _send_loop(self): # so even if the network fails they won't be lost. If they were # never re-enqueued, the future waiting for a response "locks". for state in batch: - if not isinstance(state, list): - if isinstance(state.request, TLRequest): - self._pending_state[state.msg_id] = state - else: + if isinstance(state, list): for s in state: if isinstance(s.request, TLRequest): self._pending_state[s.msg_id] = s + elif isinstance(state.request, TLRequest): + self._pending_state[state.msg_id] = state try: await self._connection.send(data) except IOError as e: @@ -576,23 +575,17 @@ def _pop_states(self, msg_id): This method should be used when the response isn't specific. """ - state = self._pending_state.pop(msg_id, None) - if state: + if state := self._pending_state.pop(msg_id, None): return [state] - to_pop = [] - for state in self._pending_state.values(): - if state.container_id == msg_id: - to_pop.append(state.msg_id) - - if to_pop: + if to_pop := [ + state.msg_id + for state in self._pending_state.values() + if state.container_id == msg_id + ]: return [self._pending_state.pop(x) for x in to_pop] - for ack in self._last_acks: - if ack.msg_id == msg_id: - return [ack] - - return [] + return next(([ack] for ack in self._last_acks if ack.msg_id == msg_id), []) async def _handle_rpc_result(self, message): """ @@ -730,8 +723,7 @@ async def _handle_pong(self, message): if self._ping == pong.ping_id: self._ping = None - state = self._pending_state.pop(pong.msg_id, None) - if state: + if state := self._pending_state.pop(pong.msg_id, None): state.future.set_result(pong) async def _handle_bad_server_salt(self, message): @@ -856,8 +848,7 @@ async def _handle_future_salts(self, message): # TODO save these salts and automatically adjust to the # correct one whenever the salt in use expires. self._log.debug('Handling future salts for message %d', message.msg_id) - state = self._pending_state.pop(message.msg_id, None) - if state: + if state := self._pending_state.pop(message.msg_id, None): state.future.set_result(message.obj) async def _handle_state_forgotten(self, message): diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 5aed60397..e0e78c1b2 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -142,8 +142,7 @@ def _entities_to_rows(self, tlo): rows = [] # Rows to add (id, hash, username, phone, name) for e in entities: - row = self._entity_to_row(e) - if row: + if row := self._entity_to_row(e): rows.append(row) return rows @@ -176,14 +175,13 @@ def get_entity_rows_by_id(self, id, exact=True): if exact: return next((found_id, hash) for found_id, hash, _, _, _ in self._entities if found_id == id) - else: - ids = ( - utils.get_peer_id(PeerUser(id)), - utils.get_peer_id(PeerChat(id)), - utils.get_peer_id(PeerChannel(id)) - ) - return next((found_id, hash) for found_id, hash, _, _, _ - in self._entities if found_id in ids) + ids = ( + utils.get_peer_id(PeerUser(id)), + utils.get_peer_id(PeerChat(id)), + utils.get_peer_id(PeerChannel(id)) + ) + return next((found_id, hash) for found_id, hash, _, _, _ + in self._entities if found_id in ids) except StopIteration: pass @@ -205,17 +203,14 @@ def get_input_entity(self, key): result = None if isinstance(key, str): - phone = utils.parse_phone(key) - if phone: + if phone := utils.parse_phone(key): result = self.get_entity_rows_by_phone(phone) else: username, invite = utils.parse_username(key) if username and not invite: result = self.get_entity_rows_by_username(username) - else: - tup = utils.resolve_invite_link(key)[1] - if tup: - result = self.get_entity_rows_by_id(tup, exact=False) + elif tup := utils.resolve_invite_link(key)[1]: + result = self.get_entity_rows_by_id(tup, exact=False) elif isinstance(key, int): result = self.get_entity_rows_by_id(key, exact) @@ -223,22 +218,21 @@ def get_input_entity(self, key): if not result and isinstance(key, str): result = self.get_entity_rows_by_name(key) - if result: - entity_id, entity_hash = result # unpack resulting tuple - entity_id, kind = utils.resolve_id(entity_id) - # removes the mark and returns type of entity - if kind == PeerUser: - return InputPeerUser(entity_id, entity_hash) - elif kind == PeerChat: - return InputPeerChat(entity_id) - elif kind == PeerChannel: - return InputPeerChannel(entity_id, entity_hash) - else: + if not result: raise ValueError('Could not find input entity with key ', key) + entity_id, entity_hash = result # unpack resulting tuple + entity_id, kind = utils.resolve_id(entity_id) + # removes the mark and returns type of entity + if kind == PeerUser: + return InputPeerUser(entity_id, entity_hash) + elif kind == PeerChat: + return InputPeerChat(entity_id) + elif kind == PeerChannel: + return InputPeerChannel(entity_id, entity_hash) def cache_file(self, md5_digest, file_size, instance): if not isinstance(instance, (InputDocument, InputPhoto)): - raise TypeError('Cannot cache %s instance' % type(instance)) + raise TypeError(f'Cannot cache {type(instance)} instance') key = (md5_digest, file_size, _SentFileType.from_type(type(instance))) value = (instance.id, instance.access_hash) self._files[key] = value diff --git a/telethon/tl/custom/message.py b/telethon/tl/custom/message.py index 28fd7b87f..ae76edf56 100644 --- a/telethon/tl/custom/message.py +++ b/telethon/tl/custom/message.py @@ -471,8 +471,7 @@ def file(self): etc., without having to manually inspect the ``document.attributes``. """ if not self._file: - media = self.photo or self.document - if media: + if media := self.photo or self.document: self._file = File(media) return self._file @@ -1043,10 +1042,7 @@ def find_button(): if i is None: i = 0 - if j is None: - return self._buttons_flat[i] - else: - return self._buttons[i][j] + return self._buttons_flat[i] if j is None else self._buttons[i][j] button = find_button() if button: @@ -1148,10 +1144,10 @@ def _needed_markup_bot(self): if isinstance(button, types.KeyboardButtonSwitchInline): # no via_bot_id means the bot sent the message itself (#1619) if button.same_peer or not self.via_bot_id: - bot = self.input_sender - if not bot: + if bot := self.input_sender: + return bot + else: raise ValueError('No input sender') - return bot else: try: return self._client._mb_entity_cache.get( @@ -1164,12 +1160,9 @@ def _document_by_attribute(self, kind, condition=None): Helper method to return the document only if it has an attribute that's an instance of the given kind, and passes the condition. """ - doc = self.document - if doc: + if doc := self.document: for attr in doc.attributes: if isinstance(attr, kind): - if not condition or condition(attr): - return doc - return None + return doc if not condition or condition(attr) else None # endregion Private Methods diff --git a/telethon/utils.py b/telethon/utils.py index ab78d8a4f..ffe67a2a5 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -87,7 +87,7 @@ def get_display_name(entity): """ if isinstance(entity, types.User): if entity.last_name and entity.first_name: - return '{} {}'.format(entity.first_name, entity.last_name) + return f'{entity.first_name} {entity.last_name}' elif entity.first_name: return entity.first_name elif entity.last_name: @@ -128,8 +128,9 @@ def get_extension(media): def _raise_cast_fail(entity, target): - raise TypeError('Cannot cast {} to any kind of {}.'.format( - type(entity).__name__, target)) + raise TypeError( + f'Cannot cast {type(entity).__name__} to any kind of {target}.' + ) def get_input_peer(entity, allow_self=True, check_hash=True): @@ -469,18 +470,17 @@ def get_input_media( if isinstance(media, (types.InputFile, types.InputFileBig)): if is_photo: return types.InputMediaUploadedPhoto(file=media, ttl_seconds=ttl) - else: - attrs, mime = get_attributes( - media, - attributes=attributes, - force_document=force_document, - voice_note=voice_note, - video_note=video_note, - supports_streaming=supports_streaming - ) - return types.InputMediaUploadedDocument( - file=media, mime_type=mime, attributes=attrs, force_file=force_document, - ttl_seconds=ttl) + attrs, mime = get_attributes( + media, + attributes=attributes, + force_document=force_document, + voice_note=voice_note, + video_note=video_note, + supports_streaming=supports_streaming + ) + return types.InputMediaUploadedDocument( + file=media, mime_type=mime, attributes=attrs, force_file=force_document, + ttl_seconds=ttl) if isinstance(media, types.MessageMediaGame): return types.InputMediaGame(id=types.InputGameID( @@ -607,7 +607,7 @@ def get_message_id(message): except AttributeError: pass - raise TypeError('Invalid message type: {}'.format(type(message))) + raise TypeError(f'Invalid message type: {type(message)}') def _get_metadata(file): @@ -631,23 +631,18 @@ def _get_metadata(file): else: stream = file close_stream = False - if getattr(file, 'seekable', None): - seekable = file.seekable() - else: - seekable = False - + seekable = file.seekable() if getattr(file, 'seekable', None) else False if not seekable: return None pos = stream.tell() filename = getattr(file, 'name', '') - parser = hachoir.parser.guess.guessParser(hachoir.stream.InputIOStream( - stream, - source='file:' + filename, - tags=[], - filename=filename - )) + parser = hachoir.parser.guess.guessParser( + hachoir.stream.InputIOStream( + stream, source=f'file:{filename}', tags=[], filename=filename + ) + ) return hachoir.metadata.extractMetadata(parser) @@ -655,10 +650,12 @@ def _get_metadata(file): _log.warning('Failed to analyze %s: %s %s', file, e.__class__, e) finally: - if stream and close_stream: - stream.close() - elif stream and seekable: - stream.seek(pos) + if close_stream: + if stream: + stream.close() + elif seekable: + if stream: + stream.seek(pos) def get_attributes(file, *, attributes=None, mime_type=None, @@ -677,8 +674,7 @@ def get_attributes(file, *, attributes=None, mime_type=None, types.DocumentAttributeFilename(os.path.basename(name))} if is_audio(file): - m = _get_metadata(file) - if m: + if m := _get_metadata(file): if m.has('author'): performer = m.get('author') elif m.has('artist'): @@ -696,8 +692,7 @@ def get_attributes(file, *, attributes=None, mime_type=None, ) if not force_document and is_video(file): - m = _get_metadata(file) - if m: + if m := _get_metadata(file): doc = types.DocumentAttributeVideo( round_message=video_note, w=m.get('width') if m.has('width') else 1, @@ -708,13 +703,8 @@ def get_attributes(file, *, attributes=None, mime_type=None, ) elif thumb: t_m = _get_metadata(thumb) - width = 1 - height = 1 - if t_m and t_m.has("width"): - width = t_m.get("width") - if t_m and t_m.has("height"): - height = t_m.get("height") - + width = t_m.get("width") if t_m and t_m.has("width") else 1 + height = t_m.get("height") if t_m and t_m.has("height") else 1 doc = types.DocumentAttributeVideo( 0, width, height, round_message=video_note, supports_streaming=supports_streaming) @@ -776,9 +766,9 @@ def unparse(text, entities): 'html': html }[mode.lower()] except KeyError: - raise ValueError('Unknown parse mode {}'.format(mode)) + raise ValueError(f'Unknown parse mode {mode}') else: - raise TypeError('Invalid parse mode type {}'.format(mode)) + raise TypeError(f'Invalid parse mode type {mode}') def get_input_location(location): @@ -846,8 +836,9 @@ def is_image(file): """ Returns `True` if the file extension looks like an image file to Telegram. """ - match = re.match(r'\.(png|jpe?g)', _get_extension(file), re.IGNORECASE) - if match: + if match := re.match( + r'\.(png|jpe?g)', _get_extension(file), re.IGNORECASE + ): return True else: return isinstance(resolve_bot_file_id(file), types.Photo) @@ -862,30 +853,30 @@ def is_gif(file): def is_audio(file): """Returns `True` if the file has an audio mime type.""" - ext = _get_extension(file) - if not ext: - metadata = _get_metadata(file) - if metadata and metadata.has('mime_type'): - return metadata.get('mime_type').startswith('audio/') - else: - return False - else: - file = 'a' + ext + if ext := _get_extension(file): + file = f'a{ext}' return (mimetypes.guess_type(file)[0] or '').startswith('audio/') + else: + metadata = _get_metadata(file) + return ( + metadata.get('mime_type').startswith('audio/') + if metadata and metadata.has('mime_type') + else False + ) def is_video(file): """Returns `True` if the file has a video mime type.""" - ext = _get_extension(file) - if not ext: - metadata = _get_metadata(file) - if metadata and metadata.has('mime_type'): - return metadata.get('mime_type').startswith('video/') - else: - return False - else: - file = 'a' + ext + if ext := _get_extension(file): + file = f'a{ext}' return (mimetypes.guess_type(file)[0] or '').startswith('video/') + else: + metadata = _get_metadata(file) + return ( + metadata.get('mime_type').startswith('video/') + if metadata and metadata.has('mime_type') + else False + ) def is_list_like(obj): @@ -903,10 +894,9 @@ def parse_phone(phone): """Parses the given phone, or returns `None` if it's invalid.""" if isinstance(phone, int): return str(phone) - else: - phone = re.sub(r'[+()\s-]', '', str(phone)) - if phone.isdigit(): - return phone + phone = re.sub(r'[+()\s-]', '', str(phone)) + if phone.isdigit(): + return phone def parse_username(username): @@ -919,8 +909,7 @@ def parse_username(username): Returns ``(None, False)`` if the ``username`` or link is not valid. """ username = username.strip() - m = USERNAME_RE.match(username) or TG_JOIN_RE.match(username) - if m: + if m := USERNAME_RE.match(username) or TG_JOIN_RE.match(username): username = username[m.end():] is_invite = bool(m.group(1)) if is_invite: @@ -1028,11 +1017,7 @@ def get_peer_id(peer, add_mark=True): if not (0 < peer.channel_id <= 9999999999): peer.channel_id = resolve_id(peer.channel_id)[0] - if not add_mark: - return peer.channel_id - - # Growing backwards from -100_0000_000_000 indicates it's a channel - return -(1000000000000 + peer.channel_id) + return peer.channel_id if not add_mark else -(1000000000000 + peer.channel_id) def resolve_id(marked_id): @@ -1041,11 +1026,10 @@ def resolve_id(marked_id): return marked_id, types.PeerUser marked_id = -marked_id - if marked_id > 1000000000000: - marked_id -= 1000000000000 - return marked_id, types.PeerChannel - else: + if marked_id <= 1000000000000: return marked_id, types.PeerChat + marked_id -= 1000000000000 + return marked_id, types.PeerChannel def _rle_decode(data): @@ -1148,19 +1132,18 @@ def resolve_bot_file_id(file_id): return None attributes = [] - if file_type == 3 or file_type == 9: + if file_type in [3, 9]: attributes.append(types.DocumentAttributeAudio( duration=0, voice=file_type == 3 )) - elif file_type == 4 or file_type == 13: + elif file_type in [4, 13]: attributes.append(types.DocumentAttributeVideo( duration=0, w=0, h=0, round_message=file_type == 13 )) - # elif file_type == 5: # other, cannot know which elif file_type == 8: attributes.append(types.DocumentAttributeSticker( alt='', @@ -1180,7 +1163,11 @@ def resolve_bot_file_id(file_id): attributes=attributes, file_reference=b'' ) - elif (version == 2 and len(data) == 44) or (version == 4 and len(data) in (49, 77)): + elif ( + (version == 2 and len(data) == 44) + or version == 4 + and len(data) in {49, 77} + ): if version == 2: (file_type, dc_id, media_id, access_hash, volume_id, secret, local_id) = struct.unpack('LQ', payload)) elif len(payload) == 16: return struct.unpack('>LLQ', payload) - else: - pass except (struct.error, TypeError): pass return None, None, None @@ -1329,9 +1314,7 @@ def get_appropriated_part_size(file_size): """ if file_size <= 104857600: # 100MB return 128 - if file_size <= 786432000: # 750MB - return 256 - return 512 + return 256 if file_size <= 786432000 else 512 def encode_waveform(waveform): @@ -1505,10 +1488,7 @@ async def wrapper(*args, **kwargs): val = w(*args, **kwargs) return await val if inspect.isawaitable(val) else val - if callable(w): - return wrapper - else: - return w + return wrapper if callable(w) else w def stripped_photo_to_jpg(stripped): diff --git a/telethon_generator/generators/tlobject.py b/telethon_generator/generators/tlobject.py index f7ce8d4b1..bea4b071a 100644 --- a/telethon_generator/generators/tlobject.py +++ b/telethon_generator/generators/tlobject.py @@ -56,8 +56,8 @@ def _write_modules( # namespace_tlobjects: {'namespace', [TLObject]} out_dir.mkdir(parents=True, exist_ok=True) for ns, tlobjects in namespace_tlobjects.items(): - file = out_dir / '{}.py'.format(ns or '__init__') - with file.open('w') as f, SourceBuilder(f) as builder: + file = out_dir / f"{ns or '__init__'}.py" + with (file.open('w') as f, SourceBuilder(f) as builder): builder.writeln(AUTO_GEN_NOTICE) builder.writeln('from {}.tl.tlobject import TLObject', '.' * depth) @@ -88,9 +88,9 @@ def _write_modules( tlobjects.sort(key=lambda x: x.name) - type_names = set() type_defs = [] + type_names = set() # Find all the types in this file and generate type definitions # based on the types. The type definitions are written to the # file at the end. @@ -106,12 +106,11 @@ def _write_modules( if not constructors: pass elif len(constructors) == 1: - type_defs.append('Type{} = {}'.format( - type_name, constructors[0].class_name)) + type_defs.append(f'Type{type_name} = {constructors[0].class_name}') else: - type_defs.append('Type{} = Union[{}]'.format( - type_name, ','.join(c.class_name - for c in constructors))) + type_defs.append( + f"Type{type_name} = Union[{','.join(c.class_name for c in constructors)}]" + ) imports = {} primitives = {'int', 'long', 'int128', 'int256', 'double', @@ -124,11 +123,11 @@ def _write_modules( if not name or name in primitives: continue - import_space = '{}.tl.types'.format('.' * depth) + import_space = f"{'.' * depth}.tl.types" if '.' in name: namespace = name.split('.')[0] name = name.split('.')[1] - import_space += '.{}'.format(namespace) + import_space += f'.{namespace}' if name not in type_names: type_names.add(name) @@ -137,7 +136,7 @@ def _write_modules( continue elif import_space not in imports: imports[import_space] = set() - imports[import_space].add('Type{}'.format(name)) + imports[import_space].add(f'Type{name}') # Add imports required for type checking if imports: @@ -188,8 +187,8 @@ def _write_class_init(tlobject, kind, type_constructors, builder): builder.writeln() # Convert the args to string parameters, those with flag having =None - args = ['{}: {}{}'.format( - a.name, a.type_hint(), '=None' if a.flag or a.can_be_inferred else '') + args = [ + f"{a.name}: {a.type_hint()}{'=None' if a.flag or a.can_be_inferred else ''}" for a in tlobject.real_args ] @@ -224,12 +223,9 @@ def _write_class_init(tlobject, kind, type_constructors, builder): if not arg.can_be_inferred: builder.writeln('self.{0} = {0}', arg.name) - # Currently the only argument that can be - # inferred are those called 'random_id' elif arg.name == 'random_id': # Endianness doesn't really matter, and 'big' is shorter - code = "int.from_bytes(os.urandom({}), 'big', signed=True)" \ - .format(8 if arg.type == 'long' else 4) + code = f"int.from_bytes(os.urandom({8 if arg.type == 'long' else 4}), 'big', signed=True)" if arg.is_vector: # Currently for the case of "messages.forwardMessages" @@ -239,7 +235,7 @@ def _write_class_init(tlobject, kind, type_constructors, builder): raise ValueError( 'Cannot infer list of random ids for ', tlobject ) - code = '[{} for _ in range(len(id))]'.format(code) + code = f'[{code} for _ in range(len(id))]' builder.writeln( "self.random_id = random_id if random_id " @@ -252,36 +248,40 @@ def _write_class_init(tlobject, kind, type_constructors, builder): def _write_resolve(tlobject, builder): - if tlobject.is_function and any( - (arg.type in AUTO_CASTS - or ((arg.name, arg.type) in NAMED_AUTO_CASTS - and tlobject.fullname not in NAMED_BLACKLIST)) - for arg in tlobject.real_args + if not tlobject.is_function or not any( + ( + arg.type in AUTO_CASTS + or ( + (arg.name, arg.type) in NAMED_AUTO_CASTS + and tlobject.fullname not in NAMED_BLACKLIST + ) + ) + for arg in tlobject.real_args ): - builder.writeln('async def resolve(self, client, utils):') - for arg in tlobject.real_args: - ac = AUTO_CASTS.get(arg.type) - if not ac: - ac = NAMED_AUTO_CASTS.get((arg.name, arg.type)) - if not ac: - continue + return + builder.writeln('async def resolve(self, client, utils):') + for arg in tlobject.real_args: + ac = AUTO_CASTS.get(arg.type) + if not ac: + ac = NAMED_AUTO_CASTS.get((arg.name, arg.type)) + if not ac: + continue - if arg.flag: - builder.writeln('if self.{}:', arg.name) + if arg.flag: + builder.writeln('if self.{}:', arg.name) - if arg.is_vector: - builder.writeln('_tmp = []') - builder.writeln('for _x in self.{0}:', arg.name) - builder.writeln('_tmp.append({})', ac.format('_x')) - builder.end_block() - builder.writeln('self.{} = _tmp', arg.name) - else: - builder.writeln('self.{} = {}', arg.name, - ac.format('self.' + arg.name)) + if arg.is_vector: + builder.writeln('_tmp = []') + builder.writeln('for _x in self.{0}:', arg.name) + builder.writeln('_tmp.append({})', ac.format('_x')) + builder.end_block() + builder.writeln('self.{} = _tmp', arg.name) + else: + builder.writeln('self.{} = {}', arg.name, ac.format(f'self.{arg.name}')) - if arg.flag: - builder.end_block() - builder.end_block() + if arg.flag: + builder.end_block() + builder.end_block() def _write_to_dict(tlobject, builder): @@ -299,19 +299,18 @@ def _write_to_dict(tlobject, builder): arg.name) else: builder.write('self.{}', arg.name) + elif arg.is_vector: + builder.write( + '[] if self.{0} is None else [x.to_dict() ' + 'if isinstance(x, TLObject) else x for x in self.{0}]', + arg.name + ) else: - if arg.is_vector: - builder.write( - '[] if self.{0} is None else [x.to_dict() ' - 'if isinstance(x, TLObject) else x for x in self.{0}]', - arg.name - ) - else: - builder.write( - 'self.{0}.to_dict() ' - 'if isinstance(self.{0}, TLObject) else self.{0}', - arg.name - ) + builder.write( + 'self.{0}.to_dict() ' + 'if isinstance(self.{0}, TLObject) else self.{0}', + arg.name + ) builder.writeln() builder.current_indent -= 1 @@ -362,7 +361,7 @@ def _write_from_reader(tlobject, builder): builder.writeln('@classmethod') builder.writeln('def from_reader(cls, reader):') for arg in tlobject.args: - _write_arg_read_code(builder, arg, tlobject, name='_' + arg.name) + _write_arg_read_code(builder, arg, tlobject, name=f'_{arg.name}') builder.writeln('return cls({})', ', '.join( '{0}=_{0}'.format(a.name) for a in tlobject.real_args)) @@ -549,7 +548,7 @@ def _write_arg_read_code(builder, arg, tlobject, name): if arg.flag: # Treat 'true' flags as a special case, since they're true if # they're set, and nothing else needs to actually be read. - if 'true' == arg.type: + if arg.type == 'true': builder.writeln('{} = bool({} & {})', name, arg.flag, 1 << arg.flag_index) return @@ -577,7 +576,7 @@ def _write_arg_read_code(builder, arg, tlobject, name): builder.writeln('{} = reader.read_int()', arg.name) builder.writeln() - elif 'int' == arg.type: + elif arg.type == 'int': # User IDs are becoming larger than 2³¹ - 1, which would translate # into reading a negative ID, which we would treat as a chat. So # special case them to read unsigned. See https://t.me/BotNews/57. @@ -586,58 +585,57 @@ def _write_arg_read_code(builder, arg, tlobject, name): else: builder.writeln('{} = reader.read_int()', name) - elif 'long' == arg.type: + elif arg.type == 'long': builder.writeln('{} = reader.read_long()', name) - elif 'int128' == arg.type: + elif arg.type == 'int128': builder.writeln('{} = reader.read_large_int(bits=128)', name) - elif 'int256' == arg.type: + elif arg.type == 'int256': builder.writeln('{} = reader.read_large_int(bits=256)', name) - elif 'double' == arg.type: + elif arg.type == 'double': builder.writeln('{} = reader.read_double()', name) - elif 'string' == arg.type: + elif arg.type == 'string': builder.writeln('{} = reader.tgread_string()', name) - elif 'Bool' == arg.type: + elif arg.type == 'Bool': builder.writeln('{} = reader.tgread_bool()', name) - elif 'true' == arg.type: + elif arg.type == 'true': # Arbitrary not-None value, don't actually read "true" flags builder.writeln('{} = True', name) - elif 'bytes' == arg.type: + elif arg.type == 'bytes': builder.writeln('{} = reader.tgread_bytes()', name) - elif 'date' == arg.type: # Custom format + elif arg.type == 'date': # Custom format builder.writeln('{} = reader.tgread_date()', name) - else: - # Else it may be a custom type - if not arg.skip_constructor_id: - builder.writeln('{} = reader.tgread_object()', name) - else: - # Import the correct type inline to avoid cyclic imports. - # There may be better solutions so that we can just access - # all the types before the files have been parsed, but I - # don't know of any. - sep_index = arg.type.find('.') - if sep_index == -1: - ns, t = '.', arg.type - else: - ns, t = '.' + arg.type[:sep_index], arg.type[sep_index+1:] - class_name = snake_to_camel_case(t) - - # There would be no need to import the type if we're in the - # file with the same namespace, but since it does no harm - # and we don't have information about such thing in the - # method we just ignore that case. - builder.writeln('from {} import {}', ns, class_name) - builder.writeln('{} = {}.from_reader(reader)', - name, class_name) + elif arg.skip_constructor_id: + # Import the correct type inline to avoid cyclic imports. + # There may be better solutions so that we can just access + # all the types before the files have been parsed, but I + # don't know of any. + sep_index = arg.type.find('.') + ns, t = ( + ('.', arg.type) + if sep_index == -1 + else (f'.{arg.type[:sep_index]}', arg.type[sep_index + 1 :]) + ) + class_name = snake_to_camel_case(t) + + # There would be no need to import the type if we're in the + # file with the same namespace, but since it does no harm + # and we don't have information about such thing in the + # method we just ignore that case. + builder.writeln('from {} import {}', ns, class_name) + builder.writeln('{} = {}.from_reader(reader)', + name, class_name) + else: + builder.writeln('{} = reader.tgread_object()', name) # End vector and flag blocks if required (if we opened them before) if arg.is_vector: builder.end_block()