# Cellstates JAX vs CPU (Colab ready)

This notebook installs `cellstates`, builds a tiny synthetic dataset, and times the JAX-accelerated vs. CPU gene contribution calculations. It is meant to be run in Google Colab; just set the runtime to GPU if you want to test JAX on GPU.


**How to use**

1. Run the install cell below (set `USE_GPU=True` only if your Colab runtime has a GPU).
2. Generate the toy dataset (<=100 cells, <=200 genes) with a fixed seed.
3. Benchmark `gene_contribution_table` (CPU) and `gene_contribution_table_jax` (JAX) with timings and parity checks.


In [None]:
#@title Install cellstates (Colab)
# Clean any previous checkout so it doesn't shadow the installed package
!rm -rf cellstates

# 1) Make sure build tools and git are present
!apt-get update -qq
!apt-get install -y build-essential git

# 2) Python deps that setup.py expects
!pip install -q cython numpy scipy pandas matplotlib

# 3) Clone your fork
!git clone https://github.com/RemyNicolle/cellstates.git
%cd cellstates

# 4) Build and install via pip (handles the C extension build)
!python -m pip install --no-build-isolation --no-cache-dir .

# 5) Install JAX for the JAX benchmark (CPU by default; swap to cuda12_pip on GPU runtimes)
!pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html
# For GPU runtimes, instead run:
# !pip install -q 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# 6) Go back to the notebook root and remove the checkout so imports use site-packages
%cd /content
!rm -rf cellstates

# 7) Verify the import resolves to the installed package
import importlib.util, cellstates
print("cellstates spec:", importlib.util.find_spec("cellstates"))
from cellstates import Cluster, get_hierarchy_df
print("Cluster import OK; module file:", cellstates.__file__)


In [None]:
import cellstates

## Generate a tiny synthetic dataset
We keep it small (<=100 cells, <=200 genes) so it runs fast in Colab. A fixed seed ensures reproducibility.


In [None]:
import numpy as np
from cellstates import Cluster, get_hierarchy_df

SEED = 7
GENES = 120
CELLS = 90
CLUSTERS = 3

rng = np.random.default_rng(SEED)
cells_per_cluster = np.full(CLUSTERS, CELLS // CLUSTERS)
cells_per_cluster[: CELLS % CLUSTERS] += 1

base_rates = rng.gamma(shape=1.5, scale=1.0, size=(GENES, CLUSTERS))
cluster_scales = np.linspace(0.7, 1.3, CLUSTERS)
cluster_rates = base_rates * cluster_scales

counts_parts = []
labels = []
for idx, n_cells in enumerate(cells_per_cluster):
    lam = cluster_rates[:, idx : idx + 1]
    counts_parts.append(rng.poisson(lam, size=(GENES, n_cells)))
    labels.append(np.full(n_cells, idx, dtype=np.int32))

counts = np.concatenate(counts_parts, axis=1).astype(np.int64)
init_clusters = np.concatenate(labels)

print(f"Counts shape: {counts.shape} | cluster sizes: {np.bincount(init_clusters)}")

clst_cpu = Cluster(counts, c=init_clusters, max_clusters=init_clusters.max() + 1, seed=SEED)
hierarchy, delta_ll = clst_cpu.get_cluster_hierarchy()
hierarchy_df = get_hierarchy_df(hierarchy, delta_ll)
print(f"Hierarchy steps: {hierarchy_df.shape[0]}")


## Benchmark CPU vs JAX gene contribution scores
The first JAX call includes compilation; the second shows cached execution time. Both paths are compared for numerical parity.


In [None]:
import time
import numpy as np
import jax

from cellstates import (
    Cluster,
    available_jax_devices,
    gene_contribution_table,
    gene_contribution_table_jax,
    jax_available,
)


def time_call(fn):
    start = time.perf_counter()
    out = fn()
    return out, time.perf_counter() - start


if not jax_available():
    raise RuntimeError("JAX is not available; run the install cell first.")

device_preference = "gpu" if available_jax_devices("gpu") else "cpu"

# Fresh cluster for the JAX path so we avoid any state mutations
clst_jax = Cluster(counts, c=init_clusters, max_clusters=init_clusters.max() + 1, seed=SEED)

cpu_scores, cpu_time = time_call(lambda: gene_contribution_table(clst_cpu, hierarchy_df))

jax_scores_first, jax_time_first = time_call(
    lambda: gene_contribution_table_jax(
        clst_jax,
        hierarchy_df,
        device=device_preference,
        enable_x64=True,
    )
)

jax_scores_cached, jax_time_cached = time_call(
    lambda: gene_contribution_table_jax(
        clst_jax,
        hierarchy_df,
        device=device_preference,
        enable_x64=True,
    )
)

max_diff = float(np.max(np.abs(cpu_scores - jax_scores_cached)))
np.testing.assert_allclose(cpu_scores, jax_scores_cached, rtol=1e-9, atol=1e-9)

print(f"JAX devices detected: {jax.devices()}")
print(f"Device used for JAX helper: {device_preference}")
print(f"CPU gene_contribution_table: {cpu_time:.4f} s")
print(f"JAX gene_contribution_table_jax (compile+run): {jax_time_first:.4f} s")
print(f"JAX gene_contribution_table_jax (cached run): {jax_time_cached:.4f} s")
print(f"Max |CPU - JAX| difference: {max_diff:.3e}")
