# Getting Started with GRN-VAE

This document provides an end-to-end demonstration on how to infer GRN with our implementation of GRN-VAE. 

In [1]:
import numpy as np
from data import load_beeline
from logger import LightLogger
from runner import runGRNVAE, runGRNVAE_single_opt
from evaluate import extract_edges, get_metrics
import seaborn as sns
import matplotlib.pyplot as plt

## Model Configurations

First you need to define some configs for running the model. We suggest you start with the following set of parameters. The three key concepts proposed in the GRN-VAE paper are controlled by the following parameters. 

- `delayed_steps_on_sparse`: Number of delayed steps on introducing the sparse loss. 
- `dropout_augmentation`: The proportion of data that will be randomly masked as dropout in each traing step.
- `train_on_non_zero`: Whether to train the model on non-zero expression data

Note that here we also use lower learning rates for a longer period of time when model is trained on non-zero data. 

In [2]:
configs = {
    # Train/Test split
    'train_split': 1.0,
    'train_split_seed': None, 
    
    # Neural Net Definition
    'hidden_dim': 128,
    'z_dim': 1,
    'train_on_non_zero': True,
    'dropout_augmentation': 0.1,
    'cuda': True,
    
    # Loss
    'alpha': 100,
    'beta': 1,
    'delayed_steps_on_sparse': 30,
    
    # Neural Net Training
    'batch_size': 64,
    'n_epochs': 800,
    'eval_on_n_steps': 10,
    'early_stopping': 0,
    'lr_nn': 5e-5,
    'lr_adj': 5e-6,
    'K1': 1,
    'K2': 1
}

To train the model in the original DeepSEM flavor, you can modify the list above with the following configurations. Note that even in this case, our implementation has several differences compared with the original DeepSEM implementation. 

```
'train_on_non_zero': False
'dropout_augmentation': 0.0
'delayed_steps_on_sparse': 0

'n_epochs': 120
'lr_nn': 1e-4
'lr_adj': 2e-5
'K1': 2
```

## Data loading
[BEELINE benchmarks](https://github.com/Murali-group/Beeline) could be loaded by the `load_beeline` function, where you specify where to look for data and which benchmark to load. If it's the first time, this function will download the files automatically. 

The `data` object exported by `load_beeline` is an [annData](https://anndata.readthedocs.io/en/stable/generated/anndata.AnnData.html#anndata.AnnData) object read by [scanpy](https://scanpy.readthedocs.io/en/stable/). The `ground_truth` object includes ground truth edges based on the BEELINE benchmark but it's not required for network inference. 

When you use GRN-VAE on a real world data to discover noval regulatory relationship, here are a few tips on preparing your data:

- You can read in data in any formats but make sure your data has genes in the column/var and cells in the rows/obs. Transpose your data if it's necessary. 
- Find out the most variable genes. Unlike many traditional algorithm, GRN-VAE has the capacity to run on large amount of data. Therefore you can set the number of variable genes very high. As described in the paper, we used 5,000 for our Hammond experiment. The only reason why we need this gene filter is to help converge the model.
- Normalize your data. A simple log transformation is good enough. 

In [3]:
# Load data from a BEELINE benchmark
data, ground_truth = load_beeline(
    data_dir='data', 
    benchmark_data='hESC', 
    benchmark_setting='500_STRING'
)

## Model Training

Model training is simple with the `runGRNVAE` function. As said above, if ground truth is not available, just set `ground_truth` to be `None`.

In [4]:
logger = LightLogger()
# runGRNVAE initializes and trains a GRNVAE model with the configs specified. 
vae = runGRNVAE(
    data.X, configs, ground_truth=ground_truth, logger=logger)

100%|██████████| 800/800 [01:41<00:00,  7.88it/s]


The learned adjacency matrix could be obtained by the `get_adj()` method. For BEELINE benchmarks, you can get the performance metrics of this run using the `get_metrics` function. 

In [5]:
A = vae.get_adj()
get_metrics(A, ground_truth)

{'AUPR': 0.053550111348527926,
 'AUPRR': 2.2263864475812216,
 'EP': 448,
 'EPR': 4.375367487418227}