(Copyright 2024) by Jiahe Lin, Huitian Lei and George Michailidis; paper accepted in TMLR, 2024. [link to paper]
Assume anaconda/miniconda/miniforge has already been installed. To set up the environment, proceed with the following commands:
conda create -n vae-gc python=3.9
conda activate vae-gc
conda install pyyaml numpy pandas scipy scikit-learn
conda install matplotlib seaborn
pip install pytorch-lightning torch
See also requirements.txt
.
To verify that your GPU is up and running:
python -c "import torch; print(torch.cuda.is_available())"
We outline the major components in this repository for ease of navigation.
-
bin/
: shell scripts for execution; see also section Experiments in the paper -
src/
:networks.py
: torch-based modules whose forward passes correspond to the proposed methods. Model supported:OneLayer
andTwoLayer
, resp. for single and multi-entity VAE-based models.models_pl.py
: pl.lightning-based modules that encapsulate the forward-backward propagation pipeline for running VAE-based multi-entity/single-entity methods on a given datasetmodels_pl_sim.py
: similar to the above but specifically for synthetic datasets, where the underlying true GC graphs are known. In particular, dataloader and graph evaluation (throughtorchmetrics
) is integrated in every step of training, to facilitate model development and tracking.- Of note, the printed metrics during training do not correspond to the final metrics presented in the paper (e.g., AUROC and AUPRC). In particular, for the case where graph type is numeric, it calculates a Pearson Correlation-type metric between the truth and the estimates at the individual sample level.
datasets/
: objects withtorch.utils.data.dataset.Dataset
being the base class, to read a specific (type) of dataset from disk so that it can be loaded properly throughDataLoader
later on.- See also the demo in Run Your Own Datasets for a concrete example
-
generator/
: scripts used for generating synthetic datasimulator/
: various simulator objects for synthetic data generation of the corresponding setting
-
utils/
: utilitiessimrunner.py
,realrunner.py
: wrapper functions for training models for synthetic and real data experimentsutils_data.py
: utility functions for data processing and trajectory parsingutils_eval.py
: for results evaluation
-
configs/
: data parameters (e.g., # of entities, # of nodes, trajectory length, etc) and hyperparameters for all VAE-based methods. Some naming convention:- no suffix: this is the base config and the synthetic data setting parameters are specified here in the
data_params
section. The remaining sections correspond to multi-entity learning using a node-centric decoder - pattern
*_edge
: edge-centric decoder - pattern
*_oneSub
: parameters corresponding to single-entity learning
One can alternatively deviates from these naming conventions, use any customized config file name and pass it with
--config
in the run command to override the default ones. - no suffix: this is the base config and the synthetic data setting parameters are specified here in the
-
root/
:run_sim.py
: script for running synthetic data experiments using multi-entity VAE-based methodrun_simOne.py
: script for running synthetic data experiments using single-entity VAE-based methodtrain.py
: script for running real data experiments using multi-entity VAE-based methodtrain_one.py
: script for running real data experiments using single-entity VAE-based method
- Run synthetic data experiments based on the VAE-based methods, including the proposed one (multi-entity learning) and its single-entity counterpart:
cd bin mkdir -p logs ## data generation is included by default; toggle (in the shell script) to false if not needed (say, the data has already been generated) ## argvs: SETTING_NAME, GPU_ID, CONFIG_VERSION (default to none, indicating no suffix) ## choose SETTING_NAME amongst {Lorenz96, LinearVAR, NonLinearVAR, Lotka, Springs5} bash run-simVAE.sh [SETTING_NAME] 0 &>logs/[SETTING_NAME]_run.log
- Evaluate a single run (a specific experiment setting and a single data seed):
## the following command should be executed in the root dir python -u eval_sim.py --ds_str=[SETTING_NAME] --seed=0
- Evaluate all data replicates for a specific experiment setting:
cd bin ## argv: ds_str bash eval_sim.sh [SETTING_NAME]
Data is available from rsed2017-dataverse. Once the data are downloaded, they should be put under data_real/EEG_CSV_files/
, with the filenames being Subject[ID]_EO.csv
or Subject[ID]_EC.csv
, depending on the underlying neurophysiological experiment setting.
- Prepare the raw datasets so that the long trajectories are parsed for the VAE-based method to consume
python -u process_EEG.py --ds_str='EEG_EC,EEG_EO'
- Run the experiments
cd bin ## argv: ds_str; choose between EEG_EO and EEG_EC bash run-real.sh EEG_EO &>logs/EO_log.log
See ./demo.ipynb
. In the notebook, we generate a demo dataset and outline the steps/files required to utilize our end-to-end training pipeline.
To cite this work:
@article{Lin2024VAE,
title = {A VAE-based Framework for Learning Multi-Level Neural Granger-Causal Connectivity},
author = {Lin, Jiahe and Lei, Huitian and Michailidis, George},
year = {2024},
journal = {Transactions on Machine Learning Research},
issn={2835-8856},
url = {https://openreview.net/pdf?id=kNCZ95mw7N}
}
- For questions on the paper and/or collaborations based on the methods (extensions or applications), contact George Michailidis
- For questions on the code implementation, contact Jiahe Lin and/or Huitian Lei
We list below the repositories referenced as this codebase was being developed
- Part of the design in
Traj2GraphEncoder
andGraph2TrajGC*Decoder
referenced the implementation in NRI-encoder-decoder - Data generation in synthetic data experiments is modified based on Lorenz96, LoktaVolterra and Springs5
Below lists the competitor models considered in the paper and their corresponding repositories
-
VAE-based neural relational learning that identifies edge types for a single entity
- NRI (Kipf et al., 2018) [link to NRI paper] [link to NRI repo]
-
Prediction-model based Granger-causal estimation for a single-entity learning; code therein has been referenced
- GVAR (Marcinkevics and Vogt, 2021) [link to GVAR paper] [link to GVAR repo]
- NGC (Tank et al., 2021) [link to NGC paper] [link to NGC repo]
- TCDF (Nauta et al., 2019) [link to TCDF paper] [link to TCDF repo]