In [1]:
import causaldag as cd
from strategies.simulator import IterationData
import numpy as np
import os
from config import DATA_FOLDER
import itertools as itr
from utils import graph_utils
samples_folder = os.path.join(DATA_FOLDER, 'check-interventions', 'samples')
os.makedirs(samples_folder, exist_ok=True)
from analysis.check_gies import get_parent_probs_by_dag, get_l1_score

from tqdm import tqdm
from strategies.information_gain import create_info_gain_strategy

In [2]:
IV_STRENGTH = 2
p = 10
s = .5
target = int(np.ceil(p/2))
n = 2000
N_BOOT = 500
target

5

In [3]:
dag = cd.rand.directed_erdos(p, s)
arcs = {(i, j): graph_utils.RAND_RANGE() for i, j in dag.arcs}
gdag = cd.GaussDAG(nodes=list(range(p)), arcs=arcs)
obs_samples = gdag.sample(250)

In [4]:
intervention_nodes = list(range(p))
interventions = [
    cd.BinaryIntervention(
        intervention1=cd.ConstantIntervention(val=-IV_STRENGTH*std),
        intervention2=cd.ConstantIntervention(val=IV_STRENGTH*std),
    ) for std in np.diag(gdag.covariance)**.5
]
del interventions[target]
del intervention_nodes[target]

In [5]:
ivs2dags = {}

In [6]:
folder = os.path.join(samples_folder, 'observational')
os.makedirs(folder)
samples_file = os.path.join(folder, 'samples.csv')
interventions_folder = os.path.join(folder, 'interventions')
gies_dags_folder = os.path.join(folder, 'gies_dags/')

graph_utils._write_data({-1: obs_samples}, samples_file, interventions_folder)
graph_utils.run_gies_boot(N_BOOT, samples_file, interventions_folder, gies_dags_folder)
amats, dags = graph_utils._load_dags(gies_dags_folder)

ivs2dags[-1] = dags

In [7]:
for iv_node, intervention in tqdm(zip(intervention_nodes, interventions), total=len(interventions)):
    samples = gdag.sample_interventional({iv_node: intervention}, n)
    all_samples = {-1: obs_samples, iv_node: samples}
    folder = os.path.join(samples_folder, 'iv=%d' % iv_node)
    os.makedirs(folder)
    
    samples_file = os.path.join(folder, 'samples.csv')
    interventions_folder = os.path.join(folder, 'interventions')
    gies_dags_folder = os.path.join(folder, 'gies_dags/')
    graph_utils._write_data(all_samples, samples_file, interventions_folder)
    graph_utils.run_gies_boot(N_BOOT, samples_file, interventions_folder, gies_dags_folder)
    
    amats, dags = graph_utils._load_dags(gies_dags_folder)
    ivs2dags[iv_node] = dags

100%|██████████| 9/9 [01:18<00:00,  8.78s/it]


In [8]:
ivs2parent_probs = {}
for iv_node, dags in ivs2dags.items():
    parent_counts = {node: 0 for node in gdag.nodes}
    for dag in dags:
        for p in dag.parents[target]:
            parent_counts[p] += 1
    parent_probs = {p: c/len(dags) for p, c in parent_counts.items()}
    ivs2parent_probs[iv_node] = parent_probs 

In [9]:
ivs2scores = {}
for iv_node, parent_probs in ivs2parent_probs.items():
    ivs2scores[iv_node] = get_l1_score(parent_probs, gdag, target)

In [10]:
ivs2scores

{-1: 2.522,
 0: 1.6,
 1: 4.556,
 2: 4.73,
 3: 0.034000000000000016,
 4: 0.446,
 6: 1.526,
 7: 0.08800000000000004,
 8: 0.842,
 9: 1.556}

In [11]:
str(gdag)

'[4][0][1|0][2|1][6|0,1,2,4][8|6][5|2][3|0,2][7|0,2,3,4,5][9|2,4,5,7,8]'

In [12]:
def get_parent_functional(parent):
    def parent_functional(dag):
        return parent in dag.parents[target]
    return parent_functional

parent_functionals = [get_parent_functional(node) for node in gdag.nodes if node != target]

