diff --git a/src/chat_sdk/__init__.py b/src/chat_sdk/__init__.py index fbf57dd..c1b6172 100644 --- a/src/chat_sdk/__init__.py +++ b/src/chat_sdk/__init__.py @@ -63,14 +63,6 @@ resolve_emoji_from_slack, ) from chat_sdk.errors import ChatError, ChatNotImplementedError, LockError, RateLimitError -from chat_sdk.shared.errors import ( - AdapterRateLimitError, - AuthenticationError, - NetworkError, - PermissionError as AdapterPermissionError, - ResourceNotFoundError, - ValidationError, -) from chat_sdk.from_full_stream import from_full_stream from chat_sdk.logger import ConsoleLogger, Logger, LogLevel from chat_sdk.message_history import MessageHistoryCache, MessageHistoryConfig @@ -95,6 +87,16 @@ text_input, ) from chat_sdk.shared.base_format_converter import BaseFormatConverter +from chat_sdk.shared.errors import ( + AdapterRateLimitError, + AuthenticationError, + NetworkError, + ResourceNotFoundError, + ValidationError, +) +from chat_sdk.shared.errors import ( + PermissionError as AdapterPermissionError, +) from chat_sdk.shared.streaming_markdown import StreamingMarkdownRenderer from chat_sdk.state.memory import MemoryStateAdapter from chat_sdk.thread import ThreadImpl diff --git a/src/chat_sdk/adapters/discord/adapter.py b/src/chat_sdk/adapters/discord/adapter.py index d8b72f7..ce4da92 100644 --- a/src/chat_sdk/adapters/discord/adapter.py +++ b/src/chat_sdk/adapters/discord/adapter.py @@ -13,7 +13,7 @@ import os import re from contextvars import ContextVar -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from urllib.parse import quote @@ -275,7 +275,9 @@ async def _verify_signature( verify_key.verify(message, bytes.fromhex(signature)) return True except ImportError: - self._logger.error("PyNaCl is required for Discord signature verification. Install with: pip install PyNaCl") + self._logger.error( + "PyNaCl is required for Discord signature verification. Install with: pip install PyNaCl" + ) return False except Exception as exc: self._logger.warn("Discord signature verification failed", {"error": str(exc)}) @@ -320,7 +322,9 @@ def _handle_component_interaction( channel_type = channel.get("type", 0) is_thread = channel_type in (CHANNEL_TYPE_PUBLIC_THREAD, CHANNEL_TYPE_PRIVATE_THREAD) parent_channel_id = ( - channel.get("parent_id", interaction_channel_id) if is_thread and channel.get("parent_id") else interaction_channel_id + channel.get("parent_id", interaction_channel_id) + if is_thread and channel.get("parent_id") + else interaction_channel_id ) thread_id = self.encode_thread_id( @@ -391,7 +395,9 @@ def _handle_application_command_interaction( channel_type = channel.get("type", 0) is_thread = channel_type in (CHANNEL_TYPE_PUBLIC_THREAD, CHANNEL_TYPE_PRIVATE_THREAD) parent_channel_id = ( - channel.get("parent_id", interaction_channel_id) if is_thread and channel.get("parent_id") else interaction_channel_id + channel.get("parent_id", interaction_channel_id) + if is_thread and channel.get("parent_id") + else interaction_channel_id ) channel_id = self.encode_thread_id( @@ -534,7 +540,9 @@ async def _handle_forwarded_message( mentions = data.get("mentions", []) is_user_mentioned = data.get("is_mention", False) or any(m.get("id") == self._application_id for m in mentions) mention_roles = data.get("mention_roles", []) - is_role_mentioned = bool(self._mention_role_ids) and any(role_id in self._mention_role_ids for role_id in mention_roles) + is_role_mentioned = bool(self._mention_role_ids) and any( + role_id in self._mention_role_ids for role_id in mention_roles + ) is_mentioned = is_user_mentioned or is_role_mentioned # If mentioned and not in a thread, create one @@ -578,7 +586,7 @@ async def _handle_forwarded_message( metadata=MessageMetadata( date_sent=datetime.fromisoformat(data.get("timestamp", "")) if data.get("timestamp") - else datetime.now(timezone.utc), + else datetime.now(UTC), edited=False, ), attachments=[ @@ -1158,7 +1166,7 @@ def _parse_discord_message(self, raw: dict[str, Any], thread_id: str) -> Message is_me=is_me, ), metadata=MessageMetadata( - date_sent=datetime.fromisoformat(msg["timestamp"]) if msg.get("timestamp") else datetime.now(timezone.utc), + date_sent=datetime.fromisoformat(msg["timestamp"]) if msg.get("timestamp") else datetime.now(UTC), edited=msg.get("edited_timestamp") is not None, edited_at=datetime.fromisoformat(msg["edited_timestamp"]) if msg.get("edited_timestamp") else None, ), @@ -1203,7 +1211,7 @@ async def _create_discord_thread( message_id: str, ) -> dict[str, str]: """Create a Discord thread from a message.""" - thread_name = f"Thread {datetime.now(timezone.utc).isoformat()}" + thread_name = f"Thread {datetime.now(UTC).isoformat()}" self._logger.debug( "Discord API: POST thread", diff --git a/src/chat_sdk/adapters/github/adapter.py b/src/chat_sdk/adapters/github/adapter.py index f84a683..5fb2ab9 100644 --- a/src/chat_sdk/adapters/github/adapter.py +++ b/src/chat_sdk/adapters/github/adapter.py @@ -15,7 +15,7 @@ import re import time from collections.abc import AsyncIterable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.adapters.github.cards import card_to_github_markdown @@ -148,7 +148,8 @@ def __init__(self, config: GitHubAdapterConfig | None = None) -> None: else: self._app_credentials = {"app_id": app_id, "private_key": private_key} self._logger.info( - "GitHub adapter initialized in multi-tenant mode (installation ID will be extracted from webhooks)" + "GitHub adapter initialized in multi-tenant mode " + "(installation ID will be extracted from webhooks)" ) else: raise ValidationError( @@ -366,7 +367,7 @@ def _parse_issue_comment( metadata=MessageMetadata( date_sent=datetime.fromisoformat(created_at.replace("Z", "+00:00")) if created_at - else datetime.now(tz=timezone.utc), + else datetime.now(tz=UTC), edited=edited, edited_at=datetime.fromisoformat(updated_at.replace("Z", "+00:00")) if edited and updated_at else None, ), @@ -408,7 +409,7 @@ def _parse_review_comment( metadata=MessageMetadata( date_sent=datetime.fromisoformat(created_at.replace("Z", "+00:00")) if created_at - else datetime.now(tz=timezone.utc), + else datetime.now(tz=UTC), edited=edited, edited_at=datetime.fromisoformat(updated_at.replace("Z", "+00:00")) if edited and updated_at else None, ), @@ -682,7 +683,10 @@ async def fetch_thread(self, thread_id: str) -> ThreadInfo: def encode_thread_id(self, platform_data: GitHubThreadId) -> str: """Encode platform data into a thread ID string.""" if platform_data.review_comment_id: - return f"github:{platform_data.owner}/{platform_data.repo}:{platform_data.pr_number}:rc:{platform_data.review_comment_id}" + return ( + f"github:{platform_data.owner}/{platform_data.repo}" + f":{platform_data.pr_number}:rc:{platform_data.review_comment_id}" + ) return f"github:{platform_data.owner}/{platform_data.repo}:{platform_data.pr_number}" def decode_thread_id(self, thread_id: str) -> GitHubThreadId: @@ -762,7 +766,7 @@ async def list_threads( metadata=MessageMetadata( date_sent=datetime.fromisoformat(pr.get("created_at", "").replace("Z", "+00:00")) if pr.get("created_at") - else datetime.now(tz=timezone.utc), + else datetime.now(tz=UTC), edited=pr.get("created_at") != pr.get("updated_at"), ), ) diff --git a/src/chat_sdk/adapters/github/cards.py b/src/chat_sdk/adapters/github/cards.py index 3c4b93f..33c2df2 100644 --- a/src/chat_sdk/adapters/github/cards.py +++ b/src/chat_sdk/adapters/github/cards.py @@ -152,7 +152,8 @@ def _render_text(text: dict[str, Any]) -> list[str]: def _render_fields(fields: dict[str, Any]) -> list[str]: """Render fields as key-value pairs.""" return [ - f"**{_escape_markdown(f.get('label', ''))}:** {_escape_markdown(f.get('value', ''))}" for f in fields.get("children", []) + f"**{_escape_markdown(f.get('label', ''))}:** {_escape_markdown(f.get('value', ''))}" + for f in fields.get("children", []) ] diff --git a/src/chat_sdk/adapters/github/types.py b/src/chat_sdk/adapters/github/types.py index 428814f..7a4d3a9 100644 --- a/src/chat_sdk/adapters/github/types.py +++ b/src/chat_sdk/adapters/github/types.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, TypedDict, Union +from typing import Literal, TypedDict from chat_sdk.logger import Logger @@ -81,12 +81,9 @@ class GitHubAdapterAutoConfig(GitHubAdapterBaseConfig, total=False): # Union of all configuration types -GitHubAdapterConfig = Union[ - GitHubAdapterPATConfig, - GitHubAdapterAppConfig, - GitHubAdapterMultiTenantAppConfig, - GitHubAdapterAutoConfig, -] +GitHubAdapterConfig = ( + GitHubAdapterPATConfig | GitHubAdapterAppConfig | GitHubAdapterMultiTenantAppConfig | GitHubAdapterAutoConfig +) # ============================================================================= # Thread ID @@ -290,7 +287,7 @@ class GitHubRawReviewComment(TypedDict): pr_number: int -GitHubRawMessage = Union[GitHubRawIssueComment, GitHubRawReviewComment] +GitHubRawMessage = GitHubRawIssueComment | GitHubRawReviewComment # ============================================================================= # GitHub API Response Types diff --git a/src/chat_sdk/adapters/google_chat/adapter.py b/src/chat_sdk/adapters/google_chat/adapter.py index 8a2dec2..f0cfcd7 100644 --- a/src/chat_sdk/adapters/google_chat/adapter.py +++ b/src/chat_sdk/adapters/google_chat/adapter.py @@ -16,7 +16,7 @@ import re import time from collections.abc import AsyncIterable, Awaitable, Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.adapters.google_chat.cards import card_to_google_card @@ -135,7 +135,9 @@ def __init__(self, config: GoogleChatAdapterConfig | None = None) -> None: self._pubsub_topic = config.pubsub_topic or os.environ.get("GOOGLE_CHAT_PUBSUB_TOPIC") self._impersonate_user = config.impersonate_user or os.environ.get("GOOGLE_CHAT_IMPERSONATE_USER") self._endpoint_url = config.endpoint_url - self._google_chat_project_number = config.google_chat_project_number or os.environ.get("GOOGLE_CHAT_PROJECT_NUMBER") + self._google_chat_project_number = config.google_chat_project_number or os.environ.get( + "GOOGLE_CHAT_PROJECT_NUMBER" + ) self._pubsub_audience = config.pubsub_audience or os.environ.get("GOOGLE_CHAT_PUBSUB_AUDIENCE") # In-progress subscription creations to prevent duplicate requests @@ -433,7 +435,9 @@ async def on_thread_subscribe(self, thread_id: str) -> None: ) if not self._pubsub_topic: - self._logger.warn("No pubsubTopic configured, skipping space subscription. Set GOOGLE_CHAT_PUBSUB_TOPIC env var.") + self._logger.warn( + "No pubsubTopic configured, skipping space subscription. Set GOOGLE_CHAT_PUBSUB_TOPIC env var." + ) return decoded = self.decode_thread_id(thread_id) @@ -470,7 +474,9 @@ async def _ensure_space_subscription(self, space_name: str) -> None: # Check if we already have a valid subscription cached = await self._state.get(cache_key) if cached: - expire_time = cached.get("expire_time", 0) if isinstance(cached, dict) else getattr(cached, "expire_time", 0) + expire_time = ( + cached.get("expire_time", 0) if isinstance(cached, dict) else getattr(cached, "expire_time", 0) + ) time_until_expiry = expire_time - int(time.time() * 1000) if time_until_expiry > SUBSCRIPTION_REFRESH_BUFFER_MS: self._logger.debug( @@ -2270,7 +2276,7 @@ async def list_threads( last_reply_at=last_reply_at, ) ) - count += 1 + count += 1 # noqa: SIM113 self._logger.debug( "GChat API: listThreads result", @@ -2677,9 +2683,9 @@ def _parse_message_metadata(message: dict[str, Any]) -> Any: try: date_sent = datetime.fromisoformat(create_time.replace("Z", "+00:00")) except (ValueError, AttributeError): - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) else: - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) return MessageMetadata( date_sent=date_sent, diff --git a/src/chat_sdk/adapters/google_chat/cards.py b/src/chat_sdk/adapters/google_chat/cards.py index fa3c7d1..36e1070 100644 --- a/src/chat_sdk/adapters/google_chat/cards.py +++ b/src/chat_sdk/adapters/google_chat/cards.py @@ -16,7 +16,8 @@ card_child_to_fallback_text, table_element_to_ascii, ) -from chat_sdk.shared import card_to_fallback_text as shared_card_to_fallback_text, create_emoji_converter +from chat_sdk.shared import card_to_fallback_text as shared_card_to_fallback_text +from chat_sdk.shared import create_emoji_converter # Convert emoji placeholders in text to GChat format (Unicode). convert_emoji = create_emoji_converter("gchat") diff --git a/src/chat_sdk/adapters/linear/adapter.py b/src/chat_sdk/adapters/linear/adapter.py index d12d384..6217b6c 100644 --- a/src/chat_sdk/adapters/linear/adapter.py +++ b/src/chat_sdk/adapters/linear/adapter.py @@ -15,7 +15,7 @@ import os import re import time -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.adapters.linear.cards import card_to_linear_markdown @@ -414,7 +414,7 @@ def _build_message( raw=LinearRawMessage(comment=comment), author=author, metadata=MessageMetadata( - date_sent=datetime.fromisoformat(created_at) if created_at else datetime.now(timezone.utc), + date_sent=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC), edited=created_at != updated_at, edited_at=datetime.fromisoformat(updated_at) if (created_at != updated_at and updated_at) else None, ), @@ -432,10 +432,7 @@ async def post_message( # Render message to markdown card = extract_card(message) - if card: - body = card_to_linear_markdown(card) - else: - body = self._format_converter.render_postable(message) + body = card_to_linear_markdown(card) if card else self._format_converter.render_postable(message) # Convert emoji placeholders to unicode body = convert_emoji_placeholders(body, "linear") @@ -496,10 +493,7 @@ async def edit_message( decoded = self.decode_thread_id(thread_id) card = extract_card(message) - if card: - body = card_to_linear_markdown(card) - else: - body = self._format_converter.render_postable(message) + body = card_to_linear_markdown(card) if card else self._format_converter.render_postable(message) body = convert_emoji_placeholders(body, "linear") @@ -754,7 +748,7 @@ def _comment_node_to_message( is_me=user_id == self._bot_user_id, ), metadata=MessageMetadata( - date_sent=datetime.fromisoformat(node["createdAt"]) if node.get("createdAt") else datetime.now(timezone.utc), + date_sent=datetime.fromisoformat(node["createdAt"]) if node.get("createdAt") else datetime.now(UTC), edited=node.get("createdAt") != node.get("updatedAt"), edited_at=( datetime.fromisoformat(node["updatedAt"]) @@ -854,7 +848,7 @@ def parse_message(self, raw: LinearRawMessage) -> Message: ), metadata=MessageMetadata( date_sent=( - datetime.fromisoformat(comment["created_at"]) if comment.get("created_at") else datetime.now(timezone.utc) + datetime.fromisoformat(comment["created_at"]) if comment.get("created_at") else datetime.now(UTC) ), edited=comment.get("created_at") != comment.get("updated_at"), edited_at=( diff --git a/src/chat_sdk/adapters/linear/cards.py b/src/chat_sdk/adapters/linear/cards.py index a7958f8..7f7d938 100644 --- a/src/chat_sdk/adapters/linear/cards.py +++ b/src/chat_sdk/adapters/linear/cards.py @@ -123,7 +123,8 @@ def _render_text(text: TextElement) -> list[str]: def _render_fields(fields: FieldsElement) -> list[str]: """Render fields as key-value pairs.""" return [ - f"**{_escape_markdown(f.get('label', ''))}:** {_escape_markdown(f.get('value', ''))}" for f in fields.get("children", []) + f"**{_escape_markdown(f.get('label', ''))}:** {_escape_markdown(f.get('value', ''))}" + for f in fields.get("children", []) ] diff --git a/src/chat_sdk/adapters/linear/types.py b/src/chat_sdk/adapters/linear/types.py index 561b95f..b585d14 100644 --- a/src/chat_sdk/adapters/linear/types.py +++ b/src/chat_sdk/adapters/linear/types.py @@ -66,7 +66,9 @@ class LinearAdapterAppConfig(LinearAdapterBaseConfig): # Union type for all config options -LinearAdapterConfig = LinearAdapterBaseConfig | LinearAdapterAPIKeyConfig | LinearAdapterOAuthConfig | LinearAdapterAppConfig +LinearAdapterConfig = ( + LinearAdapterBaseConfig | LinearAdapterAPIKeyConfig | LinearAdapterOAuthConfig | LinearAdapterAppConfig +) # ============================================================================= diff --git a/src/chat_sdk/adapters/slack/adapter.py b/src/chat_sdk/adapters/slack/adapter.py index 480ca2c..7482e11 100644 --- a/src/chat_sdk/adapters/slack/adapter.py +++ b/src/chat_sdk/adapters/slack/adapter.py @@ -18,7 +18,7 @@ import time from collections.abc import AsyncIterable, Awaitable, Callable from contextvars import ContextVar -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from urllib.parse import parse_qs @@ -164,7 +164,9 @@ class SlackAdapter: def __init__(self, config: SlackAdapterConfig | None = None) -> None: # ContextVar replaces Node AsyncLocalStorage for per-request token context. # Created per-instance so multiple SlackAdapter instances don't share state. - self._request_context: ContextVar[RequestContext | None] = ContextVar(f"slack_request_context_{id(self)}", default=None) + self._request_context: ContextVar[RequestContext | None] = ContextVar( + f"slack_request_context_{id(self)}", default=None + ) if config is None: config = SlackAdapterConfig() @@ -197,7 +199,9 @@ def __init__(self, config: SlackAdapterConfig | None = None) -> None: # Multi-workspace OAuth fields self._client_id: str | None = config.client_id or (os.environ.get("SLACK_CLIENT_ID") if zero_config else None) - self._client_secret: str | None = config.client_secret or (os.environ.get("SLACK_CLIENT_SECRET") if zero_config else None) + self._client_secret: str | None = config.client_secret or ( + os.environ.get("SLACK_CLIENT_SECRET") if zero_config else None + ) self._installation_key_prefix = config.installation_key_prefix or "slack:installation" encryption_key_raw = config.encryption_key or os.environ.get("SLACK_ENCRYPTION_KEY") @@ -387,7 +391,9 @@ async def handle_oauth_callback(self, request: Any) -> dict[str, Any]: query = dict(parse_qs(url.split("?", 1)[1])) code = query.get("code", [None])[0] if isinstance(query.get("code"), list) else query.get("code") redirect_uri = ( - query.get("redirect_uri", [None])[0] if isinstance(query.get("redirect_uri"), list) else query.get("redirect_uri") + query.get("redirect_uri", [None])[0] + if isinstance(query.get("redirect_uri"), list) + else query.get("redirect_uri") ) else: code = None @@ -507,7 +513,11 @@ async def _lookup_user(self, user_id: str) -> dict[str, str]: profile = user.get("profile", {}) display_name = ( - profile.get("display_name") or profile.get("real_name") or user.get("real_name") or user.get("name") or user_id + profile.get("display_name") + or profile.get("real_name") + or user.get("real_name") + or user.get("name") + or user_id ) real_name = user.get("real_name") or profile.get("real_name") or display_name @@ -814,7 +824,9 @@ def _handle_block_actions(self, payload: dict[str, Any], options: WebhookOptions channel = (payload.get("channel") or {}).get("id") or (payload.get("container") or {}).get("channel_id") message_ts = (payload.get("message") or {}).get("ts") or (payload.get("container") or {}).get("message_ts") thread_ts = ( - (payload.get("message") or {}).get("thread_ts") or (payload.get("container") or {}).get("thread_ts") or message_ts + (payload.get("message") or {}).get("thread_ts") + or (payload.get("container") or {}).get("thread_ts") + or message_ts ) is_view_action = (payload.get("container") or {}).get("type") == "view" @@ -872,7 +884,9 @@ def _handle_block_actions(self, payload: dict[str, Any], options: WebhookOptions # View submission / close # ================================================================== - async def _handle_view_submission(self, payload: dict[str, Any], options: WebhookOptions | None = None) -> dict[str, Any]: + async def _handle_view_submission( + self, payload: dict[str, Any], options: WebhookOptions | None = None + ) -> dict[str, Any]: if not self._chat: self._logger.warn("Chat instance not initialized, ignoring view submission") return {"body": "", "status": 200} @@ -884,7 +898,9 @@ async def _handle_view_submission(self, payload: dict[str, Any], options: Webhoo values: dict[str, str] = {} for block_values in state_values.values(): for action_id, input_val in block_values.items(): - values[action_id] = input_val.get("value") or (input_val.get("selected_option") or {}).get("value") or "" + values[action_id] = ( + input_val.get("value") or (input_val.get("selected_option") or {}).get("value") or "" + ) meta = decode_modal_metadata(view.get("private_metadata") or None) user_ref = payload.get("user", {}) @@ -1077,7 +1093,9 @@ async def _resolve_and_process() -> None: except RuntimeError: return # No running event loop task.add_done_callback( - lambda t: self._logger.error("Reaction resolve error", {"error": str(t.exception())}) if t.exception() else None + lambda t: ( + self._logger.error("Reaction resolve error", {"error": str(t.exception())}) if t.exception() else None + ) ) if options and options.wait_until: options.wait_until(task) @@ -1399,7 +1417,9 @@ def replace_mention(match: re.Match[str]) -> str: return mention_pattern.sub(replace_mention, text) - async def _resolve_message_mentions(self, message: AdapterPostableMessage, thread_id: str) -> AdapterPostableMessage: + async def _resolve_message_mentions( + self, message: AdapterPostableMessage, thread_id: str + ) -> AdapterPostableMessage: """Pre-process outgoing message to resolve @name mentions.""" if not self._chat: return message @@ -1507,14 +1527,14 @@ async def _parse_slack_message( ts_str = event.get("ts", "0") try: - date_sent = datetime.fromtimestamp(float(ts_str), tz=timezone.utc) + date_sent = datetime.fromtimestamp(float(ts_str), tz=UTC) except (ValueError, TypeError, OSError): - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) edited_at: datetime | None = None if event.get("edited"): try: - edited_at = datetime.fromtimestamp(float(event["edited"].get("ts", "0")), tz=timezone.utc) + edited_at = datetime.fromtimestamp(float(event["edited"].get("ts", "0")), tz=UTC) except (ValueError, TypeError, OSError): edited_at = None @@ -1549,14 +1569,14 @@ def _parse_slack_message_sync(self, event: dict[str, Any], thread_id: str) -> Me ts_str = event.get("ts", "0") try: - date_sent = datetime.fromtimestamp(float(ts_str), tz=timezone.utc) + date_sent = datetime.fromtimestamp(float(ts_str), tz=UTC) except (ValueError, TypeError, OSError): - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) edited_at: datetime | None = None if event.get("edited"): try: - edited_at = datetime.fromtimestamp(float(event["edited"].get("ts", "0")), tz=timezone.utc) + edited_at = datetime.fromtimestamp(float(event["edited"].get("ts", "0")), tz=UTC) except (ValueError, TypeError, OSError): edited_at = None @@ -2296,8 +2316,12 @@ async def fetch_messages(self, thread_id: str, options: FetchOptions | None = No try: if direction == "forward": - return await self._fetch_messages_forward(channel, thread_ts, thread_id, limit, getattr(opts, "cursor", None)) - return await self._fetch_messages_backward(channel, thread_ts, thread_id, limit, getattr(opts, "cursor", None)) + return await self._fetch_messages_forward( + channel, thread_ts, thread_id, limit, getattr(opts, "cursor", None) + ) + return await self._fetch_messages_backward( + channel, thread_ts, thread_id, limit, getattr(opts, "cursor", None) + ) except Exception as error: self._handle_slack_error(error) @@ -2355,7 +2379,9 @@ async def fetch_message(self, thread_id: str, message_id: str) -> Message | None try: client = self._get_client() - result = await client.conversations_replies(channel=channel, ts=thread_ts, oldest=message_id, inclusive=True, limit=1) + result = await client.conversations_replies( + channel=channel, ts=thread_ts, oldest=message_id, inclusive=True, limit=1 + ) messages = result.get("messages", []) target = next((m for m in messages if m.get("ts") == message_id), None) if not target: @@ -2450,7 +2476,9 @@ async def _fetch_channel_messages_forward(self, channel: str, limit: int, cursor return FetchResult(messages=list(messages), next_cursor=next_cursor) - async def _fetch_channel_messages_backward(self, channel: str, limit: int, cursor: str | None = None) -> FetchResult: + async def _fetch_channel_messages_backward( + self, channel: str, limit: int, cursor: str | None = None + ) -> FetchResult: client = self._get_client() kwargs: dict[str, Any] = {"channel": channel, "limit": limit} if cursor: @@ -2510,7 +2538,7 @@ async def list_threads(self, channel_id: str, options: ListThreadsOptions | None last_reply_at: datetime | None = None if msg.get("latest_reply"): try: - last_reply_at = datetime.fromtimestamp(float(msg["latest_reply"]), tz=timezone.utc) + last_reply_at = datetime.fromtimestamp(float(msg["latest_reply"]), tz=UTC) except (ValueError, TypeError, OSError): last_reply_at = None diff --git a/src/chat_sdk/adapters/slack/cards.py b/src/chat_sdk/adapters/slack/cards.py index 66c0466..2630f4d 100644 --- a/src/chat_sdk/adapters/slack/cards.py +++ b/src/chat_sdk/adapters/slack/cards.py @@ -27,7 +27,8 @@ table_element_to_ascii, ) from chat_sdk.modals import SelectElement -from chat_sdk.shared import card_to_fallback_text as shared_card_to_fallback_text, create_emoji_converter, map_button_style +from chat_sdk.shared import card_to_fallback_text as shared_card_to_fallback_text +from chat_sdk.shared import create_emoji_converter, map_button_style # Type aliases for Slack Block Kit structures SlackBlock = dict[str, Any] @@ -364,7 +365,10 @@ def convert_fields_to_block(element: FieldsElement) -> SlackBlock: fields.append( { "type": "mrkdwn", - "text": f"*{_markdown_to_mrkdwn(convert_emoji(f['label']))}*\n{_markdown_to_mrkdwn(convert_emoji(f['value']))}", + "text": ( + f"*{_markdown_to_mrkdwn(convert_emoji(f['label']))}*" + f"\n{_markdown_to_mrkdwn(convert_emoji(f['value']))}" + ), } ) diff --git a/src/chat_sdk/adapters/teams/adapter.py b/src/chat_sdk/adapters/teams/adapter.py index 8fe35a3..b81c462 100644 --- a/src/chat_sdk/adapters/teams/adapter.py +++ b/src/chat_sdk/adapters/teams/adapter.py @@ -9,13 +9,11 @@ from __future__ import annotations import base64 -import hmac import json import os import re -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any -from urllib.parse import urlparse from chat_sdk.adapters.teams.cards import card_to_adaptive_card from chat_sdk.adapters.teams.format_converter import TeamsFormatConverter @@ -32,9 +30,11 @@ AdapterRateLimitError, AuthenticationError, NetworkError, - PermissionError as AdapterPermissionError, ValidationError, ) +from chat_sdk.shared.errors import ( + PermissionError as AdapterPermissionError, +) from chat_sdk.types import ( ActionEvent, AdapterPostableMessage, @@ -71,9 +71,7 @@ ] # Bot Framework OpenID configuration URL for JWT verification -BOT_FRAMEWORK_OPENID_CONFIG_URL = ( - "https://login.botframework.com/v1/.well-known/openid-configuration" -) +BOT_FRAMEWORK_OPENID_CONFIG_URL = "https://login.botframework.com/v1/.well-known/openid-configuration" def _validate_service_url(url: str) -> None: @@ -98,14 +96,18 @@ def _handle_teams_error(error: Any, operation: str) -> None: """ if error and isinstance(error, dict): inner_error = error.get("innerHttpError", {}) - status_code = inner_error.get("statusCode") or error.get("statusCode") or error.get("status") or error.get("code") + status_code = ( + inner_error.get("statusCode") or error.get("statusCode") or error.get("status") or error.get("code") + ) if status_code == 401: raise AuthenticationError( "teams", f"Authentication failed for {operation}: {error.get('message', 'unauthorized')}", ) - if status_code == 403 or (isinstance(error.get("message"), str) and "permission" in error.get("message", "").lower()): + if status_code == 403 or ( + isinstance(error.get("message"), str) and "permission" in error.get("message", "").lower() + ): raise AdapterPermissionError("teams", operation) if status_code == 404: raise NetworkError( @@ -501,7 +503,7 @@ def _parse_teams_message( metadata=MessageMetadata( date_sent=datetime.fromisoformat(activity["timestamp"]) if activity.get("timestamp") - else datetime.now(timezone.utc), + else datetime.now(UTC), edited=False, ), attachments=attachments, @@ -532,9 +534,7 @@ def _is_message_from_self(self, activity: dict[str, Any]) -> bool: return False if from_id == self._app_id: return True - if from_id.endswith(f":{self._app_id}"): - return True - return False + return bool(from_id.endswith(f":{self._app_id}")) async def post_message( self, @@ -780,7 +780,9 @@ def encode_thread_id(self, platform_data: TeamsThreadId) -> str: encoded_conversation_id = ( base64.urlsafe_b64encode(platform_data.conversation_id.encode("utf-8")).decode("ascii").rstrip("=") ) - encoded_service_url = base64.urlsafe_b64encode(platform_data.service_url.encode("utf-8")).decode("ascii").rstrip("=") + encoded_service_url = ( + base64.urlsafe_b64encode(platform_data.service_url.encode("utf-8")).decode("ascii").rstrip("=") + ) return f"teams:{encoded_conversation_id}:{encoded_service_url}" def decode_thread_id(self, thread_id: str) -> TeamsThreadId: @@ -1396,7 +1398,7 @@ def _map_graph_message(self, msg: dict[str, Any], thread_id: str) -> Message: ), metadata=MessageMetadata( date_sent=( - datetime.fromisoformat(msg["createdDateTime"]) if msg.get("createdDateTime") else datetime.now(timezone.utc) + datetime.fromisoformat(msg["createdDateTime"]) if msg.get("createdDateTime") else datetime.now(UTC) ), edited=bool(msg.get("lastModifiedDateTime")), ), @@ -1446,7 +1448,7 @@ def _extract_card_title(self, card: Any) -> str | None: # First pass: look for prominent text blocks for element in body: - if isinstance(element, dict) and element.get("type") == "TextBlock": + if isinstance(element, dict) and element.get("type") == "TextBlock": # noqa: SIM102 if element.get("weight") == "bolder" or element.get("size") in ("large", "extraLarge"): text = element.get("text") if isinstance(text, str): @@ -1673,7 +1675,7 @@ async def _verify_bot_framework_token(self, request: Any) -> Any | None: if self._jwks_client is None: import aiohttp - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession() as session: # noqa: SIM117 async with session.get(BOT_FRAMEWORK_OPENID_CONFIG_URL) as resp: if resp.status != 200: self._logger.error("Failed to fetch Bot Framework OpenID config", {"status": resp.status}) diff --git a/src/chat_sdk/adapters/telegram/adapter.py b/src/chat_sdk/adapters/telegram/adapter.py index 1f660ab..7b720b2 100644 --- a/src/chat_sdk/adapters/telegram/adapter.py +++ b/src/chat_sdk/adapters/telegram/adapter.py @@ -17,7 +17,7 @@ import os import re from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.adapters.telegram.cards import ( @@ -53,10 +53,12 @@ AdapterRateLimitError, AuthenticationError, NetworkError, - PermissionError as AdapterPermissionError, ResourceNotFoundError, ValidationError, ) +from chat_sdk.shared.errors import ( + PermissionError as AdapterPermissionError, +) from chat_sdk.types import ( ActionEvent, AdapterPostableMessage, @@ -926,7 +928,7 @@ async def edit_message( metadata=MessageMetadata( date_sent=existing.metadata.date_sent, edited=True, - edited_at=datetime.now(timezone.utc), + edited_at=datetime.now(UTC), ), attachments=existing.attachments, is_mention=existing.is_mention, @@ -1237,9 +1239,9 @@ def parse_telegram_message( raw=raw, author=author, metadata=MessageMetadata( - date_sent=datetime.fromtimestamp(raw["date"], tz=timezone.utc), + date_sent=datetime.fromtimestamp(raw["date"], tz=UTC), edited=edit_date is not None, - edited_at=(datetime.fromtimestamp(edit_date, tz=timezone.utc) if edit_date is not None else None), + edited_at=(datetime.fromtimestamp(edit_date, tz=UTC) if edit_date is not None else None), ), attachments=self.extract_attachments(raw), is_mention=self.is_bot_mentioned(raw, plain_text), diff --git a/src/chat_sdk/adapters/whatsapp/adapter.py b/src/chat_sdk/adapters/whatsapp/adapter.py index 3228b0a..f12a46b 100644 --- a/src/chat_sdk/adapters/whatsapp/adapter.py +++ b/src/chat_sdk/adapters/whatsapp/adapter.py @@ -15,7 +15,7 @@ import os import time from collections.abc import AsyncIterable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from urllib.parse import parse_qs, urlparse @@ -508,7 +508,7 @@ def _build_message( metadata=MessageMetadata( date_sent=datetime.fromtimestamp( int(inbound.get("timestamp", "0")), - tz=timezone.utc, + tz=UTC, ), edited=False, ), @@ -647,14 +647,14 @@ async def download_media(self, media_id: str) -> bytes: ) host = (parsed.hostname or "").lower() allowed_suffixes = ( - ".facebook.com", ".fbcdn.net", ".fbsbx.com", - ".whatsapp.net", ".whatsapp.com", + ".facebook.com", + ".fbcdn.net", + ".fbsbx.com", + ".whatsapp.net", + ".whatsapp.com", ) allowed_exact = {"facebook.com", "fbcdn.net", "fbsbx.com", "whatsapp.net", "whatsapp.com"} - if not ( - any(host.endswith(s) for s in allowed_suffixes) - or host in allowed_exact - ): + if not (any(host.endswith(s) for s in allowed_suffixes) or host in allowed_exact): raise ValidationError( "whatsapp", f"Media download URL host is not an allowed Meta domain: {host}", @@ -803,7 +803,9 @@ async def edit_message( message: AdapterPostableMessage, ) -> RawMessage: """Edit a message. Not supported by WhatsApp Cloud API.""" - raise RuntimeError("WhatsApp does not support editing messages. Use post_message to send a new message instead.") + raise RuntimeError( + "WhatsApp does not support editing messages. Use post_message to send a new message instead." + ) async def stream( self, @@ -940,7 +942,9 @@ def parse_message(self, raw: WhatsAppRawMessage) -> Message: contact = raw.get("contact") contact_name = "" - contact_name = contact.get("profile", {}).get("name", "") or raw["message"]["from"] if contact else raw["message"]["from"] + contact_name = ( + contact.get("profile", {}).get("name", "") or raw["message"]["from"] if contact else raw["message"]["from"] + ) return Message( id=raw["message"]["id"], @@ -957,7 +961,7 @@ def parse_message(self, raw: WhatsAppRawMessage) -> Message: metadata=MessageMetadata( date_sent=datetime.fromtimestamp( int(raw["message"].get("timestamp", "0")), - tz=timezone.utc, + tz=UTC, ), edited=False, ), diff --git a/src/chat_sdk/adapters/whatsapp/cards.py b/src/chat_sdk/adapters/whatsapp/cards.py index 32b022c..b00ffb3 100644 --- a/src/chat_sdk/adapters/whatsapp/cards.py +++ b/src/chat_sdk/adapters/whatsapp/cards.py @@ -12,7 +12,7 @@ from __future__ import annotations import json -from typing import Any, Literal, TypedDict, Union +from typing import Any, Literal, TypedDict from chat_sdk.adapters.whatsapp.types import WhatsAppInteractiveMessage from chat_sdk.cards import ( @@ -55,7 +55,7 @@ class WhatsAppCardResultText(TypedDict): text: str -WhatsAppCardResult = Union[WhatsAppCardResultInteractive, WhatsAppCardResultText] +WhatsAppCardResult = WhatsAppCardResultInteractive | WhatsAppCardResultText def encode_whatsapp_callback_data(action_id: str, value: str | None = None) -> str: diff --git a/src/chat_sdk/cards.py b/src/chat_sdk/cards.py index a195f97..2df3789 100644 --- a/src/chat_sdk/cards.py +++ b/src/chat_sdk/cards.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, TypedDict, Union +from typing import Any, Literal, TypedDict # Button style options ButtonStyle = Literal["primary", "danger", "default"] @@ -123,16 +123,16 @@ class SectionElement(TypedDict): # Union of all card child element types -CardChild = Union[ - TextElement, - ImageElement, - DividerElement, - ActionsElement, - SectionElement, - FieldsElement, - LinkElement, - TableElement, -] +CardChild = ( + TextElement + | ImageElement + | DividerElement + | ActionsElement + | SectionElement + | FieldsElement + | LinkElement + | TableElement +) class CardElement(TypedDict, total=False): diff --git a/src/chat_sdk/channel.py b/src/chat_sdk/channel.py index e847fe9..9ed1529 100644 --- a/src/chat_sdk/channel.py +++ b/src/chat_sdk/channel.py @@ -9,7 +9,7 @@ from collections.abc import AsyncIterator from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.errors import ChatNotImplementedError @@ -424,7 +424,7 @@ async def _remove_reaction(emoji: EmojiValue | str) -> None: is_me=True, ), metadata=MessageMetadata( - date_sent=datetime.now(tz=timezone.utc), + date_sent=datetime.now(tz=UTC), edited=False, ), attachments=attachments, diff --git a/src/chat_sdk/chat.py b/src/chat_sdk/chat.py index 313641c..661651f 100644 --- a/src/chat_sdk/chat.py +++ b/src/chat_sdk/chat.py @@ -13,7 +13,7 @@ import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.channel import ChannelImpl @@ -731,7 +731,9 @@ async def _task() -> None: if task is not None: task.add_done_callback( lambda t: ( - self._logger.error("Message processing error", {"thread_id": thread_id, "error": str(t.exception())}) + self._logger.error( + "Message processing error", {"thread_id": thread_id, "error": str(t.exception())} + ) if not t.cancelled() and t.exception() else None ) @@ -1113,7 +1115,7 @@ async def _handle_action_event(self, event: ActionEvent) -> None: formatted={"type": "root", "children": []}, raw=event.raw, author=event.user, - metadata=MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + metadata=MessageMetadata(date_sent=datetime.now(tz=UTC), edited=False), attachments=[], ) thread = self._create_thread(event.adapter, event.thread_id, dummy_message, is_subscribed) @@ -1145,7 +1147,7 @@ async def _open_modal(modal: Any) -> dict[str, str] | None: metadata=getattr( raw_fetched, "metadata", - MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + MessageMetadata(date_sent=datetime.now(tz=UTC), edited=False), ), ) except Exception: @@ -1223,7 +1225,7 @@ async def _handle_reaction_event(self, event: ReactionEvent) -> None: formatted={"type": "root", "children": []}, raw=None, author=event.user, - metadata=MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + metadata=MessageMetadata(date_sent=datetime.now(tz=UTC), edited=False), ), is_subscribed, ) @@ -1252,7 +1254,8 @@ async def _handle_reaction_event(self, event: ReactionEvent) -> None: filt is full_event.emoji or (isinstance(filt, str) and (filt == full_event.emoji.name or filt == full_event.raw_emoji)) or ( - isinstance(filt, EmojiValue) and (filt.name == full_event.emoji.name or filt.name == full_event.raw_emoji) + isinstance(filt, EmojiValue) + and (filt.name == full_event.emoji.name or filt.name == full_event.raw_emoji) ) ) for filt in pat.emoji @@ -1283,7 +1286,7 @@ async def open_dm(self, user: str | Author) -> ThreadImpl: formatted={"type": "root", "children": []}, raw=None, author=Author(user_id="", user_name="", full_name="", is_bot=False, is_me=False), - metadata=MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + metadata=MessageMetadata(date_sent=datetime.now(tz=UTC), edited=False), ), False, ) @@ -1347,7 +1350,11 @@ async def _get_lock_key(self, adapter: Adapter, thread_id: str) -> str: scope: LockScope if callable(self._lock_scope_config): - is_dm = adapter.is_dm(thread_id) if hasattr(adapter, "is_dm") and callable(getattr(adapter, "is_dm", None)) else False # type: ignore[union-attr] + is_dm = ( + adapter.is_dm(thread_id) + if hasattr(adapter, "is_dm") and callable(getattr(adapter, "is_dm", None)) + else False + ) # type: ignore[union-attr] scope = await self._lock_scope_config( LockScopeContext( adapter=adapter, @@ -1480,7 +1487,7 @@ async def _handle_queue_or_debounce( ) return - now = int(datetime.now(tz=timezone.utc).timestamp() * 1000) + now = int(datetime.now(tz=UTC).timestamp() * 1000) entry = QueueEntry( message=message, enqueued_at=now, @@ -1498,7 +1505,7 @@ async def _handle_queue_or_debounce( try: if strategy == "debounce": - now = int(datetime.now(tz=timezone.utc).timestamp() * 1000) + now = int(datetime.now(tz=UTC).timestamp() * 1000) await self._state_adapter.enqueue( lock_key, QueueEntry(message=message, enqueued_at=now, expires_at=now + queue_entry_ttl_ms), @@ -1555,7 +1562,7 @@ async def _debounce_loop( break msg = self._rehydrate_message(entry.message) - now = int(datetime.now(tz=timezone.utc).timestamp() * 1000) + now = int(datetime.now(tz=UTC).timestamp() * 1000) if now > entry.expires_at: self._logger.info("message-expired", {"thread_id": thread_id, "message_id": msg.id}) continue @@ -1585,7 +1592,7 @@ async def _drain_queue( if entry is None: break msg = self._rehydrate_message(entry.message) - now = int(datetime.now(tz=timezone.utc).timestamp() * 1000) + now = int(datetime.now(tz=UTC).timestamp() * 1000) if now <= entry.expires_at: pending.append((msg, entry.expires_at)) else: @@ -1596,7 +1603,9 @@ async def _drain_queue( extended = await self._state_adapter.extend_lock(lock, DEFAULT_LOCK_TTL_MS) if not extended: - self._logger.warn("Lock lost during drain processing, aborting", {"thread_id": thread_id, "lock_key": lock_key}) + self._logger.warn( + "Lock lost during drain processing, aborting", {"thread_id": thread_id, "lock_key": lock_key} + ) return latest_msg, _ = pending[-1] @@ -1788,7 +1797,7 @@ def _rehydrate_message(self, raw: Any) -> Message: if isinstance(date_sent, str): date_sent = datetime.fromisoformat(date_sent) elif not isinstance(date_sent, datetime): - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) edited_at = metadata_raw.get("edited_at") if isinstance(edited_at, str): @@ -1849,7 +1858,7 @@ def _message_from_json(data: dict[str, Any]) -> Message: if isinstance(date_sent, str): date_sent = datetime.fromisoformat(date_sent) elif not isinstance(date_sent, datetime): - date_sent = datetime.now(tz=timezone.utc) + date_sent = datetime.now(tz=UTC) edited_at = metadata_raw.get("edited_at") if isinstance(edited_at, str): diff --git a/src/chat_sdk/from_full_stream.py b/src/chat_sdk/from_full_stream.py index 4f4be54..16e637f 100644 --- a/src/chat_sdk/from_full_stream.py +++ b/src/chat_sdk/from_full_stream.py @@ -54,7 +54,9 @@ async def from_full_stream( # AI SDK v5 uses textDelta, v6 uses text text_delta = event.get("textDelta") if event.get("textDelta") is not None else event.get("text_delta") text_content = ( - text_delta if text_delta is not None else (event.get("text") if event.get("text") is not None else event.get("delta")) + text_delta + if text_delta is not None + else (event.get("text") if event.get("text") is not None else event.get("delta")) ) if event_type == "text-delta" and isinstance(text_content, str): diff --git a/src/chat_sdk/shared/markdown_parser.py b/src/chat_sdk/shared/markdown_parser.py index 2d0a3a6..66a8ce7 100644 --- a/src/chat_sdk/shared/markdown_parser.py +++ b/src/chat_sdk/shared/markdown_parser.py @@ -318,19 +318,13 @@ def _collect_list_items(lines: list[str], start: int, ordered: bool) -> tuple[li """ items: list[Content] = [] i = start - if ordered: - item_re = re.compile(r"^(\d+)[.)]\s+(.*)") - else: - item_re = re.compile(r"^[-*+]\s+(.*)") + item_re = re.compile(r"^(\d+)[.)]\s+(.*)") if ordered else re.compile(r"^[-*+]\s+(.*)") while i < len(lines): line = lines[i] m = item_re.match(line) if m: - if ordered: - item_text = m.group(2) - else: - item_text = m.group(1) + item_text = m.group(2) if ordered else m.group(1) item_children_lines = [item_text] i += 1 diff --git a/src/chat_sdk/shared/mock_adapter.py b/src/chat_sdk/shared/mock_adapter.py index be3f8e7..5a8728c 100644 --- a/src/chat_sdk/shared/mock_adapter.py +++ b/src/chat_sdk/shared/mock_adapter.py @@ -9,7 +9,7 @@ import time from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from chat_sdk.types import ( @@ -339,7 +339,7 @@ def create_test_message( is_me=False, ), "metadata": MessageMetadata( - date_sent=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + date_sent=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC), edited=False, ), "attachments": [], diff --git a/src/chat_sdk/state/memory.py b/src/chat_sdk/state/memory.py index 329bccd..9901e90 100644 --- a/src/chat_sdk/state/memory.py +++ b/src/chat_sdk/state/memory.py @@ -16,9 +16,10 @@ from dataclasses import dataclass from typing import Any +from chat_sdk.types import Lock, QueueEntry + logger = logging.getLogger(__name__) -from chat_sdk.types import Lock, QueueEntry # --------------------------------------------------------------------------- # Internal helpers diff --git a/src/chat_sdk/state/postgres.py b/src/chat_sdk/state/postgres.py index 588049f..1c562ba 100644 --- a/src/chat_sdk/state/postgres.py +++ b/src/chat_sdk/state/postgres.py @@ -139,7 +139,9 @@ def __init__( self._pool = None # created lazily in connect() resolved_url = url or os.environ.get("POSTGRES_URL") or os.environ.get("DATABASE_URL") if not resolved_url: - raise ValueError("Postgres url is required. Set POSTGRES_URL or DATABASE_URL, or provide url in options.") + raise ValueError( + "Postgres url is required. Set POSTGRES_URL or DATABASE_URL, or provide url in options." + ) self._url = resolved_url # -- lifecycle ----------------------------------------------------------- @@ -460,30 +462,29 @@ async def enqueue(self, thread_id: str, entry: QueueEntry, max_size: int) -> int expires_at = _pg_timestamp_from_epoch_ms(entry.expires_at) # Wrap insert + trim in a transaction to avoid TOCTOU races - async with self._pool.acquire() as conn: - async with conn.transaction(): - # Purge expired entries first - await conn.execute( - """DELETE FROM chat_state_queues + async with self._pool.acquire() as conn, conn.transaction(): + # Purge expired entries first + await conn.execute( + """DELETE FROM chat_state_queues WHERE key_prefix = $1 AND thread_id = $2 AND expires_at <= now()""", - self._key_prefix, - thread_id, - ) + self._key_prefix, + thread_id, + ) - # Insert the new entry - await conn.execute( - """INSERT INTO chat_state_queues (key_prefix, thread_id, value, expires_at) + # Insert the new entry + await conn.execute( + """INSERT INTO chat_state_queues (key_prefix, thread_id, value, expires_at) VALUES ($1, $2, $3, $4)""", - self._key_prefix, - thread_id, - serialized, - expires_at, - ) + self._key_prefix, + thread_id, + serialized, + expires_at, + ) - # Trim overflow (keep newest max_size non-expired entries) - if max_size > 0: - await conn.execute( - """DELETE FROM chat_state_queues + # Trim overflow (keep newest max_size non-expired entries) + if max_size > 0: + await conn.execute( + """DELETE FROM chat_state_queues WHERE key_prefix = $1 AND thread_id = $2 AND seq IN ( SELECT seq FROM chat_state_queues WHERE key_prefix = $1 AND thread_id = $2 @@ -496,19 +497,19 @@ async def enqueue(self, thread_id: str, entry: QueueEntry, max_size: int) -> int 0 ) )""", - self._key_prefix, - thread_id, - max_size, - ) - - # Return current non-expired depth - depth = await conn.fetchval( - """SELECT count(*) FROM chat_state_queues - WHERE key_prefix = $1 AND thread_id = $2 AND expires_at > now()""", self._key_prefix, thread_id, + max_size, ) + # Return current non-expired depth + depth = await conn.fetchval( + """SELECT count(*) FROM chat_state_queues + WHERE key_prefix = $1 AND thread_id = $2 AND expires_at > now()""", + self._key_prefix, + thread_id, + ) + return int(depth) async def dequeue(self, thread_id: str) -> QueueEntry | None: @@ -585,12 +586,12 @@ async def _ensure_schema(self) -> None: def _pg_timestamp_from_ms(ttl_ms: int) -> _dt.datetime: """Return a timezone-aware datetime ``ttl_ms`` milliseconds from now.""" - return _dt.datetime.now(_dt.timezone.utc) + _dt.timedelta(milliseconds=ttl_ms) + return _dt.datetime.now(_dt.UTC) + _dt.timedelta(milliseconds=ttl_ms) def _pg_timestamp_from_epoch_ms(epoch_ms: int) -> _dt.datetime: """Return a timezone-aware datetime from an epoch-millisecond value.""" - return _dt.datetime.fromtimestamp(epoch_ms / 1000, tz=_dt.timezone.utc) + return _dt.datetime.fromtimestamp(epoch_ms / 1000, tz=_dt.UTC) # --------------------------------------------------------------------------- diff --git a/src/chat_sdk/thread.py b/src/chat_sdk/thread.py index 4833066..e5acdf5 100644 --- a/src/chat_sdk/thread.py +++ b/src/chat_sdk/thread.py @@ -10,7 +10,7 @@ import asyncio from collections.abc import AsyncIterator from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from chat_sdk.errors import ChatNotImplementedError @@ -139,7 +139,10 @@ def _extract_message_content( # Simplified: store markdown as text node (no full parser in Python port) return ( message.markdown, - {"type": "root", "children": [{"type": "paragraph", "children": [{"type": "text", "value": message.markdown}]}]}, + { + "type": "root", + "children": [{"type": "paragraph", "children": [{"type": "text", "value": message.markdown}]}], + }, list(message.attachments or []), ) @@ -538,7 +541,9 @@ async def _fallback_stream( Posts an initial placeholder, then edits the message at intervals as new text arrives from the stream. """ - interval_ms = options.update_interval_ms if options and options.update_interval_ms else self._streaming_update_interval_ms + interval_ms = ( + options.update_interval_ms if options and options.update_interval_ms else self._streaming_update_interval_ms + ) interval_s = interval_ms / 1000.0 placeholder_text = self._fallback_streaming_placeholder_text @@ -739,7 +744,7 @@ async def _remove_reaction(emoji: EmojiValue | str) -> None: is_me=True, ), metadata=MessageMetadata( - date_sent=datetime.now(tz=timezone.utc), + date_sent=datetime.now(tz=UTC), edited=False, ), attachments=attachments, diff --git a/src/chat_sdk/types.py b/src/chat_sdk/types.py index a039dd3..d68da7a 100644 --- a/src/chat_sdk/types.py +++ b/src/chat_sdk/types.py @@ -13,7 +13,6 @@ Literal, Protocol, TypedDict, - Union, runtime_checkable, ) @@ -680,10 +679,10 @@ class PostableCard: # Union of adapter-postable message types -AdapterPostableMessage = Union[str, PostableRaw, PostableMarkdown, PostableAst, PostableCard, CardElement] +AdapterPostableMessage = str | PostableRaw | PostableMarkdown | PostableAst | PostableCard | CardElement # Union of all postable message types (includes streaming) -PostableMessage = Union[AdapterPostableMessage, AsyncIterable[Any]] +PostableMessage = AdapterPostableMessage | AsyncIterable[Any] # ============================================================================= # Streaming Types @@ -717,7 +716,7 @@ class PlanUpdateChunk: title: str = "" -StreamChunk = Union[MarkdownTextChunk, TaskUpdateChunk, PlanUpdateChunk] +StreamChunk = MarkdownTextChunk | TaskUpdateChunk | PlanUpdateChunk @dataclass @@ -876,7 +875,9 @@ async def get(self, key: str) -> Any | None: ... async def set(self, key: str, value: Any, ttl_ms: int | None = None) -> None: ... async def set_if_not_exists(self, key: str, value: Any, ttl_ms: int | None = None) -> bool: ... async def delete(self, key: str) -> None: ... - async def append_to_list(self, key: str, value: Any, *, max_length: int | None = None, ttl_ms: int | None = None) -> None: ... + async def append_to_list( + self, key: str, value: Any, *, max_length: int | None = None, ttl_ms: int | None = None + ) -> None: ... async def get_list(self, key: str) -> list[Any]: ... async def enqueue(self, thread_id: str, entry: QueueEntry, max_size: int) -> int: ... async def dequeue(self, thread_id: str) -> QueueEntry | None: ... @@ -1279,7 +1280,9 @@ def process_slash_command(self, event: Any, options: WebhookOptions | None = Non def process_modal_submit( self, event: Any, context_id: str | None = None, options: WebhookOptions | None = None ) -> Awaitable[ModalResponse | None]: ... - def process_modal_close(self, event: Any, context_id: str | None = None, options: WebhookOptions | None = None) -> None: ... + def process_modal_close( + self, event: Any, context_id: str | None = None, options: WebhookOptions | None = None + ) -> None: ... def process_assistant_thread_started( self, event: AssistantThreadStartedEvent, options: WebhookOptions | None = None ) -> None: ... @@ -1287,7 +1290,9 @@ def process_assistant_context_changed( self, event: AssistantContextChangedEvent, options: WebhookOptions | None = None ) -> None: ... def process_app_home_opened(self, event: AppHomeOpenedEvent, options: WebhookOptions | None = None) -> None: ... - def process_member_joined_channel(self, event: MemberJoinedChannelEvent, options: WebhookOptions | None = None) -> None: ... + def process_member_joined_channel( + self, event: MemberJoinedChannelEvent, options: WebhookOptions | None = None + ) -> None: ... # =============================================================================