In [1]:
import numpy as np
import torch
import argparse
from sklearn.preprocessing import scale
from pandas import DataFrame, Series
from cdt.data import load_dataset

Detecting 1 CUDA device(s).


In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from functions.tcep_utils import cut_num_pairs,_get_wd, _get_nc
from causal.generative.mmdgen.gnn import CausalMmdNet

Detecting 1 CUDA device(s).


In [4]:
from itertools import product

In [5]:
# set hyperparams

nh, lr, n_kernels = 20, 0.01, 5
epochs = 1000

In [6]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fe0c653c910>

In [7]:
data, labels = load_dataset('tuebingen', shuffle=False)

In [8]:
MAX_N_PAIR = 1000
cut_num_pairs(data, num_max=MAX_N_PAIR)

In [9]:
slice_lo, slice_up = 0, 3
n_c = 1000
variabilities = ['mean', 'max']
stats = ['quantiles', 'variances']


for i,row in data.iloc[slice_lo:slice_up].iterrows():
    print(10*'-<>-')
    print(row['A'].shape[0],row.shape, i, '<-- shape and # of pair')
    wd = _get_wd(row['A'].shape[0])
    
    pair = np.vstack((scale(row['A']), scale(row['B'])))
    
    mmdNet = CausalMmdNet(nh=nh, lr=lr, n_kernels=n_kernels, weight_decay=wd)
    
    # set & fit
    mmdNet.set_data(pair)
    mmdNet.fit_two_directions(train_epochs=epochs, idx=i)
    
    # sample for testing
    mmdNet.generate_conditional_sampling(pair=pair, n_cause=1000, sampling_type='sample')
    # compute vars & qs's
    mmdNet.estimate_conditional_var()
    mmdNet.estimate_conditional_quants()
    
    for stat, varm in product(stats, variabilities):
        print(f"score for penalty using stat:{stat} & variability meas: {varm}")
        print(f"causal, anticausal : {mmdNet.add_penalty(stat, varm)}")
        
    print(f"score for penalty using stat: norm")
    print(f"causal, anticausal : {mmdNet.add_penalty('norm', None)}")
    
    print(10*'~*')
    
    #print('-- These differences are broken down as --')
    #
    #print(f"diff due to mmd: {mmdNet.mmd_score_causal} vs {mmdNet.mmd_score_anticausal}")
    #print(f"diff due to  variabilities: ---")
    #for stat, varm in product(stats, variabilities):
    #    print(f"score for penalty using stat:{stat} & variability meas: {varm}")
    #    print(f"causal, anticausal : {mmdNet.compute_variability(stat, varm)}")
    #
    #norm_causal,norm_anticausal = mmdNet.penalize_weight(p=2)
    #print(f"diff due to norm: {norm_causal} vs {norm_anticausal}")

-<>--<>--<>--<>--<>--<>--<>--<>--<>--<>-
349 (2,) pair1 <-- shape and # of pair


100%|██████████| 1000/1000 [00:05<00:00, 167.92it/s, idx=pair1, score=(0.016290457919239998, 0.01596209965646267)]
  5%|▌         | 27/500 [00:00<00:02, 217.25it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 263.39it/s]
  2%|▏         | 17/1000 [00:00<00:05, 165.49it/s, idx=pair2, score=(0.2704072594642639, 0.3066158890724182)]

mmd: c=0.015163770876824856, ac=0.01849326677620411
penalty: c=0.01684212258691206, ac=0.04207919043406938
scores before norm: c=0.023584832170280887, ac=0.0395328619932388
normaliz. factor = 0.06311769416351969
scores after norm: c=0.3736643501136054, ac=0.6263356498863947
causal, anticausal : (0.3736643501136054, 0.6263356498863947)
score for penalty using stat:quantiles & variability meas: max
mmd: c=0.015163770876824856, ac=0.01849326677620411
penalty: c=0.10718702254030625, ac=0.1189887831910052
scores before norm: c=0.06875728214697799, ac=0.07798765837170671
normaliz. factor = 0.1467449405186847
scores after norm: c=0.4685495929464313, ac=0.5314504070535687
causal, anticausal : (0.4685495929464313, 0.5314504070535687)
score for penalty using stat:variances & variability meas: mean
mmd: c=0.015163770876824856, ac=0.01849326677620411
penalty: c=0.05229739748045858, ac=0.1291241148331416
scores before norm: c=0.04131246961705415, ac=0.08305532419277491
normaliz. factor = 0.12436779

100%|██████████| 1000/1000 [00:06<00:00, 166.67it/s, idx=pair2, score=(0.015637231990695, 0.01745494455099106)]   
  5%|▍         | 23/500 [00:00<00:02, 201.04it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 256.27it/s]
  2%|▏         | 16/1000 [00:00<00:06, 158.88it/s, idx=pair3, score=(0.3520370125770569, 0.33479559421539307)]

mmd: c=0.015412960201501846, ac=0.01554118748754263
penalty: c=0.036040307100906076, ac=0.03757775746271788
scores before norm: c=0.03343311375195489, ac=0.03433006621890157
normaliz. factor = 0.06776317997085646
scores after norm: c=0.4933817121087555, ac=0.5066182878912445
causal, anticausal : (0.4933817121087555, 0.5066182878912445)
score for penalty using stat:quantiles & variability meas: max
mmd: c=0.015412960201501846, ac=0.01554118748754263
penalty: c=0.11508149659405381, ac=0.1463756973882655
scores before norm: c=0.07295370849852875, ac=0.08872903618167538
normaliz. factor = 0.16168274468020413
scores after norm: c=0.45121517848318016, ac=0.5487848215168198
causal, anticausal : (0.45121517848318016, 0.5487848215168198)
score for penalty using stat:variances & variability meas: mean
mmd: c=0.015412960201501846, ac=0.01554118748754263
penalty: c=0.13853176312402965, ac=0.14202727484460909
scores before norm: c=0.08467884176351667, ac=0.08655482490984717
normaliz. factor = 0.171

100%|██████████| 1000/1000 [00:05<00:00, 166.89it/s, idx=pair3, score=(0.013105934485793114, 0.018739428371191025)]
  6%|▌         | 28/500 [00:00<00:01, 273.89it/s]

score for penalty using stat:quantiles & variability meas: mean


100%|██████████| 500/500 [00:01<00:00, 257.98it/s]

mmd: c=0.01617363840341568, ac=0.02204710803925991
penalty: c=0.03423545445844574, ac=0.035518286619560746
scores before norm: c=0.03329136563263855, ac=0.039806251349040284
normaliz. factor = 0.07309761698167883
scores after norm: c=0.4554370854659009, ac=0.5445629145340992
causal, anticausal : (0.4554370854659009, 0.5445629145340992)
score for penalty using stat:quantiles & variability meas: max
mmd: c=0.01617363840341568, ac=0.02204710803925991
penalty: c=0.14311887717406346, ac=0.10887364654584233
scores before norm: c=0.08773307699044741, ac=0.07648393131218108
normaliz. factor = 0.16421700830262848
scores after norm: c=0.5342508543863367, ac=0.4657491456136634
causal, anticausal : (0.5342508543863367, 0.4657491456136634)
score for penalty using stat:variances & variability meas: mean
mmd: c=0.01617363840341568, ac=0.02204710803925991
penalty: c=0.16226256527368257, ac=0.23325330440789158
scores before norm: c=0.09730492104025697, ac=0.1386737602432057
normaliz. factor = 0.2359786


