Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions runware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool:

except Exception as e:
if retry_count >= 2:
self.logger.error(f"Error in photoMaker request: {e}")
exit()
return self.handle_incomplete_images(task_uuids=task_uuids, error=e)
self.logger.error(f"Error in photoMaker request:", exc_info=e)
raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"})
else:
raise e

Expand Down Expand Up @@ -464,9 +463,8 @@ async def imageInference(
)
except Exception as e:
if retry_count >= 2:
self.logger.error(f"Error in requestImages: {e}")
exit()
return self.handle_incomplete_images(task_uuids=task_uuids, error=e)
self.logger.error(f"Error in requestImages:", exc_info=e)
raise RunwareAPIError({"message": f"Image inference failed after retries: {str(e)}"})
else:
raise e

Expand Down
134 changes: 56 additions & 78 deletions runware/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
import logging
import websockets
from websockets.protocol import State
import inspect
import pprint
from typing import Any, Callable, Dict, List, Union, Optional, TypeVar
from typing import Any, Dict, Optional


from .types import RunwareBaseType, SdkType
from .types import SdkType
from .utils import (
delay,
getUUID,
removeListener,
BASE_RUNWARE_URLS,
PING_INTERVAL,
PING_TIMEOUT_DURATION,
Expand All @@ -21,12 +16,7 @@
from .base import RunwareBase
from .types import (
Environment,
EPreProcessor,
EPreProcessorGroup,
ListenerType,
IControlNet,
File,
GetWithPromiseCallBackType,
)

from .logging_config import configure_logging
Expand All @@ -50,6 +40,7 @@ def __init__(
self._apiKey: str = api_key
self._message_handler_task: Optional[asyncio.Task] = None
self._last_pong_time: float = 0.0
self._is_shutting_down: bool = False

# Configure logging
configure_logging(log_level)
Expand Down Expand Up @@ -160,29 +151,19 @@ def pong_lis(m):
async def on_message(self, ws, message):
if not message:
return
m = json.loads(message)
# print(
# f"\n\n\n================================================ Received message ============================================================"
# )
# print(f"{m}")

# print(f"Listenerse:")
# for lis in self._listeners:
# print(lis, "\n")
# print(
# f"============================================= End received message ============================================================\n\n\n"
# )

try:
m = json.loads(message)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse JSON message:", exc_info=e)
return

for lis in self._listeners:
try:
# result = True
result = lis.listener(m)
except Exception as e:
print(f"Unexpected error in on_message: {e}")
print(dir(lis))
print(f"Listeners: {self._listeners}")
for lis in self._listeners:
print(dir(lis), "\n")
return
self.logger.error(f"Error in listener {lis.key}:", exc_info=e)
continue
if result:
return

Expand All @@ -192,31 +173,25 @@ async def _handle_messages(self):
f"Starting message handler task {self._message_handler_task}"
)
async for message in self._ws:
if self._is_shutting_down:
break
try:
await self.on_message(self._ws, message)
except Exception as e:
print(f"Unexpected error in async loop: {e}")
print(self.on_message)
exit()
self.logger.error(f"Error in on_message:", exc_info=e)
continue
except websockets.exceptions.ConnectionClosedError as e:
self.logger.error(f"Connection Closed Error: {e}")
await self.handleClose()
if not self._is_shutting_down:
self.logger.error(f"Connection Closed Error:", exc_info=e)
await self.handleClose()
except Exception as e:
print(f"Unexpected error in _handle_messages: {e}")
print(self.on_message)
exit()
await self._ws.close()
self.logger.error(f"Critical error in _handle_messages:", exc_info=e)
if not self._is_shutting_down:
await self.handleClose()

async def send(self, msg: Dict[str, Any]):
self.logger.debug(f"Sending message: {msg}")
# print(
# f"\n\n\n================================================= Sending message ================================================================="
# )
# print(f"{msg}")
# print(
# f"=============================================== End sending message ===============================================================\n\n\n"
# )
if self._ws and self._ws.state is State.OPEN:
if self._ws and self._ws.state is State.OPEN and not self._is_shutting_down:
await self._ws.send(json.dumps(msg))

def _get_task_by_name(self, name):
Expand All @@ -240,7 +215,7 @@ async def handleClose(self):
try:
reconnecting_task.cancel()
except Exception as e:
self.logger.error(f"Error while cancelling Task_Reconnecting: {e}")
self.logger.error(f"Error while cancelling Task_Reconnecting:", exc_info=e)

message_handler_task = self._get_task_by_name("Task_Message_Handler")
if message_handler_task is not None:
Expand All @@ -252,7 +227,7 @@ async def handleClose(self):
message_handler_task.cancel()
except Exception as e:
self.logger.error(
f"Error while cancelling Task_Message_Handler: {e}"
f"Error while cancelling Task_Message_Handler:", exc_info=e
)

heartbeat_task = self._get_task_by_name("Task_Heartbeat")
Expand All @@ -262,12 +237,15 @@ async def handleClose(self):
try:
heartbeat_task.cancel()
except Exception as e:
self.logger.error(f"Error while cancelling Task_Heartbeat: {e}")
self.logger.error(f"Error while cancelling Task_Heartbeat:", exc_info=e)

async def reconnect():
while True:
self.logger.info("Reconnecting...")
await asyncio.sleep(1)
reconnect_attempts = 0
max_reconnect_attempts = 5

while reconnect_attempts < max_reconnect_attempts and not self._is_shutting_down:
self.logger.info(f"Reconnecting... (attempt {reconnect_attempts + 1})")
await asyncio.sleep(min(reconnect_attempts * 2 + 1, 10))
try:
await self.connect()
if self.isWebsocketReadyState():
Expand All @@ -278,43 +256,43 @@ async def reconnect():
"WebSocket connection is not in a ready state after reconnecting"
)
except Exception as e:
self.logger.error(f"Error while reconnecting: {e}")
self.logger.error(f"Error while reconnecting:", exc_info=e)

