In [1]:
import copy 
import os 
import itertools
from collections import Counter
import json
import pickle
import pprint
import matplotlib
import matplotlib.pyplot as plt 
import pandas as pd
import numpy as np
res_dir = '0413_reddata'
expdir = os.path.join(os.path.join(os.getcwd(), res_dir), 'causal_discovery')

# Utility Functions 

In [2]:
def dataset_name_from_unid(uid):
    if 'adult' in uid:
        return 'adult'
    if 'german' in uid:
        return 'germanCredit'
    
    assert True == False 

In [3]:
def get_hps_from_rawres(fname):
    '''rawres fname -> features'''
    unique_id = (fname.split('rawres_')[1]).split('.json')[0]
    alpha = unique_id.split('_')[0]
    feateng = unique_id.split('_')[1]
    dataset = unique_id.split('_')[2]
    seed = unique_id.split('_')[3]
    environment = unique_id.split('_')[4]
    return feateng, dataset, seed, environment
    
     #print('{}_{}_{}_{}'.format(feateng, dataset, seed, environment))

In [4]:
def str_2_pcp(pcpstr):
    pcpstr = (pcpstr.split('(')[1]).split(')')[0]
    pcpstr = pcpstr.replace(' ', '')
    ret = set(pcpstr.split(','))
    ret.discard('')
    return ret

In [5]:
def eligible_exps(queries, e_list):
    '''Gets all experiments from e_list with the queries attributes
    param: elist: list of all posible expeirment keys 
    param: queries: list of terms which must be elements of e_list tuples'''
    ret = []
    for exp in e_list:
        flag = True
        for q in queries: 
            if q not in exp:
                flag = False
        if flag:
            ret.append(exp)
    return ret

In [6]:
def norm_ctr(Ctr, n):
    '''Returns sorted list (ctr_name, p) of the n most common elements in Ctr, where p is normalized freq'''
    sort = [list(x) for x in Ctr.most_common(n)]
    norm = len(list(Ctr.elements()))
    for t in sort:
        t[1] = t[1]/norm
    
    return sort 

def add_slist(s1, s2):
    'For all common-keyed tuples in s1-2, add prob values. For new ones, add to returned list'
    s1_keys = [x[0] for x in s1]
    ret = copy.deepcopy(s1)
    
    for t in s2: 
        if t[0] not in s1_keys:
            ret.append(t)
        else:
            for cp in ret:
                if cp[0] == t[0]:
                    cp[1] += t[1]
    return ret
    
def scale_slist(s, nf):
    for t in s:
        t[1] = t[1]/nf
    return s

def sqrt_slist(s):
    for t in s:
        t[1] = np.sqrt(t[1])
    return s

def sqdiff_slist(s1, means):
    '''For every value of m, find the (m, (s-m)^2)'''
    s1_keys = [s[0] for s in s1]
    ret = []
    
    for m in means:
        if m[0] in s1_keys: 
            for s in s1:
                if m[0] == s[0]:
                    ret.append([m[0], ((m[1]-s[1])**2)])
        else:
            ret.append([m[0], (m[1]**2)])
    
    return ret

def mean_slist(all_exps, results):
    '''For all specified experiments, find the list of (CPid, prob) for each and average them'''
    cp = []
    norm = len(all_exps)
    for exp in all_exps: 
        cp = add_slist(cp, norm_ctr(results[exp], None)) #Add the causal predictors to each 
    scaled = scale_slist(cp, norm)
    return sorted(scaled, key=lambda x:x[1], reverse=True)
        

def var_slist(all_exps, results):
    cp = []
    norm = len(all_exps)
    means = mean_slist(all_exps, results)
    
    for exp in all_exps: 
        cp = add_slist(cp, sqdiff_slist(norm_ctr(results[exp], None), means))

    cp = sqrt_slist(scale_slist(cp, norm))
    
    #Sort in same order as means 
    ret = []
    for m in means:
        for e in cp:
            if m[0] == e[0]:
                ret.append(e)
    
    return ret


