GRAVITY: Dynamic gene regulatory network-enhanced RNA velocity modeling for trajectory inference and biological discovery
This subpackage provides a refactored, modular implementation of the GRAVITY workflow, inspired by the usability of scVelo while retaining GRAVITY’s regulation‑aware kinetics. The library exposes high‑level pipeline helpers and modular building blocks for preprocessing, two‑stage training, future position estimation, visualization, and TF importance analysis.
This subpackage focuses on a streamlined, research‑oriented interface with clear separation between preprocessing, training, future projection, and visualization components.
- End‑to‑end pipeline with a single configuration object.
- Two‑stage training (cell‑wise, then gene‑wise) with multi‑GPU support.
- Regulation‑informed future position estimation and plotting utilities.
- Exportable TF attention scores for downstream analysis.
It is recommended to use Python 3.9+ and a fresh virtual environment.
python -m venv .venv
source .venv/bin/activate
pip install -e .If you need to control dependency versions manually, refer to pyproject.toml.
For GPU, install an appropriate torch wheel first, then run pip install -e ..
from gravity import PipelineConfig, run_pipeline
cfg = PipelineConfig(
raw_counts="data/pancreas_long.csv",
workdir="gravity_outputs",
prior_network="prior_data/network_mouse.zip",
accelerator="gpu",
devices=[0, 1],
strategy="ddp",
make_plot=True,
plot_genes=["GCG", "INS1"],
)
outputs = run_pipeline(cfg)
print(outputs)If your dataset starts as an AnnData object, convert it once to the long-format CSV that GRAVITY consumes:
from gravity import export_intermediate_from_h5ad
export_intermediate_from_h5ad(
input_h5ad="data/postprocessed.h5ad",
output_csv="data/hair.csv",
n_top_genes=1000,
embed_key="X_umap",
celltype_key="celltype",
)This helper mirrors the workflow in gravity/smoke_test_hair.py and persists
embeddings/clusters alongside spliced/unspliced counts.
Upon completion, workdir contains (names configurable via PipelineConfig):
combine.csv— preprocessed wide tablestage1.csv,stage1.ckpt— cell‑wise stage outputsfuture_positions.npy— predicted future positionsstage2.csv,stage2.ckpt— gene‑wise stage outputsattentions/— TF score matrices and cell‑type mean attention networksvelocity_plots/*.png— cell‑ and gene‑level velocity plots (if enabled)
from gravity import (
preprocess_counts,
CellStageConfig, train_cell_stage,
GeneStageConfig, train_gene_stage,
)
from gravity.tools.future import estimate_future_positions
from gravity.plotting.velocity import plot_velocity_cell, plot_velocity_gene
# 1) Preprocess
middle_csv = preprocess_counts("data/pancreas_long.csv", "gravity_outputs/combine.csv")
# 2) Cell‑wise training (multi‑GPU optional)
cell_cfg = CellStageConfig(
raw_counts="data/pancreas_long.csv",
middle_csv=str(middle_csv),
prior_network="prior_data/network_mouse.zip",
output_dir="gravity_outputs",
accelerator="gpu",
devices=[0, 1],
strategy="ddp",
)
stage1 = train_cell_stage(cell_cfg)
# 3) Future position estimation
estimate_future_positions(stage1["stage1_csv"], "gravity_outputs/future_positions.npy")
# 4) Gene‑wise fine‑tuning
gene_cfg = GeneStageConfig(
raw_counts="data/pancreas_long.csv",
middle_csv=str(middle_csv),
stage1_checkpoint=str(stage1["checkpoint"]),
future_positions="gravity_outputs/future_positions.npy",
prior_network="prior_data/network_mouse.zip",
output_dir="gravity_outputs",
accelerator="gpu",
devices=[0, 1],
strategy="ddp",
)
stage2 = train_gene_stage(gene_cfg)
# 5) Visualization (cell‑ and gene‑level)
plot_velocity_cell(str(stage2["stage2_csv"]), output_path="gravity_outputs/cell_velocity.png")
plot_velocity_gene(str(stage2["stage2_csv"]), gene="GCG", output_path="gravity_outputs/gcg_velocity_expression.png")PipelineConfiggene_subset: restrict the gene set used for trainingstage1_epochs/stage2_epochs: number of epochs per stageval_fraction_stage1/val_fraction_stage2: optional hold-out ratio (default0.0, meaning no validation split)future_tau: scaling factor controlling the radius for future-neighbor searchaccelerator/devices/strategy: forwarded to PyTorch Lightning (e.g.,accelerator="gpu",devices=[0,1],strategy="ddp")make_plot,plot_genes: enable plotting and choose genes;'all'plots every gene
CellStageConfigattention_output: whether to export TF attention matricesattention_topk: number of TFs kept per cell
GeneStageConfigfuture_positions: path to the.npyproduced by future projectionstage1_checkpoint: cell‑wise checkpoint
The long‑format CSV must include at least: cellID, gene_name, unsplice,
splice, embedding1, embedding2. The optional column clusters is used for
coloring in plots and summary tables. Prior network archive
prior_data/network_mouse.zip should match the original GRAVITY prior format.
- Out‑of‑memory (OOM): reduce
batch_sizeor provide a smallergene_subset. - No GPU available: Lightning falls back to CPU; training will be slower.
- Optional deps missing (e.g., SciPy): plotting/sampling may disable gracefully.
- Verbosity: use
from gravity.utils import log_verboseor project‑level toggles.
If this package contributes to your research, please cite the GRAVITY paper, “GRAVITY: Dynamic gene regulatory network-enhanced RNA velocity modeling for trajectory inference and biological discovery.” Include version, environment details, and key configuration options in your methods section when possible.
Please open issues/PRs with reproduction steps and sample commands. This subpackage is MIT‑licensed as declared in the project’s metadata.
An up‑to‑date Chinese version is available as README_zh.md in this folder.