From 49f2945d4d3842ea16601830d7f163f780390eb7 Mon Sep 17 00:00:00 2001 From: Christopher Lackner Date: Tue, 30 Dec 2025 15:03:53 +0100 Subject: [PATCH] fixes for running webgpu in vscode notebooks --- webgpu/jupyter.py | 39 ++++++++++++++++++++++++++-------- webgpu/platform.py | 53 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 73 insertions(+), 19 deletions(-) diff --git a/webgpu/jupyter.py b/webgpu/jupyter.py index 0db2fd0..a1c201b 100644 --- a/webgpu/jupyter.py +++ b/webgpu/jupyter.py @@ -91,7 +91,8 @@ def _draw_scene(scene: Scene, width, height, id_): html_canvas.height = height gui_element = platform.js.document.getElementById(f"{id_}lilgui") - canvas = Canvas(utils.get_device(), html_canvas) + # Lazily initialize the WebGPU device the first time we draw. + canvas = Canvas(init_device_sync(), html_canvas) scene.gui = LilGUI(gui_element, scene) scene.init(canvas) scene.render() @@ -136,7 +137,13 @@ def Draw( height = height if height is not None else 640 scene, id_ = _init_html(scene, width, height, flex) - _draw_scene(scene, width, height, id_) + + # In classic Jupyter we already have a websocket connection at import + # time, so this callback runs immediately. In VS Code, outputs are only + # processed once the cell has finished executing; using execute_when_init + # ensures that drawing happens once the websocket connection is ready + # instead of blocking the import. + platform.execute_when_init(lambda js: _draw_scene(scene, width, height, id_)) return scene @@ -178,12 +185,26 @@ def Draw( js_code += f"\nwindow.pyodide_ready = init_pyodide('{webgpu_module_b64}');" display(Javascript(js_code)) else: - # Not exporting and not running in pyodide -> Start a websocket server and wait for the client to connect + # Not exporting and not running in pyodide -> Start a websocket server + # and wait for the client to connect. + # + # In VS Code notebooks, outputs are typically only processed once the + # cell has completed execution. If we were to block here waiting for + # the websocket connection, the JavaScript that establishes the + # connection would never run, leading to a deadlock. We therefore + # avoid blocking on the connection in that environment and instead + # defer drawing until the link is ready via execute_when_init. + + def _webgpu_js(server): + js = _link_js_code + """ +const __is_vscode = (typeof location !== 'undefined' && location.protocol === 'vscode-webview:'); +const __webgpu_host = __is_vscode ? '127.0.0.1' : ((typeof location !== 'undefined' && location.hostname) || '127.0.0.1'); +WebsocketLink('ws://' + __webgpu_host + ':{port}'); +""".format(port=server.port) + display(Javascript(js)) + + is_vscode = "VSCODE_PID" in os.environ platform.init( - before_wait_for_connection=lambda server: display( - Javascript( - _link_js_code + f"WebsocketLink('ws://'+location.hostname+':{server.port}');" - ) - ) + before_wait_for_connection=_webgpu_js, + block_on_connection=not is_vscode, ) - device = init_device_sync() diff --git a/webgpu/platform.py b/webgpu/platform.py index b0b8a3a..76f8f7e 100644 --- a/webgpu/platform.py +++ b/webgpu/platform.py @@ -8,6 +8,7 @@ """ from collections.abc import Mapping +import threading is_pyodide = False is_pyodide_main_thread = False @@ -127,16 +128,37 @@ def _serialize_jsproxy(link, value): def execute_when_init(func): + """Register a callback to run once the JS side is ready. + + If the platform has already been initialized, the callback is executed + immediately. Otherwise it is queued and executed from ``init`` once the + websocket connection has been established and ``js`` is set. + """ + if js is not None: func(js) else: _funcs_after_init.append(func) -def init(before_wait_for_connection=None): +def init(before_wait_for_connection=None, block_on_connection: bool = True): + """Initialize the websocket link to the browser. + + In the default (classic Jupyter) mode, this blocks until the browser has + connected via websocket so that ``js`` is ready to use. + + In environments like VS Code notebooks, outputs are typically only + processed once the cell has finished executing. In that situation calling + ``init`` with ``block_on_connection=False`` avoids a deadlock by moving the + blocking ``wait_for_connection`` part to a background thread. Code that + depends on ``js`` should use :func:`execute_when_init` so it runs once the + connection is ready. + """ + global js, create_proxy, destroy_proxy, websocket_server, link if is_pyodide or js is not None: return + websocket_server = WebsocketLinkServer() create_proxy = websocket_server.create_proxy destroy_proxy = websocket_server.destroy_proxy @@ -147,19 +169,30 @@ def init(before_wait_for_connection=None): if before_wait_for_connection: before_wait_for_connection(websocket_server) - websocket_server.wait_for_connection() - js = websocket_server.get(None, None) - from .link.base import LinkBase from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject - LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle) - LinkBase.register_serializer(BaseWebGPUObject, lambda _, v: v.__dict__ or None) + def _finish_init(): + websocket_server.wait_for_connection() + js_local = websocket_server.get(None, None) - websocket_server._start_handling_messages.set() - for func in _funcs_after_init: - func(js) - _funcs_after_init.clear() + LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle) + LinkBase.register_serializer( + BaseWebGPUObject, lambda _, v: v.__dict__ or None + ) + + # Publish js and run any deferred callbacks. + globals()["js"] = js_local + websocket_server._start_handling_messages.set() + for func in _funcs_after_init: + func(js_local) + _funcs_after_init.clear() + + if block_on_connection: + _finish_init() + else: + thread = threading.Thread(target=_finish_init, daemon=True) + thread.start() def init_pyodide(link_):