### Dataset to use - change this value to analyze a different data set

In [2]:
dataset = 'chain_test5'
nnodes = 10

In [3]:
from analysis import check_folder
check_folder.check_folder(dataset)

[]

In [4]:
import pandas as pd

def order_tpr_fpr(fprs, tprs):
    df = pd.DataFrame(dict(fpr=fprs, tpr=tprs))
    df.sort_values('fpr', ascending=False)
    df.drop_duplicates(subset='fpr', keep='last', inplace=True)
    return df
    
    
a = order_tpr_fpr([.2, .1, .2], [.3, .2, .4])

### Load in parent probabilities and rates

In [5]:
from analysis import check_gies, check_samples
import numpy as np
import itertools as itr
from collections import defaultdict
import xarray as xr
from tqdm import tqdm
import os
import causaldag as cd
from utils import graph_utils
from scipy.special import logsumexp
IV_STRENGTH = 5

In [6]:
def get_arc_probs(nnodes, dags):
    poss_arcs = set(itr.permutations(range(nnodes), 2))
    counts = {arc: 0 for arc in poss_arcs}
    for dag in dags:
        for arc in dag.arcs:
            counts[arc] += 1
    return {arc: count/len(dags) for arc, count in counts.items()}

In [7]:
def l1_score_full(arc_probs, true_dag):
    possible_arcs = set(itr.permutations(true_dag.nodes, 2))
    nonarcs = possible_arcs - true_dag.arcs
    return sum(1-arc_probs[arc] for arc in true_dag.arcs) + sum(arc_probs[nonarc] for nonarc in nonarcs)


def l1_score_fp_full(arc_probs, true_dag):
    possible_arcs = set(itr.permutations(true_dag.nodes, 2))
    nonarcs = possible_arcs - true_dag.arcs
    return sum(arc_probs[nonarc] for nonarc in nonarcs)


def l1_score_fn_full(arc_probs, gdag):
    return sum(1-arc_probs[arc] for arc in true_dag.arcs)

In [8]:
dag_folders = check_gies.get_dag_folders(dataset)
true_dags = check_gies.get_true_dags(dag_folders)
covs = [d.covariance for d in true_dags]
true_dags_barren = [cd.DAG(set(dag.nodes), dag.arcs) for dag in true_dags]
true_mecs_barren = [[cd.DAG(true_dag.nodes, arcs) for arcs in true_dag.cpdag().all_dags()] for true_dag in true_dags_barren]
true_mecs = [[graph_utils.cov2dag(cov, d) for d in mec] for mec, cov in zip(true_mecs_barren, covs)]
ndags = len(true_dags)

strategy_names = ['entropy-dag-collection', 'random']
ks = [1]
bs = [1]
ns = [2048]

In [10]:
interventions_by_dag = [
    [
        cd.BinaryIntervention(
            intervention1=cd.ConstantIntervention(val=-IV_STRENGTH*std),
            intervention2=cd.ConstantIntervention(val=IV_STRENGTH*std),
        ) for std in np.diag(true_dag.covariance)**.5
    ]
    for true_dag in true_dags
]

In [11]:
mec_posteriors = []
for true_dag, true_mec in zip(true_dags, true_mecs):
    da = xr.DataArray(
        np.zeros([len(true_mec), len(strategy_names), len(ns), len(bs), len(ks)]),
        dims=['mec_member', 'strategy', 'n', 'b', 'k'],
        coords={
            'mec_member': list(range(len(true_mec))),
            'strategy': strategy_names,
            'n': ns,
            'b': bs,
            'k': ks
        }
    )
    mec_posteriors.append(da)


In [12]:
for dag_ix, dag_folder, true_dag, true_mec in tqdm(zip(range(ndags), dag_folders, true_dags, true_mecs), total=ndags):
    for strat, n, b, k in itr.product(strategy_names, ns, bs, ks):
        strat_str = '%s,n=%s,b=%s,k=%s' % (strat, n, b, k)
        log_posteriors = np.zeros(len(true_mec))
        
        # == calculate log posteriors based on interventional data
        for iv_node in list(range(nnodes)) + [-1]:
            intervention_fn = os.path.join(dag_folder, strat_str, 'samples', 'intervention=%d.csv' % iv_node)
            if sum(1 for line in open(intervention_fn)) != 0:
                samples = np.loadtxt(intervention_fn)
                for mec_ix, mec_member in enumerate(true_mec):
                    if iv_node == -1:
                        logpdfs = mec_member.logpdf(samples)
                    else:
                        logpdfs = mec_member.logpdf(samples, {iv_node: interventions_by_dag[dag_ix][iv_node]})
                    log_posteriors[mec_ix] += logpdfs.sum()
        
        posteriors = np.exp(log_posteriors - logsumexp(log_posteriors))
        if not np.isclose(posteriors.sum(), 1):
            raise ValueError
        mec_posteriors[dag_ix].loc[dict(strategy=strat, n=n, b=b, k=k)] = posteriors

100%|██████████| 20/20 [00:02<00:00,  7.72it/s]


