This is the present of scNODE for Single-cell transcriptomics time-series modelling and cell trajectory inference.

Author:
    Jiaqi Zhang <jiaqi_zhang2@brown.edu>

In [None]:
# Set up matplotlib for interactive display
import matplotlib
try:
    # Try to use TkAgg backend for interactive plotting
    matplotlib.use('TkAgg')
except:
    # Fallback to default backend if TkAgg is not available
    pass
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')
warnings.filterwarnings('ignore', category=UserWarning, module='plotting.visualization')

import sys
import os
import pathlib

project_root = str(pathlib.Path.cwd().resolve().parent)
sys.path.insert(0, project_root)
os.chdir(project_root)  # Change the current working directory to the project root

import torch
import numpy as np
import time

from BenchmarkUtils import loadSCData, tpSplitInd, tunedOurPars, splitBySpec
from plotting.__init__ import *
from plotting.visualization import plotPredAllTime, plotPredTestTime, computeDrift, plotStream, plotStreamByCellType
from plotting.PlottingUtils import umapWithPCA, computeLatentEmbedding
from optim.running import constructscNODEModel, scNODETrainWithPreTrain, scNODEPredict
from optim.evaluation import globalEvaluation

In [None]:
# Load data and pre-processing
# Specify the dataset: zebrafish, drosophila, wot
# Representing ZB, DR, SC, repectively
data_name= "zebrafish"
print("[ {} ]".format(data_name).center(60))
split_type = "three_interpolation"
print("Split type: {}".format(split_type))
ann_data, cell_tps, cell_types, n_genes, n_tps, all_tps = loadSCData(data_name, split_type)
train_tps, test_tps = tpSplitInd(data_name, split_type)
data = ann_data.X

# Convert to torch project
traj_data = [torch.FloatTensor(data[np.where(cell_tps == t)[0], :]) for t in range(1, n_tps + 1)]
if cell_types is not None:
    traj_cell_types = [cell_types[np.where(cell_tps == t)[0]] for t in range(1, n_tps + 1)]

all_tps = list(all_tps)  # Convert to list
train_data, test_data = splitBySpec(traj_data, train_tps, test_tps)
tps = torch.FloatTensor(all_tps)
train_tps = torch.FloatTensor(train_tps)
test_tps = torch.FloatTensor(test_tps)
n_cells = [each.shape[0] for each in traj_data]
print("# tps={}, # genes={}".format(n_tps, n_genes))
print("# cells={}".format(n_cells))
print("Train tps={}".format(train_tps))
print("Test tps={}".format(test_tps))

                       [ zebrafish ]                        
Split type: three_interpolation
[ Data=zebrafish | Split=three_interpolation ] Loading data...
Dataset is loaded.
Dataset is loaded.
# tps=12, # genes=2000
# cells=[311, 200, 1158, 1467, 5716, 1026, 4101, 6178, 5442, 7114, 1614, 4404]
Train tps=tensor([ 0.,  1.,  2.,  3.,  5.,  7.,  9., 10., 11.])
Test tps=tensor([4., 6., 8.])
# tps=12, # genes=2000
# cells=[311, 200, 1158, 1467, 5716, 1026, 4101, 6178, 5442, 7114, 1614, 4404]
Train tps=tensor([ 0.,  1.,  2.,  3.,  5.,  7.,  9., 10., 11.])
Test tps=tensor([4., 6., 8.])


In [3]:
# Model Defining and Training Settings
NUM_EPOCHS = 10
pretrain_lr = 1e-3
latent_coeff = 1.0 # regularization coefficient: beta
batch_size = 32
lr = 1e-3
act_name = "relu"
n_sim_cells = 2000

latent_dim, drift_latent_size, enc_latent_list, dec_latent_list = tunedOurPars(data_name, split_type) # use tuned hyperparameters
latent_ode_model = constructscNODEModel(
    n_genes, latent_dim=latent_dim,
    enc_latent_list=enc_latent_list, dec_latent_list=dec_latent_list, drift_latent_size=drift_latent_size,
    latent_enc_act="none", latent_dec_act=act_name, drift_act=act_name,
    ode_method="euler"
)

# Start timing for progress tracking
print(f"Starting training for {NUM_EPOCHS} epochs...")
train_start_time = time.time()

# Import necessary modules for custom training loop
import geomloss
import itertools

