In [1]:
import numpy as np
import torch
import argparse
from sklearn.preprocessing import scale
from pandas import DataFrame, Series
from cdt.data import load_dataset

Detecting 1 CUDA device(s).


In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from functions.tcep_utils import cut_num_pairs,_get_wd, _get_nc
from causal.generative.mmdgen.gnn import CausalMmdNet

Detecting 1 CUDA device(s).


In [4]:
from itertools import product

In [5]:
# set hyperparams

nh, lr, n_kernels = 20, 0.01, 5
epochs = 1000

In [6]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7ff9bdfcba70>

In [7]:
data, labels = load_dataset('tuebingen', shuffle=False)

In [8]:
MAX_N_PAIR = 1000
cut_num_pairs(data, num_max=MAX_N_PAIR)

In [9]:
slice_lo, slice_up = 0, 3
n_c = 1000
variabilities = ['mean', 'max']
stats = ['quantiles', 'variances']


for i,row in data.iloc[slice_lo:slice_up].iterrows():
    print(10*'-<>-')
    print(row['A'].shape[0],row.shape, i, '<-- shape and # of pair')
    wd = _get_wd(row['A'].shape[0])
    
    pair = np.vstack((scale(row['A']), scale(row['B'])))
    
    mmdNet = CausalMmdNet(nh=nh, lr=lr, n_kernels=n_kernels, weight_decay=wd)
    
    # set & fit
    mmdNet.set_data(pair)
    mmdNet.fit_two_directions(train_epochs=epochs, idx=i)
    
    # sample for testing
    mmdNet.generate_conditional_sampling(pair=pair, n_cause=1000, sampling_type='sample')
    # compute vars & qs's
    mmdNet.estimate_conditional_var()
    mmdNet.estimate_conditional_quants()
    
    for stat, varm in product(stats, variabilities):
        print(f"score for penalty using stat:{stat} & variability meas: {varm}")
        print(f"causal, anticausal : {mmdNet.add_penalty(stat, varm)}")
        
    print(f"score for penalty using stat: norm")
    print(f"causal, anticausal : {mmdNet.add_penalty('norm', None)}")
    
    print(10*'~*')
    
    print('-- These differences are broken down as --')
    
    print(f"diff due to mmd: {mmdNet.mmd_score_causal} vs {mmdNet.mmd_score_anticausal}")
    print(f"diff due to  variabilities: ---")
    for stat, varm in product(stats, variabilities):
        print(f"score for penalty using stat:{stat} & variability meas: {varm}")
        print(f"causal, anticausal : {mmdNet.compute_variability(stat, varm)}")
    
    norm_causal,norm_anticausal = mmdNet.penalize_weight(p=2)
    print(f"diff due to norm: {norm_causal} vs {norm_anticausal}")

-<>--<>--<>--<>--<>--<>--<>--<>--<>--<>-
349 (2,) pair1 <-- shape and # of pair


100%|██████████| 1000/1000 [00:05<00:00, 168.43it/s, idx=pair1, score=(0.016290457919239998, 0.01596209965646267)]
  5%|▌         | 27/500 [00:00<00:02, 206.95it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 262.75it/s]
  2%|▏         | 17/1000 [00:00<00:05, 164.46it/s, idx=pair2, score=(0.2704072594642639, 0.3066158890724182)]

causal, anticausal : (0.3736643501136054, 0.3736643501136054)
score for penalty using stat:quantiles & variability meas: max
causal, anticausal : (0.4685495929464313, 0.4685495929464313)
score for penalty using stat:variances & variability meas: mean
causal, anticausal : (0.33217980597312113, 0.33217980597312113)
score for penalty using stat:variances & variability meas: max
causal, anticausal : (0.4324882764397, 0.4324882764397)
score for penalty using stat: norm
causal, anticausal : (0.46847618584880807, 0.46847618584880807)
-- These differences are broken down as --
diff due to mmd: 0.015163770876824856 vs 0.01849326677620411
diff due to  var: [0.17000694 0.18697478 0.14205549 0.1317832  0.16244683 0.14096985
 0.14969817 0.14507339 0.14584018 0.16119585 0.14158907 0.14770648
 0.17075472 0.13504725 0.15824557 0.15537202 0.1613862  0.16895653
 0.13825087 0.14922663 0.1602183  0.15532074 0.13677896 0.13683191
 0.14075733 0.1368337  0.13805265 0.13555197 0.14412579 0.12058141
 0.1689225

100%|██████████| 1000/1000 [00:06<00:00, 165.48it/s, idx=pair2, score=(0.015637231990695, 0.01745494455099106)]   
  5%|▍         | 23/500 [00:00<00:02, 200.78it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 264.37it/s]
  2%|▏         | 17/1000 [00:00<00:05, 164.69it/s, idx=pair3, score=(0.3520370125770569, 0.33479559421539307)]

causal, anticausal : (0.4933817121087555, 0.4933817121087555)
score for penalty using stat:quantiles & variability meas: max
causal, anticausal : (0.45121517848318016, 0.45121517848318016)
score for penalty using stat:variances & variability meas: mean
causal, anticausal : (0.4945221544840565, 0.4945221544840565)
score for penalty using stat:variances & variability meas: max
causal, anticausal : (0.46029019300056273, 0.46029019300056273)
score for penalty using stat: norm
causal, anticausal : (0.5040058162610129, 0.5040058162610129)
-- These differences are broken down as --
diff due to mmd: 0.015412960201501846 vs 0.01554118748754263
diff due to  var: [0.1319367  0.19873039 0.13074463 0.15316173 0.15031584 0.14548705
 0.13183881 0.14302304 0.13683326 0.11294561 0.19285678 0.13743827
 0.15802739 0.18987669 0.11936425 0.11043453 0.15120083 0.12567828
 0.15895045 0.16019084 0.15002007 0.14254457 0.1963059  0.1147765
 0.15583928 0.12806747 0.14127679 0.18488986 0.11365275 0.14387335
 0.16

100%|██████████| 1000/1000 [00:05<00:00, 169.06it/s, idx=pair3, score=(0.013105934485793114, 0.018739428371191025)]
  6%|▌         | 28/500 [00:00<00:01, 279.19it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 255.98it/s]

causal, anticausal : (0.4554370854659009, 0.4554370854659009)
score for penalty using stat:quantiles & variability meas: max
causal, anticausal : (0.5342508543863367, 0.5342508543863367)
score for penalty using stat:variances & variability meas: mean
causal, anticausal : (0.41234623615585086, 0.41234623615585086)
score for penalty using stat:variances & variability meas: max
causal, anticausal : (0.48263232639963555, 0.48263232639963555)
score for penalty using stat: norm
causal, anticausal : (0.45908668726692725, 0.45908668726692725)
-- These differences are broken down as --
diff due to mmd: 0.01617363840341568 vs 0.02204710803925991
diff due to  var: [0.8158237  0.62843782 0.76222115 0.56179453 0.47254979 0.72153066
 0.66191762 0.62176047 0.39529219 0.58535093 0.52634409 0.48221911
 0.69253014 0.67612919 0.57760569 0.60922053 0.61579708 0.59426238
 0.87151471 0.73820918 0.53570785 0.39730384 0.77801868 0.49653563
 0.78246082 0.6838687  0.53061101 0.68247832 0.56253246 0.57818622
 0.


