In [None]:
%matplotlib inline
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina')
# sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set_palette(sns.color_palette())

import pandas as pd
import pickle
from tqdm.notebook import tqdm
from plotnine import *
import igraph
import itertools
from functools import reduce
from scipy.stats import chisquare

from mcmcmd.samplers import *
from mcmcmd.tests import *

import grakel

# Helper functions

In [5]:
def b_nk(n, k, A, B, B_terms):
    if (n, k) not in B:
        if n == k:
            B_terms[(n, k)] = onp.array([1])
            B[(n, k)] = 1
        else:
            s = onp.arange(1, n-k+1)
            B_terms[(n, k)] = ((2**k-1)**s * 2**(k*(n-k-s)) * A[n-k-1, :][s-1]).astype('int')
            B[(n, k)] = int(B_terms[(n, k)].sum())
    return B[(n, k)]

def count_DAG(num_nodes):
    A = onp.identity(num_nodes, dtype='int')
    B = {}
    B_terms = {}

    for n in range(1, num_nodes+1):
        for k in range(1, n+1):
            A[n-1, k-1] = comb(n, k, exact=True) * b_nk(n, k, A, B, B_terms)
    
    # A[n-1, k-1] is the number of DAGs with n nodes and k outpoints (roots).
    return A

# Sample a DAG uniformly using the enumeration method from Kuipers and Moffa 2013.
def sample_uniform_DAG(num_nodes, rng=None):
    assert num_nodes > 0 and num_nodes <= 10 # overflow
    if rng is None:
        rng = onp.random.default_rng()

    A = count_DAG(num_nodes)
    
    K = [] # List of k
    n = num_nodes
    
    r = rng.choice(A[n-1, :].sum())+1
    
    k = 1
    while r > A[n-1, k-1]:
        r -= A[n-1, k-1]
        k += 1
    K.append(k)

    r = int(onp.ceil(r/comb(n, k, exact=True)))
    m = n-k
    while m > 0:
        s = 1
        t = int((2**k - 1)**s * 2**(k*(m-s)) * A[m-1, s-1])
        while r > t:
            r -= t
            s += 1
            t = (2**k - 1)**s * 2**(k*(m-s)) * A[m-1, s-1]
        
        r = int(onp.ceil(r/(comb(m, s, exact=True) * t/A[m-1, s-1])))
        
        n = m
        k = s
        m = n-k
        K.append(k)

    Q = onp.zeros(shape=(num_nodes, num_nodes), dtype='int')
    j = K[-1]
    for i in range(len(K)-1, 0, -1):
        for l in range(j-K[i], j):
            bln_zeroCol = True
            while bln_zeroCol:
                for m in range(j, j+K[i-1]):
                    Q[m, l] = rng.choice(2)
                    if Q[m, l] == 1:
                        bln_zeroCol = False
                        
            for m in range(j+K[i-1], num_nodes):
                Q[m, l] = rng.choice(2)
        j += K[i-1]

    node_labels = rng.permutation(num_nodes)
    Q = Q[node_labels, :][:, node_labels]
    return Q

# Check if DAG is cyclic
def isCyclic(adj_matrix):
    if onp.diag(adj_matrix).sum() != 0:
        return True
    
    num_nodes = adj_matrix.shape[0]
    nodes = onp.arange(num_nodes)
    
    done = False
    notLeaf = adj_matrix.sum(1).astype('bool')
    while not done: 
        adj_matrix = adj_matrix[notLeaf, :][:, notLeaf]
        notLeaf = adj_matrix.sum(1).astype('bool')
        if notLeaf.sum() == len(notLeaf):
            done = True
    
    return adj_matrix.shape[0] != 0

