Code for generating the results for semi-synthetic linear Gaussian graphs from [*bnlearn*](https://www.bnlearn.com/) (Figs. 11,17).

In [23]:
import numpy as np
import pandas as pd
import networkx as nx
import numpy
import pandas
import networkx
from itertools import combinations, permutations
import logging
from causallearn.utils.cit import CIT
import collections
import matplotlib.pyplot as plt
import pickle as pkl
from datetime import datetime
import copy

from causal_discovery.utils import *
from causal_discovery.pc_alg import PCAlgorithm
from causal_discovery.mb_by_mb import MBbyMBAlgorithm
from causal_discovery.sd_alg import SequentialDiscoveryAlgorithm
from causal_discovery.ldecc import LDECCAlgorithm

In [24]:
def get_bnlearn_graph_params(name):
  import pickle
  return pickle.load(open("data/bnlearn-%s.pkl" % name, "rb"))

def get_cov_matrix(G_weighted, G_std):
  A = G_weighted
  B = np.linalg.inv(np.eye(A.shape[0]) - A)
  noise_cov = G_std**2
  cov_matrix = B @ noise_cov @ B.T
  return cov_matrix

In [25]:
graph_to_tmt_outcome = {
    "magic-niab": {
        "treatment": "G266",
        "outcome": "HT",
    },
    "magic-irri": {
        "treatment": "G6003",
        "outcome": "BROWN",
    },
}

In [30]:
# one of "magic-niab" or "magic-irri".
graph_name = "magic-niab"

In [31]:
G_weighted, G_std, node_names = get_bnlearn_graph_params(graph_name)
graph_true = adj_matrix_to_nx_graph(G_weighted, node_names)

In [32]:
graph_true = nx.relabel_nodes(graph_true, {
    graph_to_tmt_outcome[graph_name]["treatment"]: "X",
    graph_to_tmt_outcome[graph_name]["outcome"]: "Y"
})

In [29]:
data = get_dataset(G_weighted, np.diag(G_std)**2, 1000, list(graph_true.nodes()))
data.head()

Unnamed: 0,YR.GLASS,Y,YR.FIELD,MIL,FT,G418,G311,G1217,G800,G866,...,G2318,G1294,G1800,YLD,FUS,G1750,G524,G775,G2835,G43
0,0.420137,-9.412403,-0.64128,-0.5483,-4.104396,-1.234455,-0.115014,0.269181,1.177147,0.493421,...,-0.867749,-0.775289,-0.026453,-0.076053,0.333144,0.09113,0.756651,-0.503202,0.314705,0.617827
1,-0.00275,-4.332171,0.020122,-0.989963,0.95485,-0.858694,-1.436846,0.102548,1.271658,-0.096132,...,-0.243383,-1.825503,0.11041,0.632123,-1.325769,0.258811,-0.94472,-1.031226,0.257486,0.620786
2,-1.117361,3.640469,0.100454,-0.244524,-1.813873,-1.017628,0.392061,-1.361735,-0.468206,0.275902,...,-0.883446,-0.426219,-0.255114,0.077915,-0.00194,0.847443,-2.61956,0.22962,-0.641768,-0.194753
3,0.216346,3.526878,-0.247618,0.67826,1.121236,0.071027,1.054274,1.330939,0.096987,0.863503,...,-0.164306,0.591188,-0.123854,-0.655673,-1.221081,1.406668,-0.18054,0.215197,-0.161733,-0.432399
4,-0.565661,-5.829983,-0.33491,-0.453876,7.003238,0.478129,-0.347818,0.436353,-0.373046,1.388168,...,-1.933598,0.456514,-0.064923,-0.237212,-0.112231,-0.715062,0.641959,0.660831,-0.151781,-0.023642


In [None]:
sample_cov = data.cov().values

In [None]:
mb_by_mb_alg = MBbyMBAlgorithm(use_ci_oracle=False)
result_mb_by_mb = mb_by_mb_alg.run(data)
print("Total CI tests done: %d" % mb_by_mb_alg.ci_test_calls["total"])
print("Estimated ATE set: %s" % str(
    set(get_ATE_using_nodes_and_cov(list(data.columns),
                                    list(result_mb_by_mb["tmt_parents"]) + par,
                                    sample_cov)
    for par in get_all_combinations(result_mb_by_mb["unoriented"], 
                                    result_mb_by_mb["non_colliders"]))))

In [None]:
sd_alg = SequentialDiscoveryAlgorithm(use_ci_oracle=False)
result_sd = sd_alg.run(data)
print("Total CI tests done: %d" % sd_alg.ci_test_calls["total"])
print("Estimated ATE set: %s" % str(
    set(get_ATE_using_nodes_and_cov(list(data.columns),
                                    list(result_sd["tmt_parents"]) + par,
                                    sample_cov)
    for par in get_all_combinations(result_sd["unoriented"], 
                                    result_sd["non_colliders"]))))

In [None]:
ldecc_alg = LDECCAlgorithm(use_ci_oracle=False, ldecc_do_checks=True)
result_ldecc = ldecc_alg.run(data)
print("Total CI tests done: %d" % ldecc_alg.ci_test_calls["total"])
print("Estimated ATE set: %s" % str(
    set(get_ATE_using_nodes_and_cov(list(data.columns),
                                    list(result_ldecc["tmt_parents"]) + par,
                                    sample_cov)
    for par in get_all_combinations(result_ldecc["unoriented"], 
                                    result_ldecc["non_colliders"]))))

In [None]:
import ipyparallel as ipp

In [2]:
# Verify that ipcluster is running and import the necessary Python packages.

parallel_client = ipp.Client(debug=False)
dview = parallel_client[:]
# Execute an identity map in parallel.
ar = dview.map(lambda x: x, (i for i in range(0, 20000, 2)))
assert ar.get()[0] == 0

# Import the required Python packages.
with dview.sync_imports():
  import numpy
  import os
  import pandas
  import networkx
  from itertools import combinations, permutations
  import logging
  from causallearn.utils.cit import CIT
  import collections
  import copy
  import ipyparallel

  from causal_discovery.utils import orient_colliders, get_dataset, get_connected_component_with_node, remove_directed_only_edges, apply_meek_rules, adj_matrix_to_nx_graph, dag_to_cpdag, get_neighbor_types, get_all_combinations, get_ATE_using_nodes_and_cov, get_ATE_using_cov, generate_local_dags_from, to_dict    
  from causal_discovery.pc_alg import PCAlgorithm
  from causal_discovery.mb_by_mb import MBbyMBAlgorithm
  from causal_discovery.sd_alg import SequentialDiscoveryAlgorithm
  from causal_discovery.ldecc import LDECCAlgorithm
  
  try:
    from cPickle import dumps, loads, HIGHEST_PROTOCOL as PICKLE_PROTOCOL
  except ImportError:
    from pickle import dumps, loads, HIGHEST_PROTOCOL as PICKLE_PROTOCOL

# Make sure ipyparallel is still able to execute functions.
dview = parallel_client[:]
ar = dview.map(lambda x: x, (i for i in range(0, 20000, 2)))
assert ar.get()[0] == 0

In [None]:
def execute_iteration(params, num_samples, seed, progress_log_filepath=None):
  np = numpy
  nx = networkx
    
  np.random.seed(seed)
  
  # linearly scale down alpha.
  alpha = (32000 / num_samples) * 0.01

  MAX_TESTS = 7000
  
  result = {
      "graph_true_sample_cov_ate": None,
      "oracle": {
          "ate": None,
          "neighbors": None,
      },
      "oracle_sample_cov": {
          "ate": None,
      },
      "ldecc": {
          "ci_tests": None,
          "ate": None,
          "neighbors": None,
      },
      "ldecc-checks": {
          "ci_tests": None,
          "ate": None,
          "neighbors": None,
      },
      "mb-by-mb": {
          "ci_tests": None,
          "ate": None,
          "neighbors": None,
      },
      "sd-alg": {
          "ci_tests": None,
          "ate": None,
          "neighbors": None,
      },
  }

  G_weighted, G_std, node_names, graph_name, cov_matrix = params
  graph_true = adj_matrix_to_nx_graph(G_weighted, node_names)
  graph_true = nx.relabel_nodes(graph_true, {
      graph_to_tmt_outcome[graph_name]["treatment"]: "X",
      graph_to_tmt_outcome[graph_name]["outcome"]: "Y"
  })
  data = get_dataset(G_weighted, np.diag(G_std)**2, num_samples, 
                     list(graph_true.nodes()))
  sample_cov = data.cov().values

  result["graph_true_sample_cov_ate"] = get_ATE_using_cov(graph_true, sample_cov)

  cpdag_oracle = dag_to_cpdag(graph_true)
  result["oracle"]["ate"] = [get_ATE_using_cov(g, cov_matrix) 
                             for g in generate_local_dags_from(cpdag_oracle)]
  result["oracle"]["neighbors"] = get_neighbor_types(cpdag_oracle, "X")                           
  
  result["oracle_sample_cov"]["ate"] = [get_ATE_using_cov(g, sample_cov)
                                        for g in generate_local_dags_from(cpdag_oracle)]
  
  ldecc_alg = LDECCAlgorithm(use_ci_oracle=False, alpha=alpha,
                             ldecc_do_checks=False, max_tests=MAX_TESTS)
  result_ldecc = ldecc_alg.run(data)
  result["ldecc"]["ci_tests"] = to_dict(ldecc_alg.ci_test_calls)
  result["ldecc"]["neighbors"] = {
      "parents": result_ldecc["tmt_parents"],
      "children": result_ldecc["tmt_children"],
      "unoriented": result_ldecc["unoriented"], 
  }
  tmt_par = list(result_ldecc["tmt_parents"])
  ate_ldecc = [get_ATE_using_nodes_and_cov(list(data.columns), tmt_par + par, sample_cov)
               for par in get_all_combinations(result_ldecc["unoriented"], result_ldecc["non_colliders"])]
  result["ldecc"]["ate"] = ate_ldecc

  ldecc_alg = LDECCAlgorithm(use_ci_oracle=False, alpha=alpha,
                             ldecc_do_checks=True, max_tests=MAX_TESTS)
  result_ldecc = ldecc_alg.run(data)
  result["ldecc-checks"]["ci_tests"] = to_dict(ldecc_alg.ci_test_calls)
  result["ldecc-checks"]["neighbors"] = {
      "parents": result_ldecc["tmt_parents"],
      "children": result_ldecc["tmt_children"],
      "unoriented": result_ldecc["unoriented"], 
  }
  tmt_par = list(result_ldecc["tmt_parents"])
  ate_ldecc = [get_ATE_using_nodes_and_cov(list(data.columns), tmt_par + par, sample_cov)
               for par in get_all_combinations(result_ldecc["unoriented"], result_ldecc["non_colliders"])]
  result["ldecc-checks"]["ate"] = ate_ldecc

  mb_by_mb_alg = MBbyMBAlgorithm(use_ci_oracle=False, alpha=alpha, 
                                 max_tests=MAX_TESTS)
  result_mb = mb_by_mb_alg.run(data)
  result["mb-by-mb"]["ci_tests"] = to_dict(mb_by_mb_alg.ci_test_calls)
  result["mb-by-mb"]["neighbors"] = {
      "parents": result_mb["tmt_parents"],
      "children": result_mb["tmt_children"],
      "unoriented": result_mb["unoriented"], 
  }
  tmt_par = list(result_mb["tmt_parents"])
  ate_mb = [get_ATE_using_nodes_and_cov(list(data.columns), tmt_par + par, sample_cov)
            for par in get_all_combinations(result_mb["unoriented"], result_mb["non_colliders"])]
  result["mb-by-mb"]["ate"] = ate_mb

  sd_alg = SequentialDiscoveryAlgorithm(use_ci_oracle=False, alpha=alpha,
                                        max_tests=MAX_TESTS)
  result_sd = sd_alg.run(data)
  result["sd-alg"]["ci_tests"] = to_dict(sd_alg.ci_test_calls)
  result["sd-alg"]["neighbors"] = {
      "parents": result_sd["tmt_parents"],
      "children": result_sd["tmt_children"],
      "unoriented": result_sd["unoriented"], 
  }
  tmt_par = list(result_sd["tmt_parents"])
  ate_sd = [get_ATE_using_nodes_and_cov(list(data.columns), tmt_par + par, sample_cov)
            for par in get_all_combinations(result_sd["unoriented"], result_sd["non_colliders"])]
  result["sd-alg"]["ate"] = ate_sd
  
  if progress_log_filepath is not None:
    os.system("echo '%d' >> %s" % (seed, progress_log_filepath))

  np.random.seed(None)
  return result

In [None]:
def async_results_to_results(async_results, results):
  for N, ar in zip(results["sample_sizes"], async_results):
    if N not in results:
      results[N] = []
    for r in ar.get():
      results[N].append(r)
  return results

def execute_estimation_in_parallel(results, start_graph, end_graph, iters_per_graph=1):

  num_threads = len(parallel_client.ids)
  dview = parallel_client[:]

  num_graphs = end_graph - start_graph
  print("Executing %d iterations across %d engines" % (
      num_graphs*iters_per_graph, num_threads))

  dview["get_dataset"] = get_dataset
  dview["graph_to_tmt_outcome"] = graph_to_tmt_outcome
  dview["get_noise_matrix"] = get_noise_matrix
  dview["execute_iteration"] = execute_iteration

  graph_name = "magic-niab"
  
  G_weighted, G_std, node_names = get_bnlearn_graph_params(graph_name)
  cov_matrix = get_cov_matrix(G_weighted, G_std)
  params = (G_weighted, G_std, node_names, graph_name, cov_matrix)
  graph_true = adj_matrix_to_nx_graph(G_weighted, node_names)
  graph_true = nx.relabel_nodes(graph_true, {
      graph_to_tmt_outcome[graph_name]["treatment"]: "X",
      graph_to_tmt_outcome[graph_name]["outcome"]: "Y"
  })

  RANDOM_SEED = 716697
  np.random.seed(RANDOM_SEED)
  new_param_list = []
  for _ in range(start_graph, end_graph):

    for graph_iter in range(iters_per_graph):
      new_param_list.append(params)
      results["params"].append(params)
      results["ground_truth_ate"].append(get_ATE_using_cov(graph_true, cov_matrix))
      
  sample_sizes = [8000, 16000, 24000, 32000]
  results["sample_sizes"] = sample_sizes
  async_results = []
  for N in sample_sizes:
    print("%s: Starting iterations %d for N=%d" % (
        str(datetime.now()), len(new_param_list), N))
    
    def execute_estimation_iteration(i):
      return execute_iteration(new_param_list[i], N, (RANDOM_SEED + N + i))
    
    async_result = dview.map(execute_estimation_iteration, range(len(new_param_list)))
    async_results.append(async_result)

  np.random.seed(None)
  return results, async_results

In [None]:
results = {}
results.update({
    "params": [],
    "ground_truth_ate": [],
})

results, async_results = execute_estimation_in_parallel(results, 
                                                        start_graph=0,
                                                        end_graph=100,
                                                        iters_per_graph=1)

In [None]:
# Monitor progress of the tasks.
for ar in async_results:
  ar.wait_interactive(timeout=1)

In [None]:
results = async_results_to_results(async_results, results)

The code below plots the results.

In [None]:
# For the color map:
# https://gist.github.com/AndiH/c957b4d769e628f506bd

# Tableau 20 Colors
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
             
# Tableau Color Blind 10
tableau20blind = [(0, 107, 164), (255, 128, 14), (171, 171, 171), (89, 89, 89),
             (95, 158, 209), (200, 82, 0), (137, 137, 137), (163, 200, 236),
             (255, 188, 121), (207, 207, 207)]
  
# Rescale to values between 0 and 1 
for i in range(len(tableau20)):  
    r, g, b = tableau20[i]  
    tableau20[i] = (r / 255., g / 255., b / 255.)
for i in range(len(tableau20blind)):  
    r, g, b = tableau20blind[i]  
    tableau20blind[i] = (r / 255., g / 255., b / 255.)

In [None]:
SMALL_SIZE = 8
MEDIUM_SIZE = 8
BIGGER_SIZE = 8

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE+6)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE+6)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE + 4)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE+4)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE+2)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE+20)  # fontsize of the figure title

