In [2]:
import glob
import sys
sys.path.append('..')

import pandas as pd
import pytorch_lightning as pl
import torch as T

from src import metrics
from src import utils
from src.ds import CausalGraph
from src.scm import (SCM, CTM, DiscreteDistribution,
                     PartiallyCorrelatedDistribution, Distribution)
from src.nn import Simple, NCM, NLLModule, UniformDistribution

In [8]:
x = NCM(CausalGraph.read('../dat/cg/backdoor.cg'))

with metrics.evaluating(x):
    v = x(1)
    n = 123
    batch_size = len(next(iter(v.values())))

    u = {k: t.expand((batch_size,) + t.shape).transpose(0, 1)
         for k, t in x.pu.sample(n=n).items()} #(n, batch_size, var_siz)
    v_new = {k: t.expand((n,) + t.shape).float()
             for k, t in v.items()}

In [9]:
v

{'Z': tensor([[1]]), 'X': tensor([[0]]), 'Y': tensor([[1]])}

In [18]:
import math
math.e ** -x.nll(v, n=1000)

tensor([0.2652], grad_fn=<PowBackward1>)

In [19]:
metrics.probability_table(x)

Unnamed: 0,Z,X,Y,P(V)
0,0,0,0,0.037919
1,0,0,1,0.090704
2,0,1,0,0.047098
3,0,1,1,0.109799
4,1,0,0,0.121551
5,1,0,1,0.264594
6,1,1,0,0.092198
7,1,1,1,0.236137


In [10]:

            logpv = sum(
                T.log(
                    (1 - v_new[k])
                    + (2 * v_new[k] - 1)
                    * x.f[k](v_new, u)
                )
                for k in x.v)

In [None]:

        try:
            self.train()
            batch_size = len(next(iter(v.values())))
            u = {k: t.expand((batch_size,) + t.shape).transpose(0, 1)
                 for k, t in self.pu.sample(n=n).items()} #(n, batch_size, var_siz)
            v_new = {k: t.expand((n,) + t.shape).float()
                     for k, t in v.items()}
            logpv = sum(
                T.log(
                    (1 - v_new[k])
                    + (2 * v_new[k] - 1)
                    * self.f[k](v_new, u)
                )
                for k in self.v)
            logpv = T.log((T.exp(logpv - logpv.max(dim=0, keepdim=True).values).mean(dim=0))
                     + logpv.max(dim=0).values)
            logpv = T.log((T.exp(logpv - logpv.max(dim=0, keepdim=True).values).mean(dim=0))
                     + logpv.max(dim=0).values)
            return logpv.item()
        finally:

In [5]:
x.nll(v)

nan

In [20]:
import glob
import json

def process(d):
    results = {}
    _, results['name'], results['key'] = d.rsplit('/', 2)
    (results['graph'],
     results['n_samples'],
     results['n_trial']) = results['key'].split('-')
    results['n_samples'] = int(results['n_samples'],)
    results['n_trial'] = int(results['n_trial'])

#     m = T.load(f'{d}/best.th')
    try:
        dat = T.load(f'{d}/dat.th')
        with open(f'{d}/results.json') as file:
            results.update(json.load(file))
        results['data_tv'] = (dat['Y'][dat['X'] == 1].float().mean()
                              - dat['Y'][dat['X'] == 0].float().mean()).item()
        
        def diff(g):
            return (-next(g) + next(g)).item()
        
        if results['graph'] == 'backdoor':
            results['plugin_ate'] = diff(
                sum(dat['Y'][(dat['Z'] == z)
                             & (dat['X'] == x)].float().mean()
                    * (dat['Z'] == z).float().mean()
                    for z in (0, 1))
                for x in (0, 1))
        elif results['graph'] == 'frontdoor':
            results['plugin_ate'] = diff(
                sum((dat['M'][dat['X'] == x] == m).float().mean()
                    * sum(dat['Y'][(dat['X'] == xp)
                                   & (dat['M'] == m)].float().mean()
                          * (dat['X'] == xp).float().mean()
                          for xp in (0, 1))
                    for m in (0, 1))
                for x in (0, 1))
        elif results['graph'] == 'napkin':
            results['plugin_ate'] = diff(
                sum(sum((((dat['X'] == x)
                          & (dat['Y'] == 1))
                         [(dat['W'] == w)
                          & (dat['Z'] == z)]).float().mean()
                        * (dat['W'] == w).float().mean()
                        for w in (0, 1))
                    / sum(((dat['X'] == x)
                           [(dat['W'] == w)
                            & (dat['Z'] == z)]).float().mean()
                          * (dat['W'] == w).float().mean()
                          for w in (0, 1))
                    for z in (0, 1)) / 2
                 for x in (0, 1))
        elif results['graph'] in ('m', 'simple'):
            results['plugin_ate'] = results['dat_tv']
        else:
            results['plugin_ate'] = float('nan')
            
        results['err_dat_tv_ncm_ate'] = results['dat_tv'] - results['ncm_ate']
        results['err_true_ate_ncm_tv'] = results['true_ate'] - results['ncm_tv']
        results['err_plugin_ate_ncm_ate'] = results['plugin_ate'] - results['ncm_ate']
        results['err_plugin_ate'] = results['true_ate'] - results['plugin_ate']
    except FileNotFoundError:
        print('error processing', d)
    return results

