# Cellstates JAX vs CPU (Colab ready)

**How to use**

1. Run the install cell.
2. Run the clustering synthetic data + PCA/k-means init.
3. Run the clustering benchmark (CPU vs JAX devices).
4. Run the hierarchy synthetic data + benchmark.

In [None]:
#@title Install cellstates (Colab)
import os, subprocess, sys

# Clean any previous checkout and any shadowing namespace dirs
subprocess.run("find /content -maxdepth 2 -type d -name 'cellstates' -print -exec rm -rf {} +", shell=True)
subprocess.run("rm -rf /content/cellstates_src", shell=True)

# 1) Build tools and git
subprocess.check_call("apt-get update -qq", shell=True)
subprocess.check_call("apt-get install -y build-essential git", shell=True)

# 2) Python deps
subprocess.check_call("pip install -q cython numpy scipy pandas matplotlib scikit-learn", shell=True)

# 3) Clone and install cellstates
subprocess.check_call("git clone https://github.com/RemyNicolle/cellstates.git /content/cellstates_src", shell=True)
subprocess.check_call("python -m pip install --no-build-isolation --no-cache-dir /content/cellstates_src", shell=True)

# 4) Install JAX with simple detection: TPU > GPU > CPU
has_tpu = bool(os.environ.get("COLAB_TPU_ADDR") or os.environ.get("TPU_NAME"))
has_gpu = subprocess.run("nvidia-smi", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0

if has_tpu:
    jax_cmd = "pip install -q 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    ret = subprocess.run(jax_cmd, shell=True)
    if ret.returncode != 0:
        print("TPU JAX install failed; falling back to CPU wheel")
        subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)
elif has_gpu:
    jax_cmd = "pip install -q 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
    ret = subprocess.run(jax_cmd, shell=True)
    if ret.returncode != 0:
        print("CUDA JAX install failed; falling back to CPU wheel")
        subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)
else:
    subprocess.check_call("pip install -q 'jax[cpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html", shell=True)

# 5) Return to root and clean clone to avoid shadowing
subprocess.run("rm -rf /content/cellstates_src", shell=True)

# 6) Verify import
import importlib.util, cellstates
print("cellstates spec:", importlib.util.find_spec("cellstates"))

import numpy as np
import time
import jax
from cellstates import (
    Cluster,
    available_jax_devices,
    stochastic_partition_jax,
    run_gibbs_partition_jax,
    get_cluster_hierarchy_jax_from_counts,
)
print("Cluster import OK; module file:", cellstates.__file__)
print("TPU detected:", has_tpu)
print("GPU detected:", has_gpu)


## Synthetic data for clustering benchmark

## Benchmark options

In [None]:
# Benchmark options
USE_JAX_GIBBS = False  # set False to skip JAX Gibbs variants
GIBBS_SWEEPS = 3
MCMC_STEPS = 50
MCMC_TRIES = 400

SEED = 7
GENES_CLUST = 100
CELLS_CLUST = 2000
CLUSTERS_CLUST = 100


