# STODE: Reproducible Spatiotemporal Dynamics Notebook

This notebook is intended for **reviewers** and **readers of the STODE paper**.
Its goal is to make the end‑to‑end workflow as transparent as possible:

- How the spatial transcriptomics data are preprocessed.
- How the **VAE** and **potential‑guided neural ODE** are trained.
- How **backward simulations** and downstream analyses are run.

The notebook is deliberately **high‑level and heavily commented** so that the
logic of each step is clear even if you do not execute all cells (full
training is computationally expensive).

## 0. Prerequisites and project layout

This notebook assumes you have cloned the STODE project repository, e.g.

```bash
git clone https://github.com/LzrRacer/STODE.git
cd STODE
```

and that you have created a conda (or equivalent) environment with the
required Python packages (PyTorch, Scanpy, POT, GeomLoss, etc.).

The code base is organized in three main parts:

- `src/` – model, loss, data, and utility modules (VAE, potential, ODE, etc.).
- `scripts/` – command‑line entry points for training, simulation, and analysis.
- `data/` and `results/` – input AnnData files and experiment outputs.

This notebook will *not* redefine the models; instead, it calls the same
scripts that are used for the main experiments, but wraps them in a way that
explains **what** is happening and **why** each step is needed.

In [None]:
from pathlib import Path
import os, sys

# EDIT THIS if you are running the notebook from a different location
PROJECT_ROOT = Path.cwd()  # assumed to be the repo root when you open the notebook
print('PROJECT_ROOT =', PROJECT_ROOT)

SRC_DIR = PROJECT_ROOT / 'src'
SCRIPTS_DIR = PROJECT_ROOT / 'scripts'
DATA_DIR = PROJECT_ROOT / 'data'
RESULTS_DIR = PROJECT_ROOT / 'results'

for p in [SRC_DIR, SCRIPTS_DIR, DATA_DIR, RESULTS_DIR]:
    print(' -', p)

# Ensure src/ is importable
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))
print('\nPython path updated.')

## 1. Data preprocessing (MOSTA mouse organogenesis example)

The raw spatial transcriptomics data are stored as an AnnData object
(`.h5ad`). The preprocessing pipeline performs:

1. **Quality control**
   - Remove spots with too few detected genes.
   - Remove genes expressed in very few spots.
2. **Normalization + log transform**
   - `sc.pp.normalize_total(adata, target_sum=1e4)`
   - `sc.pp.log1p(adata)`
3. **Highly variable genes (HVGs)**
   - Select ~2,000 informative genes across time points.
4. **Batch correction across time points**
   - Treat each biological time (`obs['timepoint']`) as a batch.
   - Apply ComBat (`sc.pp.combat`) to log‑normalized values.
   - This yields **continuous, batch‑corrected expression**, which matches
     the Gaussian reconstruction loss used in the VAE.
5. **Save preprocessed AnnData** for downstream steps.

The goal is that reviewers can see that the model never trains directly on
raw counts: all training uses a standardized, batch‑corrected matrix.

In [None]:
import scanpy as sc
import numpy as np

RAW_DATA_PATH = DATA_DIR / 'mosta_data.h5ad'   # adjust if your filename differs
PREPROCESSED_DATA_PATH = DATA_DIR / 'preprocessed.h5ad'

print('Reading raw data from:', RAW_DATA_PATH)
adata = sc.read_h5ad(RAW_DATA_PATH).copy()
print('Raw shape (spots × genes):', adata.shape)
print('Timepoints:', sorted(adata.obs['timepoint'].astype(str).unique()))

# --- 1) Preserve raw counts ---
adata.layers['counts'] = adata.X.copy()
adata.raw = adata

# --- 2) Define batch as timepoint ---
adata.obs['batch'] = adata.obs['timepoint'].astype('category')

# --- 3) Basic filters ---
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=10)
print('After QC:', adata.shape)