df = pd.DataFrame.from_records(list(map(process, glob.glob('../out/nll/*'))))
err_cols = df.columns[df.columns.str.contains('err')]
mae_cols = err_cols.str.replace('err', 'mae')
t = (
    df
    .groupby(['graph', 'n_samples'])[err_cols]
    .apply(lambda s: s.abs().mean())
    .rename(dict(zip(err_cols, mae_cols)), axis=1)
    .assign(n_trials=df.groupby(['graph', 'n_samples']).apply(lambda s: len(s.err_ncm_ate.dropna()))))
t[['n_trials', 'mae_ncm_ate', 'mae_ncm_tv', 'mae_dat_tv', 'mae_true_ate_ncm_tv', 'mae_plugin_ate']][~t.mae_ncm_ate.isna()]

error processing ../out/nll/frontdoor-10000-0


Unnamed: 0_level_0,Unnamed: 1_level_0,n_trials,mae_ncm_ate,mae_ncm_tv,mae_dat_tv,mae_true_ate_ncm_tv,mae_plugin_ate
graph,n_samples,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
backdoor,10000,10,0.015311,0.018663,0.009362,0.273277,0.008907
bow,10000,10,0.281393,0.024465,0.013909,0.282395,


In [37]:
(df.query('graph == "backdoor" and n_samples == 10000').err_ncm_ate ** 2).mean() ** .5

0.011396379179280686

In [33]:
rm -rf out/nll/napkin-10000-0

In [70]:
{
    'rmse_ncm': '%.5f'%((df.true_ate - df.ncm_ate)**2).sum()**.5,
    'rmse_plugin': '%.5f'%((df.true_ate - df['plug-in ate'])**2).sum()**.5,
    'rmse_ncm_obs': '%.5f'%((df.true_ate - df['ncm_obs_effect'])**2).sum()**.5,
    'rmse_obs': '%.5f'%((df.true_ate - df['obs effect'])**2).sum()**.5,
}

{'rmse_ncm': '0.01663',
 'rmse_plugin': '0.01882',
 'rmse_ncm_obs': '0.28532',
 'rmse_obs': '0.28116'}

In [53]:
g = 'frontdoor'
t3 = []

In [4]:
g = 'backdoor'
i = 0
d = T.load(f'../out/nll/{g}-10000-{i}/best.th')
dat = T.load(f'../out/nll/{g}-10000-{i}/dat.th')
m = NLLModule(CTM(utils.RPA[g]), NCM(CausalGraph.read(f'../dat/cg/{g}.cg')))
m.load_state_dict(d)

<All keys matched successfully>

In [5]:
metrics.probability_table(m)

Unnamed: 0,Z,X,Y,P(V)
0,0,0,0,0.999997
1,0,0,1,3e-06


In [6]:
metrics.probability_table(dat=dat)

Unnamed: 0,Z,X,Y,P(V)
0,0,0,0,0.499
1,0,0,1,0.38
2,0,1,0,0.0051
3,0,1,1,0.0184
4,1,0,0,0.0183
5,1,0,1,0.0116
6,1,1,0,0.0514
7,1,1,1,0.0162


In [19]:
math.e ** -m.ncm.nll(dat1)

tensor([4.9632e-06], grad_fn=<PowBackward1>)

In [22]:
math.e ** -m.ncm.nll(dat1, n=1000)

tensor([4.4107e-06], grad_fn=<PowBackward1>)