In [None]:
plot_prop = {
    "pc-alg": ["dashed", tableau20blind[0], "o"],
    "ldecc": ["dashdot", tableau20blind[1], "^"],
    "ldecc-checks": ["solid", tableau20blind[5], "s"],
    "mb-by-mb": ["dashdot", tableau20blind[7], "D"],
    "sd-alg": ["dotted", tableau20blind[3], "v"],
}

In [None]:
def get_hausdorff_distance(arr_1, arr_2):

  arr_1 = np.array(arr_1)
  arr_2 = np.array(arr_2)

  def get_distance_point_from_array(point, arr):
    return np.min(np.abs(arr - point))

  dist_1_to_2 = np.max([get_distance_point_from_array(p, arr_2) for p in arr_1])
  dist_2_to_1 = np.max([get_distance_point_from_array(p, arr_1) for p in arr_2])

  return max(dist_1_to_2, dist_2_to_1)

def plot_mse(results):
  plt.figure(figsize=(6, 4))
  samples = results["sample_sizes"]

  dict_to_name = {
      'ldecc': "LDECC", 
      'ldecc-checks': "LDECC-checks",
      'sd-alg': "SD",
      'mb-by-mb': "MB-by-MB",
  }

  errors = collections.defaultdict(list)

  for N in samples:
    for k in dict_to_name.keys():
      errors[k].append([])

    for i, res in enumerate(results[N]):
      for k in dict_to_name.keys():
        errors[k][-1].append(get_hausdorff_distance(
            res["oracle"]["ate"], res[k]["ate"]
        )**2)

      
  samples = np.array(samples)
  mses = {}
  stds = {}
  for k in dict_to_name.keys():
    errors[k] = np.array(errors[k])
    mses[k] = np.median(errors[k], axis=-1)
    stds[k] = np.std(errors[k], axis=-1)

    plt.plot(samples, mses[k], label=dict_to_name[k],
             color=plot_prop[k][1], linestyle=plot_prop[k][0],
             marker=plot_prop[k][2])
  
  plt.title("Median SE vs Sample size")
  plt.xlabel("Sample size")
  plt.ylabel("Median SE (Hausdorff distance)")
  plt.legend()
  plt.show()
  