In [80]:
import enum 
#Part 1
START_ALPHA = 1.0
FACTOR = 2
EPS = 1e-20
#Part 2
STEP = 1e-2
FACTOR2 = 2
EPS2 = 1e-10

class POS(enum.Enum):
   big = 1
   small = 2
   perf = 3

def alpha_tune(pVals):
    #First find a CP returning alpha 
    a0 = START_ALPHA
    bounds0 = [0, 100.0]
    cp_ret = False 
    while not cp_ret:
        pos = 0
        accepted = pVals[pVals['Final_tstat'] > a0]
        
        #Determine position of alpha 
        if len(accepted.index) == 0:
            pos = POS.big
        else: 
            accepted_sets = [str_2_pcp(a) for a in list(accepted.index)]
            causal_preds = set.intersection(*accepted_sets)
            if len(causal_preds) == 0:
                pos = POS.small 
            else:
                pos = POS.perf
                cp_ret = True
                continue
        
#         print(a0, bounds0, pos)
        #Determine what alpha to check next 
        if pos == POS.big:
            bounds0[1] = a0
            if a0/FACTOR <= bounds0[0]:
                a0 = a0 - abs((a0 - bounds0[0])/2)
            else:
                a0 = a0/FACTOR
        elif pos == POS.small:
            bounds0[0] = a0
            if a0 * FACTOR >= bounds0[1]:
                a0 = a0 + abs((a0 - bounds0[1])/2)
            else:
                a0 = a0 * FACTOR
        
        #Stability check in case no CPs 
        if abs(bounds0[0] - bounds0[1]) < EPS:
            return None
    
    #Then establish interval bounds 
    lowerB = [0, a0]
    upperB = [a0, 100]
    
    #Upper Bound
    a1 = a0
    step = STEP
    pos = POS.perf
    while abs(upperB[0] - upperB[1]) > EPS2:
        a1 = a1 + step
        accepted = pVals[pVals['Final_tstat'] > a1]
        
        #Determine position of alpha 
        if len(accepted.index) == 0:
            pos = POS.big
        else:
            pos = POS.perf
        
        #Determine what alpha to check next 
        if pos == POS.perf:
            upperB[0] = a1
            if a1 + abs(step * FACTOR2) >= upperB[1]:
                step = abs(a1 - upperB[1])/FACTOR2
            else:
                step = abs(step * FACTOR2) 
        elif pos == POS.big:
            upperB[1] = a1
            if (a1 - abs(step * FACTOR2)) <= upperB[0]:
                step = -1 * abs(a1 - upperB[0])/FACTOR2
            else:
                step = -1 * abs(step * FACTOR2) 
        else:
            assert False

    #Lower Bound
    a2 = a0
    if a2 - STEP > 1e-20:
        step = STEP
    else: 
        step = a2/FACTOR2 
    pos = POS.perf
    while abs(lowerB[0] - lowerB[1]) > EPS2:
        print(a2, lowerB)

        a2 = a2 - step
        accepted = pVals[pVals['Final_tstat'] > a2]
        
        #Determine position of alpha 
        accepted_sets = [str_2_pcp(a) for a in list(accepted.index)]
        causal_preds = set.intersection(*accepted_sets)
        if len(causal_preds) == 0:
            pos = POS.small 
        else:
            pos = POS.perf       
        
        #Determine what alpha to check next 
        if pos == POS.perf:
            lowerB[1] = a2
            if a2 - abs(step * FACTOR2) <= lowerB[0]:
                step = abs(a2 - lowerB[0])/FACTOR2
            else:
                step = abs(step * FACTOR2) 
        elif pos == POS.small:
            lowerB[0] = a2
            if (a1 + abs(step * FACTOR2)) >= lowerB[1]:
                step = -1 * abs(a2 - lowerB[1])/FACTOR2
            else:
                step = -1 * abs(step * FACTOR2) 
        else:
            assert False
            
    return (a2, a1)


# # File Generation