In [None]:
m.ncm.nll(dat, n=1000)

In [15]:
dat1 = {k: dat[k][:1] for k in dat}

In [16]:
m.ncm.nll(dat1)

tensor([12.1280], grad_fn=<NegBackward>)

In [11]:
m.ncm(1)

{'Z': tensor([[8.0537e-07]], grad_fn=<SigmoidBackward>),
 'X': tensor([[1.2135e-08]], grad_fn=<SigmoidBackward>),
 'Y': tensor([[4.8810e-06]], grad_fn=<SigmoidBackward>)}

In [56]:
metrics.probability_table(m.ctm)

Unnamed: 0,X,M,Y,P(V)
0,0,0,0,0.408095
1,0,0,1,0.255021
2,0,1,0,0.081951
3,0,1,1,0.041043
4,1,0,0,0.016633
5,1,0,1,0.016292
6,1,1,0,0.0571
7,1,1,1,0.123865


In [57]:
metrics.probability_table(m.ncm)

Unnamed: 0,X,M,Y,P(V)
0,0,0,0,0.393242
1,0,0,1,0.257669
2,0,1,0,0.065628
3,0,1,1,0.072716
4,1,0,0,0.019863
5,1,0,1,0.012859
6,1,1,0,0.084287
7,1,1,1,0.093736


In [59]:
metrics.supremum_norm(m.ncm)

TypeError: supremum_norm() missing 1 required positional argument: 'ncm'

In [51]:
ls out/nll

ls: out/nll: No such file or directory


In [2]:
g='backdoor'
i=0
m = NLLModule(CTM(utils.RPA[g]), NCM(CausalGraph.read(f'../dat/cg/{g}.cg')))
d = T.load(f'../err/nll/backdoor-1000-0/logs/default/version_0/checkpoints/epoch=43-step=43.ckpt')
m.load_state_dict(d['state_dict'])

<All keys matched successfully>

In [3]:
tmp = metrics.all_metrics(m.ctm, m.ncm)

In [75]:
import json

In [76]:
json.dumps(tmp)

'{"true_ate": -0.3767334818840027, "ncm_ate": -0.3195730149745941, "true_obs_effect": -0.11592955277964911, "ncm_obs_effect": -0.0732731819152832, "kl": 0.005362575132306613, "supremum_norm": 0.012062007682800302}'

In [60]:
i = 2
d = T.load(f'out/nll/{g}-10000-{i}/best.th')
dat = T.load(f'out/nll/{g}-10000-{i}/dat.th')
m = NLLModule(CTM(utils.RPA[g]), NCM(CausalGraph.read(f'../dat/cg/{g}.cg')))
m.load_state_dict(d)

tmp = {}
tmp['obs effect'] = \
      ((dat['Y'][dat['X'] == 1].float().mean()
       - dat['Y'][dat['X'] == 0].float().mean()).item())
for k, v in metrics.all_metrics(m.ctm, m.ncm).items():
    tmp[k] = v
t3.append(tmp)

In [73]:
df = pd.DataFrame.from_records(t3)[['true_obs_effect', 'obs effect', 'ncm_obs_effect',
                               'true_ate', 'ncm_ate', 'kl', 'supremum_norm']]
df.index = df.index.rename('bow graph run')
df

Unnamed: 0_level_0,true_obs_effect,obs effect,ncm_obs_effect,true_ate,ncm_ate,kl,supremum_norm
bow graph run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.232097,0.255174,0.257038,0.010851,0.256474,0.000197,0.002974
1,0.062687,0.05419,0.056226,0.177448,0.055701,0.000151,0.006151
2,0.003147,0.010869,0.012482,0.164128,0.011901,0.00062,0.01321


In [74]:
{
    'rmse_ncm': '%.5f'%((df.true_ate - df.ncm_ate)**2).sum()**.5,
    'rmse_ncm_obs': '%.5f'%((df.true_ate - df['ncm_obs_effect'])**2).sum()**.5,
    'rmse_obs': '%.5f'%((df.true_ate - df['obs effect'])**2).sum()**.5,
}

{'rmse_ncm': '0.31357', 'rmse_ncm_obs': '0.31353', 'rmse_obs': '0.31365'}