# Pre-training the VAE component
latent_encoder = latent_ode_model.latent_encoder
obs_decoder = latent_ode_model.obs_decoder
all_train_data = torch.cat(train_data, dim=0)
pretrain_iters = 100
print("Pre-training VAE component...")
dim_reduction_params = itertools.chain(*[latent_encoder.parameters(), obs_decoder.parameters()])
dim_reduction_optimizer = torch.optim.Adam(params=dim_reduction_params, lr=pretrain_lr, betas=(0.95, 0.99))
latent_encoder.train()
obs_decoder.train()
for i in range(pretrain_iters):
    rand_idx = np.random.choice(all_train_data.shape[0], size=batch_size, replace=False)
    batch_data = all_train_data[rand_idx, :]
    dim_reduction_optimizer.zero_grad()
    latent_mu, latent_std = latent_encoder(batch_data)
    latent_sample = latent_mu + torch.randn_like(latent_std) * latent_std
    recon_obs = obs_decoder(latent_sample)
    recon_loss = torch.mean((recon_obs - batch_data) ** 2)
    recon_loss.backward()
    dim_reduction_optimizer.step()
    if (i + 1) % 20 == 0:
        print(f"Pre-training iteration {i+1}/{pretrain_iters}, Loss: {recon_loss.item():.4f}")

# Dynamic training setup
optimizer = torch.optim.Adam(params=latent_ode_model.parameters(), lr=lr, betas=(0.95, 0.99))
ot_solver = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.5, debias=True, backend="tensorized")
loss_list = []

# Convert train_tps to list once for model forward
train_tps_list = train_tps.detach().numpy().tolist() if isinstance(train_tps, torch.Tensor) else train_tps

Starting training for 10 epochs...
Pre-training VAE component...
Pre-training iteration 20/100, Loss: 6.4668
Pre-training iteration 40/100, Loss: 5.7228
Pre-training iteration 60/100, Loss: 2.0397
Pre-training iteration 80/100, Loss: 1.9819
Pre-training iteration 100/100, Loss: 2.5954
Pre-training iteration 20/100, Loss: 6.4668
Pre-training iteration 40/100, Loss: 5.7228
Pre-training iteration 60/100, Loss: 2.0397
Pre-training iteration 80/100, Loss: 1.9819
Pre-training iteration 100/100, Loss: 2.5954


In [None]:
# Training process
print("\nStarting main training loop...")
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0
    n_iters = 0
    
    for iter_idx in range(100):  # iterations per epoch
        rand_t_idx = np.random.choice(len(train_tps))
        rand_idx = np.random.choice(train_data[rand_t_idx].shape[0], size=batch_size, replace=False)
        batch_data = train_data[rand_t_idx][rand_idx, :]
        
        optimizer.zero_grad()
        latent_ode_model.train()
        # Model expects: forward(data_list, tps, batch_size=None)
        # data should be a list with data at first timepoint
        recon_obs, first_latent_dist, first_tp_data, latent_seq = latent_ode_model(
            [batch_data], 
            torch.FloatTensor(train_tps_list)
        )
        
        # Compute optimal transpost loss
        ot_loss = 0
        for t_idx, t in enumerate(train_tps):
            pred_x = recon_obs[:, t_idx, :]
            true_x = train_data[t_idx]
            subsample_size = min(200, true_x.shape[0])
            subsample_idx = np.random.choice(true_x.shape[0], subsample_size, replace=False)
            ot_loss += ot_solver(pred_x, true_x[subsample_idx])
        
        latent_drift_loss = torch.mean((latent_seq[:, 1:, :] - latent_seq[:, :-1, :]) ** 2)
        loss = ot_loss + latent_coeff * latent_drift_loss
        loss.backward()
        optimizer.step()
        loss_list.append((loss.item(), ot_loss.item(), latent_drift_loss.item()))
        
        epoch_loss += loss.item()
        n_iters += 1
    
    avg_loss = epoch_loss / n_iters
    elapsed = time.time() - train_start_time
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Avg Loss: {avg_loss:.4f}, Time: {elapsed/60:.1f}m")

print("\nGenerating final predictions...")
latent_ode_model.eval()
with torch.no_grad():
    first_obs = train_data[0]
    recon_obs, first_latent_dist, first_tp_data, latent_seq = latent_ode_model(
        [first_obs], 
        torch.FloatTensor(train_tps_list)
    )

# Training summary
train_duration = time.time() - train_start_time
print("=" * 70)
print(f"Training completed! Total epochs: {NUM_EPOCHS}")
print(f"Training time: {train_duration:.2f} seconds ({train_duration/60:.2f} minutes)")
print("=" * 70)

all_recon_obs = scNODEPredict(
    latent_ode_model, 
    traj_data[0], 
    tps, 
    n_cells=n_sim_cells
)  # (# cells, # tps, # genes)


