In [1]:
import pandas as pd
import numpy as np
import sklearn
import sys
import os
import torch
import tqdm
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../')
sys.path.append('../../')

from domino import DominoSlicer
from sklearn.linear_model import LogisticRegression
from hyperlm import HyperLabelModel as LabelModel

from fairws.data_util import load_wrench_dataset, load_LF
from fairws.metrics import exp_eval
from fairws.sbm import get_baseline_pseudolabel, get_sbm_pseudolabel, find_sbm_mapping, correct_bias
from fairws.data_util import load_dataset

In [2]:
# Hyper LM
lm = LabelModel()

# Configurations

In [3]:
data_base_path = '../data'
dataset_name = "census"
use_LIFT_embedding = False # only for adult, bank_marketing
sbm_diff_threshold = 0.05

result_collection = pd.DataFrame() # to keep results

# Fully supervised

In [4]:
cond = "fs"
x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

{'condition': 'fs', 'accuracy': 0.8456483017013697, 'fscore': 0.6341534430048041, 'precision': 0.7204763479986768, 'recall': 0.5663026521060842}


# WS baseline

In [5]:
cond = "ws_baseline"

x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
# weak supervision
L = load_LF(dataset_name, data_base_path='../data')
y_train = lm.infer(L)
# y_train = get_baseline_pseudolabel(L)

# downstream task
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

{'condition': 'ws_baseline', 'accuracy': 0.7998894416804865, 'fscore': 0.5506206896551723, 'precision': 0.5863689776733255, 'recall': 0.5189807592303692}


# Domino + SBM

In [6]:
x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
for ot_type in [None, "linear", "sinkhorn"]:
    cond = f"sbm({ot_type})"

    L = load_LF(dataset_name, data_base_path='../data')
    
    y_pseudo_train = lm.infer(L)
    y_train_proba = lm.infer(L, return_probs=True)
    slicer = DominoSlicer(n_slices=2, n_pca_components=min(x_train.shape[1], 128))
    slicer.fit(embeddings=x_train, targets=y_pseudo_train, pred_probs=y_train_proba)
    a_train = slicer.predict(
        data={'embedding':x_train, 'target':y_pseudo_train, 'pred_probs':y_train_proba},
            embeddings='embedding'
        )[:, 1]
    
    # apply sbm
    sbm_mapping = find_sbm_mapping(x_train, a_train, ot_type)
    L = correct_bias(L, a_train, sbm_mapping, sbm_diff_threshold)       
    y_train = lm.infer(L)
    
    # downstream task
    model = LogisticRegression()
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
    print(result)
    result_collection = result_collection.append(result, ignore_index=True)


 22%|[38;2;241;122;74m██▏       [0m| 22/100 [00:00<00:01, 57.60it/s]


{'condition': 'sbm(None)', 'accuracy': 0.8005036545666728, 'fscore': 0.5636754433100484, 'precision': 0.5831017231795442, 'recall': 0.5455018200728029}


 41%|[38;2;241;122;74m████      [0m| 41/100 [00:00<00:01, 54.09it/s]


{'condition': 'sbm(linear)', 'accuracy': 0.7981696455991647, 'fscore': 0.5417015341701534, 'precision': 0.5842358604091457, 'recall': 0.5049401976079043}


 52%|[38;2;241;122;74m█████▏    [0m| 52/100 [00:00<00:00, 55.77it/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.8067686260057736, 'fscore': 0.5788487282463185, 'precision': 0.5965783664459161, 'recall': 0.5621424856994279}


# Result summary

In [7]:
result_collection

Unnamed: 0,condition,accuracy,fscore,precision,recall
0,fs,0.845648,0.634153,0.720476,0.566303
1,ws_baseline,0.799889,0.550621,0.586369,0.518981
2,sbm(None),0.800504,0.563675,0.583102,0.545502
3,sbm(linear),0.79817,0.541702,0.584236,0.50494
4,sbm(sinkhorn),0.806769,0.578849,0.596578,0.562142