# Sample from likelihood
def sample_DAG_data(adj_matrix, N=1, epsilon=1, rng=None):
    if rng is None:
        rng = onp.random.default_rng()
    
    num_nodes = adj_matrix.shape[0]
    nodes = onp.arange(num_nodes)

    x = onp.zeros(shape=(N, num_nodes))
    node_gen = onp.zeros(num_nodes)
    node_gen_count = 0

    isRoot = ~adj_matrix.sum(0).astype('bool')
    roots = nodes[isRoot]
    if len(roots) == 0:
        raise ValueError('adj_matrix encodes a cyclic graph!')

    children = onp.zeros(num_nodes, dtype='int')
    for r in roots:
        x[:, r] = rng.normal(0, epsilon, size = N)
        node_gen[r] = 1
        node_gen_count += 1
        children += adj_matrix[r, :]

    while node_gen_count < num_nodes:
        for child in nodes[children.astype('bool')]:
            if node_gen[child] == 1:
                raise ValueError('adj_matrix encodes a cyclic graph!')
            parents = nodes[adj_matrix[:, child] == 1]
            if node_gen[parents].sum() == len(parents):
                x[:, child] = rng.normal(x[:, parents].sum(1), epsilon, size = N)
                node_gen[child] = 1
                node_gen_count += 1
                children += adj_matrix[child, :]
                children[child] = 0
    return x

# Calculate log-evidence
def log_evidence(X, adj_matrix, epsilon):
    num_nodes = adj_matrix.shape[0]
    n, p = X.shape
    mu = (adj_matrix.T.reshape(1, num_nodes, num_nodes) @ X.reshape(n, p, 1)).sum(2)
    return norm.logpdf(loc=mu, scale=epsilon, x=X).sum()

# Modify an edge
def changeEdge_DAG(adj_matrix, i, j, change_type='toggle'):
    assert change_type in ['toggle', 'reverse']
    proposal = adj_matrix.copy()
    
    if i == -1 and j == -1:
        return proposal

    if change_type == 'reverse':
        if adj_matrix[i, j] == 1:
            proposal[i, j] = 0
            proposal[j, i] = 1
        elif adj_matrix[i, j] == 0:
            return proposal
#             raise ValueError('Cannot reverse non-existent edge')
        else:
            raise ValueError('adj_matrix is non-binary')
    else:
        if adj_matrix[i, j] == 1:
            proposal[i, j] = 0
        elif adj_matrix[i, j] == 0:
            proposal[i, j] = 1
        else:
            raise ValueError('adj_matrix is non-binary')
    
    return proposal

# Enumerate all DAGs that can be reached by adding/deleting/reversing edges. Optionally sample one uniformly at random
def neighbors_DAG(adj_matrix, return_sample=False, rng=None):
    if rng is None:
        rng = onp.random.default_rng()
        
    nodes = onp.arange(adj_matrix.shape[0])
    edges = [(i, j) for i, j in itertools.product(nodes, nodes)]
    edges += [(-1, -1)] # no change

    lst_proposals = []
    for i, j in edges:

        # No change
        if i == -1 and j == -1:
            lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})
        
        # Skip diagonals
        if i == j:
            continue
        
        # Add DAG reached by toggling edge i,j
        proposal = changeEdge_DAG(adj_matrix, i, j, change_type='toggle')
        if not isCyclic(proposal):
            lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})
        
        # Add DAG reached by reversing edge i,j
        if adj_matrix[i, j] == 1:
            proposal = changeEdge_DAG(adj_matrix, i, j, change_type='reverse')
            if not isCyclic(proposal):
                lst_proposals.append({'i':i, 'j':j, 'change_type':'reverse'})

    k = len(lst_proposals)
    if return_sample == True:
        args = rng.choice(lst_proposals)
        proposal = changeEdge_DAG(adj_matrix, **args)
        return k, proposal
    else:
        return k

# Row-wise
def array_to_strings(z):
    z = z.astype('int').astype('str')
    z = onp.hsplit(z, z.shape[1])
    z = reduce(onp.char.add, z)
    return z

def count_sample_DAG(z):
    z = array_to_strings(z)
    vals, counts = onp.unique(z, return_counts = True)
    return vals, counts

# Generate all possible kernel evaluations for caching
def graph_kernel_cache(num_nodes):
    num_DAGs = count_DAG(num_nodes)[-1, :].sum()
    sample_size = 5*num_DAGs
    graph_ids = count_sample_DAG(onp.vstack([sample_uniform_DAG(num_nodes).reshape(1,num_nodes**2) for _ in range(sample_size)]))[0]
    while len(graph_ids) != num_DAGs:
        graph_ids=onp.unique(onp.hstack([graph_ids, count_sample_DAG(onp.vstack([sample_uniform_DAG(num_nodes).reshape(1,num_nodes**2) for _ in range(sample_size)]))[0]]))

    graphs = [grakel.Graph(initialization_object=onp.array(list(g)).astype('int').reshape(num_nodes, num_nodes)) for g in graph_ids]

    K = grakel.RandomWalk()
    index_table = dict(zip(graph_ids, onp.arange(num_DAGs))) # lookup table
    kernel_table = K.fit_transform(graphs)
    return kernel_table, index_table

