In [1]:
import pandas as pd
import numpy as np
import sklearn
import sys
import os
import torch
import tqdm
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../')
sys.path.append('../../')

from domino import DominoSlicer
from sklearn.linear_model import LogisticRegression
from snorkel.labeling.model import LabelModel

from fairws.data_util import load_wrench_dataset, load_LF
from fairws.metrics import exp_eval
from fairws.sbm import get_baseline_pseudolabel, get_sbm_pseudolabel

Note: the current version of domino has a bug, which can be fixed modifying \<domino_path\>/_embed/\_\_init\_\_.py

```def infer_modality(col: mk.AbstractColumn)``` --> ```def infer_modality(col: mk.Column)```


# Configurations

In [2]:
data_base_path = '../data'
dataset_name = "census" # adult | bank_marketing | CivilComments | hateXplain | CelebA | UTKFace
use_LIFT_embedding = False # only for adult, bank_marketing
sbm_diff_threshold = 0.05

result_collection = pd.DataFrame() # to keep results

# Fully supervised

In [3]:
cond = "fs"
x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

{'condition': 'fs', 'accuracy': 0.8456483017013697, 'fscore': 0.6341534430048041, 'precision': 0.7204763479986768, 'recall': 0.5663026521060842}


# WS baseline

In [4]:
cond = "ws_baseline"

x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
# weak supervision
L = load_LF(dataset_name, data_base_path='../data')
y_train = get_baseline_pseudolabel(L)

# downstream task
model = LogisticRegression()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
print(result)

result_collection = result_collection.append(result, ignore_index=True)

100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 167.28epoch/s]


{'condition': 'ws_baseline', 'accuracy': 0.6017443645967693, 'fscore': 0.511820508959494, 'precision': 0.3602161933022467, 'recall': 0.8837753510140406}


# Domino + SBM

In [6]:
x_train, y_train, x_test, y_test = load_wrench_dataset(dataset_name=dataset_name,
                                                       data_base_path='../data')
for ot_type in [None, "linear", "sinkhorn"]:
    cond = f"sbm({ot_type})"

    L = load_LF(dataset_name, data_base_path='../data')
    
    # get a_train using domino
    label_model = LabelModel(cardinality=2, verbose=False)
    label_model.fit(L_train=L,
                    n_epochs=1000, log_freq=100, seed=123)

    y_pseudo_train = label_model.predict(L, tie_break_policy="random")
    y_train_proba = label_model.predict_proba(L)
    slicer = DominoSlicer(n_slices=2, n_pca_components=min(x_train.shape[1], 128))
    slicer.fit(embeddings=x_train, targets=y_pseudo_train, pred_probs=y_train_proba)
    a_train = slicer.predict(
        data={'embedding':x_train, 'target':y_pseudo_train, 'pred_probs':y_train_proba},
            embeddings='embedding'
        )[:, 1]
    
    # apply sbm
    y_train= get_sbm_pseudolabel(L, x_train, a_train, dataset_name, 
                                 ot_type=ot_type, diff_threshold=sbm_diff_threshold)
    
    # downstream task
    model = LogisticRegression()
    model.fit(x_train, y_train)
    y_pred = model.predict(x_test)
    result = exp_eval(y_test, y_pred, cond=cond, fairness=False)
    print(result)
    result_collection = result_collection.append(result, ignore_index=True)


100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 166.54epoch/s]


  0%|          | 0/100 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 159.93epoch/s]


{'condition': 'sbm(None)', 'accuracy': 0.7322646029113691, 'fscore': 0.54626834599771, 'precision': 0.45547647977781636, 'recall': 0.6822672906916276}


100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 148.95epoch/s]


  0%|          | 0/100 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 190.58epoch/s]


{'condition': 'sbm(linear)', 'accuracy': 0.7142681653461089, 'fscore': 0.47064178425125175, 'precision': 0.41845406717927963, 'recall': 0.5377015080603225}


100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 170.62epoch/s]


  0%|          | 0/100 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 159.08epoch/s]


{'condition': 'sbm(sinkhorn)', 'accuracy': 0.678398132792826, 'fscore': 0.3693086003372681, 'precision': 0.34403052064631956, 'recall': 0.3985959438377535}


# Result summary

In [7]:
result_collection

Unnamed: 0,condition,accuracy,fscore,precision,recall
0,fs,0.845648,0.634153,0.720476,0.566303
1,ws_baseline,0.601744,0.511821,0.360216,0.883775
2,sbm(None),0.732265,0.546268,0.455476,0.682267
3,sbm(linear),0.714268,0.470642,0.418454,0.537702
4,sbm(sinkhorn),0.678398,0.369309,0.344031,0.398596
