|
50 | 50 |
|
51 | 51 | # Standard |
52 | 52 | import asyncio |
| 53 | +from asyncio import Task |
53 | 54 | from datetime import datetime, timezone |
54 | 55 | import json |
55 | 56 | import logging |
@@ -184,7 +185,7 @@ def __init__( |
184 | 185 | # Set up backend-specific components |
185 | 186 | if self._backend == "memory": |
186 | 187 | # Nothing special needed for memory backend |
187 | | - self._session_message = None |
| 188 | + self._session_message: dict[str, Any] | None = None |
188 | 189 |
|
189 | 190 | elif self._backend == "none": |
190 | 191 | # No session tracking - this is just a dummy registry |
@@ -295,7 +296,7 @@ def __init__( |
295 | 296 | super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl) |
296 | 297 | self._sessions: Dict[str, Any] = {} # Local transport cache |
297 | 298 | self._lock = asyncio.Lock() |
298 | | - self._cleanup_task = None |
| 299 | + self._cleanup_task: Task | None = None |
299 | 300 |
|
300 | 301 | async def initialize(self) -> None: |
301 | 302 | """Initialize the registry with async setup. |
@@ -697,7 +698,7 @@ async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None: |
697 | 698 | else: |
698 | 699 | msg_json = json.dumps(str(message)) |
699 | 700 |
|
700 | | - self._session_message: Dict[str, Any] = {"session_id": session_id, "message": msg_json} |
| 701 | + self._session_message: Dict[str, Any] | None = {"session_id": session_id, "message": msg_json} |
701 | 702 |
|
702 | 703 | elif self._backend == "redis": |
703 | 704 | try: |
@@ -835,7 +836,7 @@ async def respond( |
835 | 836 | elif self._backend == "memory": |
836 | 837 | # if self._session_message: |
837 | 838 | transport = self.get_session_sync(session_id) |
838 | | - if transport: |
| 839 | + if transport and self._session_message: |
839 | 840 | message = json.loads(str(self._session_message.get("message"))) |
840 | 841 | await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url) |
841 | 842 |
|
@@ -863,7 +864,7 @@ async def respond( |
863 | 864 |
|
864 | 865 | elif self._backend == "database": |
865 | 866 |
|
866 | | - def _db_read_session(session_id: str) -> SessionRecord: |
| 867 | + def _db_read_session(session_id: str) -> SessionRecord | None: |
867 | 868 | """Check if session still exists in the database. |
868 | 869 |
|
869 | 870 | Queries the SessionRecord table to verify that the session |
@@ -898,7 +899,7 @@ def _db_read_session(session_id: str) -> SessionRecord: |
898 | 899 | finally: |
899 | 900 | db_session.close() |
900 | 901 |
|
901 | | - def _db_read(session_id: str) -> SessionMessageRecord: |
| 902 | + def _db_read(session_id: str) -> SessionMessageRecord | None: |
902 | 903 | """Read pending message for a session from the database. |
903 | 904 |
|
904 | 905 | Retrieves the first (oldest) unprocessed message for the given |
@@ -1284,23 +1285,23 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo |
1284 | 1285 | result = {} |
1285 | 1286 |
|
1286 | 1287 | if "method" in message and "id" in message: |
| 1288 | + method = message["method"] |
| 1289 | + params = message.get("params", {}) |
| 1290 | + params["server_id"] = server_id |
| 1291 | + req_id = message["id"] |
| 1292 | + |
| 1293 | + rpc_input = { |
| 1294 | + "jsonrpc": "2.0", |
| 1295 | + "method": method, |
| 1296 | + "params": params, |
| 1297 | + "id": req_id, |
| 1298 | + } |
| 1299 | + # Get the token from the current authentication context |
| 1300 | + # The user object doesn't contain the token directly, we need to reconstruct it |
| 1301 | + # Since we don't have access to the original headers here, we need a different approach |
| 1302 | + # We'll extract the token from the session or create a new admin token |
| 1303 | + token = None |
1287 | 1304 | try: |
1288 | | - method = message["method"] |
1289 | | - params = message.get("params", {}) |
1290 | | - params["server_id"] = server_id |
1291 | | - req_id = message["id"] |
1292 | | - |
1293 | | - rpc_input = { |
1294 | | - "jsonrpc": "2.0", |
1295 | | - "method": method, |
1296 | | - "params": params, |
1297 | | - "id": req_id, |
1298 | | - } |
1299 | | - # Get the token from the current authentication context |
1300 | | - # The user object doesn't contain the token directly, we need to reconstruct it |
1301 | | - # Since we don't have access to the original headers here, we need a different approach |
1302 | | - # We'll extract the token from the session or create a new admin token |
1303 | | - token = None |
1304 | 1305 | if hasattr(user, "get") and "auth_token" in user: |
1305 | 1306 | token = user["auth_token"] |
1306 | 1307 | else: |
|
0 commit comments