# --- 4) Normalize and log1p ---
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# --- 5) Highly variable genes ---
sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True)
print('After HVG selection:', adata.shape)

# --- 6) Batch correction with ComBat ---
sc.pp.combat(adata, key='batch')
adata.uns['combat_corrected'] = True

# --- 7) Save preprocessed AnnData ---
adata.write_h5ad(PREPROCESSED_DATA_PATH)
print('Saved preprocessed data to:', PREPROCESSED_DATA_PATH)

## 2. VAE pre‑training (latent representation)

The first model component is a **Variational Autoencoder (VAE)** that learns a
low‑dimensional latent representation `z` of the batch‑corrected expression.

**Architecture (as used in the paper):**

- Encoder MLP: input dimension = number of HVGs; hidden layers = `[128, 64]`.
- Latent dimension: `d = 8`.
- Decoder MLP mirrors the encoder.
- Reconstruction loss: **Gaussian (MSE)** on the preprocessed, continuous
  expression matrix.
- KL regularization to a standard normal prior.

The VAE is pre‑trained for ~100 epochs to stabilize the latent space before
training the dynamical model. This notebook exposes the configuration used in
the scripts and shows how to launch pre‑training.

> **Note for reviewers:** full training can take hours on a GPU. You do not
> need to rerun it to follow the logic; the next cell is mainly for
> transparency and reproducibility.

In [None]:
import json

PRETRAINED_VAE_OUTPUT_DIR = RESULTS_DIR / 'pretrained_vae'
PRETRAINED_VAE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
PRETRAINED_VAE_MODEL_PATH = PRETRAINED_VAE_OUTPUT_DIR / 'vae_model_pretrained.pt'

config_train_vae = {
    'data_path': str(PREPROCESSED_DATA_PATH),
    'time_key': 'timepoint',
    'spatial_key': 'spatial',
    'hidden_dims_str': '128,64',
    'latent_dim': 8,
    'dropout_rate': 0.1,
    'epochs': 100,
    'learning_rate': 1e-3,
    'batch_size': 128,
    'recon_loss_type': 'gaussian',
    'kl_weight': 0.005,
    'output_model_path': str(PRETRAINED_VAE_MODEL_PATH),
    'results_dir': str(PRETRAINED_VAE_OUTPUT_DIR),
    'seed': 42,
    'device': ''  # empty → auto‑select GPU/CPU in the script
}

with open(PRETRAINED_VAE_OUTPUT_DIR / 'config_train_vae_only.json', 'w') as f:
    json.dump(config_train_vae, f, indent=2)

print('VAE config written to:', PRETRAINED_VAE_OUTPUT_DIR / 'config_train_vae_only.json')
print('\nIf you want to actually train the VAE from the notebook, uncomment the cell below.')

In [None]:
# WARNING: This will run full VAE training and can take a long time.
# Uncomment the lines below to execute.

import subprocess
train_vae_script = SCRIPTS_DIR / '00_train_vae.py'
cmd = [
    'python', str(train_vae_script),
]
for k, v in config_train_vae.items():
    cmd.append(f'--{k}')
    cmd.append(str(v))
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=True)

## 3. Generative spatiotemporal dynamics (potential + neural ODE)

Once the VAE is pre‑trained, the **dynamical system** is trained on pairs of
consecutive time points. Conceptually:

1. Each spot at time \(t_k\) is encoded to a latent vector \(z_k\) and has
   spatial coordinates \(s_k\).
2. We form a **joint state** \(y_k = [s_k, z_k]\).
3. A learnable **potential field** \(U(s, z, t)\) defines an energy landscape
   over space–latent–time.
4. A **Time‑Aware ODE** network learns a velocity field and is regularized so
   that the velocity is consistent with the negative gradient of the potential
   (a potential‑guided flow).
