In [1]:
import pandas as pd
import numpy as np
import sklearn
import sys
import os
import torch
import tqdm
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../')
sys.path.append('../../')

from sklearn.linear_model import LogisticRegression
from fairws.data_util import load_dataset, load_LIFT_embedding, load_LF
from fairws.metrics import exp_eval
from fairws.sbm import get_baseline_pseudolabel, get_sbm_pseudolabel



In [None]:
repetition = 5

# for dataset_name in  ["bank_marketing", "utkface", "hatexplain", "celeba", "civilcomments"]:
for dataset_name in  ["bank_marketing", "utkface", "hatexplain"]:
    use_LIFT_embedding = False # only for adult, bank_marketing
    sbm_diff_threshold = 0.05
    
    result_collection = pd.DataFrame() # to keep results
    
    for seed in range(repetition):
        np.random.seed(seed)
        x_train, y_train, a_train, x_test, y_test, a_test = load_dataset(dataset_name=dataset_name,
                                                                    data_base_path='../data/')
        for ot_type in ["sinkhorn"]:
            cond = f"sbm({ot_type})"

            L = load_LF(dataset_name, data_base_path='../data')
            if use_LIFT_embedding:
                x_embedding_train, x_embedding_test = load_LIFT_embedding(dataset_name=dataset_name,
                                                                            data_base_path='../data/')
                y_train= get_sbm_pseudolabel(L, x_embedding_train, a_train, dataset_name, seed=seed,
                                             ot_type=ot_type, diff_threshold=sbm_diff_threshold,
                                             use_LIFT_embedding=True)

            else:
                y_train= get_sbm_pseudolabel(L, x_train, a_train, dataset_name, seed=seed,
                                             ot_type=ot_type, diff_threshold=sbm_diff_threshold)

            # downstream task
            model = LogisticRegression(random_state=seed)
            model.fit(x_train, y_train)
            y_pred = model.predict(x_test)
            result = exp_eval(y_test, y_pred, a_test, cond=cond)
            result['seed'] = seed
            print(result)
            result_collection = result_collection.append(result, ignore_index=True)
            
    result_collection.to_csv(os.path.join('../rebuttal_results', f'01_{dataset_name}_sinkhorn.csv'))

100%|██████████| 1000/1000 [00:00<00:00, 1866.13epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8465647001699441, 'fscore': 0.5150895140664962, 'precision': 0.3992862807295797, 'recall': 0.7255043227665706, 'demographic_parity_gap': 0.12193702161312103, 'equal_opportunity_gap': 0.07966578006744385, 'seed': 0}


computing sbm mapping...: 100%|██████████| 188/188 [00:03<00:00, 54.32it/s]
computing sbm mapping...: 100%|██████████| 7021/7021 [02:11<00:00, 53.30it/s]


SBM (sinkhorn) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_1.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_1.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2931.95epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8318362061989156, 'fscore': 0.5, 'precision': 0.37536127167630057, 'recall': 0.7485590778097982, 'demographic_parity_gap': 0.1401161104440689, 'equal_opportunity_gap': 0.06574338674545288, 'seed': 1}


computing sbm mapping...: 100%|██████████| 188/188 [00:03<00:00, 54.05it/s]
computing sbm mapping...: 100%|██████████| 7021/7021 [02:11<00:00, 53.47it/s]


SBM (sinkhorn) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_2.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_2.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2942.48epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.833211944646759, 'fscore': 0.5010893246187363, 'precision': 0.3773240977032446, 'recall': 0.7456772334293948, 'demographic_parity_gap': 0.1325392723083496, 'equal_opportunity_gap': 0.07549279928207397, 'seed': 2}


computing sbm mapping...: 100%|██████████| 188/188 [00:03<00:00, 53.06it/s]
computing sbm mapping...: 100%|██████████| 7021/7021 [02:11<00:00, 53.22it/s]


SBM (sinkhorn) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_3.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_3.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2918.72epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8366917536618921, 'fscore': 0.5066014669926651, 'precision': 0.38341968911917096, 'recall': 0.7463976945244957, 'demographic_parity_gap': 0.1359451413154602, 'equal_opportunity_gap': 0.0762590765953064, 'seed': 3}


computing sbm mapping...: 100%|██████████| 188/188 [00:03<00:00, 53.64it/s]
computing sbm mapping...: 100%|██████████| 7021/7021 [02:11<00:00, 53.24it/s]


SBM (sinkhorn) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_4.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_SBM_mapping_sinkhorn_0->1_4.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2924.27epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8677672574249413, 'fscore': 0.5412689500280742, 'precision': 0.44342226310947563, 'recall': 0.6945244956772334, 'demographic_parity_gap': 0.1283203810453415, 'equal_opportunity_gap': 0.046715617179870605, 'seed': 4}