In [17]:
l1_error_da = xr.DataArray(
    np.zeros([len(true_dags), len(strategy_names), len(ns), len(bs), len(ks)]),
    dims=['dag', 'strategy', 'n', 'b', 'k'],
    coords={
        'dag': list(range(len(true_dags))),
        'strategy': strategy_names,
        'n': ns,
        'b': bs,
        'k': ks
    }
)
for dag_ix, true_dag, true_mec in zip(range(len(true_dags)), true_dags, true_mecs):
    for strat, n, b, k in itr.product(strategy_names, ns, bs, ks):
        posteriors = mec_posteriors[dag_ix].loc[dict(strategy=strat, n=n, b=b, k=k)]
        ntrue = 0
        for mec_member, posterior in zip(true_mec, posteriors):
            if mec_member.arcs == true_dag.arcs:
                ntrue += 1
                penalty = 1 - posterior
            else:
                penalty = posterior
            l1_error_da.loc[dict(dag=dag_ix, strategy=strat, n=n, b=b, k=k)] += penalty
        if ntrue != 1:
            raise ValueError

In [19]:
l1_error_da.mean('dag')

<xarray.DataArray (strategy: 2, n: 1, b: 1, k: 1)>
array([[[[1.506667]]],


       [[[1.713968]]]])
Coordinates:
  * strategy  (strategy) <U22 'entropy-dag-collection' 'random'
  * n         (n) int64 2048
  * b         (b) int64 1
  * k         (k) int64 1

In [None]:
for n, k, b in itr.product(ns, ks, bs):
    print(n, k, b)
    print(l1_scores_da.sel(n=n, k=k, b=b).mean(dim='dag').values)

In [None]:
list(range(5))+[-1]

In [None]:
counts_da = check_samples.count_samples(dataset, strategy_names, ks, bs, ns)

In [None]:
rates_da = check_gies.get_rates_data_array(
    parent_probs_by_dag,
    true_dags,
    target=9,
    strategy_names=strategy_names,
    ks=ks,
    bs=bs,
    ns=ns,
    alphas=np.linspace(0, 1, 11)
)
print(rates_da.dims)

### Plot curves for each strategy

In [None]:
import matplotlib.pylab as plt
import seaborn as sns
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
%matplotlib inline
plt.rcParams["figure.figsize"] =(20,12)
sns.set(font_scale=1.5)

In [None]:

linestyles = ['solid', 'dashed', 'dotted']
colors = sns.color_palette()
sample_handles = [
    mlines.Line2D([0], [0], color='k', linestyle=linestyle, label=n) 
    for n, linestyle in zip(ns, linestyles)
]
strat_handles = [
    mpatches.Patch(facecolor=color, label=strat)
    for strat, color in zip(strategy_names, colors)
]

In [None]:
plt.clf()

rate_avgs_da = rates_da.mean(dim='dag')
b=1
k=1
for strategy, color in zip(strategy_names, colors):
    for n, linestyle in zip(ns, linestyles):
        strat_avg_tprs = rate_avgs_da.sel(strategy=strategy, b=b, k=k, n=n, rate='tpr').values
        strat_avg_fprs = rate_avgs_da.sel(strategy=strategy, b=b, k=k, n=n, rate='fpr').values
        tpr_fpr_df = order_tpr_fpr(strat_avg_fprs, strat_avg_tprs)
        plt.plot(tpr_fpr_df['fpr'], tpr_fpr_df['tpr'], linestyle=linestyle, color=color)

plt.legend(
    handles=strat_handles + sample_handles
)
plt.title(dataset + ', batches=%s, k=%s' % (b, k))
plt.xlabel('Average FPR')
plt.ylabel('Average TPR');

In [None]:
fig, ax = plt.subplots(len(bs), len(ns), sharey=True, sharex=True)
k=1
for (b_ix, b), (n_ix, n) in itr.product(list(enumerate(bs)), list(enumerate(ns))):
    for strategy, color in zip(strategy_names, colors):
        avg_rates = rate_avgs_da.sel(strategy=strategy, b=b, k=k, n=n)
        tpr_fpr_df = order_tpr_fpr(avg_rates.sel(rate='fpr').values, avg_rates.sel(rate='tpr').values)
        ax[b_ix, n_ix].plot(tpr_fpr_df['fpr'], tpr_fpr_df['tpr'], color=color)
        if b_ix == len(bs)-1:
            ax[b_ix, n_ix].set_xlabel('n = %s' % n)
        if n_ix == 0:
            ax[b_ix, n_ix].set_ylabel('b = %s' % b)


In [None]:
from scipy.stats import entropy

In [None]:
n = 256
k = 2
b = 2

c_e = counts_da.sel(strategy='entropy', k=k, b=b, n=n)/n
c_e

In [None]:
c_r = counts_da.sel(strategy='random', k=k, b=b, n=n)/n

In [None]:
ent_e = entropy(c_e.T)
ent_r = entropy(c_r.T)

In [None]:
ent_e

In [None]:
ent_e.mean()

In [None]:
ent_r

In [None]:
ent_r.mean()

In [None]:
np.log(4)

In [None]:
import random
from collections import Counter

def random_choices(p, k):
    c = Counter(random.choices(list(range(p)), k=k))
    arr = np.zeros(p)
    for i, val in c.items():
        arr[i] += val
    return arr/arr.sum()


In [None]:
for k in [20, 40, 100, 200]:
    print(entropy(random_choices(20, k)))

In [None]:
entropy([.5, .5])

In [None]:
np.log(2)

In [None]:
c_e

In [None]:
counts_da.sel(strategy='entropy', k=1, b=1, n=256, dag=1)

In [None]:
counts_da.sel(strategy='random', k=2, b=2, n=256, dag=1)