# Calculate graph kernel using adjacency matrices
class graph_kernel(kernel):
    def __init__(self, X, Y, kernel_type='random_walk', cache=True, cache_index_table=None, cache_kernel_table=None, **kwargs):
        assert X.shape[1] == Y.shape[1]
        assert len(X.shape) == 2 and len(X.shape) == len(Y.shape)
        assert int(onp.sqrt(X.shape[1]))**2 == X.shape[1] # adjacency matrix must be square
        if kernel_type is None:
            kernel_type = 'random_walk'
        assert kernel_type in ['random_walk']

        self._ids_X = array_to_strings(X).flatten()
        self._ids_Y = array_to_strings(Y).flatten()

        self._num_nodes = int(onp.sqrt(X.shape[1]))
        self._X = [grakel.Graph(initialization_object=X[i,:].reshape(self._num_nodes, self._num_nodes)) for i in range(X.shape[0])]
        self._Y = [grakel.Graph(initialization_object=Y[i,:].reshape(self._num_nodes, self._num_nodes)) for i in range(Y.shape[0])]
        self._kernel_type = kernel_type
        if self._kernel_type == 'random_walk':
            self._K = grakel.RandomWalk(**kwargs)
        
        self._cached = False
        if cache is True or (cache_index_table is None and cache_kernel_table is None):
            self.cache(index_table=cache_index_table, kernel_table=cache_kernel_table)

        pass

    @property
    def params(self):
        return self._kernel_type

    def set_params(self, params):
        self._kernel_type = params
        pass

    def cache(self, index_table=None, kernel_table=None):
        if index_table is not None and kernel_table is not None:
            self._index_table = index_table
            self._kernel_table = kernel_table
        else:
            num_DAGs = count_DAG(self._num_nodes)[-1, :].sum()
            sample_size = 5*num_DAGs
            graph_ids = count_sample_DAG(onp.vstack([sample_uniform_DAG(self._num_nodes).reshape(1,self._num_nodes**2) for _ in range(sample_size)]))[0]
            while len(graph_ids) != num_DAGs:
                graph_ids=onp.unique(onp.hstack([graph_ids, count_sample_DAG(onp.vstack([sample_uniform_DAG(self._num_nodes).reshape(1,self._num_nodes**2) for _ in range(sample_size)]))[0]]))

            graphs = [grakel.Graph(initialization_object=onp.array(list(g)).astype('int').reshape(self._num_nodes, self._num_nodes)) for g in graph_ids]

            self._index_table = dict(zip(graph_ids, onp.arange(num_DAGs))) # lookup table
            self._kernel_table = self._K.fit_transform(graphs)
        self._cached = True
        pass


    def eval(self):
        if self._cached == False:
            return self._K.fit(self._X).transform(self._Y)
        else:
            kernel_eval = onp.zeros(shape=(len(self._ids_X), len(self._ids_Y)))
            for i, id_X in enumerate(self._ids_X):
                for j, id_Y in enumerate(self._ids_Y):
                    kernel_eval[i,j] = self._kernel_table[self._index_table[id_X], self._index_table[id_Y]]
            return kernel_eval

    def f_kernel(self, x, y, **kwargs):
        assert len(x.shape) == len(y.shape) and len(x.shape) == 1
        if self._cached == False:
            x_graph = [grakel.Graph(initialization_object=x.reshape(self._num_nodes, self._num_nodes))]
            y_graph = [grakel.Graph(initialization_object=y.reshape(self._num_nodes, self._num_nodes))]
            return float(self._K.fit(x_graph).transform(y_graph))
        else:
            id_x = array_to_strings(x.astype('int').reshape(1,-1))[0][0]
            id_y = array_to_strings(y.astype('int').reshape(1,-1))[0][0]
            return self._kernel_table[self._index_table[id_x], self._index_table[id_y]]

# Setup

For simplicity, assume the model parameters are fixed. Given graph structure $\mathcal{G}$, data $\mathbf{X}$, and $\epsilon=1$, let the root nodes be standard normal random variables
\begin{equation}
    x_{r} \sim \mathcal{N}(0, \epsilon^2)