In [16]:
truth, dat, train_loader, val_loader = utils.datagen(cg_file, n=n)
ctm = truth
display(pd.merge(
    pd.merge(
        metrics.probability_table(ctm, n=n),
        (metrics.probability_table(ctm, n=n, do={'X': T.ones(1, 1).long()})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1))),
    (metrics.probability_table(ctm, n=n, do={'X': T.zeros(1, 1).long()})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1))))

display(pd.merge(
    pd.merge(
        metrics.probability_table(ctm, n=n),
        (metrics.probability_table(ctm, n=n, do={'X': T.ones(1, 1).long()})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1))),
    (metrics.probability_table(ctm, n=n, do={'X': T.zeros(1, 1).long()})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1))).groupby(['X', 'Y']).sum())

print(metrics.all_metrics(ctm, ctm))

Unnamed: 0,Z,X,Y,P(V),P(V | do(X=1)),P(V | do(X=0))
0,0,0,0,0.093762,0.0,0.632677
1,0,0,1,0.032607,0.0,0.220019
2,0,1,0,0.591106,0.693949,0.0
3,0,1,1,0.13522,0.158746,0.0
4,1,0,0,0.085777,0.0,0.122569
5,1,0,1,0.01731,0.0,0.024735
6,1,1,0,0.035214,0.117312,0.0
7,1,1,1,0.009003,0.029993,0.0


Unnamed: 0_level_0,Unnamed: 1_level_0,Z,P(V),P(V | do(X=1)),P(V | do(X=0))
X,Y,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0,1,0.179539,0.0,0.755246
0,1,1,0.049917,0.0,0.244754
1,0,1,0.626321,0.811261,0.0
1,1,1,0.144223,0.188739,0.0


{'true_ate': -0.05601458251476288, 'ncm_ate': -0.05601458251476288, 'true_obs_effect': -0.030373276980473235, 'ncm_obs_effect': -0.030373276980473235, 'kl': 0.0, 'supremum_norm': 0.0}


In [6]:
print('dat obs effect',
      (dat['Y'][dat['X'] == 1].float().mean()
       - dat['Y'][dat['X'] == 0].float().mean()).item())
print('dat ate',
      ((((dat['Y'][(dat['X'] == 1) & (dat['Z'] == 0)].float().mean())
         * (1 - dat['Z'].float().mean()))
        + ((dat['Y'][(dat['X'] == 1) & (dat['Z'] == 1)].float().mean())
           * (dat['Z'].float().mean())))
       - (((dat['Y'][(dat['X'] == 0) & (dat['Z'] == 0)].float().mean())
         * (1 - dat['Z'].float().mean()))
        + ((dat['Y'][(dat['X'] == 0) & (dat['Z'] == 1)].float().mean())
           * (dat['Z'].float().mean())))).item())

dat obs effect -0.08716708421707153
dat ate -0.11225023865699768


In [11]:
ncm = NCM(CausalGraph.read(f'../dat/cg/{g}.cg'), default_module=Simple)
display(pd.merge(
    pd.merge(
        metrics.probability_table(ncm, n=n),
        (metrics.probability_table(ncm, n=n, do={'X': T.ones(n, 1)})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1)), how='outer'),
    (metrics.probability_table(ncm, n=n, do={'X': T.zeros(n, 1)})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1)), how='outer'))

display(pd.merge(
    pd.merge(
        metrics.probability_table(ncm, n=n),
        (metrics.probability_table(ncm, n=n, do={'X': T.ones(n, 1)})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1)), how='outer'),
    (metrics.probability_table(ncm, n=n, do={'X': T.zeros(n, 1)})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1)), how='outer').groupby(['X', 'Y']).sum())

Unnamed: 0,Z,X,Y,P(V),P(V | do(X=1)),P(V | do(X=0))
0,0,0,0,0.1649,,0.2624
1,0,0,1,0.2091,,0.3324
2,0,1,0,0.0857,0.2324,
3,0,1,1,0.1365,0.3603,
4,1,0,0,0.1334,,0.2085
5,1,0,1,0.128,,0.1967
6,1,1,0,0.0601,0.174,
7,1,1,1,0.0823,0.2333,


Unnamed: 0_level_0,Unnamed: 1_level_0,Z,P(V),P(V | do(X=1)),P(V | do(X=0))
X,Y,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0,1,0.2945,0.0,0.4601
0,1,1,0.345,0.0,0.5399
1,0,1,0.1444,0.3893,0.0
1,1,1,0.2161,0.6107,0.0


In [12]:
d = 'out/tmp'