computing sbm mapping...: 100%|██████████| 4059/4059 [12:27<00:00,  5.43it/s]
computing sbm mapping...: 100%|██████████| 682/682 [02:09<00:00,  5.25it/s]


SBM (sinkhorn) saved in ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_0.pt, ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_0.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2892.15epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.7987766294030795, 'fscore': 0.7909728308501315, 'precision': 0.7955046275892463, 'recall': 0.7864923747276689, 'demographic_parity_gap': 0.12833914160728455, 'equal_opportunity_gap': 0.044963181018829346, 'seed': 0}


computing sbm mapping...: 100%|██████████| 4059/4059 [12:37<00:00,  5.36it/s]
computing sbm mapping...: 100%|██████████| 682/682 [02:12<00:00,  5.16it/s]


SBM (sinkhorn) saved in ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_1.pt, ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_1.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2319.51epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8006749630879562, 'fscore': 0.7930807970221152, 'precision': 0.7970950704225352, 'recall': 0.7891067538126362, 'demographic_parity_gap': 0.15089532732963562, 'equal_opportunity_gap': 0.06007653474807739, 'seed': 1}


computing sbm mapping...: 100%|██████████| 4059/4059 [12:29<00:00,  5.41it/s]
computing sbm mapping...: 100%|██████████| 682/682 [02:11<00:00,  5.18it/s]


SBM (sinkhorn) saved in ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_2.pt, ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_2.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2604.03epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8034170006327779, 'fscore': 0.7958826106000876, 'precision': 0.8000880669308674, 'recall': 0.7917211328976035, 'demographic_parity_gap': 0.13949334621429443, 'equal_opportunity_gap': 0.050823748111724854, 'seed': 2}


computing sbm mapping...: 100%|██████████| 4059/4059 [12:32<00:00,  5.40it/s]
computing sbm mapping...: 100%|██████████| 682/682 [02:09<00:00,  5.25it/s]


SBM (sinkhorn) saved in ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_3.pt, ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_3.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2371.67epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8000421851929973, 'fscore': 0.7945383615084525, 'precision': 0.7904269081500647, 'recall': 0.7986928104575164, 'demographic_parity_gap': 0.14424002170562744, 'equal_opportunity_gap': 0.04848533868789673, 'seed': 3}


computing sbm mapping...: 100%|██████████| 4059/4059 [12:32<00:00,  5.39it/s]
computing sbm mapping...: 100%|██████████| 682/682 [02:05<00:00,  5.43it/s]


SBM (sinkhorn) saved in ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_4.pt, ../data/UTKFace/SBM_mapping/UTKFace_SBM_mapping_sinkhorn_0->1_4.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2879.98epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.7985657034380932, 'fscore': 0.7926167209554832, 'precision': 0.79004329004329, 'recall': 0.7952069716775599, 'demographic_parity_gap': 0.14480620622634888, 'equal_opportunity_gap': 0.04051727056503296, 'seed': 4}


100%|██████████| 1000/1000 [00:00<00:00, 2816.75epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.6121811556480999, 'fscore': 0.6873688627780109, 'precision': 0.6588897827835881, 'recall': 0.718421052631579, 'demographic_parity_gap': 0.07214945554733276, 'equal_opportunity_gap': 0.036746084690093994, 'seed': 0}


computing sbm mapping...: 100%|██████████| 2399/2399 [05:58<00:00,  6.69it/s]
computing sbm mapping...: 100%|██████████| 1440/1440 [03:36<00:00,  6.64it/s]


SBM (sinkhorn) saved in ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_1.pt, ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_1.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2725.55epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.610098906819365, 'fscore': 0.6851618327028164, 'precision': 0.6577885391444713, 'recall': 0.7149122807017544, 'demographic_parity_gap': 0.0821344256401062, 'equal_opportunity_gap': 0.043850839138031006, 'seed': 1}


computing sbm mapping...: 100%|██████████| 2399/2399 [05:58<00:00,  6.70it/s]
computing sbm mapping...: 100%|██████████| 1440/1440 [03:36<00:00,  6.65it/s]


SBM (sinkhorn) saved in ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_2.pt, ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_2.pt!


100%|██████████| 1000/1000 [00:00<00:00, 1700.73epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.610098906819365, 'fscore': 0.6851618327028164, 'precision': 0.6577885391444713, 'recall': 0.7149122807017544, 'demographic_parity_gap': 0.0821344256401062, 'equal_opportunity_gap': 0.043850839138031006, 'seed': 2}


computing sbm mapping...: 100%|██████████| 2399/2399 [05:57<00:00,  6.72it/s]
computing sbm mapping...: 100%|██████████| 1440/1440 [03:38<00:00,  6.60it/s]


