In [None]:
import pandas as pd 
import os 
from os.path import join
from collections import Counter

import matplotlib.pyplot as plt 
plt.rcParams['figure.figsize'] = [6, 4]

import sys
basedir = '/Users/RobertAdragna/Documents/School/Fourth_Year/ESC499-Thesis/codebases/causal_discovery'
sys.path.append(basedir)

In [None]:
def order_cols(cols):
    '''Order columns pulled from dataframe in the desired order of display'''
    for c in cols: 
        if c == 'TestSet':
            n0 = c
        if ('ACC' in c) and ('train' in c):
            n1 = c  
        if ('ACC' in c) and ('test' in c):
            n2 = c  
        if ('DP' in c) and ('train' in c):
            n3 = c
        if ('DP' in c) and ('test' in c):
            n4 = c
            
    return [n0, n1, n2, n3, n4]

In [None]:
def fname_2_algo(n):
    if 'irm_results' in n:
        return 'irm'
    if 'linear-irm_results' in n:
        return 'linear-irm'
    if 'icp_results' in n:
        return 'icp'
    if 'logreg_results' in n:
        return 'logreg'
    if 'linreg_results' in n:
        return 'linreg' 
    if 'mlp_results' in n:
        return 'mlp'
    if 'constant_results' in n:
        return 'constant' 

#Generate relevant files 
expdir = '0610_baseline_adultgerman/testing'

rel_files = []
for rdir in os.listdir(expdir):
    if os.path.isdir(join(expdir, rdir)):
        assert rdir in ['irm', 'linear-irm', 'icp', 'linreg', 'logreg', 'mlp', 'constant']
        rel_files.append(join(join(join(expdir, rdir), 'analysis'), '{}_results.xlsx'.format(rdir)))

#Collect relevant dataframe of each experiment run 
data_store = {}
for f in rel_files:  
    res = pd.read_excel(f, index_col=0, header=0)
    algo = fname_2_algo(f)
    rel_cols = [c for c in res.columns if ("ACC" in c) or ("DP" in c) or (c == "TestSet")]
    rel_cols = order_cols(rel_cols)
    data = res[rel_cols]
    
#     #Make corrections for formatting of fairness in prior data
#     f_cols = [c for c in rel_cols if 'DP' in c]
#     data[f_cols].apply(lambda x: x.split(':'))
    
    data_store[algo] = data

#Do any processing of internal entries inside the dataframe     
# import pdb; pdb.set_trace()
#Get final df tables by context variable 
context_name = 'TestSet'
context_vals = ['workclass_DUMmY', 'native-country_DUMmY', 'relationship_DUMmY', 'Purpose_DUMmY', 'Housing_DUMmY']
final_colnames = ['algo', 'training_loss', 'testing_loss', 'training_fairness', 'testing_fairness']

context_store = {c:[] for c in context_vals}
for al, df in data_store.items():
    assert Counter(context_vals) == Counter(list(df[context_name]))  #Make one of each context in results
    for resid, row in df.iterrows():
        c = row[context_name]
        v = row.drop(context_name).values ; v = np.insert(v, 0, al)
        context_store[c].append(v)
context_store = {k:pd.DataFrame(v, columns=final_colnames) for k,v in context_store.items()}


In [None]:
##Aggregating relevant columns 
adult_keys = ['workclass_DUMmY', 'native-country_DUMmY', 'relationship_DUMmY']
german_keys = ['Housing_DUMmY', 'Purpose_DUMmY']

agg_context_store = {'adult':[], 'german':[]}
for c, df in context_store.items():  #Sort into appropiate buckets
    if c in adult_keys:
        agg_context_store['adult'].append(df)
    elif c in german_keys:
        agg_context_store['german'].append(df)

method_id = 'algo'
for c in agg_context_store.keys():
    to_agg = agg_context_store[c]
    agg_context_store[c] = pd.concat(to_agg)
    agg_context_store[c] = agg_context_store[c].groupby(method_id)['training_loss', \
                                                'testing_loss', 'training_fairness', 'testing_fairness'].mean()
    
    agg_context_store[c] = agg_context_store[c].plot.bar(rot=15, title=c) 

In [None]:
##Note - this is formatting contingent on the plots you want to save 

results_dir = join(expdir, 'final_results')
if not os.path.exists(results_dir):
    os.mkdir(results_dir)

for k, v in context_store.items():
    v.to_latex(join(results_dir, '{}_merge_results.tex'.format(k)), caption=k)
    
for k, v in agg_context_store.items():
    fig = v.get_figure()
    fig.savefig(join(results_dir, '{}_merge_results.png'.format(k)))