# Cellstates JAX vs CPU (Colab ready)

**How to use**

1. Run the install cell.
2. Run the clustering synthetic data + PCA/k-means init.
3. Run the clustering benchmark (CPU vs JAX devices).
4. Run the hierarchy synthetic data + benchmark.

In [None]:
#@title Install cellstates (Colab)
import os, subprocess, sys

# Clean any previous checkout and any shadowing namespace dirs
subprocess.run("find /content -maxdepth 2 -type d -name 'cellstates' -print -exec rm -rf {} +", shell=True)
subprocess.run("rm -rf /content/cellstates_src", shell=True)

# 1) Build tools and git
subprocess.check_call("apt-get update -qq", shell=True)
subprocess.check_call("apt-get install -y build-essential git", shell=True)

# 2) Python deps
subprocess.check_call("pip install -q cython numpy scipy pandas matplotlib scikit-learn", shell=True)

# 3) Clone and install cellstates
subprocess.check_call("git clone https://github.com/RemyNicolle/cellstates.git /content/cellstates_src", shell=True)
subprocess.check_call("python -m pip install --no-build-isolation --no-cache-dir /content/cellstates_src", shell=True)

# 4) Install JAX with simple detection: TPU > GPU > CPU
has_tpu = bool(os.environ.get("COLAB_TPU_ADDR") or os.environ.get("TPU_NAME"))
has_gpu = subprocess.run("nvidia-smi", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0

if has_tpu:
    jax_cmd = "pip install -q 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    ret = subprocess.run(jax_cmd, shell=True)
    if ret.returncode != 0:
        print("TPU JAX install failed; falling back to CPU wheel")
        subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)
elif has_gpu:
    jax_cmd = "pip install -q 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
    ret = subprocess.run(jax_cmd, shell=True)
    if ret.returncode != 0:
        print("CUDA JAX install failed; falling back to CPU wheel")
        subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)
else:
    subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)

# 5) Return to root and clean clone to avoid shadowing
subprocess.run("rm -rf /content/cellstates_src", shell=True)

# 6) Verify import
import importlib.util, cellstates
print("cellstates spec:", importlib.util.find_spec("cellstates"))
from cellstates import Cluster
print("Cluster import OK; module file:", cellstates.__file__)
print("TPU detected:", has_tpu)
print("GPU detected:", has_gpu)


In [None]:
import numpy as np
import time
import jax
from cellstates import (
    Cluster,
    available_jax_devices,
    run_greedy_partition_jax,
    get_cluster_hierarchy_jax_from_counts,
)

## Synthetic data for clustering benchmark

In [None]:
SEED = 7
GENES_CLUST = 10
CELLS_CLUST = 200
CLUSTERS_CLUST = 20