5. Using a simple Euler integrator, we map states from \(t_k\) to \(t_{k-1}\)
   and match the simulated distribution to the observed distribution at
   \(t_{k-1}\) using a **sliced Wasserstein distance** (distributional loss).

Additional regularization terms encourage:

- **Time alignment**: a small regressor predicts biological time from latent
  codes, encouraging monotonic trajectories in latent space.
- **Force consistency**: the ODE’s velocity agrees with the gradient of the
  potential.
- **Convergence to a progenitor state**: trajectories at the earliest time
  cluster around a learned `t0` anchor.

Below we record the configuration used for training this system and show how
the training script is invoked.

In [None]:
GENERATIVE_OUTPUT_DIR = RESULTS_DIR / 'generative_system_training_output'
GENERATIVE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
TRAINED_MODEL_PATH = GENERATIVE_OUTPUT_DIR / 'system_model_final.pt'

config_train_dynamics = {
    'data_path': str(PREPROCESSED_DATA_PATH),
    'time_key': 'timepoint',
    'spatial_key': 'spatial',
    'pretrained_vae_path': str(PRETRAINED_VAE_MODEL_PATH),
    'spatial_dim': 2,
    'vae_hidden_dims': '128,64',
    'vae_latent_dim': 8,
    'vae_recon_loss_type': 'gaussian',
    'vae_kl_weight': 0.005,
    'vae_dropout_rate': 0.1,
    # Potential + ODE
    'potential_time_embedding_dim': 4,
    'potential_hidden_dims': '32,16',
    'ode_time_embedding_dim': 4,
    'ode_hidden_dims': '64,32',
    'ode_damping_coeff': 0.01,
    # Training
    'epochs': 50,
    'learning_rate': 5e-4,
    'batch_size_transition': 32,
    'ode_n_integration_steps': 3,
    # Loss weights (see paper Methods)
    'loss_weight_vae_recon': 1.0,
    'loss_weight_vae_kl': 0.1,
    'loss_weight_ode_swd': 1.0,
    'loss_weight_time_align': 0.1,
    'loss_weight_force_consistency': 0.1,
    'loss_weight_t0_distance': 10.0,
    'loss_weight_t0_velocity_align': 10.0,
    'results_dir': str(GENERATIVE_OUTPUT_DIR),
    'seed': 42,
    'device': ''
}

with open(GENERATIVE_OUTPUT_DIR / 'config_train_generative_system.json', 'w') as f:
    json.dump(config_train_dynamics, f, indent=2)

print('Dynamics config written to:', GENERATIVE_OUTPUT_DIR / 'config_train_generative_system.json')

In [None]:
# WARNING: Training the full generative system is compute‑intensive.
# Uncomment to launch training from the notebook.

import subprocess
gen_script = SCRIPTS_DIR / '01_train_generative_system.py'
cmd = ['python', str(gen_script)]
for k, v in config_train_dynamics.items():
    cmd.append(f'--{k}')
    cmd.append(str(v))
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=True)

## 4. Backward simulation from an observed stage

After training, the model can **simulate trajectories backward in biological
time**. For example, starting from E11.5, we integrate the learned ODE
backwards to reconstruct earlier progenitor states.

Conceptually:

1. Take the observed AnnData at a late time point (e.g., E11.5).
2. Encode each spot into \(z\), combine with spatial coordinates \(s\) to get
   \(y = [s, z]\).
3. Numerically integrate the ODE backward from \(t_{\text{obs}}\) down to \(t_0\).
4. Save the full trajectory of states \(y(t)\) and associated biological times.

The next cell shows the configuration used by the simulation script.

In [None]:
SIM_BWD_DIR = RESULTS_DIR / 'backward_simulation_from_observed'
SIM_BWD_DIR.mkdir(parents=True, exist_ok=True)

# Here we assume the latest timepoint in PREPROCESSED_DATA_PATH is E11.5
LATEST_TIME = 11.5
SNAPSHOT_PATH = DATA_DIR / f'snapshot_E{LATEST_TIME}.h5ad'  # created in the paper workflow

