In [1]:
import pandas as pd
import numpy as np
import sklearn
import sys
import os
import torch
import tqdm
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../')
sys.path.append('../../')

from fairws.data_util import load_dataset, load_LIFT_embedding, load_LF
from sklearn.linear_model import LogisticRegression
from fairws.metrics import exp_eval
from fairws.sbm import get_baseline_pseudolabel, get_sbm_pseudolabel

# Configurations

In [2]:
dataset_name = "bank_marketing" # adult | bank_marketing | CivilComments | hateXplain | CelebA | UTKFace
use_LIFT_embedding = True # only for adult, bank_marketing
sbm_diff_threshold = 0.05

result_collection = pd.DataFrame() # to keep results

# Fully supervised

In [3]:
cond = "fs"
x_train, y_train, a_train, x_test, y_test, a_test = load_dataset(dataset_name=dataset_name,
                                                                    data_base_path='../data')
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, a_test, cond=cond)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

{'condition': 'fs', 'accuracy': 0.9121955167111758, 'fscore': 0.517992003553976, 'precision': 0.6755504055619931, 'recall': 0.420028818443804, 'demographic_parity_gap': 0.12781807780265808, 'equal_opportunity_gap': 0.11709368228912354}


# WS baseline

In [4]:
cond = "ws_baseline"

x_train, y_train, a_train, x_test, y_test, a_test = load_dataset(dataset_name=dataset_name,
                                                                    data_base_path='../data')

# weak supervision
L = load_LF(dataset_name, data_base_path='../data')
y_train = get_baseline_pseudolabel(L)

# downstream task
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, a_test, cond=cond)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

100%|██████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1995.22epoch/s]


{'condition': 'ws_baseline', 'accuracy': 0.6809905316824472, 'fscore': 0.2753676470588235, 'precision': 0.18484698914116485, 'recall': 0.5396253602305475, 'demographic_parity_gap': 0.5547482967376709, 'equal_opportunity_gap': 0.42558276653289795}


# SBM

In [None]:
x_train, y_train, a_train, x_test, y_test, a_test = load_dataset(dataset_name=dataset_name,
                                                                    data_base_path='../data')

for ot_type in [None, "linear", "sinkhorn"]:
    cond = f"sbm({ot_type})"

    L = load_LF(dataset_name, data_base_path='../data')
    if use_LIFT_embedding:
        x_embedding_train, x_embedding_test = load_LIFT_embedding(dataset_name=dataset_name,
                                                                    data_base_path='../data')
        y_train= get_sbm_pseudolabel(L, x_embedding_train, a_train, dataset_name, 
                                     ot_type=ot_type, diff_threshold=sbm_diff_threshold,
                                     use_LIFT_embedding=True)
        
    else:
        y_train= get_sbm_pseudolabel(L, x_train, a_train, dataset_name, 
                                     ot_type=ot_type, diff_threshold=sbm_diff_threshold)
    
    # downstream task
    model = LogisticRegression()
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    result = exp_eval(y_test, y_pred, a_test, cond=cond)
    print(result)
    result_collection = result_collection.append(result, ignore_index=True)


computing sbm mapping...: 100%|███████████████████████████████████████████████████████| 188/188 [00:25<00:00,  7.49it/s]
computing sbm mapping...: 100%|█████████████████████████████████████████████████████| 7021/7021 [15:18<00:00,  7.64it/s]


SBM (None) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_embedding_SBM_mapping_None_0->1.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_embedding_SBM_mapping_None_0->1.pt!


100%|██████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2890.54epoch/s]


{'condition': 'sbm(None)', 'accuracy': 0.7130371449380918, 'fscore': 0.26765799256505574, 'precision': 0.18760856977417487, 'recall': 0.4668587896253602, 'demographic_parity_gap': 0.07025843858718872, 'equal_opportunity_gap': 0.09291419386863708}


computing sbm mapping...: 100%|███████████████████████████████████████████████████████| 188/188 [00:58<00:00,  3.23it/s]
computing sbm mapping...: 100%|█████████████████████████████████████████████████████| 7021/7021 [34:50<00:00,  3.36it/s]


SBM (linear) saved in ../data/bank_marketing/SBM_mapping/bank_marketing_embedding_SBM_mapping_linear_0->1.pt, ../data/bank_marketing/SBM_mapping/bank_marketing_embedding_SBM_mapping_linear_0->1.pt!


100%|██████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2455.86epoch/s]


{'condition': 'sbm(linear)', 'accuracy': 0.8918831431577243, 'fscore': 0.30489073881373574, 'precision': 0.548689138576779, 'recall': 0.21109510086455333, 'demographic_parity_gap': 0.10366232693195343, 'equal_opportunity_gap': 0.12146978080272675}


computing sbm mapping...: 100%|███████████████████████████████████████████████████████| 188/188 [00:56<00:00,  3.30it/s]
computing sbm mapping...:  69%|████████████████████████████████████▍                | 4824/7021 [23:54<10:28,  3.49it/s]

# Result summary

In [7]:
result_collection

Unnamed: 0,condition,accuracy,fscore,precision,recall,demographic_parity_gap,equal_opportunity_gap
0,fs,0.912196,0.517992,0.67555,0.420029,0.127818,0.117094
1,ws_baseline,0.680991,0.275368,0.184847,0.539625,0.554748,0.425583
2,sbm(None),0.713037,0.267658,0.187609,0.466859,0.070258,0.092914
3,sbm(linear),0.891883,0.304891,0.548689,0.211095,0.103662,0.12147
4,sbm(sinkhorn),0.698147,0.079467,0.060435,0.115994,0.108674,0.072114