plot_mse(results)

In [None]:
def plot_accuracy(results):

  dict_to_name = {
      'ldecc': "LDECC", 
      'ldecc-checks': "LDECC-checks",
      'sd-alg': "SD",
      'mb-by-mb': "MB-by-MB",
  }

  def get_detected_neighbors(r):
    return r["neighbors"]["parents"], r["neighbors"]["children"], r["neighbors"]["unoriented"]
  
  def is_correct_adj_set(pa, ch, uo, pa_or, ch_or, uo_or):

    if uo == set() and uo == uo_or:
      return pa == pa_or
    
    return pa == pa_or and ch == ch_or and uo == uo_or

  N = len(results[results["sample_sizes"][0]])
  print("Total = %d" % N)

  same_count = collections.defaultdict(lambda: 0)
  correct_count = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
  
  for s in results["sample_sizes"]:
    for i in range(N):
      parents_or, children_or, uo_or = get_detected_neighbors(results[s][i]["oracle"])
      
      for k in dict_to_name.keys():
        parents, children, uo = get_detected_neighbors(results[s][i][k]) 
        if is_correct_adj_set(parents, children, uo, parents_or, children_or, uo_or):
          correct_count[k][s] += 1
        
  plt.figure(figsize=(6, 4))

  accuracy = {}
  samples = results["sample_sizes"]
  for k in dict_to_name.keys():
    accuracy[k] = np.array([correct_count[k][s] * 100 / N for s in samples])

    plt.plot(samples, accuracy[k], label=dict_to_name[k],
             color=plot_prop[k][1], linestyle=plot_prop[k][0],
             marker=plot_prop[k][2])
  
  plt.title("Accuracy vs Sample size")
  plt.xlabel("Sample size")
  plt.ylabel("Accuracy (%)")
  plt.legend()
  plt.show()