config_sim_bwd = {
    'model_load_path': str(TRAINED_MODEL_PATH),
    'config_train_load_path': str(GENERATIVE_OUTPUT_DIR / 'config_train_generative_system.json'),
    'original_adata_path_for_vae_input_dim': str(PREPROCESSED_DATA_PATH),
    'observed_adata_path': str(SNAPSHOT_PATH if SNAPSHOT_PATH.exists() else PREPROCESSED_DATA_PATH),
    'observed_time_point_numeric': LATEST_TIME,
    't_final_bio_target': 0.0,
    'num_cells_to_sample_from_observed': 0,
    'output_dir': str(SIM_BWD_DIR),
    'simulation_n_steps': 300,
    'grid_size': 0,   # set >0 to enable spatial merging / coarse‑graining
    'seed': 45,
    'device': ''
}

print('Simulation output directory:', SIM_BWD_DIR)

In [None]:
# As before, this call can be long; it is shown for reproducibility.

import subprocess
sim_script = SCRIPTS_DIR / '02_simulate_backward_from_observed.py'
cmd = ['python', str(sim_script)]
for k, v in config_sim_bwd.items():
    cmd.append(f'--{k}')
    cmd.append(str(v))
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=True)

## 5. Analysis and visualization helpers

Once trajectories have been simulated, several analysis scripts can be run to
reproduce figures similar to those in the manuscript:

- **Backward contraction / forward replay animations** of the tissue
  (morphological changes over time).
- **Temporal clustering** and mapping of simulated particles to observed
  annotations.
- **Latent dynamics summaries** (e.g., how cluster‑wise latent coordinates
  evolve over simulated time).
- **Potential and divergence maps**, showing regions of expansion and
  contraction.

For clarity, we only give an example of how to launch the backward animation
analysis here; other analyses follow the same pattern (config dictionary →
script invocation).

In [None]:
ANALYSIS_BWD_DIR = RESULTS_DIR / 'analysis_backward_simulation'
ANALYSIS_BWD_DIR.mkdir(parents=True, exist_ok=True)

config_analyze_bwd = {
    'backward_simulation_dir': str(SIM_BWD_DIR),
    'model_load_path': str(TRAINED_MODEL_PATH),
    'config_train_load_path': str(GENERATIVE_OUTPUT_DIR / 'config_train_generative_system.json'),
    'original_adata_path_for_vae_config': str(PREPROCESSED_DATA_PATH),
    'output_dir': str(ANALYSIS_BWD_DIR),
    'animation_frames': 0,          # 0 → use all time points
    'animation_fps': 10,
    'animation_dot_size': 15,
    'latent_compress_method': 'umap',
    'latent_compress_umap_neighbors': 15,
    'animation_end_time_bio': 0.0,  # stop when reaching t=0
    'auto_adjust_fps_to_bio_time': True,
    'seed': 46,
    'device': ''
}

# Example (commented) command:
import subprocess
analyze_script = SCRIPTS_DIR / '03_analyze_backward_simulation.py'
cmd = ['python', str(analyze_script)]
for k, v in config_analyze_bwd.items():
    cmd.append(f'--{k}')
    cmd.append(str(v))
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=True)

## 6. How to read this notebook as a reviewer

You can use this notebook in two complementary ways:

1. **As a narrative of the pipeline**
   - Read the markdown cells to understand the modeling choices and how the
     command‑line scripts fit together.
   - Skim the configuration dictionaries to see exact hyperparameters.
2. **As an executable reproduction script (optional)**
   - If you have access to the raw data and sufficient compute, you can
     progressively uncomment the `subprocess.run(...)` cells to rerun
     pre‑training, dynamic training, simulation, and analysis.

Either way, the intent is that every major step in STODE’s pipeline is
visible in one place, with clear motivation and parameterization.