In [7]:
#Collect all files appropiate to each unique identifier 
rawres_files= []
for f in os.listdir(expdir):
    if ('rawres_' in f):
        rawres_files.append(f)

In [9]:
total_combos = []
for fname in rawres_files:
    f, d, s, e = get_hps_from_rawres(fname) 
    total_combos.append((s, d, e))
total_combos = pd.DataFrame({x:[0] for x in total_combos}).T
total_combos.head(1000)

Unnamed: 0,Unnamed: 1,Unnamed: 2,0
1000,adult,marital-status,0
1000,adult,native-country,0
1000,adult,occupation,0
1000,adult,relationship,0
1000,adult,workclass,0
147,adult,marital-status,0
147,adult,native-country,0
147,adult,occupation,0
147,adult,relationship,0
147,adult,workclass,0


# Parameters

In [53]:
feateng = ['12']
dataset = ['adult']  #['adult', germanCredit']
seed = ['8079']  #, '8079', '52', '147', '256', '784', '990', '587', '304','737']
environment = ['workclass'] 

available_exps = itertools.product(feateng, dataset, seed, environment)

alphas = {'8079':{'adult':{'workclass':(0.1, 0.6, 100),\
                           'occupation':(1e-6, 1e-4, 100),\
                           'native-country':(0.5, 1.5, 100),\
                           'marital-status':(1e-60, 1e-20, 100)},\
                  'german':{'Purpose':(3.5, 5.5, 100), \
                            'Housing':(1.6, 3.0, 100), \
                            'Telephone':(3.0, 4.0, 100), \
                            'Property':(4, 5.5, 100)} \
                 },
          '1000':{'adult':{'workclass':(0.1, 0.6, 100),\
                           'occupation':(1e-6, 1e-4, 100),\
                           'native-country':(1e-2, 0.15, 100),\
                           'marital-status':(1e-60, 1e-20, 100)},\
                  'german':{'Purpose':(3.5, 5.5, 100), \
                            'Housing':(1.6, 3.0, 100), \
                            'Telephone':(3.0, 4.0, 100), \
                            'Property':(4, 5.5, 100)} \
                 }
}

alphas = pd.DataFrame({(s, d, e):alphas[s][d][e] for s in alphas.keys() for d in alphas[s].keys() for e in alphas[s][d].keys()}).T
alphas.columns = ['start', 'stop', 'num_points']

# CALIBRATION

In [79]:
for exp in itertools.product(feateng, dataset, seed, environment):
    for fname in rawres_files:
        f, d, s, e = get_hps_from_rawres(fname)
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            print(alpha_tune(pvals))

0.125 [0, 0.125]
0.115 [0, 0.115]
0.095 [0.095, 0.115]
0.10500000000000001 [0.10500000000000001, 0.115]
0.11000000000000001 [0.10500000000000001, 0.11000000000000001]
0.10750000000000001 [0.10750000000000001, 0.11000000000000001]
0.10875000000000001 [0.10875000000000001, 0.11000000000000001]
0.10937500000000001 [0.10875000000000001, 0.10937500000000001]
0.1090625 [0.10875000000000001, 0.1090625]
0.10890625000000001 [0.10890625000000001, 0.1090625]
0.10898437500000001 [0.10890625000000001, 0.10898437500000001]
0.1089453125 [0.1089453125, 0.10898437500000001]
0.10896484375000001 [0.10896484375000001, 0.10898437500000001]
0.108974609375 [0.10896484375000001, 0.108974609375]
0.10896972656250001 [0.10896484375000001, 0.10896972656250001]
0.10896728515625001 [0.10896484375000001, 0.10896728515625001]
0.10896606445312501 [0.10896606445312501, 0.10896728515625001]
0.10896667480468751 [0.10896667480468751, 0.10896728515625001]
0.10896697998046875 [0.10896667480468751, 0.10896697998046875]
0.108

In [81]:
# #Plot Accepted subsets vs Alpha for specified hyperparams 

# #fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(w*5, int(l/w)*5)) #Note - is +2 for reason


