In [11]:
import copy 
import os 
import shutil
import itertools
from collections import Counter
import json
import pickle
import pprint
import pandas as pd
import numpy as np

In [12]:
import matplotlib.pyplot as plt 
plt.rcParams['figure.figsize'] = [6, 8]

In [13]:
print(os.listdir(os.getcwd()))

['.DS_Store', '.ipynb_checkpoints', '0226_pre-standard', '0409_adult-alldata', '0409_german-alldata', '0410_adult-nodata', '0410_german-nodata', '0413_reddata1000', '0501_reddata10000', '0501_reddata2000', '0501_reddata20000', '0501_reddata5000', '0505_reddata1-20t', 'analysis.ipynb', 'backwards_compatibility', 'plotting.ipynb', 'processed_results']


In [15]:
res_dir = '0505_reddata1-20t'
extra = ''
expdir = os.path.join(os.path.join(os.getcwd(), res_dir), 'causal_discovery')
savedir = os.path.join(os.path.join(os.getcwd(), res_dir), '{}_{}'.format(res_dir, extra))

if os.path.exists(savedir):
    shutil.rmtree(savedir)
os.mkdir(savedir)



# Utility Functions 

In [16]:
def dataset_name_from_unid(uid):
    if 'adult' in uid:
        return 'adult'
    if 'german' in uid:
        return 'germanCredit'
    
    assert True == False 

In [17]:
def get_hps_from_rawres(fname):
    '''rawres fname -> features'''
    unique_id = (fname.split('rawres_')[1]).split('.json')[0]
    alpha = unique_id.split('_')[0]
    feateng = unique_id.split('_')[1]
    dataset = unique_id.split('_')[2]
    redsize = unique_id.split('_')[3]
    seed = unique_id.split('_')[4]
    environment = unique_id.split('_')[5]
    
    return feateng, dataset, seed, environment, redsize

In [18]:
def str_2_pcp(pcpstr):
    pcpstr = (pcpstr.split('(')[1]).split(')')[0]
    pcpstr = pcpstr.replace(' ', '')
    ret = set(pcpstr.split(','))
    ret.discard('')
    return ret

In [19]:
import enum 
#Part 1
START_ALPHA = 1.0
FACTOR = 2
EPS = 1e-20
#Part 2
STEP = 1e-2
FACTOR2 = 2
EPS2 = 1e-10

class POS(enum.Enum):
   big = 1
   small = 2
   perf = 3

