Skip to content

Eager import of jax.numpy in _type_utils.py causes unnecessary JAX initialization #1469

@rgeronimi

Description

@rgeronimi

Summary

Since v4.4.1 (PR #1151, commit c088167), _type_utils.py unconditionally imports jax.numpy at module level:

try:
    import jax.numpy as jnp
except ImportError:
    jnp = None

When JAX happens to be installed in the same environment (e.g. as a transitive dependency of PyMC/PyTensor), this triggers a full JAX initialization (~344ms per jax-ml/jax#24967) on every import lets_plot, even when no JAX arrays are ever used.

JAX is not listed as a dependency (required or optional) in lets-plot's setup.py, so this side effect is unexpected.

Reproduction

In an environment where both lets-plot and JAX are installed:

import builtins, time

_real_import = builtins.__import__
def _tracing_import(name, *args, **kwargs):
    if name == 'jax' or name.startswith('jax.'):
        import traceback
        print(f'=== Import of {name} ===')
        traceback.print_stack()
    return _real_import(name, *args, **kwargs)

builtins.__import__ = _tracing_import

t0 = time.time()
import lets_plot
print(f'lets_plot import took {time.time() - t0:.3f}s')

Output shows jax.numpy being imported via lets_plot._type_utils, adding significant startup time.

Suggested fix

Use a lazy module-level sentinel so that jax.numpy is only imported on first use:

import importlib

_jnp_loaded = False
jnp = None

def _ensure_jnp():
    global _jnp_loaded, jnp
    if not _jnp_loaded:
        _jnp_loaded = True
        try:
            jnp = importlib.import_module('jax.numpy')
        except ImportError:
            jnp = None
    return jnp

Then replace jnp and isinstance(...) checks with _ensure_jnp() and isinstance(...) in is_int, is_float, is_ndarray, and _standardize_value.

This would preserve the existing behavior for users who do pass JAX arrays, while avoiding the startup cost for the vast majority of users who don't.

Environment

  • lets-plot 4.8.2
  • JAX installed as transitive dependency (PyMC → PyTensor → JAX)
  • Python 3.12, macOS

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions