# Getting Startted with GRN inference using diffusion model

Diffusion model has been widely used in generative AI, especially in the vision domain. In our paper, we proposed RegDiffusion, a diffusion based model for GRN inference. Compared with previous model, RegDiffusion completes inference within a fraction of time and yield better benchmarking results. 

In this tutorial, we provide an example of running GRN inference using RegDiffusion and generating biological insights from the inferred network. 

## Requirements

We will need the python package `regdiffusion` for GRN inference. For accelerated inference speed, you may want to run `regdiffusion` on GPU devices with the latest CUDA installation. 

In [17]:
import regdiffusion as rd
import numpy as np

# For displaying visualization in notebook
from IPython.core.display import HTML

## Data loading

The input of `regdiffusion` is simply a single-cell gene expression matrix, where the columns are genes and rows are cells. We expect you to log transform your data. RegDiffusion is capable to infer GRNs among 10,000+ genes (depending on GPU hardware) within minutes so there is no need to apply heavy gene filtering. The only genes you may want to remove are genes that are not expressed at all (total raw count on all cells == 0). 

The `regdiffusion` package comes with a set of preprocessed data, including the [BEELINE benchmarks](https://pubmed.ncbi.nlm.nih.gov/31907445/), [Hammond microglia](https://pubmed.ncbi.nlm.nih.gov/30471926/) in male adult mice, and another labelled microglia subset from a [mice cerebellum atlas project](https://singlecell.broadinstitute.org/single_cell/study/SCP795/a-transcriptomic-atlas-of-the-mouse-cerebellum#study-summary). 

Here we use the `mESC` data from the BEELINE benchmark. The `mESC` data comes from [Mouse embryonic stem cells](https://www.nature.com/articles/s41467-018-02866-0). It has 421 cells and 1,620 genes. 

If you want to see the inference on a larger network with 14,000+ genes and 8,000+ cells, check out the other example. 

In [2]:
bl_dt, bl_gt = rd.data.load_beeline(
    benchmark_data='mESC', benchmark_setting='1000_STRING'
)

Here, `load_beeline` gives you a tuple, where the first element is an anndata of the single cell experession data and the second element is an array of all the ground truth links (based on the STRING network in this case). 

In [3]:
bl_dt

AnnData object with n_obs × n_vars = 421 × 1620
    obs: 'cell_type', 'cell_type_index'

In [4]:
bl_gt

array([['KLF6', 'JUN'],
       ['JUN', 'KLF6'],
       ['KLF6', 'ATF3'],
       ...,
       ['SIN3A', 'TET1'],
       ['MEF2C', 'TCF12'],
       ['TCF12', 'MEF2C']], dtype=object)

# GRN Inference

You are recommended to use the provided trainer to train a RegDiffusion Model. You need to provide the expression data in a numpy array to the trainer. 

During the training process, the training loss and the average amount of change on the adjacency matrix are provided on the progress bar. The model converges when the step change n the adjacency matrix is near-zero. By default, the `train` method will train the model for 1,000 iterations. It should be sufficient in most cases. If you want to keep training the model afterwards, you can simply call the `train` methods again with the desired number of iterations. 

In [5]:
rd_trainer = rd.RegDiffusionTrainer(bl_dt.X)
rd_trainer.train()

Training loss: 0.251, Change on Adj: -0.000: 100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 119.87it/s]


When ground truth links are avaiable, you can test the inference performance by setting up an evaluator. You need to provide both the ground truth links and the gene names. Note that the order of the provided gene names here should be the same as the column order in the expression table (and the inferred adjacency matrix). 

In [6]:
evaluator = rd.evaluator.GRNEvaluator(bl_gt, bl_dt.var_names)
inferred_adj = rd_trainer.get_adj()
evaluator.evaluate(inferred_adj)

{'AUROC': 0.6128314776240764,
 'AUPR': 0.051967174752646526,
 'AUPRR': 2.443609451710688,
 'EP': 750,
 'EPR': 4.159291109741152}

## GRN object

In order to facilitate the downstream analyses on GRN, we defined an `GRN` object in the `regdiffusion` package. You need to provide the gene names in the same order as in your expression table.

In [7]:
grn = rd_trainer.get_grn(bl_dt.var_names)
grn

Inferred GRN: 1,620 TFs x 1,620 Target Genes

You can easily export the GRN object as a HDF5 file. Right now, HDF5 is the only supported export format but more formats will be added in the future.

In [8]:
grn.to_hdf5('demo_mESC_grn.hdf5')

## Inspecting the local network around particular genes

In this example, we run GRN inference on one of the BEELINE benchmark single cell datasets. The provided ground truth makes it possible to validate through standard statistical metrics. However, such ground truth is in fact very noisy and incomplete. 

In our paper, we proposed a method to visualize the local 2-hop to 3-hop neighborhood around selected genes. We find that genes with similar function will be topologically bounded together and form obvious functional groups. Inspecting these local networks gives us confidence that the inferred networks are biologically meaningful. Here we show an example of using these inferred networks to discover novel findings. 

### Step 1. Discover target genes

There are many ways to discover target genes to study the local networks. For example, you can put your lens on the most varied genes or the top genes that are up/down regulated, using any methods you prefer. Here, we simply pick the gene that has the strongest single regulation based on the inferred adjacency matrix. 

In [9]:
grn.gene_names[np.argmax(grn.adj_matrix.max(1))]

'HIST1H1D'

### Step 2. Visualize the local network around the selected gene

The `visualize_local_neighborhood` method of an `GRN` object extracts the 2-hop top-k neighborhood around a selected gene and visualize it using `pyvis`/`vis.js`. The default `k` here is 20. However, in cases when the regulatory relationships are strong and bidirectional, `k=20` only gives a very simple network. You may increase the magnitude of `k` to find some meaningful results to you. 

In [11]:
import networkx as nx
from sklearn.cluster import KMeans
from node2vec import Node2Vec

adj_table = grn.extract_node_2hop_neighborhood('HIST1H1D', 40)
nxg = nx.from_pandas_edgelist(adj_table)

node2vec = Node2Vec(nxg, dimensions=64, walk_length=30, num_walks=200, 
                    workers=4, seed=123)
model = node2vec.fit(window=10, min_count=1, batch_words=4)

node_embeddings = [model.wv.get_vector(str(node)) for node in nxg.nodes()]

kmeans = KMeans(n_clusters=4, random_state=0).fit(node_embeddings)
node_labels = kmeans.labels_

print("Clusters:")
for cluster_id in range(max(node_labels) + 1):
    cluster_nodes = [node for node, label in zip(nxg.nodes(), node_labels) if label == cluster_id]
    print(f"Cluster {cluster_id}: {','.join(cluster_nodes)}")

  from .autonotebook import tqdm as notebook_tqdm
Computing transition probabilities: 100%|████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:00<00:00, 1448.14it/s]
Generating walks (CPU: 4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 154.17it/s]
Generating walks (CPU: 2): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 152.38it/s]
Generating walks (CPU: 1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 152.20it/s]
Generating walks (CPU: 3): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 151.50it/s]
  super()._check_params_vs_input(X, default_n_init=10)


Clusters:
Cluster 0: HIST1H1D,HIST1H2BN,HIST1H2BK,HIST1H1B,HIST1H2BL,HIST1H2AK,HIST1H1A,HIST1H2AC,HIST1H2BF,HIST1H4K,HIST1H3H,HIST1H2AF,HIST1H2AI,HIST1H2AG,HIST1H2BB,DNMT1,BRCA1,KNTC1,RAD54B,GM44335,FBXO5,TAF1,ABTB1,DEK,KANK3
Cluster 1: MCM10,TIMELESS,RAD51,RBBP4,RRM2,MCM6,PCNA,E2F1,UHRF1,MCM4,MCM5,UNG,MCM7,MCM3,ZFP367,EZH2,BARD1
Cluster 2: TOP2A,MAZ,POLR3B,GM10184,ATF4
Cluster 3: GM26448,EGR1


In [31]:
g = grn.visualize_local_neighborhood('HIST1H1D', k=40, node_group_dict=gene_group_dict)
HTML(g.generate_html())