def alpha_tune(pVals, amin, flag=0):
    #First find a CP returning alpha 
    a0 = START_ALPHA
    bounds0 = [0, 100.0]
    cp_ret = False 
    while not cp_ret:
        pos = 0
        accepted = pVals[pVals['Final_tstat'] > a0]
        
        #Determine position of alpha 
        if len(accepted.index) == 0:
            pos = POS.big
        else: 
            accepted_sets = [str_2_pcp(a) for a in list(accepted.index)]
            causal_preds = set.intersection(*accepted_sets)
            if len(causal_preds) == 0:
                pos = POS.small 
            else:
                pos = POS.perf
                cp_ret = True
                
                if flag:
                    print(causal_preds)
                    print(a0)
                
                continue
                
        #Determine what alpha to check next 
        if pos == POS.big:
            bounds0[1] = a0
            if a0/FACTOR <= bounds0[0]:
                a0 = a0 - abs((a0 - bounds0[0])/2)
            else:
                a0 = a0/FACTOR
        elif pos == POS.small:
            bounds0[0] = a0
            if a0 * FACTOR >= bounds0[1]:
                a0 = a0 + abs((a0 - bounds0[1])/2)
            else:
                a0 = a0 * FACTOR
        
        #Stability check in case no CPs 
        if abs(bounds0[0] - bounds0[1]) < EPS:
            return (-1, -1)
    
    #Then establish interval bounds 
    lowerB = [0, a0]
    upperB = [a0, 100]
    
    #Upper Bound
    a1 = a0
    step = STEP
    pos = POS.perf
    while abs(upperB[0] - upperB[1]) > EPS2:
        a1 = a1 + step
        accepted = pVals[pVals['Final_tstat'] > a1]
        
        #Determine position of alpha 
        if len(accepted.index) == 0:
            pos = POS.big
        else:
            pos = POS.perf
        
        #Determine what alpha to check next 
        if pos == POS.perf:
            upperB[0] = a1
            if a1 + abs(step * FACTOR2) >= upperB[1]:
                step = abs(a1 - upperB[1])/FACTOR2
            else:
                step = abs(step * FACTOR2) 
        elif pos == POS.big:
            upperB[1] = a1
            if (a1 - abs(step * FACTOR2)) <= upperB[0]:
                step = -1 * abs(a1 - upperB[0])/FACTOR2
            else:
                step = -1 * abs(step * FACTOR2) 
        else:
            assert False

    #Lower Bound
    a2 = a0
    if a2 - STEP > 1e-20:
        step = STEP
    else: 
        step = a2/FACTOR2 
    pos = POS.perf
    while abs(lowerB[0] - lowerB[1]) > EPS2:
        a2 = a2 - step
        accepted = pVals[pVals['Final_tstat'] > a2]
        
        #Determine position of alpha 
        accepted_sets = [str_2_pcp(a) for a in list(accepted.index)]
        causal_preds = set.intersection(*accepted_sets)
        if len(causal_preds) == 0:
            pos = POS.small 
        else:
            pos = POS.perf       
        
        #Determine what alpha to check next 
        if pos == POS.perf:
            lowerB[1] = a2
            if a2 - abs(step * FACTOR2) <= lowerB[0]:
                step = abs(a2 - lowerB[0])/FACTOR2
            else:
                step = abs(step * FACTOR2) 
        elif pos == POS.small:
            lowerB[0] = a2
            if (a1 + abs(step * FACTOR2)) >= lowerB[1]:
                step = -1 * abs(a2 - lowerB[1])/FACTOR2
            else:
                step = -1 * abs(step * FACTOR2) 
        else:
            assert False
    
    #Check if interval is too close to 0 to be meaningful 
    if a2 < amin: 
        return (-1, -1)
        
    #Establish 0-padding to interval
    interval = abs(a1 - a2)/5
    
    assert (a2 < a0) and (a0 < a1)
    
    return (max(0, a2 - interval), a1 + interval)


# # File Generation

In [20]:
#Collect all files appropiate to each unique identifier 
rawres_files= []
for f in os.listdir(expdir):
    if ('rawres_' in f):
        rawres_files.append(f)

# Parameters

In [21]:
#Generate Alphas 
NUM_POINTS = 100
MIN_ALPHA = 1e-4

alphas = {}
for fname in rawres_files:
    f, d, s, e, rd = get_hps_from_rawres(fname) 
    unid = '{}_{}_{}_{}_{}'.format(f, d, s, e, rd)
    try:
        pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
        del pvals["()"]
    except:
        continue
    pvals = pd.DataFrame.from_dict(pvals, orient='index')
    alphas[(f, rd, s, d, e)] = [x for x in alpha_tune(pvals, MIN_ALPHA)] + [NUM_POINTS]
    
alphas = pd.DataFrame(alphas).T
alphas.columns = ['start', 'stop', 'num_points']
alphas.index.names = ['feateng', 'reddata', 'seed', 'dataset', 'env']
alphas.head(1000)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,start,stop,num_points
feateng,reddata,seed,dataset,env,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,1000,1000,adult,marital-status,-1.000000,-1.000000,100.0
1,1000,1000,adult,native-country,0.024094,0.055310,100.0
1,1000,1000,adult,occupation,-1.000000,-1.000000,100.0
1,1000,1000,adult,relationship,-1.000000,-1.000000,100.0
1,1000,1000,adult,workclass,0.369009,0.545112,100.0
1,1000,147,adult,marital-status,-1.000000,-1.000000,100.0
1,1000,147,adult,native-country,0.019480,0.041083,100.0
1,1000,147,adult,occupation,-1.000000,-1.000000,100.0
1,1000,147,adult,relationship,-1.000000,-1.000000,100.0
1,1000,147,adult,workclass,1.491107,3.266724,100.0


