-
Notifications
You must be signed in to change notification settings - Fork 57
Description
Summary
Since v4.4.1 (PR #1151, commit c088167), _type_utils.py unconditionally imports jax.numpy at module level:
try:
import jax.numpy as jnp
except ImportError:
jnp = NoneWhen JAX happens to be installed in the same environment (e.g. as a transitive dependency of PyMC/PyTensor), this triggers a full JAX initialization (~344ms per jax-ml/jax#24967) on every import lets_plot, even when no JAX arrays are ever used.
JAX is not listed as a dependency (required or optional) in lets-plot's setup.py, so this side effect is unexpected.
Reproduction
In an environment where both lets-plot and JAX are installed:
import builtins, time
_real_import = builtins.__import__
def _tracing_import(name, *args, **kwargs):
if name == 'jax' or name.startswith('jax.'):
import traceback
print(f'=== Import of {name} ===')
traceback.print_stack()
return _real_import(name, *args, **kwargs)
builtins.__import__ = _tracing_import
t0 = time.time()
import lets_plot
print(f'lets_plot import took {time.time() - t0:.3f}s')Output shows jax.numpy being imported via lets_plot._type_utils, adding significant startup time.
Suggested fix
Use a lazy module-level sentinel so that jax.numpy is only imported on first use:
import importlib
_jnp_loaded = False
jnp = None
def _ensure_jnp():
global _jnp_loaded, jnp
if not _jnp_loaded:
_jnp_loaded = True
try:
jnp = importlib.import_module('jax.numpy')
except ImportError:
jnp = None
return jnpThen replace jnp and isinstance(...) checks with _ensure_jnp() and isinstance(...) in is_int, is_float, is_ndarray, and _standardize_value.
This would preserve the existing behavior for users who do pass JAX arrays, while avoiding the startup cost for the vast majority of users who don't.
Environment
- lets-plot 4.8.2
- JAX installed as transitive dependency (PyMC → PyTensor → JAX)
- Python 3.12, macOS