In [15]:
info_strat = create_info_gain_strategy(100, parent_functionals)
info_strat_folder = os.path.join(samples_folder, 'info/')
os.makedirs(info_strat_folder)
iteration_data = IterationData(
    current_data={-1: obs_samples},
    max_interventions=1,
    n_samples=n,
    batch_num=1,
    n_batches=1,
    intervention_set=intervention_nodes,
    interventions=interventions,
    batch_folder=info_strat_folder,
    precision_matrix=gdag.precision
)
info_strat(iteration_data)

COLLECTING DATA POINTS


  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/9 [00:00<?, ?it/s][A

CALCULATING LOG PDFS



 11%|█         | 1/9 [00:00<00:02,  3.03it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  2.99it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  2.97it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.01it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.06it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.03it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  2.99it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  2.97it/s][A
100%|██████████| 9/9 [00:03<00:00,  2.96it/s][A
  1%|          | 1/100 [00:03<04:59,  3.02s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.04it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.00it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  2.97it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.01it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.06it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.03it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  2.99it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  2.97it/s][A
100%|██████████| 9/9 [00:03<00:00,  2.95it/s][A
  2%|▏         | 2/100 [00:0

 56%|█████▌    | 5/9 [00:01<00:01,  3.06it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.04it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.00it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  2.97it/s][A
100%|██████████| 9/9 [00:03<00:00,  2.96it/s][A
 16%|█▌        | 16/100 [00:48<04:13,  3.02s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.04it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.00it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  2.97it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.00it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.05it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.03it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  2.99it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  2.97it/s][A
100%|██████████| 9/9 [00:03<00:00,  2.96it/s][A
 17%|█▋        | 17/100 [00:51<04:10,  3.02s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.04it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  2.99it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  

100%|██████████| 9/9 [00:03<00:00,  2.95it/s][A
 31%|███       | 31/100 [01:33<03:28,  3.03s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.04it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.00it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  2.98it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.01it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.06it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.04it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.00it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  2.97it/s][A
100%|██████████| 9/9 [00:03<00:00,  2.96it/s][A
 32%|███▏      | 32/100 [01:36<03:25,  3.02s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.04it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.01it/s][A
 33%|███▎      | 3/9 [00:01<00:02,  2.98it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.01it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.06it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.04it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  

 22%|██▏       | 2/9 [00:00<00:02,  3.13it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.20it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.18it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.11it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 47%|████▋     | 47/100 [02:20<02:33,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.18it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.14it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.20it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.18it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.14it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.11it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 48%|████▊     | 48/100 [02:23<02:30,  2.89s/it]
  0%|          | 0/9 [00:00<

 67%|██████▋   | 6/9 [00:01<00:00,  3.17it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.11it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 62%|██████▏   | 62/100 [03:03<01:49,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.18it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.14it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.20it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.18it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.11it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 63%|██████▎   | 63/100 [03:06<01:46,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.18it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.14it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  

 77%|███████▋  | 77/100 [03:46<01:06,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.14it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.11it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.09it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.14it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.19it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.17it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.10it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 78%|███████▊  | 78/100 [03:49<01:03,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.18it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.14it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.20it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.17it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  

 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.20it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.17it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.13it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.10it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 93%|█████████▎| 93/100 [04:33<00:20,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  3.18it/s][A
 22%|██▏       | 2/9 [00:00<00:02,  3.14it/s][A
 33%|███▎      | 3/9 [00:00<00:01,  3.11it/s][A
 44%|████▍     | 4/9 [00:01<00:01,  3.15it/s][A
 56%|█████▌    | 5/9 [00:01<00:01,  3.19it/s][A
 67%|██████▋   | 6/9 [00:01<00:00,  3.16it/s][A
 78%|███████▊  | 7/9 [00:02<00:00,  3.12it/s][A
 89%|████████▉ | 8/9 [00:02<00:00,  3.10it/s][A
100%|██████████| 9/9 [00:02<00:00,  3.09it/s][A
 94%|█████████▍| 94/100 [04:36<00:17,  2.89s/it]
  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:02,  

COLLECTING SAMPLES


defaultdict(int, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 2000, 7: 0, 8: 0})