# Cellstates JAX vs CPU (Colab ready)

This notebook installs `cellstates`, builds a tiny synthetic dataset, and times the JAX-accelerated vs. CPU gene contribution calculations. It is meant to be run in Google Colab; just set the runtime to GPU if you want to test JAX on GPU.


**How to use**

1. Run the install cell below (set `USE_GPU=True` only if your Colab runtime has a GPU).
2. Generate the toy dataset (<=100 cells, <=200 genes) with a fixed seed.
3. Benchmark `gene_contribution_table` (CPU) and `gene_contribution_table_jax` (JAX) with timings and parity checks.


In [None]:
#@title Install cellstates (Colab)
# Clean any previous checkout so it doesn't shadow the installed package
!rm -rf cellstates

# 1) Make sure build tools and git are present
!apt-get update -qq
!apt-get install -y build-essential git

# 2) Python deps that setup.py expects
!pip install -q cython numpy scipy pandas matplotlib

# 3) Clone your fork
!git clone https://github.com/RemyNicolle/cellstates.git
%cd cellstates

# 4) Build and install via pip (handles the C extension build)
!python -m pip install --no-build-isolation --no-cache-dir .

# 5) Install JAX for the JAX benchmark (CPU by default; swap to cuda12_pip on GPU runtimes)
# 5) Install JAX for the JAX benchmark (CPU by default; swap to cuda12_pip on GPU runtimes)
# !pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html
# For GPU runtimes, instead run:
# !pip install -q 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# For TPU runtimes, instead run:
!pip install -q 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html

# 6) Go back to the notebook root and remove the checkout so imports use site-packages
%cd /content
!rm -rf cellstates

# 7) Verify the import resolves to the installed package
import importlib.util, cellstates
print("cellstates spec:", importlib.util.find_spec("cellstates"))
from cellstates import Cluster, get_hierarchy_df
print("Cluster import OK; module file:", cellstates.__file__)

In [None]:
import cellstates
from cellstates import Cluster, get_hierarchy_df


## Generate a tiny synthetic dataset
We keep it small (<=100 cells, <=200 genes) so it runs fast in Colab. A fixed seed ensures reproducibility.


In [None]:
import numpy as np
from cellstates import Cluster, get_hierarchy_df

SEED = 7
GENES = 1000
CELLS = 200
CLUSTERS = 20