\end{equation}
and let each child node have mean equal to the sum of their parents
\begin{equation}
    x_{j}|\mathbf{pa}(x_{j}) \sim \mathcal{N}(\sum_{z \in \mathbf{pa}(x_{j})} z, \epsilon^2)
\end{equation}

The model evidence is
\begin{equation}
    p(\mathbf{X}|\mathcal{G}) = \prod_{i=1}^{n} \prod_{j=1}^{p} p(x_{ij}|\mathbf{pa}(x_{ij})) 
\end{equation}

In [3]:
class linear_gaussian_sampler(model_sampler):
    def __init__(self, **kwargs):
        self._N = 1
        self._num_nodes = 3
        self._epsilon = 1
        super().__init__(**kwargs)
        
        # Check inputs
        for attr in ['_N', '_num_nodes', '_epsilon']:
            assert hasattr(self, attr)
        
        assert type(self._N).__name__ == 'int'
        assert type(self._num_nodes).__name__ == 'int'
        assert self._epsilon > 0
        
#         self._MH = []
        
        pass

    @property
    def sample_dim(self):
        return self._N*self._num_nodes + self._num_nodes**2

    @property
    def adj_matrix_indices(self):
        return onp.arange(self._N*self._num_nodes, self.sample_dim)
    
    # Exclude indices corresponding to the diagonal of adj matrix
    @property
    def theta_indices(self):
        ind = onp.arange(self._N*self._num_nodes, self.sample_dim)
        ind_diag = onp.arange(0, self._num_nodes**2, self._num_nodes+1)
        return ind[onp.setdiff1d(onp.arange(self._num_nodes**2), onp.arange(0, self._num_nodes**2, self._num_nodes+1))]

    def drawPrior(self, rng=None):
        if rng is None:
            rng = onp.random.Generator(onp.random.MT19937())
        self._G = sample_uniform_DAG(self._num_nodes, rng)
        return self._G.reshape(1, self._num_nodes**2).flatten()

    def drawLikelihood(self, rng=None):
        if rng is None:
            rng = onp.random.Generator(onp.random.MT19937())
        self._X = sample_DAG_data(self._G, self._N, self._epsilon, rng)
        return self._X.reshape(1,self._N*self._num_nodes).flatten()

    def drawPosterior(self, rng=None):
        if rng is None:
            rng = onp.random.Generator(onp.random.MT19937())
        
        num_neighbors, proposal = neighbors_DAG(self._G, True, rng)
        num_neighbors_proposal = neighbors_DAG(proposal)
        MH = num_neighbors/num_neighbors_proposal * onp.exp(log_evidence(self._X, proposal, self._epsilon) - log_evidence(self._X, self._G, self._epsilon))

#         self._MH.append(MH)

        if rng.uniform() <= MH:
            self._G = proposal
        return self._G.reshape(1, self._num_nodes**2).flatten()

## Account for impossible test functions, e.g., transposed entries in adj matrix. Need to change theta_indices -> adj_matrix indices
#     def test_functions(self, samples):
#         assert samples.shape[1] >= self._num_nodes**2
#         n, p = samples.shape

#         # First, handle the adjacency matrix test functions
#         p_adj = self._num_nodes**2
#         samples_adj = samples[:, -p_adj:]
#         f1_adj = samples_adj
#         f2_adj = onp.empty([n, int((p_adj**2-p_adj)*(p_adj**2-p_adj-1))])
#         counter = 0
#         for i in range(p_adj):
#             for j in range(i+1):
#                 row_i, col_i = i//self._num_nodes + 1, i % self._num_nodes
#                 row_j, col_j = j//self._num_nodes + 1, j % self._num_nodes
#                 if row_i != col_j and row_j != col_i:
#                     f2_adj[:, counter] = f1_adj[:, i] * f1_adj[:, j]
#                     counter += 1
  
#         # Non-adjacency matrix test functions
#         p_other = max(0, p-self._num_nodes**2)
#         f1 = samples[:, :p_other]
#         f2 = onp.empty([n, p_other*(p_other+1)/2 + p_other*self._num_nodes**2])
#         counter = 0
#         for i in range(p_other):
#             for j in range(i+1):
#                 f2[:, counter] = f1[:, i] * f1[:, j]
#                 counter += 1
#         return onp.hstack([f1, f2, f1_adj, f2_adj])

