In [1]:
# matrix with all the gene expression values for a particular condition 

genes, gene lengths, 
observed- gene expressionn values
for neg bino reg model - need parameter for mean expression level and dispersion parameter 
mean expression level can be stratified by catogery

In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
# These MUST come before importing pymc or pytensor
os.environ["PYTENSOR_FLAGS"] = "optimizer_excluding=constant_folding,mode=FAST_RUN,linker=py"
os.environ["PYTENSOR_WARN__C_COMPILE"] = "False"

import pytensor
pytensor.config.cxx = ""

# Now import pymc
import pymc as pm
import arviz as az
print("Linker:", pytensor.config.linker)

import numpy as np
import pandas as pd
import scipy.sparse as sp



Linker: py


In [2]:
import scanpy as sc
import matplotlib.pyplot  as plt

In [3]:
import pymc as pm
print("PyMC version:", pm.__version__)


PyMC version: 5.12.0


In [4]:
adata = sc.read("merged_dataset.h5ad")

In [5]:
# Print a summary of the dataset
print(adata)

AnnData object with n_obs × n_vars = 186224 × 39241
    obs: 'AvgSpotLen', 'Bases', 'Bytes', 'original_index'


In [6]:
# Fill NaNs in obs['Bases']
adata.obs["Bases"] = adata.obs["Bases"].fillna(0)

In [7]:
n_cells = adata.shape[0]
n_genes = adata.shape[1]

In [8]:
gene_lengths = np.random.uniform(1000, 3000, size=n_genes)
seq_depth = np.median(adata.obs['Bases'].values) if 'Bases' in adata.obs.columns else 1e4
ref_seq_depth = 1e5
ref_gene_length = 2000.0
batch_size = 4096

In [9]:
# Prepare output shape: (n_genes, n_cells)
observed_counts = np.zeros((n_genes, n_cells), dtype=np.int32)

In [10]:
for start in range(0, n_cells, batch_size):
    end = min(start + batch_size, n_cells)
    # Get cell-wise batch
    X_batch = adata.X[start:end]

    # Convert batch to dense safely
    X_batch = X_batch.toarray() if sp.issparse(X_batch) else np.asarray(X_batch)

    # Store in observed_counts
    observed_counts[:, start:end] = X_batch.T.astype(np.int32)

In [12]:
n_genes, n_conditions = observed_counts.shape

In [13]:
import scipy.sparse as sp

def generate_minibatch_indices_from_adata(adata, batch_size=8192):
    n_genes, n_cells = adata.shape
    total_size = n_genes * n_cells

    for start in range(0, total_size, batch_size):
        end = min(start + batch_size, total_size)

        gene_idx = np.arange(start, end) // n_cells
        cond_idx = np.arange(start, end) % n_cells

        # Load only necessary columns per minibatch
        cols_needed = np.unique(cond_idx)
        adata_chunk = adata[:, cols_needed]

        # Convert sparse to dense
        X_chunk = adata_chunk.X.toarray() if sp.issparse(adata_chunk.X) else adata_chunk.X

        # Pull count values based on local cond_idx
        cond_local = np.searchsorted(cols_needed, cond_idx)
        counts = X_chunk[gene_idx, cond_local]

        yield {
            "counts": counts.astype(np.int32),
            "gene_idx": gene_idx.astype(np.int32),
            "cond_idx": cond_idx.astype(np.int32),
        }


In [14]:
init_batch = next(generate_minibatch_indices_from_adata(adata, batch_size=4096))

In [15]:
counts = init_batch['counts']
gene_idx = init_batch['gene_idx']
cond_idx = init_batch['cond_idx']

In [16]:
flat_gene_idx = gene_idx.flatten()
flat_cond_idx = cond_idx.flatten()

In [None]:
# --------------------------
# PyMC Model
# --------------------------
with pm.Model() as model_vi:
    counts_shared = pm.MutableData("counts_shared", init_batch["counts"])
    gene_idx_shared = pm.MutableData("gene_idx_shared", init_batch["gene_idx"])
    cond_idx_shared = pm.MutableData("cond_idx_shared", init_batch["cond_idx"])
    gene_lengths_shared = pm.Data("gene_lengths_shared", gene_lengths)

    mu_gene = pm.Normal("mu_gene", mu=2.0, sigma=1.0)
    sigma_gene = pm.HalfNormal("sigma_gene", sigma=1.0)
    z_gene = pm.Normal("z_gene", mu=0, sigma=1, shape=n_genes)
    gene_log_expression = pm.Deterministic("gene_log_expression", mu_gene + sigma_gene * z_gene)

    mu_cond = pm.Normal("mu_cond", mu=0.0, sigma=1.0)
    sigma_cond = pm.HalfNormal("sigma_cond", sigma=1.0)
    z_cond = pm.Normal("z_cond", mu=0, sigma=1, shape=n_cells)
    condition_intercepts_log = pm.Deterministic("condition_intercepts_log", mu_cond + sigma_cond * z_cond)

    alpha = pm.HalfCauchy("alpha", beta=2)

    log_mu = gene_log_expression[gene_idx_shared] + condition_intercepts_log[cond_idx_shared]
    log_mu += np.log(seq_depth / ref_seq_depth)
    log_mu += pm.math.log(gene_lengths_shared[gene_idx_shared] / ref_gene_length)
    mu = pm.Deterministic("mu", pm.math.exp(log_mu))

    pm.NegativeBinomial("obs", mu=mu, alpha=alpha, observed=counts_shared)

    approx = pm.fit(n=1, method="advi", progressbar=False)