SBM (sinkhorn) saved in ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_3.pt, ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_3.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2826.49epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.610098906819365, 'fscore': 0.6851618327028164, 'precision': 0.6577885391444713, 'recall': 0.7149122807017544, 'demographic_parity_gap': 0.0821344256401062, 'equal_opportunity_gap': 0.043850839138031006, 'seed': 3}


computing sbm mapping...: 100%|██████████| 2399/2399 [05:58<00:00,  6.69it/s]
computing sbm mapping...: 100%|██████████| 1440/1440 [03:36<00:00,  6.65it/s]


SBM (sinkhorn) saved in ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_4.pt, ../data/hateXplain/SBM_mapping/hateXplain_SBM_mapping_sinkhorn_0->1_4.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2613.44epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.610098906819365, 'fscore': 0.6851618327028164, 'precision': 0.6577885391444713, 'recall': 0.7149122807017544, 'demographic_parity_gap': 0.0821344256401062, 'equal_opportunity_gap': 0.043850839138031006, 'seed': 4}


100%|██████████| 1000/1000 [00:00<00:00, 2888.01epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8720781340229481, 'fscore': 0.8854618618348582, 'precision': 0.9455994622365199, 'recall': 0.8325160635779506, 'demographic_parity_gap': 0.30642226338386536, 'equal_opportunity_gap': 0.1839233636856079, 'seed': 0}


computing sbm mapping...:  40%|████      | 12748/31698 [5:53:19<9:13:20,  1.75s/it] 

In [None]:
result_collection

In [3]:
result_collection.groupby(['condition'], as_index=False).agg(
                      {'accuracy':['mean','std'],
                      'fscore':['mean','std'], 
                      'precision': ['mean', 'std'],
                      'recall': ['mean', 'std']})

Unnamed: 0_level_0,condition,accuracy,accuracy,fscore,fscore,precision,precision,recall,recall
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std,mean,std
0,sbm(sinkhorn),0.872078,,0.885462,,0.945599,,0.832516,


In [None]:
repetition = 5

for dataset_name in  ["adult", "bank_marketing"]:
    use_LIFT_embedding = True # only for adult, bank_marketing
    sbm_diff_threshold = 0.05
    
    result_collection = pd.DataFrame() # to keep results
    
    for seed in range(repetition):
        np.random.seed(seed)
        x_train, y_train, a_train, x_test, y_test, a_test = load_dataset(dataset_name=dataset_name,
                                                                    data_base_path='../data/')
        for ot_type in ["sinkhorn"]:
            cond = f"sbm({ot_type})"

            L = load_LF(dataset_name, data_base_path='../data')
            if use_LIFT_embedding:
                x_embedding_train, x_embedding_test = load_LIFT_embedding(dataset_name=dataset_name,
                                                                            data_base_path='../data/')
                y_train= get_sbm_pseudolabel(L, x_embedding_train, a_train, dataset_name, seed=seed,
                                             ot_type=ot_type, diff_threshold=sbm_diff_threshold,
                                             use_LIFT_embedding=True)

            else:
                y_train= get_sbm_pseudolabel(L, x_train, a_train, dataset_name, seed=seed,
                                             ot_type=ot_type, diff_threshold=sbm_diff_threshold)

            # downstream task
            model = LogisticRegression(random_state=seed)
            model.fit(x_train, y_train)
            y_pred = model.predict(x_test)
            result = exp_eval(y_test, y_pred, a_test, cond=cond)
            result['seed'] = seed
            print(result)
            result_collection = result_collection.append(result, ignore_index=True)
            
    result_collection.to_csv(os.path.join('../rebuttal_results', f'01_{dataset_name}_embedding_sinkhorn.csv'))

computing sbm mapping...: 100%|██████████| 2693/2693 [14:46<00:00,  3.04it/s]
computing sbm mapping...: 100%|██████████| 5448/5448 [29:45<00:00,  3.05it/s]


SBM (sinkhorn) saved in ../data/adult/SBM_mapping/adult_embedding_SBM_mapping_sinkhorn_0->1_0.pt, ../data/adult/SBM_mapping/adult_embedding_SBM_mapping_sinkhorn_0->1_0.pt!


100%|██████████| 1000/1000 [00:00<00:00, 2727.55epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.7862539156071494, 'fscore': 0.45828144458281445, 'precision': 0.5709852598913887, 'recall': 0.3827353094123765, 'demographic_parity_gap': 0.006466612219810486, 'equal_opportunity_gap': 0.16454002261161804, 'seed': 0}


computing sbm mapping...: 100%|██████████| 2693/2693 [14:50<00:00,  3.02it/s]
computing sbm mapping...:  31%|███       | 1663/5448 [09:02<20:15,  3.11it/s]