# Error 1: count all graphs (rather than DAGs) reached by reversing edges
class linear_gaussian_sampler_error_1(linear_gaussian_sampler):
    def drawPosterior(self, rng=None):
        if rng is None:
            rng = onp.random.Generator(onp.random.MT19937())
        
        num_neighbors, proposal = self.neighbors_DAG(self._G, True, rng)
        num_neighbors_proposal = self.neighbors_DAG(proposal)
        MH = num_neighbors/num_neighbors_proposal * onp.exp(log_evidence(self._X, proposal, self._epsilon) - log_evidence(self._X, self._G, self._epsilon))
#         self._MH.append(MH)
        if rng.uniform() <= MH:
            self._G = proposal
        return self._G.reshape(1, self._num_nodes**2).flatten()
    
    # Error
    def neighbors_DAG(self, adj_matrix, return_sample=False, rng=None):
        if rng is None:
            rng = onp.random.default_rng()

        nodes = onp.arange(adj_matrix.shape[0])
        edges = [(i, j) for i, j in itertools.product(nodes, nodes)]
        edges += [(-1, -1)] # no change
        
        k = 0 # count neighbors
        lst_proposals = []
        for i, j in edges:

            # No change
            if i == -1 and j == -1:
                k += 1
                lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})

            if i == j:
                k += 1
                continue

            # Add DAG reached by toggling edge i,j
            k += 1
            proposal = changeEdge_DAG(adj_matrix, i, j, change_type='toggle')
            if not isCyclic(proposal):
                lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})
                

            # Add DAG reached by reversing edge i,j
            if adj_matrix[i, j] == 1:
                k += 1
                proposal = changeEdge_DAG(adj_matrix, i, j, change_type='reverse')
                if not isCyclic(proposal):
                    lst_proposals.append({'i':i, 'j':j, 'change_type':'reverse'})
            

        if return_sample == True:
            args = rng.choice(lst_proposals)
            proposal = changeEdge_DAG(adj_matrix, **args)
            return k, proposal
        else:
            return k
        
# Error 2: double-count the number of DAGs reached by reversing edges (accidentally reversing non-edges and counting the result). Error does not affect sampling
class linear_gaussian_sampler_error_2(linear_gaussian_sampler):
    def drawPosterior(self, rng=None):
        if rng is None:
            rng = onp.random.Generator(onp.random.MT19937())
        
        num_neighbors, proposal = self.neighbors_DAG(self._G, True, rng)
        num_neighbors_proposal = self.neighbors_DAG(proposal)
        MH = num_neighbors/num_neighbors_proposal * onp.exp(log_evidence(self._X, proposal, self._epsilon) - log_evidence(self._X, self._G, self._epsilon))
#         self._MH.append(MH)
        if rng.uniform() <= MH:
            self._G = proposal
        return self._G.reshape(1, self._num_nodes**2).flatten()
    

    # Error
    def neighbors_DAG(self, adj_matrix, return_sample=False, rng=None):
        if rng is None:
            rng = onp.random.default_rng()

        nodes = onp.arange(adj_matrix.shape[0])
        edges = [(i, j) for i, j in itertools.product(nodes, nodes)]
        edges += [(-1, -1)] # no change

        lst_proposals = []
        for i, j in edges:

            # No change
            if i == -1 and j == -1:
                lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})

            # Skip diagonals
            if i == j:
                continue

            # Add DAG reached by toggling edge i,j
            proposal = changeEdge_DAG(adj_matrix, i, j, change_type='toggle')
            if not isCyclic(proposal):
                lst_proposals.append({'i':i, 'j':j, 'change_type':'toggle'})

            # Add DAG reached by reversing edge i,j
            rev_count = 0
            if adj_matrix[i, j] == 1:
                proposal = changeEdge_DAG(adj_matrix, i, j, change_type='reverse')
                if not isCyclic(proposal):
                    lst_proposals.append({'i':i, 'j':j, 'change_type':'reverse'})
                    rev_count += 1

        k = len(lst_proposals) + rev_count
        if return_sample == True:
            args = rng.choice(lst_proposals)
            proposal = changeEdge_DAG(adj_matrix, **args)
            return k, proposal
        else:
            return k

In [None]:
num_nodes = 3
num_trials = 20
alpha = 0.05

# Cache
kernel_table, index_table = graph_kernel_cache(num_nodes)

