In [1]:
# Geometry of Cell States Across Representations (PCA, Autoencoder, scVI)

## Scientific Question
How does the geometric organization of cell states differ between disease and treatment conditions, and which geometric conclusions are robust to representation choice?

## Overview
In this notebook, we analyze single-cell RNA-seq data from progressive multiple sclerosis neural stem cells using three representation learning approaches:
- Linear (PCA)
- Nonlinear deterministic (Autoencoder)
- Probabilistic, noise-aware (scVI)

For each representation, we quantify cell-state geometry using pairwise distances, dispersion, compactness, and distance distributions, and compare disease (DMSO) versus treatment (ABT) conditions.

## Goal
To assess whether disease-associated expansion and heterogeneity of cell-state space persist across fundamentally different representation assumptions.


from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!ls /content/drive/MyDrive/ms_singlecell/data

adata_abt_n   HD_ABT_filtered.h5ad   PMS_ABT_filtered.h5ad
adata_dmso_n  HD_DMSO_filtered.h5ad  PMS_DMSO_filtered.h5ad


In [3]:
!pip install scanpy

Collecting scanpy
  Downloading scanpy-1.11.5-py3-none-any.whl.metadata (9.3 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.12.7-py3-none-any.whl.metadata (9.9 kB)
Collecting legacy-api-wrap>=1.4.1 (from scanpy)
  Downloading legacy_api_wrap-1.5-py3-none-any.whl.metadata (2.2 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.3-py3-none-any.whl.metadata (3.5 kB)
Collecting array-api-compat>=1.7.1 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.13.0-py3-none-any.whl.metadata (2.5 kB)
Collecting zarr!=3.0.*,>=2.18.7 (from anndata>=0.8->scanpy)
  Downloading zarr-3.1.5-py3-none-any.whl.metadata (10 kB)
Collecting donfig>=0.8 (from zarr!=3.0.*,>=2.18.7->anndata>=0.8->scanpy)
  Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB)
Collecting numcodecs>=0.14 (from zarr!=3.0.*,>=2.18.7->anndata>=0.8->scanpy)
  Downloading numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata 

In [4]:
import scanpy as sc

base = "/content/drive/MyDrive/ms_singlecell/data/"

adata_abt_n  = sc.read_h5ad(f"{base}/adata_abt_n")


In [5]:
import scanpy as sc

base = "/content/drive/MyDrive/ms_singlecell/data/"

adata_dmso_n  = sc.read_h5ad(f"{base}/adata_dmso_n")

In [6]:
adata_dmso_n.write("adata_dmso_frozen.h5ad")
adata_abt_n.write("adata_abt_frozen.h5ad")

In [7]:
print(adata_dmso_n.shape)
print(adata_abt_n.shape)

print((adata_dmso_n.var_names == adata_abt_n.var_names).all())

(3281, 23494)
(8987, 23494)
True


In [8]:
adata_tmp = adata_dmso_n.copy()

In [9]:
print("DMSO shape:", adata_dmso_n.shape)
print("ABT  shape:", adata_abt_n.shape)
print("Genes aligned:",
      (adata_dmso_n.var_names == adata_abt_n.var_names).all())

DMSO shape: (3281, 23494)
ABT  shape: (8987, 23494)
Genes aligned: True


In [10]:
import scanpy as sc

sc.pp.scale(adata_dmso_n, max_value =10)
sc.pp.scale(adata_abt_n, max_value =10)

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


In [11]:
n_pcs = 20

sc.tl.pca(adata_dmso_n, n_comps=n_pcs, svd_solver="arpack")
sc.tl.pca(adata_abt_n,  n_comps=n_pcs, svd_solver="arpack")

In [12]:
print("DMSO variance explained (first 10 PCs):")
print(adata_dmso_n.uns["pca"]["variance_ratio"][:10])

print("\nABT variance explained (first 10 PCs):")
print(adata_abt_n.uns["pca"]["variance_ratio"][:10])

DMSO variance explained (first 10 PCs):
[0.06545809 0.01281095 0.00512672 0.00359378 0.00265931 0.00239149
 0.0022202  0.00209686 0.00161917 0.0015904 ]

ABT variance explained (first 10 PCs):
[0.02668111 0.0133332  0.00780506 0.00372046 0.0027715  0.00224595
 0.00179549 0.00158333 0.00127058 0.00120427]


In [13]:
import numpy as np

X_dmso_pca = adata_dmso_n.obsm["X_pca"][:, :n_pcs]
X_abt_pca  = adata_abt_n.obsm["X_pca"][:,  :n_pcs]

print("PCA shapes:", X_dmso_pca.shape, X_abt_pca.shape)

PCA shapes: (3281, 20) (8987, 20)


In [14]:
assert X_dmso_pca.shape[1] == n_pcs
assert X_abt_pca.shape[1]  == n_pcs
assert np.all(np.isfinite(X_dmso_pca))
assert np.all(np.isfinite(X_abt_pca))

In [15]:
from scipy.spatial.distance import pdist

# Pairwise distances within each condition
dist_dmso = pdist(X_dmso_pca, metric="euclidean")
dist_abt  = pdist(X_abt_pca,  metric="euclidean")

print("DMSO distances:", dist_dmso.shape)
print("ABT  distances:", dist_abt.shape)

DMSO distances: (5380840,)
ABT  distances: (40378591,)


In [16]:
dispersion_dmso = dist_dmso.mean()
dispersion_abt  = dist_abt.mean()

print("DMSO dispersion:", dispersion_dmso)
print("ABT  dispersion:", dispersion_abt)

DMSO dispersion: 55.32548561162718
ABT  dispersion: 47.68096583151013


In [17]:
from sklearn.neighbors import NearestNeighbors

def compute_compactness(X, k=10):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    # exclude self-distance (0)
    return distances[:, 1:].mean()

compact_dmso = compute_compactness(X_dmso_pca, k=10)
compact_abt  = compute_compactness(X_abt_pca,  k=10)

print("DMSO compactness:", compact_dmso)
print("ABT  compactness:", compact_abt)

DMSO compactness: 13.371022090770893
ABT  compactness: 13.436358024668904


In [18]:
print("DMSO mean/std:", dist_dmso.mean(), dist_dmso.std())
print("ABT  mean/std:", dist_abt.mean(),  dist_abt.std())

DMSO mean/std: 55.32548561162718 29.858979921883307
ABT  mean/std: 47.68096583151013 15.487923299258153


In [19]:
!pip install scvi-tools

Collecting scvi-tools
  Downloading scvi_tools-1.4.1-py3-none-any.whl.metadata (22 kB)
Collecting docrep>=0.3.2 (from scvi-tools)
  Downloading docrep-0.3.2.tar.gz (33 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting lightning>=2.0 (from scvi-tools)
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ml-collections (from scvi-tools)
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Collecting mudata (from scvi-tools)
  Downloading mudata-0.3.2-py3-none-any.whl.metadata (8.4 kB)
Collecting pyro-ppl (from scvi-tools)
  Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)
Collecting sparse>=0.14.0 (from scvi-tools)
  Downloading sparse-0.17.0-py2.py3-none-any.whl.metadata (5.3 kB)
Collecting torchmetrics (from scvi-tools)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecti

In [20]:
import scvi
import torch

In [21]:
scvi.__version__

'1.4.1'

In [22]:
adata_dmso_raw = sc.read_h5ad("/content/drive/MyDrive/ms_singlecell/data/PMS_DMSO_filtered.h5ad")
adata_abt_raw  = sc.read_h5ad("/content/drive/MyDrive/ms_singlecell/data/PMS_ABT_filtered.h5ad")

print(adata_dmso_raw.shape)
print(adata_abt_raw.shape)

(3281, 60623)
(8987, 60623)


In [23]:
import numpy as np

common_genes = np.intersect1d(
    adata_dmso_raw.var_names,
    adata_abt_raw.var_names
)

print("Number of common genes:", len(common_genes))

Number of common genes: 60623


In [24]:
adata_dmso_raw = adata_dmso_raw[:, common_genes].copy()
adata_abt_raw  = adata_abt_raw[:,  common_genes].copy()

print("DMSO shape:", adata_dmso_raw.shape)
print("ABT  shape:", adata_abt_raw.shape)

DMSO shape: (3281, 60623)
ABT  shape: (8987, 60623)


In [25]:
adata_dmso_raw.write("adata_dmso_raw_frozen.h5ad")
adata_abt_raw.write("adata_abt_raw_frozen.h5ad")

In [26]:
scvi.model.SCVI.setup_anndata(adata_dmso_raw)
scvi.model.SCVI.setup_anndata(adata_abt_raw)

In [27]:
#20 latent dimensions
model_dmso = scvi.model.SCVI(
    adata_dmso_raw,
    n_latent=20
)

model_abt = scvi.model.SCVI(
    adata_abt_raw,
    n_latent=20
)

In [28]:
model_dmso.train(max_epochs=100)
model_abt.train(max_epochs=100)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training:   0%|          | 0/100 [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training:   0%|          | 0/100 [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [29]:
X_dmso_scvi = model_dmso.get_latent_representation()
X_abt_scvi  = model_abt.get_latent_representation()

print("DMSO scVI shape:", X_dmso_scvi.shape)
print("ABT  scVI shape:", X_abt_scvi.shape)

DMSO scVI shape: (3281, 20)
ABT  scVI shape: (8987, 20)


In [30]:
import torch
print(torch.cuda.is_available())

True


In [31]:
from scipy.spatial.distance import pdist

dist_dmso_scvi = pdist(X_dmso_scvi, metric="euclidean")
dist_abt_scvi  = pdist(X_abt_scvi,  metric="euclidean")

print("DMSO scVI distances:", dist_dmso_scvi.shape)
print("ABT  scVI distances:", dist_abt_scvi.shape)

DMSO scVI distances: (5380840,)
ABT  scVI distances: (40378591,)


In [32]:
disp_dmso_scvi = dist_dmso_scvi.mean()
disp_abt_scvi  = dist_abt_scvi.mean()

print("scVI DMSO dispersion:", disp_dmso_scvi)
print("scVI ABT  dispersion:", disp_abt_scvi)

scVI DMSO dispersion: 7.186738149651342
scVI ABT  dispersion: 7.027669421191334


In [33]:
from sklearn.neighbors import NearestNeighbors

def compute_compactness(X, k=10):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    return distances[:, 1:].mean()

comp_dmso_scvi = compute_compactness(X_dmso_scvi, k=10)
comp_abt_scvi  = compute_compactness(X_abt_scvi,  k=10)

print("scVI DMSO compactness:", comp_dmso_scvi)
print("scVI ABT  compactness:", comp_abt_scvi)


scVI DMSO compactness: 1.5539119039919638
scVI ABT  compactness: 2.6241633893684244


In [34]:
print("scVI DMSO mean/std:", dist_dmso_scvi.mean(), dist_dmso_scvi.std())
print("scVI ABT  mean/std:", dist_abt_scvi.mean(),  dist_abt_scvi.std())

scVI DMSO mean/std: 7.186738149651342 2.9847841260363044
scVI ABT  mean/std: 7.027669421191334 1.7113930283901302


In [44]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

In [45]:
X_dmso = torch.tensor(adata_dmso_n.X, dtype=torch.float32)
X_abt  = torch.tensor(adata_abt_n.X,  dtype=torch.float32)

In [50]:
class AutoEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim=20):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

In [51]:
def train_autoencoder(X, epochs=50, batch_size=128):
    dataset = TensorDataset(X)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = AutoEncoder(X.shape[1])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        for (batch,) in loader:
            recon, _ = model(batch)
            loss = loss_fn(recon, batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model

In [52]:
ae_dmso = train_autoencoder(X_dmso, epochs=50)
ae_abt  = train_autoencoder(X_abt,  epochs=50)

In [53]:
with torch.no_grad():
    _, X_dmso_ae = ae_dmso(X_dmso)
    _, X_abt_ae  = ae_abt(X_abt)

X_dmso_ae = X_dmso_ae.numpy()
X_abt_ae  = X_abt_ae.numpy()

print(X_dmso_ae.shape, X_abt_ae.shape)

(3281, 20) (8987, 20)


In [54]:
from scipy.spatial.distance import pdist

dist_dmso_ae = pdist(X_dmso_ae, metric="euclidean")
dist_abt_ae  = pdist(X_abt_ae,  metric="euclidean")

print("DMSO AE distances:", dist_dmso_ae.shape)
print("ABT  AE distances:", dist_abt_ae.shape)

DMSO AE distances: (5380840,)
ABT  AE distances: (40378591,)


In [55]:
disp_dmso_ae = dist_dmso_ae.mean()
disp_abt_ae  = dist_abt_ae.mean()

print("AE DMSO dispersion:", disp_dmso_ae)
print("AE ABT  dispersion:", disp_abt_ae)

AE DMSO dispersion: 32.04225380299988
AE ABT  dispersion: 25.11343715559013


In [56]:
from sklearn.neighbors import NearestNeighbors

def compute_compactness(X, k=10):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
    dists, _ = nbrs.kneighbors(X)
    return dists[:, 1:].mean()

comp_dmso_ae = compute_compactness(X_dmso_ae, k=10)
comp_abt_ae  = compute_compactness(X_abt_ae,  k=10)

print("AE DMSO compactness:", comp_dmso_ae)
print("AE ABT  compactness:", comp_abt_ae)

AE DMSO compactness: 8.602520816241121
AE ABT  compactness: 8.401924006920629


In [57]:
print("AE DMSO mean/std:", dist_dmso_ae.mean(), dist_dmso_ae.std())
print("AE ABT  mean/std:", dist_abt_ae.mean(),  dist_abt_ae.std())

AE DMSO mean/std: 32.04225380299988 20.794785339311627
AE ABT  mean/std: 25.11343715559013 12.302529003397849