plot_accuracy(results)

In [None]:
def plot_recall(results):

  dict_to_name = {
      'ldecc': "LDECC", 
      'ldecc-checks': "LDECC-checks",
      'sd-alg': "SD",
      'mb-by-mb': "MB-by-MB",
  }

  N = len(results[results["sample_sizes"][0]])
  
  correct_count = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
  
  for s in results["sample_sizes"]:
    for i in range(N):
      for k in dict_to_name.keys():
        
        if results[s][i]["graph_true_sample_cov_ate"] in results[s][i][k]["ate"]:
          correct_count[k][s] += 1
  
  plt.figure(figsize=(6, 4))
  
  coverage = {}
  samples = results["sample_sizes"]
  for k in dict_to_name.keys():
    coverage[k] = np.array([correct_count[k][s] * 100 / N for s in samples])

    plt.plot(samples, coverage[k], label=dict_to_name[k],
             color=plot_prop[k][1], linestyle=plot_prop[k][0],
             marker=plot_prop[k][2])
  
  plt.title("Recall vs Sample size")
  plt.xlabel("Sample size")
  plt.ylabel("Recall (%)")
  plt.legend()
  plt.show()

plot_recall(results)

In [None]:
def plot_number_of_tests(results):

  plt.figure(figsize=(6, 4))

  dict_to_name = {
      'ldecc': "LDECC", 
      'ldecc-checks': "LDECC-checks",
      'sd-alg': "SD",
      'mb-by-mb': "MB-by-MB",
  }

  samples = results["sample_sizes"]
  samples = np.array(samples)
  
  tests_count = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
  for k in dict_to_name.keys():
    for s in samples:
      tests_count[k][s] = np.mean([t[k]["ci_tests"]["total"] for t in results[s]])

  for k in dict_to_name.keys():
    plt.plot(samples, [tests_count[k][s] for s in samples], label=dict_to_name[k],
             color=plot_prop[k][1], linestyle=plot_prop[k][0],
             marker=plot_prop[k][2])

  
  plt.title("Number of CI tests")
  plt.xlabel("Sample size")
  plt.ylabel("Number of CI tests")
  plt.legend()
  plt.show()