for exp in itertools.product(feateng, dataset, seed, environment):
    for fname in rawres_files:
        f, d, s, e = get_hps_from_rawres(fname)
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            
            start, stop, num_points = alphas.loc[s, d, e][0], alphas.loc[s, d, e][1], alphas.loc[s, d, e][2]
            for a in np.linspace(start, stop, num_points): 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) == 0:
                    print(a, unid, 0, 'null')
                elif len(accepted.index) < 1000:
                    accepted_sets = list(accepted.index)
                    accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                    print(a, unid, len(accepted.index), set.intersection(*accepted_sets))
                else:
                    print(a, unid, len(accepted.index), 'too_many_intersections')
            
    
    print('#####################################')

# Number PCPS Accepted

In [None]:
x_axis = {}  #x,y values  for plot of alpha vs #CPs 
y_axis = {}
CPid_results = {}  #Stores CPids of each expierment 
# assert len(list(itertools.product(feateng, dataset, seed, environment))) == 1

for exp in available_exps:
    for fname in rawres_files:
        f, d, s, e = get_hps_from_rawres(fname)  
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            
            #For Storing all the results 
            x_axis[(s, d, e)] = []
            y_axis[(s, d, e)] = []
            CPid_results[(s, d, e)] = Counter()
            norm = 0
            
            start, stop, num_points = alphas.loc[s, d, e][0], alphas.loc[s, d, e][1], alphas.loc[s, d, e][2]
            for a in np.linspace(start, stop, num_points): 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) > 100000:
                    raise ValueError('too many subsets: {}'.format(len(accepted.index)))
                
                accepted_sets = list(accepted.index)
                accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                if len(accepted_sets) > 0:
                    pcps = set.intersection(*accepted_sets)
                else:
                    pcps = set([])
                
                #Number of Accepted Sets 
                x_axis[(s, d, e)].append(a)
                if len(accepted_sets) == 0:
                    y_axis[(s, d, e)].append(0)
                else:
                    y_axis[(s, d, e)].append(len(set.intersection(*accepted_sets)))
                    
                #Causal predictor  
                for pcp in pcps: 
                    CPid_results[(s, d, e)].update({pcp:1})
                    
                
#                 if len(pcps) > 0:
#                     norm += 1
#                     for pcp in pcps: 
#                         if pcp not in CPid_results[(s, d, e)]:
#                             CPid_results[(s, d, e)][pcp] = 1 
#                         else:
#                             CPid_results[(s, d, e)][pcp] += 1
                        
#             #Causal Predictor More  
#             for cat in CPid_results[(s, d, e)]:
#                 CPid_results[(s, d, e)][cat] = CPid_results[(s, d, e)][cat]/norm


In [None]:
norm_ctr(CPid_results[('8079', 'adult', 'workclass')], 5)

In [None]:
matplotlib.rcParams.update({'font.size': 26})
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(40, 20))

for i, exp in enumerate(x_axis.keys()):  #Assume x_axis, y_axis keys are the same 
    axes[int(i/4), i%4].plot(x_axis[exp], y_axis[exp])
    axes[int(i/4), i%4].set_title(exp)
    axes[int(i/4), i%4].set_ylabel('Number Accepted')
    axes[int(i/4), i%4].set_xlabel('alpha')
plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
plt.show()

# ALPHA SENSITIVITY PLOTS

##Regular

In [None]:
colours = ['b', 'g', 'r', 'c', 'k']
matplotlib.rcParams.update({'font.size': 16})
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(40, 20))

for i, exp in enumerate(CPid_results.keys()):
    sort_pcp_sens = norm_ctr(CPid_results[exp], 5)    
    
    labels = []
    bars = []
    for d in range(min(5,len(sort_pcp_sens))):
        labels.append(sort_pcp_sens[d][0])
        bars.append(sort_pcp_sens[d][1])


    X = 0 
    width = 0.05  # the width of the bars
    
    axes[int(i/4), i%4].set_title(exp, pad=30)
    if (i%4) == 0:
        axes[int(i/4), i%4].set_ylabel('Proportion Included', fontsize=32)
        axes[int(i/4), i%4].yaxis.labelpad = 40
    axes[int(i/4), i%4].set_ylim(0,1)
    axes[int(i/4), i%4].tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the 
    
    for d in range(len(labels)):
        axes[int(i/4), i%4].bar(X + (d*width), bars[d], color = colours[d], width = width, label=labels[d])
    axes[int(i/4), i%4].legend(loc='lower left', prop={'size':30}) 