Starting main training loop...
Epoch 1/10, Avg Loss: 51136.5554, Time: 0.5m
Epoch 1/10, Avg Loss: 51136.5554, Time: 0.5m
Epoch 5/10, Avg Loss: 26167.1351, Time: 1.9m
Epoch 5/10, Avg Loss: 26167.1351, Time: 1.9m
Epoch 10/10, Avg Loss: 23974.3964, Time: 3.9m

Generating final predictions...
Training completed! Total epochs: 10
Training time: 236.55 seconds (3.94 minutes)
Epoch 10/10, Avg Loss: 23974.3964, Time: 3.9m

Generating final predictions...
Training completed! Total epochs: 10
Training time: 236.55 seconds (3.94 minutes)


In [None]:
# Visualization - Loss Curve
plt.figure(figsize=(8, 6))
plt.subplot(3, 1, 1)
plt.title("Loss")
plt.plot([each[0] for each in loss_list])
plt.subplot(3, 1, 2)
plt.title("OT Term")
plt.plot([each[1] for each in loss_list])
plt.subplot(3, 1, 3)
plt.title("Dynamic Reg")
plt.plot([each[2] for each in loss_list])
plt.xlabel("Dynamic Learning Iter")
plt.tight_layout()  # Optimize subplot layout
plt.show()  # Display the figure interactively
plt.savefig("figures/scNODE_Loss.png")  # Save the figure to file
plt.close()  # Explicitly close the figure to free memory

# Visualization - 2D UMAP embeddings
print("Compare true and reconstructed data...")
true_data = [each.detach().numpy() for each in traj_data]
true_cell_tps = np.concatenate([np.repeat(t, each.shape[0]) for t, each in enumerate(true_data)])
pred_cell_tps = np.concatenate([np.repeat(t, all_recon_obs[:, t, :].shape[0]) for t in range(all_recon_obs.shape[1])])
reorder_pred_data = [all_recon_obs[:, t, :] for t in range(all_recon_obs.shape[1])]

# Compute UMAP embeddings for visualization
true_umap_traj, umap_model, pca_model = umapWithPCA(
    np.concatenate(true_data, axis=0), 
    n_neighbors=50, 
    min_dist=0.1, 
    pca_pcs=50
)
pred_umap_traj = umap_model.transform(
    pca_model.transform(np.concatenate(reorder_pred_data, axis=0))
)

# Display and save the first plot - All timepoints comparison
plotPredAllTime(true_umap_traj, pred_umap_traj, true_cell_tps, pred_cell_tps)
plt.show()  # Display all timepoints comparison

# Display and save the second plot - Test timepoints comparison
plotPredTestTime(
    true_umap_traj, 
    pred_umap_traj, 
    true_cell_tps, 
    pred_cell_tps, 
    test_tps.detach().numpy(),
    save_path="figures/scNODE_Results.png"
)
plt.show()  # Display test timepoints comparison

# Compute evaluation metrics
print("Computing evaluation metrics...")
test_tps_list = [int(t) for t in test_tps]
for t in test_tps_list:
    print("-" * 70)
    print(f"Timepoint t = {t}")
    pred_global_metric = globalEvaluation(
        traj_data[t].detach().numpy(), 
        all_recon_obs[:, t, :]
    )
    print(pred_global_metric)

  plt.show()  # 先显示


Compare true and reconstructed data...


  plt.show()
  plt.show()  # 先显示所有时间点的对比图
  ax1.scatter(true_umap_traj[:, 0], true_umap_traj[:, 1], label="other", c=gray_color, s=40, alpha=0.5)
  ax2.scatter(true_umap_traj[:, 0], true_umap_traj[:, 1], label="other", c=gray_color, s=40, alpha=0.5)


   Saving figure to figures/scNODE_Results.png...


  plt.show()  # 先显示测试时间点的对比图


Compute metrics...
----------------------------------------------------------------------
t = 4
{'l2': 131.72967392419037, 'cos': 0.10999074341830356, 'corr': 0.11349175242797213, 'ot': 4088.7830120770877}
----------------------------------------------------------------------
t = 6
{'l2': 131.72967392419037, 'cos': 0.10999074341830356, 'corr': 0.11349175242797213, 'ot': 4088.7830120770877}
----------------------------------------------------------------------
t = 6
{'l2': 142.15103467594614, 'cos': 0.11507932860724181, 'corr': 0.11748344308314047, 'ot': 5456.958284737331}
----------------------------------------------------------------------
t = 8
{'l2': 142.15103467594614, 'cos': 0.11507932860724181, 'corr': 0.11748344308314047, 'ot': 5456.958284737331}
----------------------------------------------------------------------
t = 8
{'l2': 98.97907457164247, 'cos': 0.21959273818614994, 'corr': 0.22440769950576347, 'ot': 2920.9485990143908}
{'l2': 98.97907457164247, 'cos': 0.21959273818614