This is an R implementation of the following paper:
[1] Gao, M., Ding, Y., Aragam, B. (2020). A polynomial-time algorithm for learning nonparametric causal graphs (NeurIPS 2020).
If you find this code useful, please consider citing:
@inproceedings{gao2020npvar,
author = {Gao, Ming and Ding, Yi and Aragam, Bryon},
booktitle = {Advances in Neural Information Processing Systems},
title = {{A polynomial-time algorithm for learning nonparametric causal graphs}},
year = {2020}
}
The NPVAR algorithm is an algorithm for learning the structure of a directed acyclic graph (DAG) G that represents a potentially high-dimensional joint distribution P(X). In other words, given n samples from P(X), NPVAR learns the DAG G that generated these samples. In general, this problem is not well-defined since the DAG is not unique, however, we consider the setting where the residual variances E[var(Xj|pa(j))] are all approximately equal. In this setting, NPVAR is a polynomial-time algorithm for provably recovering the DAG G.
- R
- Package
np - Package
mgcv - Package
igraph
NPVAR.RMain function to run our algorithm, see demo belowutils.RSome helper functions to simulate data and evaluate resultsANM_gp.RFile used to generate data from a Gaussian process model, see references
Generate a ER graph with 5 nodes and 5 expected edges. Then generate data by a SIN model with noise variance sigma=0.5 and sample size 1000.
source('NPVAR.R')
source('utils.R')
data = data_simu(graph_type = 'ER-SIN', errvar = 0.5, d = 5, n = 1000, s0 = 1, x2 = T)
X = data$X
G = data$G
X2 = data$X2Apply NPVAR through 3 implementations:
- Naively recover ordering node by node
- Recover node layer by layer with fixed
eta - Recover node layer by layer using
X2to determineetaadaptively
result1 = NPVAR(X)
result2 = NPVAR(X, layer.select = T, eta = 0.01)
result3 = NPVAR(X, layer.select = T, x2 = X2)Check outputs
print(result1)
print(result2$ancestors)
print(result2$layers)
print(result3$ancestors)
print(result3$layers)Furthermore, infer adjacency matrix by significance given by gam
est = prune(X, result1)
print(est)
print(sum(abs(est - G)))