plt.show()

In [None]:
print(CPid_results.keys())

##Aggregate Random Seeds

In [None]:
colours = ['b', 'g', 'r', 'c', 'k']
matplotlib.rcParams.update({'font.size': 16})
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(40, 20))

aggregate = {}
for i, big_exp in enumerate(list(itertools.product(dataset, environment))):
    aggregate[big_exp] = {}
    avg_sort_pcp_sens = mean_slist(eligible_exps(big_exp, CPid_results.keys()), CPid_results)
    errors = var_slist(eligible_exps(big_exp, CPid_results.keys()), CPid_results)
    
    labels = []
    bars = []
    errors_plt = []
    for d in range(min(5,len(avg_sort_pcp_sens))):
        labels.append(avg_sort_pcp_sens[d][0])
        bars.append(avg_sort_pcp_sens[d][1])
        errors_plt.append(errors[d][1])

    X = 0 
    width = 0.05  # the width of the bars

    axes[int(i/4), i%4].set_title(big_exp, pad=30)
    if (i%4) == 0:
        axes[int(i/4), i%4].set_ylabel('Proportion Included', fontsize=32)
        axes[int(i/4), i%4].yaxis.labelpad = 40
    axes[int(i/4), i%4].set_ylim(0,1)
    axes[int(i/4), i%4].tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the 

    for d in range(len(labels)):
        axes[int(i/4), i%4].bar(X + (d*width), bars[d], color = colours[d], \
                                width = width, label=labels[d], yerr=errors_plt[d])
    axes[int(i/4), i%4].legend(loc='lower left', prop={'size':30}) 

plt.show()

## Confusion Matrix

In [None]:
res = np.zeros((4,4))
for i, env_1 in enumerate(results.keys()):
    for j, env_2 in enumerate(results.keys()):
        e1 = set(a[0] for a in sorted(results[env_1].items(), key=lambda kv: kv[1], reverse=True)[:5])
        e2 = set(a[0] for a in sorted(results[env_2].items(), key=lambda kv: kv[1], reverse=True)[:5])
        res[i,j] = len(set.intersection(e1, e2))
        
print(res)

# Feature Engineering

In [None]:
#PARAMS
alphas = {'8079':{'workclass':np.linspace(0.1, 0.6, 100),\
                 'occupation':np.linspace(1e-6, 1e-4, 100),\
                 'native-country':np.linspace(0.5, 1.5, 100),\
                 'marital-status':np.linspace(1e-60, 1e-20, 100), \
                 'Purpose':np.linspace(3.5, 5.5, 100), \
                 'Housing':np.linspace(1.6, 3.0, 100), \
                 'Telephone':np.linspace(3.0, 4.0, 100), \
                 'Property':np.linspace(4, 5.5, 100)},
         '1000':  {'workclass':np.linspace(0.1, 0.6, 100),\
                 'occupation':np.linspace(1e-6, 1e-4, 100),\
                 'native-country':np.linspace(1e-2, 0.15, 100),\
                 'marital-status':np.linspace(1e-60, 1e-20, 100), \
                 'Purpose':np.linspace(3.5, 5.5, 100), \
                 'Housing':np.linspace(1.6, 3.0, 100), \
                 'Telephone':np.linspace(3.0, 4.0, 100), \
                 'Property':np.linspace(4, 5.5, 100)}
}


feateng = ['12']
dataset = ['adult']  #['adult', germanCredit']
seed = ['1000', '8079'] #seeds=(1000 8079 52 147)
environment = ['workclass', 'occupation', 'native-country', 'marital-status']
log_alpha = False