rng = np.random.default_rng(SEED)
cells_per_cluster = np.full(CLUSTERS_CLUST, CELLS_CLUST // CLUSTERS_CLUST)
cells_per_cluster[: CELLS_CLUST % CLUSTERS_CLUST] += 1

base_rates = rng.gamma(shape=1.5, scale=1.0, size=(GENES_CLUST, CLUSTERS_CLUST))
cluster_scales = np.linspace(0.7, 1.3, CLUSTERS_CLUST)
cluster_rates = base_rates * cluster_scales

counts_parts = []
labels = []
for idx, n_cells in enumerate(cells_per_cluster):
    lam = cluster_rates[:, idx : idx + 1]
    counts_parts.append(rng.poisson(lam, size=(GENES_CLUST, n_cells)))
    labels.append(np.full(n_cells, idx, dtype=np.int32))

counts_clust = np.concatenate(counts_parts, axis=1).astype(np.int64)
print('Clustering counts shape:', counts_clust.shape)


## Init clusters with PCA + k-means (synthetic data)

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

max_genes_for_pca = min(2000, counts_clust.shape[0])
X = counts_clust
if X.shape[0] > max_genes_for_pca:
    gene_idx = np.random.default_rng(SEED).choice(X.shape[0], size=max_genes_for_pca, replace=False)
    X = X[gene_idx]

X_t = np.log1p(X).astype(np.float32)
scaler = StandardScaler(with_mean=True, with_std=True)
X_scaled = scaler.fit_transform(X_t)

n_pcs = 30
pca = PCA(n_components=min(n_pcs, X_scaled.shape[0], X_scaled.shape[1]), random_state=SEED)
X_pca = pca.fit_transform(X_scaled.T)

n_init_clusters = 200
kmeans = KMeans(n_clusters=n_init_clusters, n_init=5, random_state=SEED, verbose=0)
init_clusters = kmeans.fit_predict(X_pca).astype(np.int32)

print(f"Init clusters from k-means: {n_init_clusters}, counts per cluster (first 10): {np.bincount(init_clusters)[:10]}")


## Clustering benchmark (CPU vs JAX prototype)

In [None]:
# Device preference: TPU > GPU > CPU
if available_jax_devices("tpu"):
    device_preference = "tpu"
elif available_jax_devices("gpu"):
    device_preference = "gpu"
else:
    device_preference = "cpu"

cluster_chunk = 128
candidate_topk = None
sweeps = 3
lam = np.full(counts_clust.shape[0], 0.5, dtype=np.float32)

def time_call(fn):
    start = time.perf_counter()
    out = fn()
    return out, time.perf_counter() - start

# CPU (JAX on CPU)
(cpu_clusters, cpu_moves, cpu_delta), cpu_time = time_call(
    lambda: run_greedy_partition_jax(
        counts_clust,
        init_clusters,
        lam=lam,
        sweeps=sweeps,
        device="cpu",
        enable_x64=False,
        dtype=jax.numpy.float32,
        cluster_chunk=cluster_chunk,
        candidate_topk=candidate_topk,
        seed=SEED,
    )
)

# JAX on preferred device
(jax_clusters, jax_moves, jax_delta), jax_time = time_call(
    lambda: run_greedy_partition_jax(
        counts_clust,
        init_clusters,
        lam=lam,
        sweeps=sweeps,
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        cluster_chunk=cluster_chunk,
        candidate_topk=candidate_topk,
        seed=SEED,
    )
)

print(f"CPU sweeps: {sweeps}, moves: {cpu_moves}, total delta_LL: {cpu_delta:.3f}, time: {cpu_time:.3f} s")
print(f"JAX sweeps ({device_preference}): moves: {jax_moves}, total delta_LL: {jax_delta:.3f}, time: {jax_time:.3f} s")
print("Cluster label sample (CPU first 10):", cpu_clusters[:10])
print("Cluster label sample (JAX first 10):", jax_clusters[:10])


## JAX stochastic partition (alternative to greedy)
Samples a small set of candidate clusters per cell each sweep (always includes the current cluster) and accepts only improving moves. Use when K is large and you want lower memory than the full greedy scan.

In [None]:
from cellstates import stochastic_partition_jax

# Limit proposals to keep memory low; set to None to sample all clusters
proposals_per_cell = 8  # fixed size for JAX shapes  # lower to keep proposals fixed and small
sweeps_alt = 3

(alt_clusters, alt_moves, alt_delta), alt_time = time_call(
    lambda: stochastic_partition_jax(
        counts_clust,
        init_clusters,
        lam=lam,
        sweeps=sweeps_alt,
proposals_per_cell = 8  # fixed size for JAX shapes  # lower to keep proposals fixed and small
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        seed=SEED,
    )
)

print(f"Alt stochastic: sweeps={sweeps_alt}, moves={alt_moves}, total delta_LL={alt_delta:.3f}, time={alt_time:.3f} s")
print("Cluster label sample (alt first 10):", alt_clusters[:10])


## Synthetic data for hierarchy benchmark

In [None]:
# Smaller hierarchy-friendly dataset
GENES_HIER = 3000
CELLS_HIER = 2000
CLUSTERS_HIER = 150

rng = np.random.default_rng(SEED + 1)
cells_per_cluster_h = np.full(CLUSTERS_HIER, CELLS_HIER // CLUSTERS_HIER)
cells_per_cluster_h[: CELLS_HIER % CLUSTERS_HIER] += 1

base_rates_h = rng.gamma(shape=1.5, scale=1.0, size=(GENES_HIER, CLUSTERS_HIER))
cluster_scales_h = np.linspace(0.7, 1.3, CLUSTERS_HIER)
cluster_rates_h = base_rates_h * cluster_scales_h

counts_parts_h = []
labels_h = []
for idx, n_cells in enumerate(cells_per_cluster_h):
    lam_h = cluster_rates_h[:, idx : idx + 1]
    counts_parts_h.append(rng.poisson(lam_h, size=(GENES_HIER, n_cells)))
    labels_h.append(np.full(n_cells, idx, dtype=np.int32))

counts_hier = np.concatenate(counts_parts_h, axis=1).astype(np.int64)
init_clusters_hier = np.concatenate(labels_h)

print('Hierarchy counts shape:', counts_hier.shape)


## Hierarchy benchmark (CPU vs JAX)

In [None]:
# CPU hierarchy
clst_cpu = Cluster(counts_hier, c=init_clusters_hier, max_clusters=init_clusters_hier.max() + 1, seed=SEED)

start = time.perf_counter()
hierarchy_cpu, delta_cpu = clst_cpu.get_cluster_hierarchy()
cpu_hier_time = time.perf_counter() - start

# Aggregate counts per cluster for JAX
counts_per_cluster = np.zeros((counts_hier.shape[0], init_clusters_hier.max() + 1), dtype=np.int64)
for idx in range(counts_per_cluster.shape[1]):
    mask = init_clusters_hier == idx
    if mask.any():
        counts_per_cluster[:, idx] = counts_hier[:, mask].sum(axis=1)

lam_hier = np.asarray(clst_cpu.dirichlet_pseudocounts, dtype=np.float32)

pair_chunk_size = 50_000

start = time.perf_counter()
merges_jax, delta_jax = get_cluster_hierarchy_jax_from_counts(
    counts_per_cluster,
    lam_hier,
    device=device_preference,
    enable_x64=False,
    dtype=jax.numpy.float32,
    pair_chunk_size=pair_chunk_size,
)
jax_hier_time = time.perf_counter() - start

print(f"CPU hierarchy: merges={len(hierarchy_cpu)}, time={cpu_hier_time:.3f} s")
print(f"JAX hierarchy ({device_preference}): merges={len(merges_jax)}, time={jax_hier_time:.3f} s")
print("First 5 CPU merges:", hierarchy_cpu[:5])
print("First 5 JAX merges:", merges_jax[:5])