In [45]:
feateng = [str(x) for x in set(alphas.index.get_level_values(0))]
reddata = [str(x) for x in set(alphas.index.get_level_values(1))]
seed = [str(x) for x in set(alphas.index.get_level_values(2))]  #, '8079', '52', '147', '256', '784', '990', '587', '304','737']
dataset = [str(x) for x in set(alphas.index.get_level_values(3))]  #['adult', germanCredit']
environment = [str(x) for x in set(alphas.index.get_level_values(4))] 

available_exps = itertools.product(feateng, dataset, seed, environment, reddata)

In [46]:
x_axis = {}  #x,y values  for plot of alpha vs #CPs 
y_axis = {}
CPid_results = {}  #Stores CPids of each expierment 
# assert len(list(itertools.product(feateng, dataset, seed, environment))) == 1

for exp in available_exps:
    for fname in rawres_files:
        f, d, s, e, rd = get_hps_from_rawres(fname)  
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]) and (rd == exp[4]):
            unid = '{}_{}_{}_{}_{}'.format(f,d,s,e,rd)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            
            #For Storing all the results 
            x_axis[(f, s, d, e, rd)] = []
            y_axis[(f, s, d, e, rd)] = []
            CPid_results[(f, s, d, e, rd)] = Counter()
            norm = 0
            
            start, stop, num_points = alphas.loc[f, rd, s, d, e][0], alphas.loc[f, rd, s, d, e][1], alphas.loc[f, rd, s, d, e][2]
            for a in np.linspace(start, stop, num_points): 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) > 100000:
                    raise ValueError('too many subsets: {}'.format(len(accepted.index)))
                
                accepted_sets = list(accepted.index)
                accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                if len(accepted_sets) > 0:
                    pcps = set.intersection(*accepted_sets)
                else:
                    pcps = set([])
                
                #Number of Accepted Sets 
                x_axis[(f, s, d, e, rd)].append(a)
                if len(accepted_sets) == 0:
                    y_axis[(f, s, d, e, rd)].append(0)
                else:
                    y_axis[(f, s, d, e, rd)].append(len(set.intersection(*accepted_sets)))
                    
                #Causal predictor  
                for pcp in pcps: 
                    CPid_results[(f, s, d, e, rd)].update({pcp:1})





# Save Results

In [48]:
pickle.dump(x_axis, open(os.path.join(savedir, 'x_axis'), 'wb'))
pickle.dump(y_axis, open(os.path.join(savedir, 'y_axis'), 'wb'))
pickle.dump(CPid_results, open(os.path.join(savedir, 'CPid_results'), 'wb'))

# Appendix

## CALIBRATION

In [None]:
# #Plot Accepted subsets vs Alpha for specified hyperparams 

# #fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(w*5, int(l/w)*5)) #Note - is +2 for reason


for exp in itertools.product(feateng, dataset, seed, environment):
    for fname in rawres_files:
        f, d, s, e, rd = get_hps_from_rawres(fname)
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]) and (rd == exp[4]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e, rd)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            
            start, stop, num_points = alphas.loc[f, rd, s, d, e][0], alphas.loc[f, rd, s, d, e][1], alphas.loc[f, rd, s, d, e][2]
            for a in np.linspace(start, stop, num_points): 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) == 0:
                    print(a, unid, 0, 'null')
                elif len(accepted.index) < 1000:
                    accepted_sets = list(accepted.index)
                    accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                    print(a, unid, len(accepted.index), set.intersection(*accepted_sets))
                else:
                    print(a, unid, len(accepted.index), 'too_many_intersections')
            
    
    print('#####################################')