Skip to content

v0.1.3

Latest

Choose a tag to compare

@mahdi-shafiei mahdi-shafiei released this 01 May 02:51

Highlights

Pipeline replaces DAGExecutor

The DAGExecutor class is gone. Pipeline(nnx.Module) is now the single user-facing class for data-loading topologies, with three construction shapes:

  • Pipeline(source=..., stages=[...], ...) — linear chain.
  • Pipeline.from_dag(source=..., nodes=..., edges=..., sink=...) — declarative DAG.
  • Subclass override of __call__ — custom topologies.

Pipeline.scan(step_fn, length, modules=(...), init_carry=...) lifts an entire epoch under nnx.scan, lifting user nnx.Module state via StateAxes. Scan bodies are now cached on the Pipeline instance keyed on (id(step_fn), length, n_modules, has_init_carry), so subsequent calls reuse the compiled graph instead of re-tracing.

Locally measured (CUDA, 100 batches, A100-class):

Scenario iter mode scan mode speedup
NLP-1 small 70 ms 2.6 ms 27x
TAB-1 small 30 ms 2.7 ms 11x

Source primitives

EagerSourceBase, StreamingDiskSource, ArrayRecord, MemorySource, and MixDataSourcesNode all expose get_batch_at(start, batch_size, key) for JIT-traceable random access from inside scan bodies. MemorySource gains shuffled-mode get_batch_at via jax.random.permutation. MixDataSourcesNode.get_batch_at is JIT-traceable via vmap + lax.switch.

Benchmark + interpretation tooling

  • New DataraxScanAdapter (registered as Datarax-scan) measures whole-epoch execution alongside the iterator-mode DataraxAdapter.
  • Auto-generated comparison reports gain an Interpretation Notes section.
  • New docs/benchmarks/interpretation.md page documenting the three throughput-gap categories, which adapters iterate on GPU vs CPU regardless of the JAX backend, and the iter-vs-scan adapter pair.

P3 memory-efficiency test hardening

Decision logic extracted into classify_rss_comparison() helper:

  • Hard cap on Datarax peak RSS that fails loud regardless of SPDL.
  • SPDL-zero null floor that skips with a clear "zero-copy / pre-loaded fixture" message instead of misattributing the cause to allocator noise.
  • 9 fast unit tests covering each branch.

flax.nnx API migration

  • var.valuevar[...] for Variable[Array].
  • var.valuevar.get_value() for Variable[PyTree] (dict-typed).

All call sites in src/, tests/, benchmarks/, and examples/ updated; zero .value DeprecationWarnings in the local test run.

Examples and documentation

  • Migrated to Pipeline API: core tutorials 02-09, integration backends (HF, TFDS, ArrayRecord), differentiable flagship guides (ISP, DDSP, DADA), checkpointing example (NNX-standard to_pure_dict / replace_by_pure_dict pattern).
  • New examples/integration/ with runnable templates for downstream ML and SciML users.
  • Pipeline migration guide + Pipeline quickstart example shipped.
  • New Branching DAG Cookbook page.

Deployment + automation

  • vastai-sdk pinned to >=0.2.6,<0.3 in the automation extra to unblock SkyPilot 0.11.x.
  • matplotlib added to the automation extra so the orchestrator's local analyze step can render charts without a separate --extra benchmark install.
  • torch capped at <2.11 to keep CUDA 12 wheels (2.11+ pulls libcudart.so.13).
  • pyarrow>=21 floored for pa.json_ and macOS-14 arm64 stability.
  • uv.lock regenerated against the updated constraints.

Tests

2339 passed, 11 skipped (GPU-disabled by conftest policy + SPDL-null + apache_beam-optional, all audited justified) in the local run.

Breaking changes

  • DAGExecutor is gone. Migrate to Pipeline.from_dag(...) for branching topologies; see docs/benchmarks/interpretation.md and the Pipeline migration guide for the full diff.
  • Pipeline.epoch was replaced by Pipeline.scan(step_fn, modules=..., length=..., init_carry=...).