rng = np.random.default_rng(SEED)
cells_per_cluster = np.full(CLUSTERS, CELLS // CLUSTERS)
cells_per_cluster[: CELLS % CLUSTERS] += 1

base_rates = rng.gamma(shape=1.5, scale=1.0, size=(GENES, CLUSTERS))
cluster_scales = np.linspace(0.7, 1.3, CLUSTERS)
cluster_rates = base_rates * cluster_scales

counts_parts = []
labels = []
for idx, n_cells in enumerate(cells_per_cluster):
    lam = cluster_rates[:, idx : idx + 1]
    counts_parts.append(rng.poisson(lam, size=(GENES, n_cells)))
    labels.append(np.full(n_cells, idx, dtype=np.int32))

counts = np.concatenate(counts_parts, axis=1).astype(np.int64)
init_clusters = np.concatenate(labels)

print(f"Counts shape: {counts.shape} | cluster sizes: {np.bincount(init_clusters)}")

clst_cpu = Cluster(counts, c=init_clusters, max_clusters=init_clusters.max() + 1, seed=SEED)
hierarchy, delta_ll = clst_cpu.get_cluster_hierarchy()
hierarchy_df = get_hierarchy_df(hierarchy, delta_ll)
print(f"Hierarchy steps: {hierarchy_df.shape[0]}")


## JAX hierarchy prototype (float32, TPU-friendly, chunked)
This computes the cluster hierarchy directly in JAX from aggregated cluster counts. It operates on the current partition (clusters), not individual cells, and uses float32 to keep memory lower on TPU/GPU. The greedy search is still O(K^2), but you can chunk pair evaluations to reduce peak memory at the cost of more host/device transfers.

In [None]:
import jax
from cellstates import get_cluster_hierarchy_jax_from_counts

# Aggregate counts per cluster (shape: genes x clusters)
counts_per_cluster = np.zeros((counts.shape[0], init_clusters.max() + 1), dtype=np.int64)
for idx in range(counts_per_cluster.shape[1]):
    mask = init_clusters == idx
    if mask.any():
        counts_per_cluster[:, idx] = counts[:, mask].sum(axis=1)

lam = np.asarray(clst_cpu.dirichlet_pseudocounts, dtype=np.float32)

# Device preference: TPU > GPU > CPU
if available_jax_devices("tpu"):
    device_preference = "tpu"
elif available_jax_devices("gpu"):
    device_preference = "gpu"
else:
    device_preference = "cpu"

pair_chunk_size = 10_000  # tune down to lower peak memory

merges_jax, delta_jax = get_cluster_hierarchy_jax_from_counts(
    counts_per_cluster,
    lam,
    device=device_preference,
    enable_x64=False,
    dtype=jax.numpy.float32,
    pair_chunk_size=pair_chunk_size,
)

print(f"JAX hierarchy merges: {len(merges_jax)}")
print("First 5 merges:", merges_jax[:5])
print("First 5 delta_LL:", [float(x) for x in delta_jax[:5]])


## JAX greedy partition (multi-sweep prototype)
Runs float32 JAX-based greedy sweeps over cells with optional candidate pruning. Shuffles cell order between sweeps. Not a full MCMC replica, but TPU-friendly and chunked.

In [None]:
from cellstates import run_greedy_partition_jax

# Tune these to trade memory vs speed/coverage
cluster_chunk = 128
candidate_topk = 32  # restrict proposals to top-N clusters per cell (None to scan all)
sweeps = 3

new_clusters, moves, total_delta = run_greedy_partition_jax(
    counts,
    init_clusters,
    lam=np.asarray(clst_cpu.dirichlet_pseudocounts, dtype=np.float32),
    sweeps=sweeps,
    device=device_preference,
    enable_x64=False,
    dtype=jax.numpy.float32,
    cluster_chunk=cluster_chunk,
    candidate_topk=candidate_topk,
    seed=SEED,
)

print(f"Sweeps: {sweeps}, total moves: {moves}, total delta_LL: {total_delta:.3f}")
print("New cluster labels (first 10):", new_clusters[:10])


## Benchmark CPU vs JAX gene contribution scores
The first JAX call includes compilation; the second shows cached execution time. Both paths are compared for numerical parity.


In [None]:
import os
# Optional: relax GPU preallocation to trim RAM spikes
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.7")
os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")

import time
import numpy as np
import jax

from cellstates import (
    Cluster,
    available_jax_devices,
    gene_contribution_table,
    gene_contribution_table_jax,
    jax_available,
)

# Favor float64 for comparison to match CPU precision
jax.config.update("jax_enable_x64", False)


def time_call(fn):
    start = time.perf_counter()
    out = fn()
    return out, time.perf_counter() - start


if not jax_available():
    raise RuntimeError("JAX is not available; run the install cell first.")

# Prioritize TPU, then GPU, then CPU
if available_jax_devices("tpu"):
    device_preference = "tpu"
elif available_jax_devices("gpu"):
    device_preference = "gpu"
else:
    device_preference = "cpu"

# Fresh cluster for the JAX path so we avoid any state mutations
clst_jax = Cluster(counts, c=init_clusters, max_clusters=init_clusters.max() + 1, seed=SEED)

cpu_scores, cpu_time = time_call(lambda: gene_contribution_table(clst_cpu, hierarchy_df))

print(f"JAX devices detected: {jax.devices()}")
print(f"Device used for JAX helper: {device_preference}")
print(f"JAX enable_x64: {jax.config.read('jax_enable_x64')}")
print(f"CPU gene_contribution_table: {cpu_time:.4f} s")

jax_scores_first, jax_time_first = time_call(
    lambda: gene_contribution_table_jax(
        clst_jax,
        hierarchy_df,
        device=device_preference,
        enable_x64=True, # Set to True to match CPU precision
    )
)

jax_scores_cached, jax_time_cached = time_call(
    lambda: gene_contribution_table_jax(
        clst_jax,
        hierarchy_df,
        device=device_preference,
        enable_x64=True, # Set to True to match CPU precision
    )
)

max_diff = float(np.max(np.abs(cpu_scores - jax_scores_cached)))
np.testing.assert_allclose(cpu_scores, jax_scores_cached, rtol=5e-4, atol=5e-2)


print(f"JAX gene_contribution_table_jax (compile+run): {jax_time_first:.4f} s")
print(f"JAX gene_contribution_table_jax (cached run): {jax_time_cached:.4f} s")
print(f"Max |CPU - JAX| difference: {max_diff:.3e}")