reconnect_attempts += 1

if reconnect_attempts >= max_reconnect_attempts:
self.logger.error("Max reconnection attempts reached. Giving up.")
self._is_shutting_down = True

# TODO: I don't need to close self._ws here, as it will be cleaned by sockets library based on it's interrnal ping mechanism
# Attempting to reconnect...
self._reconnecting_task = asyncio.create_task(
reconnect(), name="Task_Reconnecting"
)
if not self._is_shutting_down:
self._reconnecting_task = asyncio.create_task(
reconnect(), name="Task_Reconnecting"
)

async def heartBeat(self):
# TODO: Not sure if we need this, as the websocket server responds to default PING messages
# 2024-04-29 10:46:23,193 - websockets.client - DEBUG - % sending keepalive ping
# 2024-04-29 10:46:23,194 - websockets.client - DEBUG - > PING f2 0b eb 3d [binary, 4 bytes]
# 2024-04-29 10:46:23,197 - runware.server - DEBUG - Sending ping
# 2024-04-29 10:46:23,197 - runware.server - DEBUG - Sending message: {'ping': True}
# 2024-04-29 10:46:23,197 - websockets.client - DEBUG - > TEXT '{"ping": true}' [14 bytes]
# 2024-04-29 10:46:23,241 - websockets.client - DEBUG - < PONG f2 0b eb 3d [binary, 4 bytes]
# 2024-04-29 10:46:23,241 - websockets.client - DEBUG - % received keepalive pong
# 2024-04-29 10:46:23,244 - websockets.client - DEBUG - < TEXT '{"pong":true}' [13 bytes]
while True:
while not self._is_shutting_down:
if self.isWebsocketReadyState():
self.logger.debug("Sending ping")
try:
await self.send([{"taskType": "ping", "ping": True}])
except websockets.exceptions.ConnectionClosedError as e:
self.logger.error(
f"Error sending ping: {e}. Connection likely closed."
f"Error sending ping. Connection likely closed.", exc_info=e
)
# Potentially handle reconnection here
except Exception as e: # Catch other potential exceptions
self.logger.error(f"Unexpected error sending ping: {e}")
# Handle unexpected errors appropriately
break
except Exception as e:
self.logger.error(f"Unexpected error sending ping", exc_info=e)
break

await asyncio.sleep(PING_INTERVAL / 1000)

if (
asyncio.get_event_loop().time() - self._last_pong_time
> PING_TIMEOUT_DURATION / 1000
): # No pong received within the timeout period
asyncio.get_event_loop().time() - self._last_pong_time
> PING_TIMEOUT_DURATION / 1000
):
self.logger.warning("No pong received. Connection may be lost.")
# Initiate a reconnection
await self.handleClose()
break
else:
break