In [13]:
rm -rf out/tmp

In [14]:
m = NLLModule(truth, ncm)
trainer = pl.Trainer(
    callbacks=[pl.callbacks.EarlyStopping(monitor='val_loss', patience=10)],
    max_epochs=100,
    accumulate_grad_batches=100,
    default_root_dir=f'{d}/checkpoints/',
    logger=pl.loggers.TensorBoardLogger(f'{d}/logs/')
)
trainer.fit(m, train_loader, val_loader)
results = metrics.all_metrics(truth, m, n=1000000)

GPU available: False, used: False
INFO:lightning:GPU available: False, used: False
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores

  | Name | Type | Params
------------------------------
0 | ctm  | CTM  | 22    
1 | ncm  | NCM  | 100 K 
------------------------------
100 K     Trainable params
0         Non-trainable params
100 K     Total params
INFO:lightning:
  | Name | Type | Params
------------------------------
0 | ctm  | CTM  | 22    
1 | ncm  | NCM  | 100 K 
------------------------------
100 K     Trainable params
0         Non-trainable params
100 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [15]:
display(pd.merge(
    pd.merge(
        metrics.probability_table(ncm, n=n),
        (metrics.probability_table(ncm, n=n, do={'X': T.ones(n, 1)})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1)), how='outer'),
    (metrics.probability_table(ncm, n=n, do={'X': T.zeros(n, 1)})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1)), how='outer'))

display(pd.merge(
    pd.merge(
        metrics.probability_table(ncm, n=n),
        (metrics.probability_table(ncm, n=n, do={'X': T.ones(n, 1)})
         .rename({'P(V)': 'P(V | do(X=1))'}, axis=1)), how='outer'),
    (metrics.probability_table(ncm, n=n, do={'X': T.zeros(n, 1)})
     .rename({'P(V)': 'P(V | do(X=0))'}, axis=1)), how='outer').groupby(['X', 'Y']).sum())

Unnamed: 0,Z,X,Y,P(V),P(V | do(X=1)),P(V | do(X=0))
0,0,0,0,0.1969,,0.4539
1,0,0,1,0.2015,,0.4689
2,0,1,0,0.3104,0.5557,
3,0,1,1,0.2099,0.3669,
4,1,0,0,0.0511,,0.0584
5,1,0,1,0.019,,0.0188
6,1,1,0,0.0086,0.0606,
7,1,1,1,0.0026,0.0168,


Unnamed: 0_level_0,Unnamed: 1_level_0,Z,P(V),P(V | do(X=1)),P(V | do(X=0))
X,Y,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0,1,0.2515,0.0,0.5074
0,1,1,0.2229,0.0,0.4926
1,0,1,0.3128,0.6088,0.0
1,1,1,0.2128,0.3912,0.0


In [17]:
print(results)

{'true_ate': -0.1163431704044342, 'ncm_ate': -0.10822698473930359, 'true_obs_effect': -0.0889156094823717, 'ncm_obs_effect': -0.08144369721412659, 'kl': 0.0008372459635139565, 'supremum_norm': 0.00967546554851531}


In [10]:
sd = T.load('out/mmd/backdoor-1000-0/best.th')
g = 'backdoor'
m = MMDModule(CTM(utils.RPA[g]), NCM(CausalGraph.read(f'../dat/cg/{g}.cg')))
m.load_state_dict(sd)

<All keys matched successfully>

In [11]:
t = ncm.training
m.eval()
display(metrics.probability_table(m))
m.train(t);

Unnamed: 0,Z,X,Y,P(V)
0,0.0,0.0,0.0,0.16278
1,0.0,0.0,1.0,0.621632
2,0.0,1.0,0.0,0.059489
3,0.0,1.0,1.0,0.156099


In [15]:
t = ncm.training
m.eval()
display(metrics.probability_table(m.ctm).groupby(['X', 'Y'])['P(V)'].sum().to_frame())
m.train(t);

Unnamed: 0_level_0,Unnamed: 1_level_0,P(V)
X,Y,Unnamed: 2_level_1
0,0,0.203939
0,1,0.342791
1,0,0.28035
1,1,0.172919


In [18]:
m.train()
x = m(100000)

In [20]:
x['Z'].min()

tensor(0.1312, grad_fn=<MinBackward1>)

In [21]:
x['Z'].max()

tensor(0.2647, grad_fn=<MaxBackward1>)