plot_number_of_tests(results)

The code below plots the distribution of conditional independence (CI) tests performed with a CI oracle by SD, LDECC, and MB-by-MB by repeatedly setting each node as the treatment (for computational reasons, we cap the maximum number of tests for each node to 20000).

In [34]:
def run_local_discovery_for_each_node(graph_true):

  MAX_TESTS = 20000

  node_to_tests = {}

  for node in graph_true.nodes():
    empty_df = pd.DataFrame(columns=list(graph_true.nodes()))

    mb_by_mb_alg = MBbyMBAlgorithm(use_ci_oracle=True, graph_true=graph_true,
                                   treatment_node=node, max_tests=MAX_TESTS)
    result_mb_by_mb = mb_by_mb_alg.run(empty_df)

    sd_alg = SequentialDiscoveryAlgorithm(use_ci_oracle=True, graph_true=graph_true,
                                          treatment_node=node, max_tests=MAX_TESTS)
    result_sd = sd_alg.run(empty_df)

    ldecc_alg = LDECCAlgorithm(use_ci_oracle=True, graph_true=graph_true, 
                              treatment_node=node, outcome_node=node,
                               max_tests=MAX_TESTS)
    result_ldecc = ldecc_alg.run(empty_df)
    
    node_to_tests[node] = {
        "mb-by-mb": mb_by_mb_alg.ci_test_calls["total"],
        "ldecc": ldecc_alg.ci_test_calls["total"],
        "sd-alg": sd_alg.ci_test_calls["total"],
    }

    print("Node done: %s, Tests: %s" % (node, node_to_tests[node]))
  
  return node_to_tests
  
node_to_tests = run_local_discovery_for_each_node(graph_true)

In [None]:
def plot_test_statistics(node_to_tests):
  plt.figure(figsize=(6, 4))

  tests_mb_by_mb = [node_to_tests[k]["mb-by-mb"] for k in node_to_tests.keys()]
  tests_sd = [node_to_tests[k]["sd-alg"] for k in node_to_tests.keys()]
  tests_ldecc = [node_to_tests[k]["ldecc"] for k in node_to_tests.keys()]

  plt.hist([tests_mb_by_mb, tests_sd, tests_ldecc],
           label=["MB-by-MB", "SD", "LDECC"], 
           color=[plot_prop["mb-by-mb"][1], 
                  plot_prop["sd-alg"][1], plot_prop["ldecc"][1]])
  plt.legend()
  plt.title("Distribution of CI tests")
  plt.xlabel("Number of CI tests")
  plt.ylabel("Number of nodes")
  plt.show()


plot_test_statistics(node_to_tests)