diff --git a/webgpu/link/base.py b/webgpu/link/base.py index f5f9024..6ea0f26 100644 --- a/webgpu/link/base.py +++ b/webgpu/link/base.py @@ -480,8 +480,6 @@ def __init__(self): super().__init__() self._send_loop = asyncio.new_event_loop() self._callback_loop = asyncio.new_event_loop() - self._callback_queue = asyncio.Queue() - self._callback_thread = threading.Thread(target=self._start_callback_thread, daemon=True) self._callback_thread.start() @@ -516,7 +514,13 @@ def _send_data(self, metadata, data, key=None): event = threading.Event() self._requests[request_id] = event, key - asyncio.run_coroutine_threadsafe(self._send_async(data), self._send_loop) + try: + asyncio.run_coroutine_threadsafe(self._send_async(data), self._send_loop) + except RuntimeError: + # Event loop is closed — connection is dead, clean up and bail. + if event: + self._requests.pop(request_id, None) + return None if event: event.wait() return self._requests.pop(request_id) @@ -530,15 +534,19 @@ async def handle_callbacks(): try: func, args = await self._callback_queue.get() func(*args) + except asyncio.CancelledError: + break + except RuntimeError: + break except asyncio.QueueEmpty: pass except Exception as e: print("error in callback", type(e), str(e)) - # await asyncio.sleep(0.01) try: self._callback_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._callback_loop) + self._callback_queue = asyncio.Queue() self._callback_loop.create_task(handle_callbacks()) self._callback_loop.run_forever() except Exception as e: diff --git a/webgpu/link/websocket.py b/webgpu/link/websocket.py index 3567bbc..ae79d23 100644 --- a/webgpu/link/websocket.py +++ b/webgpu/link/websocket.py @@ -122,3 +122,16 @@ async def start_websocket(): def stop(self): self._send_loop.call_soon_threadsafe(self._stop.set_result, None) + + # Stop the callback event loop so the _callback_thread exits. + try: + self._callback_loop.call_soon_threadsafe(self._callback_loop.stop) + except RuntimeError: + pass # Event loop already closed + + # Unblock any threads stuck waiting for websocket RPC responses. + for rid, val in list(self._requests.items()): + if isinstance(val, tuple): + event, key = val + if isinstance(event, threading.Event): + event.set() diff --git a/webgpu/platform.py b/webgpu/platform.py index 76f8f7e..92b2094 100644 --- a/webgpu/platform.py +++ b/webgpu/platform.py @@ -206,3 +206,32 @@ def init_pyodide(link_): LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle) LinkBase.register_serializer(BaseWebGPUObject, lambda _, v: v.__dict__ or None) + + +def reset(): + """Reset the platform globals so that init can be called again. + + Used by test runners that start and stop multiple app instances + in the same process. + """ + global js, websocket_server, link, create_proxy, destroy_proxy + if websocket_server is not None: + try: + websocket_server.stop() + except RuntimeError: + pass # Event loop already closed + js = None + websocket_server = None + link = None + if not is_pyodide: + create_proxy = None + destroy_proxy = None + + # Reset cached WebGPU device so it is re-requested on the new connection. + from . import utils as _utils + _utils._device = None + + # Reset the cached font atlas — its GPU texture is tied to the old + # connection and would deadlock when accessed on a new one. + from . import font as _font + _font._default_font_atlas = None diff --git a/webgpu/scene.py b/webgpu/scene.py index 31a7ab1..d8d9ddf 100644 --- a/webgpu/scene.py +++ b/webgpu/scene.py @@ -97,7 +97,11 @@ def init(self, canvas): self.options.timestamp = time.time() self.options.update_buffers() for obj in self.render_objects: - obj._update_and_create_render_pipeline(self.options) + try: + obj._update_and_create_render_pipeline(self.options) + except Exception as e: + print(f'Warning: failed to init renderer {type(obj).__name__}: {e}') + obj.active = False camera = self.options.camera self._js_render = platform.create_proxy(self._render_direct) @@ -308,8 +312,11 @@ def cleanup(self): self.options.camera._render_function = None self.options.camera._get_position_function = None self.input_handler.unregister_callbacks() - platform.destroy_proxy(self._js_render) - del self._js_render - self.canvas._on_resize_callbacks.remove(self.render) - self.canvas._on_update_html_canvas.remove(self.__on_update_html_canvas) + if hasattr(self, '_js_render'): + platform.destroy_proxy(self._js_render) + del self._js_render + if self.render in self.canvas._on_resize_callbacks: + self.canvas._on_resize_callbacks.remove(self.render) + if self.__on_update_html_canvas in self.canvas._on_update_html_canvas: + self.canvas._on_update_html_canvas.remove(self.__on_update_html_canvas) self.canvas = None