# --------------------------
# Manual VI Loop (stable)
# --------------------------
with model_vi:
    elbos = []
    gen = generate_minibatch_indices_from_adata(adata, batch_size=batch_size)

    for i in range(10000):  # tune based on time
        try:
            batch = next(gen)
        except StopIteration:
            gen = generate_minibatch_indices_from_adata(adata, batch_size=batch_size)
            batch = next(gen)

        pm.set_data({
            "counts_shared": batch["counts"],
            "gene_idx_shared": batch["gene_idx"],
            "cond_idx_shared": batch["cond_idx"],
        })

        approx = pm.fit(n=1, method="advi", progressbar=True)
        elbos.append(approx.hist[-1])

        if i % 100 == 0:
            print(f"Step {i}, ELBO: {elbos[-1]:.2f}")

    trace_vi = approx.sample(1000)

  warn(
Finished [100%]: Loss = 1.4478e+05
  warn(


Finished [100%]: Loss = 1.1512e+05


Step 0, ELBO: 115123.54


  warn(


Finished [100%]: Loss = 2.2347e+05
  warn(


Finished [100%]: Loss = 1.7416e+05
  warn(


Finished [100%]: Loss = 87,609
  warn(


Finished [100%]: Loss = 1.126e+05
  warn(


Finished [100%]: Loss = 51,995
  warn(


Finished [100%]: Loss = 1.0766e+05
  warn(


Finished [100%]: Loss = 1.3917e+05
  warn(


Finished [100%]: Loss = 1.1315e+05
  warn(


Finished [100%]: Loss = 76,326
  warn(


Finished [100%]: Loss = 55,731
  warn(


Finished [100%]: Loss = 2.1276e+05
  warn(


Finished [100%]: Loss = 2.5891e+05
  warn(


Finished [100%]: Loss = 99,729
  warn(


Finished [100%]: Loss = 69,658
  warn(


Finished [100%]: Loss = 55,637
  warn(


Finished [100%]: Loss = 1.5322e+05
  warn(


Finished [100%]: Loss = 45,614
  warn(


Finished [100%]: Loss = 53,266
  warn(


Finished [100%]: Loss = 1.5619e+05
  warn(


Finished [100%]: Loss = 1.5778e+05
  warn(


Finished [100%]: Loss = 1.0518e+05
  warn(


Finished [100%]: Loss = 1.4477e+05
  warn(


Finished [100%]: Loss = 1.7635e+05
  warn(


Finished [100%]: Loss = 68,878
  warn(


Finished [100%]: Loss = 1.1143e+05
  warn(


Finished [100%]: Loss = 1.7175e+05
  warn(


Finished [100%]: Loss = 1.0955e+05
  warn(


Finished [100%]: Loss = 1.8848e+05
  warn(


Finished [100%]: Loss = 77,880
  warn(


Finished [100%]: Loss = 53,029
  warn(


Finished [100%]: Loss = 3.3779e+05
  warn(


Finished [100%]: Loss = 76,810
  warn(


Finished [100%]: Loss = 1.6573e+05
  warn(


Finished [100%]: Loss = 81,003
  warn(


Finished [100%]: Loss = 58,089
  warn(


Finished [100%]: Loss = 98,957
  warn(


Finished [100%]: Loss = 1.3517e+05
  warn(


Finished [100%]: Loss = 71,949
  warn(


Finished [100%]: Loss = 1.946e+05
  warn(


Finished [100%]: Loss = 1.0094e+05
  warn(


Finished [100%]: Loss = 2.2371e+05
  warn(


Finished [100%]: Loss = 2.2736e+05
  warn(


Finished [100%]: Loss = 84,480
  warn(


Finished [100%]: Loss = 1.4225e+05
  warn(


Finished [100%]: Loss = 1.9373e+05
  warn(


Finished [100%]: Loss = 3.5366e+05
  warn(


Finished [100%]: Loss = 1.2123e+05
  warn(


Finished [100%]: Loss = 56,501
  warn(


Finished [100%]: Loss = 1.0311e+05
  warn(


Finished [100%]: Loss = 1.0089e+05
  warn(


Finished [100%]: Loss = 1.5192e+05
  warn(


Finished [100%]: Loss = 73,549
  warn(


Finished [100%]: Loss = 1.5789e+05
  warn(


Finished [100%]: Loss = 32,462
  warn(


Finished [100%]: Loss = 51,103
  warn(


Finished [100%]: Loss = 77,301
  warn(


Finished [100%]: Loss = 1.4911e+05
  warn(


Finished [100%]: Loss = 1.7742e+05
  warn(


Finished [100%]: Loss = 80,224
  warn(


Finished [100%]: Loss = 1.0003e+05
  warn(


Finished [100%]: Loss = 57,293
  warn(


Finished [100%]: Loss = 1.0374e+05
  warn(


Finished [100%]: Loss = 1.2158e+05
  warn(


Finished [100%]: Loss = 82,736
  warn(


Finished [100%]: Loss = 1.7871e+05
  warn(


Finished [100%]: Loss = 1.9849e+05
  warn(


Finished [100%]: Loss = 3.4897e+05
  warn(


Finished [100%]: Loss = 67,817
  warn(


Finished [100%]: Loss = 68,560
  warn(


Finished [100%]: Loss = 97,142
  warn(


Finished [100%]: Loss = 65,665
  warn(


Finished [100%]: Loss = 93,462
  warn(


Finished [100%]: Loss = 1.1771e+05
  warn(


Finished [100%]: Loss = 1.1348e+05
  warn(


Finished [100%]: Loss = 1.071e+05
  warn(


Finished [100%]: Loss = 97,011
  warn(


Finished [100%]: Loss = 1.1574e+05
  warn(


Finished [100%]: Loss = 3.1502e+05
  warn(


Finished [100%]: Loss = 46,999
  warn(


Finished [100%]: Loss = 62,436
  warn(


Finished [100%]: Loss = 1.3659e+05
  warn(


Finished [100%]: Loss = 75,056
  warn(


Finished [100%]: Loss = 48,001
  warn(


Finished [100%]: Loss = 1.3218e+05
  warn(


Finished [100%]: Loss = 1.1006e+05
  warn(


Finished [100%]: Loss = 1.4649e+05
  warn(


Finished [100%]: Loss = 89,590
  warn(


Finished [100%]: Loss = 88,473
  warn(


Finished [100%]: Loss = 1.3059e+05
  warn(


Finished [100%]: Loss = 2.7734e+05
  warn(


Finished [100%]: Loss = 2.4574e+05
  warn(


Finished [100%]: Loss = 55,378
  warn(


Finished [100%]: Loss = 1.1537e+05
  warn(


Finished [100%]: Loss = 2.3943e+05
  warn(


Finished [100%]: Loss = 1.2798e+05
  warn(


Finished [100%]: Loss = 1.0905e+05
  warn(


Finished [100%]: Loss = 56,457
  warn(


Finished [100%]: Loss = 1.4979e+05
  warn(


Finished [100%]: Loss = 1.6568e+05
  warn(


Finished [100%]: Loss = 1.3706e+05


Step 100, ELBO: 137063.46


  warn(


Finished [100%]: Loss = 1.553e+05
  warn(


Finished [100%]: Loss = 93,655
  warn(


Finished [100%]: Loss = 1.8188e+05
  warn(


Finished [100%]: Loss = 3.3304e+05
  warn(


Finished [100%]: Loss = 97,439
  warn(


Finished [100%]: Loss = 2.6457e+05
  warn(


Finished [100%]: Loss = 1.1832e+05
  warn(


Finished [100%]: Loss = 2.1798e+05
  warn(


Finished [100%]: Loss = 96,714
  warn(


Finished [100%]: Loss = 93,062
  warn(


Finished [100%]: Loss = 3.0641e+05
  warn(


Finished [100%]: Loss = 1.3035e+05
  warn(


Finished [100%]: Loss = 1.2076e+05
  warn(


Finished [100%]: Loss = 1.7403e+05
  warn(


Finished [100%]: Loss = 82,835
  warn(


Finished [100%]: Loss = 1.4145e+05
  warn(


Finished [100%]: Loss = 70,268
  warn(


Finished [100%]: Loss = 64,508
  warn(


Finished [100%]: Loss = 1.917e+05
  warn(


Finished [100%]: Loss = 53,783
  warn(


Finished [100%]: Loss = 1.3866e+05
  warn(


Finished [100%]: Loss = 1.3028e+05
  warn(


Finished [100%]: Loss = 72,133
  warn(


Finished [100%]: Loss = 65,742
  warn(


Finished [100%]: Loss = 1.2292e+05
  warn(


Finished [100%]: Loss = 1.0766e+05
  warn(


Finished [100%]: Loss = 1.0725e+05
  warn(


Finished [100%]: Loss = 48,333
  warn(


Finished [100%]: Loss = 1.5895e+05
  warn(


Finished [100%]: Loss = 98,152
  warn(


Finished [100%]: Loss = 56,226
  warn(


Finished [100%]: Loss = 73,874
  warn(


Finished [100%]: Loss = 1.5201e+05
  warn(


Finished [100%]: Loss = 2.0089e+05
  warn(


Finished [100%]: Loss = 47,479
  warn(


Finished [100%]: Loss = 85,792
  warn(


Finished [100%]: Loss = 59,359
  warn(


Finished [100%]: Loss = 1.6004e+05
  warn(


Finished [100%]: Loss = 1.1022e+05
  warn(


Finished [100%]: Loss = 1.2583e+05
  warn(


Finished [100%]: Loss = 68,229
  warn(


Finished [100%]: Loss = 74,966
  warn(


Finished [100%]: Loss = 1.728e+05
  warn(


Finished [100%]: Loss = 72,537
  warn(


Finished [100%]: Loss = 2.5772e+05
  warn(


Finished [100%]: Loss = 46,078
  warn(


Finished [100%]: Loss = 1.3708e+05
  warn(


Finished [100%]: Loss = 1.189e+05
  warn(


Finished [100%]: Loss = 79,881
  warn(


Finished [100%]: Loss = 1.191e+05
  warn(


Finished [100%]: Loss = 1.401e+05
  warn(


Finished [100%]: Loss = 1.0895e+05
  warn(


Finished [100%]: Loss = 62,263
  warn(


Finished [100%]: Loss = 2.4747e+05
  warn(


Finished [100%]: Loss = 97,993
  warn(


Finished [100%]: Loss = 96,595
  warn(


Finished [100%]: Loss = 1.4234e+05
  warn(


Finished [100%]: Loss = 1.481e+05
  warn(


Finished [100%]: Loss = 3.328e+05
  warn(


Finished [100%]: Loss = 90,755
  warn(


Finished [100%]: Loss = 73,601
  warn(


Finished [100%]: Loss = 3.0332e+05
  warn(


Finished [100%]: Loss = 1.9982e+05
  warn(


Finished [100%]: Loss = 2.342e+05
  warn(


Finished [100%]: Loss = 1.1115e+05
  warn(


Finished [100%]: Loss = 2.0859e+05
  warn(


Finished [100%]: Loss = 1.6678e+05
  warn(


Finished [100%]: Loss = 1.5663e+05
  warn(


Finished [100%]: Loss = 1.1797e+05
  warn(


Finished [100%]: Loss = 92,284
  warn(


Finished [100%]: Loss = 1.528e+05
  warn(


Finished [100%]: Loss = 1.5715e+05
  warn(


Finished [100%]: Loss = 1.6485e+05
  warn(


Finished [100%]: Loss = 99,954
  warn(


Finished [100%]: Loss = 1.3479e+05
  warn(


Finished [100%]: Loss = 1.5475e+05
  warn(


Finished [100%]: Loss = 2.3079e+05
  warn(


Finished [100%]: Loss = 92,403
  warn(


Finished [100%]: Loss = 1.1945e+05
  warn(


Finished [100%]: Loss = 95,852
  warn(


Finished [100%]: Loss = 1.1626e+05
  warn(


Finished [100%]: Loss = 65,306
  warn(


Finished [100%]: Loss = 1.0321e+05
  warn(


Finished [100%]: Loss = 1.1964e+05
  warn(


Finished [100%]: Loss = 84,246
  warn(


Finished [100%]: Loss = 92,075
  warn(


Finished [100%]: Loss = 1.1733e+05
  warn(


Finished [100%]: Loss = 1.4131e+05
  warn(


Finished [100%]: Loss = 1.0391e+05
  warn(


Finished [100%]: Loss = 1.4751e+05


In [None]:
az.to_netcdf(trace_vi, "trace_vi.nc")

In [None]:
az.summary(trace_vi)

In [None]:
# --------------------------
# ELBO Plot for VI
# --------------------------
plt.plot(approx.hist)
plt.title("ELBO Loss Curve (VI)")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.tight_layout()
plt.show()

In [None]:
# --------------------------
# Posterior Analysis
# --------------------------
az.plot_trace(trace_vi, var_names=["mu_gene", "sigma_gene", "mu_cond", "sigma_cond", "alpha"])
plt.tight_layout()
plt.show()