Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions webgpu/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def _draw_scene(scene: Scene, width, height, id_):
html_canvas.height = height
gui_element = platform.js.document.getElementById(f"{id_}lilgui")

canvas = Canvas(utils.get_device(), html_canvas)
# Lazily initialize the WebGPU device the first time we draw.
canvas = Canvas(init_device_sync(), html_canvas)
scene.gui = LilGUI(gui_element, scene)
scene.init(canvas)
scene.render()
Expand Down Expand Up @@ -136,7 +137,13 @@ def Draw(
height = height if height is not None else 640

scene, id_ = _init_html(scene, width, height, flex)
_draw_scene(scene, width, height, id_)

# In classic Jupyter we already have a websocket connection at import
# time, so this callback runs immediately. In VS Code, outputs are only
# processed once the cell has finished executing; using execute_when_init
# ensures that drawing happens once the websocket connection is ready
# instead of blocking the import.
platform.execute_when_init(lambda js: _draw_scene(scene, width, height, id_))
return scene


Expand Down Expand Up @@ -178,12 +185,26 @@ def Draw(
js_code += f"\nwindow.pyodide_ready = init_pyodide('{webgpu_module_b64}');"
display(Javascript(js_code))
else:
# Not exporting and not running in pyodide -> Start a websocket server and wait for the client to connect
# Not exporting and not running in pyodide -> Start a websocket server
# and wait for the client to connect.
#
# In VS Code notebooks, outputs are typically only processed once the
# cell has completed execution. If we were to block here waiting for
# the websocket connection, the JavaScript that establishes the
# connection would never run, leading to a deadlock. We therefore
# avoid blocking on the connection in that environment and instead
# defer drawing until the link is ready via execute_when_init.

def _webgpu_js(server):
js = _link_js_code + """
const __is_vscode = (typeof location !== 'undefined' && location.protocol === 'vscode-webview:');
const __webgpu_host = __is_vscode ? '127.0.0.1' : ((typeof location !== 'undefined' && location.hostname) || '127.0.0.1');
WebsocketLink('ws://' + __webgpu_host + ':{port}');
""".format(port=server.port)
display(Javascript(js))

is_vscode = "VSCODE_PID" in os.environ
platform.init(
before_wait_for_connection=lambda server: display(
Javascript(
_link_js_code + f"WebsocketLink('ws://'+location.hostname+':{server.port}');"
)
)
before_wait_for_connection=_webgpu_js,
block_on_connection=not is_vscode,
)
device = init_device_sync()
53 changes: 43 additions & 10 deletions webgpu/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from collections.abc import Mapping
import threading

is_pyodide = False
is_pyodide_main_thread = False
Expand Down Expand Up @@ -127,16 +128,37 @@ def _serialize_jsproxy(link, value):


def execute_when_init(func):
"""Register a callback to run once the JS side is ready.

If the platform has already been initialized, the callback is executed
immediately. Otherwise it is queued and executed from ``init`` once the
websocket connection has been established and ``js`` is set.
"""

if js is not None:
func(js)
else:
_funcs_after_init.append(func)


def init(before_wait_for_connection=None):
def init(before_wait_for_connection=None, block_on_connection: bool = True):
"""Initialize the websocket link to the browser.

In the default (classic Jupyter) mode, this blocks until the browser has
connected via websocket so that ``js`` is ready to use.

In environments like VS Code notebooks, outputs are typically only
processed once the cell has finished executing. In that situation calling
``init`` with ``block_on_connection=False`` avoids a deadlock by moving the
blocking ``wait_for_connection`` part to a background thread. Code that
depends on ``js`` should use :func:`execute_when_init` so it runs once the
connection is ready.
"""

global js, create_proxy, destroy_proxy, websocket_server, link
if is_pyodide or js is not None:
return

websocket_server = WebsocketLinkServer()
create_proxy = websocket_server.create_proxy
destroy_proxy = websocket_server.destroy_proxy
Expand All @@ -147,19 +169,30 @@ def init(before_wait_for_connection=None):
if before_wait_for_connection:
before_wait_for_connection(websocket_server)

websocket_server.wait_for_connection()
js = websocket_server.get(None, None)

from .link.base import LinkBase
from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject

LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle)
LinkBase.register_serializer(BaseWebGPUObject, lambda _, v: v.__dict__ or None)
def _finish_init():
websocket_server.wait_for_connection()
js_local = websocket_server.get(None, None)

websocket_server._start_handling_messages.set()
for func in _funcs_after_init:
func(js)
_funcs_after_init.clear()
LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle)
LinkBase.register_serializer(
BaseWebGPUObject, lambda _, v: v.__dict__ or None
)

# Publish js and run any deferred callbacks.
globals()["js"] = js_local
websocket_server._start_handling_messages.set()
for func in _funcs_after_init:
func(js_local)
_funcs_after_init.clear()

if block_on_connection:
_finish_init()
else:
thread = threading.Thread(target=_finish_init, daemon=True)
thread.start()


def init_pyodide(link_):
Expand Down