In [None]:
#Plot PCP distribution per Enviriornment variable over time 
assert ((len(dataset) == 1) and ((len(seed) == 1)) ^ (len(feateng) == 1))
# fig, axes = plt.subplots(nrows=1, ncols=4)

i = 2

results = {}
for exp in itertools.product(feateng, dataset, seed, environment):
    if exp[i] not in results.keys():
        results[exp[i]] = {}
    if (exp[i] in results.keys()) and (exp[3] not in results[exp[i]].keys()):
        results[exp[i]][exp[3]] = {}
    print(results.keys(), exp[3])
    norm = 0
    for fname in rawres_files:
        f, d, s, e = get_hps_from_rawres(fname)
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            for a in alphas[s][e]: 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) > 100000:
                    raise ValueError('too many subsets: {}'.format(len(accepted.index)))
                    
                accepted_sets = list(accepted.index)
                accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                
                if len(accepted_sets) == 0:
                    continue
                pcps = set.intersection(*accepted_sets)   
                if len(pcps) == 0:
                    continue
                else:
                    norm += 1
                    for pcp in pcps: 
                        if pcp not in results[exp[i]][exp[3]]:
                            results[exp[i]][exp[3]][pcp] = 1 
                        else:
                            results[exp[i]][exp[3]][pcp] += 1
                      
                        
            
            for cat in results[exp[i]][exp[3]]:
                results[exp[i]][exp[3]][cat] = results[exp[i]][exp[3]][cat]/norm


In [None]:
print(results.keys())
print(results['8079']['occupation'])  #.keys())

In [None]:
colours = ['b', 'g', 'r', 'c', 'k']
matplotlib.rcParams.update({'font.size': 26})
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(40, 20))

for i, fteng in enumerate(results):
    for j, env in enumerate(results[fteng]):
        labels = []
        bars = []
        to_plot = results[fteng][env]
        for res in to_plot:
            labels.append(res)
            bars.append(to_plot[res])
        X = 0 
        width = 0.05

        axes[i,j].set_title('Adult - {} - Seed={}'.format(env.capitalize(), fteng), pad=30)
        if j == 0:
            axes[i,j].set_ylabel('Proportion Included', fontsize=32)
            axes[i,j].yaxis.labelpad = 40
        axes[i,j].set_ylim(0,1)
        axes[i,j].tick_params(
            axis='x',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom=False,      # ticks along the bottom edge are off
            top=False,         # ticks along the top edge are off
            labelbottom=False) # labels along the 

        if len(labels) > 0:
            for d in range(0, min(len(labels), 5)):
                axes[i,j].bar(X + (d*width), bars[d], color = colours[d], width = width, label=labels[d])
            axes[i,j].legend(loc='lower left', prop={'size':30})
             

            

In [None]:
#Plot Accepted subsets vs Alpha for specified hyperparams 

#fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(w*5, int(l/w)*5)) #Note - is +2 for reason


for exp in itertools.product(feateng, dataset, seed, environment):
    for fname in rawres_files:
        f, d, s, e = get_hps_from_rawres(fname)
        if (f == exp[0]) and (d == exp[1]) and (s == exp[2]) and (e == exp[3]):
            unid = '{}_{}_{}_{}'.format(f,d,s,e)
            try:
                pvals = json.load(open(os.path.join(expdir, fname), 'rb'))
                del pvals["()"]
            except:
                continue
            pvals = pd.DataFrame.from_dict(pvals, orient='index')
            for a in alphas[e]: 
                accepted = pvals[pvals['Final_tstat'] > a]
                if len(accepted.index) == 0:
                    print(a, unid, 0, 'null')
                elif len(accepted.index) < 1000:
                    accepted_sets = list(accepted.index)
                    accepted_sets = [str_2_pcp(a) for a in accepted_sets]
                    print(a, unid, len(accepted.index), set.intersection(*accepted_sets))
                else:
                    print(a, unid, len(accepted.index), 'too_many_intersections')
            
    
    print('#####################################')