lst_experiments = [linear_gaussian_sampler(num_nodes=num_nodes), linear_gaussian_sampler_error_1(num_nodes=num_nodes)]
lst_sample_size = [3000]

df_results = pd.DataFrame(index=onp.arange(0, 4*len(lst_experiments)*len(lst_sample_size)), columns=('experiment', 'test_type', 'sample_size', 'result'))

i = 0
for model in tqdm(lst_experiments):
    for n in tqdm(lst_sample_size):
        mmd_test_size = n
        
        rej_chisq = 0
        rej_geweke = 0
        rej_mmd_wb = 0
        rej_mmd = 0
        for _ in tqdm(range(num_trials)):
            samples_mc = model.sample_mc(n)
            samples_sc = model.sample_sc(n)
            samples_bc = model.sample_bc(n, 5)
            
            vals_sc, counts_sc = count_sample_DAG(samples_sc[onp.arange(0, int(n), 1), :][:, model.theta_indices])
            total_count = count_DAG(num_nodes)[num_nodes-1, :].sum()
            counts_sc = onp.hstack([counts_sc, onp.zeros(total_count - len(counts_sc))])
            rej_chisq += chisquare(counts_sc).pvalue <= alpha
            rej_geweke += geweke_test(model.test_functions(samples_mc[:, model.theta_indices]), model.test_functions(samples_sc[:, model.theta_indices]), l=0.08, alpha=0.05)['result'].max()
#             rej_mmd_wb += mmd_wb_test(samples_mc[:int(mmd_test_size)][:, model.theta_indices], samples_sc[onp.arange(0, int(n), int(n/mmd_test_size)), :][:, model.theta_indices], alpha=0.05)['result']
#             rej_mmd += mmd_test(samples_mc[:int(mmd_test_size)][:, model.theta_indices], samples_bc[onp.arange(0, int(n), int(n/mmd_test_size)), :][:, model.theta_indices], alpha=0.05, mmd_type='linear')['result']
            rej_mmd_wb += mmd_wb_test(samples_mc[:int(mmd_test_size)][:, model.adj_matrix_indices], samples_sc[onp.arange(0, int(n), int(n/mmd_test_size)), :][:, model.adj_matrix_indices], graph_kernel, alpha=0.05, kernel_learn_method=None, cache_index_table = index_table, cache_kernel_table=kernel_table)['result']
            rej_mmd += mmd_test(samples_mc[:int(mmd_test_size)][:, model.adj_matrix_indices], samples_bc[onp.arange(0, int(n), int(n/mmd_test_size)), :][:, model.adj_matrix_indices], graph_kernel, alpha=0.05, mmd_type='unbiased', kernel_learn_method=None, cache_index_table = index_table, cache_kernel_table=kernel_table)['result']

        rej_chisq /= num_trials
        rej_geweke /= num_trials
        rej_mmd_wb /= num_trials
        rej_mmd /= num_trials
        df_results.loc[i] = [type(model).__name__, 'chisq', n, rej_chisq]
        df_results.loc[i+1] = [type(model).__name__, 'geweke', n, rej_geweke]
        df_results.loc[i+2] = [type(model).__name__, 'mmd_wb', n, rej_mmd_wb]
        df_results.loc[i+3] = [type(model).__name__, 'mmd', n, rej_mmd]
        i += 4
df_results['result'] = pd.to_numeric(df_results['result'])

In [19]:
df_results.loc[df_results.test_type == 'mmd_wb']

Unnamed: 0,experiment,test_type,sample_size,result
2,linear_gaussian_sampler,mmd_wb,3000,0.0
6,linear_gaussian_sampler_error_1,mmd_wb,3000,0.6


In [None]:
def exp_label(x):
    if 'error' in x:
        return f'Error {x[-1]}'
    else:
        return 'No Error'

plot_results=ggplot(df_results, aes(x='sample_size', y='result', color='test_type', group='test_type')) + \
    geom_point() + geom_line() +\
    facet_grid('~experiment', labeller=exp_label) +\
    labs(x='Sample Size', y='Rejection Rate', title='', color='Test Type') + scale_y_continuous(breaks = onp.arange(0, 1.05, 0.05))
ggsave(plot_results, '3_DAG.png')

In [None]:
df_results.to_pickle(f'{num_nodes}_DAG')