rng = np.random.default_rng(SEED)
cells_per_cluster = np.full(CLUSTERS_CLUST, CELLS_CLUST // CLUSTERS_CLUST)
cells_per_cluster[: CELLS_CLUST % CLUSTERS_CLUST] += 1

base_rates = rng.gamma(shape=1.5, scale=1.0, size=(GENES_CLUST, CLUSTERS_CLUST))
cluster_scales = np.linspace(0.7, 1.3, CLUSTERS_CLUST)
cluster_rates = base_rates * cluster_scales

counts_parts = []
labels = []
for idx, n_cells in enumerate(cells_per_cluster):
    lam = cluster_rates[:, idx : idx + 1]
    counts_parts.append(rng.poisson(lam, size=(GENES_CLUST, n_cells)))
    labels.append(np.full(n_cells, idx, dtype=np.int32))

counts_clust = np.concatenate(counts_parts, axis=1).astype(np.int64)
print('Clustering counts shape:', counts_clust.shape)

true_clusters = np.concatenate(labels)

## Init clusters with PCA + k-means (synthetic data)

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

max_genes_for_pca = min(2000, counts_clust.shape[0])
X = counts_clust
if X.shape[0] > max_genes_for_pca:
    gene_idx = np.random.default_rng(SEED).choice(X.shape[0], size=max_genes_for_pca, replace=False)
    X = X[gene_idx]

X_t = np.log1p(X).astype(np.float32)
scaler = StandardScaler(with_mean=True, with_std=True)
X_scaled = scaler.fit_transform(X_t)

n_pcs = 30
pca = PCA(n_components=min(n_pcs, X_scaled.shape[0], X_scaled.shape[1]), random_state=SEED)
X_pca = pca.fit_transform(X_scaled.T)

n_init_clusters = min(100, counts_clust.shape[1])
kmeans = KMeans(n_clusters=n_init_clusters, n_init=5, random_state=SEED, verbose=0)
init_clusters = kmeans.fit_predict(X_pca).astype(np.int32)

print(f"Init clusters from k-means: {n_init_clusters}, counts per cluster (first 10): {np.bincount(init_clusters)[:10]}")


## Clustering benchmark (CPU vs JAX prototype)

In [None]:
# Device preference: TPU > GPU > CPU
if available_jax_devices("tpu"):
    device_preference = "tpu"
elif available_jax_devices("gpu"):
    device_preference = "gpu"
else:
    device_preference = "cpu"

cluster_chunk = 128
candidate_topk = None
sweeps = 3
lam = np.full(counts_clust.shape[0], 0.5, dtype=np.float32)

# Cython CPU baseline (short MCMC)
start = time.perf_counter()
cl_baseline = Cluster(counts_clust, c=init_clusters, max_clusters=init_clusters.max() + 1, seed=SEED)
try:
    cl_baseline.biased_monte_carlo_sampling(N_steps=50, tries_per_step=400, min_index=0)
except RuntimeError:
    pass
cython_time = time.perf_counter() - start
labels_cython = np.asarray(cl_baseline.clusters, dtype=np.int32)

print(f"Cython baseline time: {cython_time:.3f} s, clusters: {labels_cython.max()+1}")


## JAX stochastic partition (alternative to greedy)
Samples a small set of candidate clusters per cell each sweep (always includes the current cluster) and accepts only improving moves. Use when K is large and you want lower memory than the full greedy scan.

In [None]:
from cellstates import stochastic_partition_jax

# fixed size for JAX shapes
proposals_per_cell = 8
sweeps_alt = 3

start = time.perf_counter()
labels_stochastic, alt_moves, alt_delta = stochastic_partition_jax(
    counts_clust,
    init_clusters,
    lam=lam,
    sweeps=sweeps_alt,
    device=device_preference,
    enable_x64=False,
    dtype=jax.numpy.float32,
    seed=SEED,
)
alt_time = time.perf_counter() - start
print(f"Stochastic ({device_preference}) time: {alt_time:.3f} s, moves: {alt_moves}, delta_LL: {alt_delta:.3f}")


## Likelihood and ARI comparison
Compare partitions: initial k-means, JAX greedy (CPU/JAX), JAX stochastic, and Cython MCMC with/without stochastic warm start. ARI is measured against the synthetic ground-truth labels. Likelihoods are from the Cython Cluster for each labeling.

**Note:** likelihood reported is log-likelihood (natural log), higher is better.

In [None]:
from sklearn.metrics import adjusted_rand_score
from cellstates import Cluster, stochastic_partition_jax, run_gibbs_partition_jax
import pandas as pd
import matplotlib.pyplot as plt
import time

# Helper to compute log-likelihood for a given labeling (natural log)
def partition_log_likelihood(labels):
    cl = Cluster(counts_clust, c=labels, max_clusters=labels.max() + 1, seed=SEED)
    return float(cl.total_likelihood)

# Baseline partitions and timings
baseline_parts = {
    "cython_mcmc": (labels_cython, cython_time),
    "kmeans_init": (init_clusters, 0.0),
    "stochastic": (labels_stochastic, alt_time),
}

rows = []
for name, (labels, t_val) in baseline_parts.items():
    t_val = t_val or 0.0
    rows.append({
        "partition": name,
        "device": device_preference,
        "log_likelihood": partition_log_likelihood(labels),
        "ari_true": adjusted_rand_score(true_clusters, labels),
        "ari_cython": adjusted_rand_score(labels_cython, labels),
        "refined": False,
        "note": "",
        "warm_time_s": t_val,
        "mcmc_time_s": 0.0,
        "total_time_s": t_val,
        "n_clusters": int(labels.max() + 1),
    })

# MCMC refinement from stochastic warm start
cl = Cluster(counts_clust, c=labels_stochastic, max_clusters=labels_stochastic.max() + 1, seed=SEED)
note = ""
start = time.perf_counter()
try:
    cl.biased_monte_carlo_sampling(N_steps=MCMC_STEPS, tries_per_step=MCMC_TRIES, min_index=0)
except RuntimeError as err:
    note = f"MCMC early stop: {err}"
mcmc_elapsed = time.perf_counter() - start
refined = np.asarray(cl.clusters, dtype=np.int32)
rows.append({
    "partition": "mcmc_warm_stochastic",
    "device": device_preference,
    "log_likelihood": partition_log_likelihood(refined),
    "ari_true": adjusted_rand_score(true_clusters, refined),
    "ari_cython": adjusted_rand_score(labels_cython, refined),
    "refined": True,
    "note": note,
    "warm_time_s": alt_time,
    "mcmc_time_s": mcmc_elapsed,
    "total_time_s": alt_time + mcmc_elapsed,
    "n_clusters": int(refined.max() + 1),
})

if USE_JAX_GIBBS:
    # JAX Gibbs from scratch
    start = time.perf_counter()
    labels_gibbs, gibbs_moves, gibbs_delta = run_gibbs_partition_jax(
        counts_clust,
        init_clusters,
        lam=lam,
        sweeps=GIBBS_SWEEPS,
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        seed=SEED,
    )
    gibbs_time = time.perf_counter() - start
    rows.append({
        "partition": "jax_gibbs",
        "device": device_preference,
        "log_likelihood": partition_log_likelihood(labels_gibbs),
        "ari_true": adjusted_rand_score(true_clusters, labels_gibbs),
        "ari_cython": adjusted_rand_score(labels_cython, labels_gibbs),
        "refined": False,
        "note": "",
        "warm_time_s": gibbs_time,
        "mcmc_time_s": 0.0,
        "total_time_s": gibbs_time,
        "n_clusters": int(labels_gibbs.max() + 1),
    })

    # JAX Gibbs warm start from stochastic
    start = time.perf_counter()
    labels_gibbs_warm, gibbs_moves_warm, gibbs_delta_warm = run_gibbs_partition_jax(
        counts_clust,
        labels_stochastic,
        lam=lam,
        sweeps=GIBBS_SWEEPS,
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        seed=SEED + 1,
    )
    gibbs_warm_time = time.perf_counter() - start
    rows.append({
        "partition": "jax_gibbs_warm_stochastic",
        "device": device_preference,
        "log_likelihood": partition_log_likelihood(labels_gibbs_warm),
        "ari_true": adjusted_rand_score(true_clusters, labels_gibbs_warm),
        "ari_cython": adjusted_rand_score(labels_cython, labels_gibbs_warm),
        "refined": False,
        "note": "",
        "warm_time_s": alt_time + gibbs_warm_time,
        "mcmc_time_s": 0.0,
        "total_time_s": alt_time + gibbs_warm_time,
        "n_clusters": int(labels_gibbs_warm.max() + 1),
    })

summary_df = pd.DataFrame(rows)
print(summary_df)

plt.figure(figsize=(10,4))
summary_df.boxplot(column="log_likelihood", by="partition", rot=45)
plt.suptitle("")
plt.title("Partition log-likelihood (natural log)")
plt.show()


from sklearn.metrics import adjusted_rand_score
from cellstates import Cluster, stochastic_partition_jax, run_gibbs_partition_jax
import pandas as pd
import matplotlib.pyplot as plt
import time

# Helper to compute likelihood for a given labeling
def partition_likelihood(labels):
    cl = Cluster(counts_clust, c=labels, max_clusters=labels.max() + 1, seed=SEED)
    return float(cl.total_likelihood)

# Baseline partitions and timings
baseline_parts = {
    "cython_mcmc": (labels_cython, cython_time),
    "stochastic": (labels_stochastic, alt_time),
}

rows = []
for name, (labels, t_val) in baseline_parts.items():
    t_val = t_val or 0.0
    rows.append({
        "partition": name,
        "device": device_preference,
        "likelihood": partition_likelihood(labels),
        "ari_true": adjusted_rand_score(true_clusters, labels),
        "ari_cython": adjusted_rand_score(labels_cython, labels),
        "refined": False,
        "note": "",
        "warm_time_s": t_val,
        "mcmc_time_s": 0.0,
        "total_time_s": t_val,
        "n_clusters": int(labels.max() + 1),
    })

# MCMC refinement from stochastic warm start
cl = Cluster(counts_clust, c=labels_stochastic, max_clusters=labels_stochastic.max() + 1, seed=SEED)
note = ""
start = time.perf_counter()
try:
    cl.biased_monte_carlo_sampling(N_steps=MCMC_STEPS, tries_per_step=MCMC_TRIES, min_index=0)
except RuntimeError as err:
    note = f"MCMC early stop: {err}"
mcmc_elapsed = time.perf_counter() - start
refined = np.asarray(cl.clusters, dtype=np.int32)
rows.append({
    "partition": "mcmc_warm_stochastic",
    "device": device_preference,
    "likelihood": partition_likelihood(refined),
    "ari_true": adjusted_rand_score(true_clusters, refined),
    "ari_cython": adjusted_rand_score(labels_cython, refined),
    "refined": True,
    "note": note,
    "warm_time_s": alt_time,
    "mcmc_time_s": mcmc_elapsed,
    "total_time_s": alt_time + mcmc_elapsed,
    "n_clusters": int(refined.max() + 1),
})

if USE_JAX_GIBBS:
    # JAX Gibbs from scratch
    start = time.perf_counter()
    labels_gibbs, gibbs_moves, gibbs_delta = run_gibbs_partition_jax(
        counts_clust,
        init_clusters,
        lam=lam,
        sweeps=GIBBS_SWEEPS,
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        seed=SEED,
    )
    gibbs_time = time.perf_counter() - start
    rows.append({
        "partition": "jax_gibbs",
        "device": device_preference,
        "likelihood": partition_likelihood(labels_gibbs),
        "ari_true": adjusted_rand_score(true_clusters, labels_gibbs),
        "ari_cython": adjusted_rand_score(labels_cython, labels_gibbs),
        "refined": False,
        "note": "",
        "warm_time_s": gibbs_time,
        "mcmc_time_s": 0.0,
        "total_time_s": gibbs_time,
        "n_clusters": int(labels_gibbs.max() + 1),
    })

    # JAX Gibbs warm start from stochastic
    start = time.perf_counter()
    labels_gibbs_warm, gibbs_moves_warm, gibbs_delta_warm = run_gibbs_partition_jax(
        counts_clust,
        labels_stochastic,
        lam=lam,
        sweeps=GIBBS_SWEEPS,
        device=device_preference,
        enable_x64=False,
        dtype=jax.numpy.float32,
        seed=SEED + 1,
    )
    gibbs_warm_time = time.perf_counter() - start
    rows.append({
        "partition": "jax_gibbs_warm_stochastic",
        "device": device_preference,
        "likelihood": partition_likelihood(labels_gibbs_warm),
        "ari_true": adjusted_rand_score(true_clusters, labels_gibbs_warm),
        "ari_cython": adjusted_rand_score(labels_cython, labels_gibbs_warm),
        "refined": False,
        "note": "",
        "warm_time_s": alt_time + gibbs_warm_time,
        "mcmc_time_s": 0.0,
        "total_time_s": alt_time + gibbs_warm_time,
        "n_clusters": int(labels_gibbs_warm.max() + 1),
    })

summary_df = pd.DataFrame(rows)
print(summary_df)

plt.figure(figsize=(10,4))
summary_df.boxplot(column="likelihood", by="partition", rot=45)
plt.suptitle("")
plt.title("Partition likelihoods")
plt.show()


In [None]:
# CPU hierarchy (reuse clustering synthetic data)
counts_hier = counts_clust
init_clusters_hier = init_clusters
clst_cpu = Cluster(counts_hier, c=init_clusters_hier, max_clusters=init_clusters_hier.max() + 1, seed=SEED)

start = time.perf_counter()
hierarchy_cpu, delta_cpu = clst_cpu.get_cluster_hierarchy()
cpu_hier_time = time.perf_counter() - start

# Aggregate counts per cluster for JAX
counts_per_cluster = np.zeros((counts_hier.shape[0], init_clusters_hier.max() + 1), dtype=np.int64)
for idx2 in range(counts_per_cluster.shape[1]):
    mask = init_clusters_hier == idx2
    if mask.any():
        counts_per_cluster[:, idx2] = counts_hier[:, mask].sum(axis=1)

lam_hier = np.asarray(clst_cpu.dirichlet_pseudocounts, dtype=np.float32)

pair_chunk_size = 50_000

start = time.perf_counter()
merges_jax, delta_jax = get_cluster_hierarchy_jax_from_counts(
    counts_per_cluster,
    lam_hier,
    device=device_preference,
    enable_x64=False,
    dtype=jax.numpy.float32,
    pair_chunk_size=pair_chunk_size,
)
jax_hier_time = time.perf_counter() - start

import pandas as pd
hier_rows = [
    {
        "method": "cpu_hierarchy",
        "merges": len(hierarchy_cpu),
        "time_s": cpu_hier_time,
    },
    {
        "method": f"jax_hierarchy_{device_preference}",
        "merges": len(merges_jax),
        "time_s": jax_hier_time,
    },
]
hier_df = pd.DataFrame(hier_rows)
print(hier_df)

print(f"CPU hierarchy: merges={len(hierarchy_cpu)}, time={cpu_hier_time:.3f} s")
print(f"JAX hierarchy ({device_preference}): merges={len(merges_jax)}, time={jax_hier_time:.3f} s")
print("First 5 CPU merges:", hierarchy_cpu[:5])
print("First 5 JAX merges:", merges_jax[:5])


## Hierarchy benchmark (CPU vs JAX)