Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fairlib/datasets/bios/bios.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def bert_encoding(self):
avg_data, cls_data = self.encoder.encode(text_data)
split_df["bert_avg_SE"] = list(avg_data)
split_df["bert_cls_SE"] = list(cls_data)
split_df["gender_class"] = split_df["g"]
split_df["gender_class"] = split_df["g"].map(gender2id)
split_df["profession_class"] = split_df["p"].map(professions2id)

split_df.to_pickle(Path(self.dest_folder) / "bios_{}_df.pkl".format(split))
Expand Down
3 changes: 1 addition & 2 deletions fairlib/datasets/utils/bert_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
from transformers import *
import pickle
from transformers import BertModel, BertTokenizer
from tqdm.auto import tqdm, trange

class BERT_encoder:
Expand Down
2 changes: 1 addition & 1 deletion fairlib/src/analysis/tables_and_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,4 @@ def make_zoom_plot(
ax.indicate_inset_zoom(axins, edgecolor="black")

if figure_name is not None:
fig.savefig(figure_name, dpi=960, bbox_inches="tight")
fig.savefig(figure_name+".pdf", format="pdf", dpi=960, bbox_inches="tight")
28 changes: 26 additions & 2 deletions fairlib/src/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .networks import adv
from .networks import FairCL
from .networks.DyBT import Group_Difference_Loss
from .networks.ARL import ARL

class State(object):

Expand Down Expand Up @@ -185,6 +186,8 @@ def __getattr__(self, name):
help='how many batches to wait before logging training status')
parser.add_argument('--save_batch_results', action='store_true', default=False,
help='if saving batch evaluation results')
parser.add_argument('--save_models', action='store_true', default=False,
help='if saving model parameters')
parser.add_argument('--checkpoint_interval', type=int, default=1, metavar='N',
help='checkpoint interval (epoch)')
parser.add_argument('--dataset', type=str, default='Moji',
Expand Down Expand Up @@ -282,6 +285,8 @@ def __getattr__(self, name):
# Gated adv
parser.add_argument('--adv_gated', action='store_true', default=False,
help='gated discriminator for augmented inputs given target labels')
parser.add_argument('--adv_gated_type', type=str, default="Augmentation",
help='Augmentation | Inputs | Separate')
parser.add_argument('--adv_BT', type=str, default=None, help='instacne reweighting for adv')
parser.add_argument('--adv_BTObj', type=str, default=None, help='instacne reweighting for adv')

Expand Down Expand Up @@ -338,6 +343,13 @@ def __getattr__(self, name):
parser.add_argument('--GBT_N', type=nonneg_int, default=None, help='size of the manipulated dataset')
parser.add_argument("--GBT_alpha", type=float, default=1, help="interpolation for generalized BT")

# ARL
parser.add_argument('--ARL', action='store_true', default=False,
help='Perform adversarial reweighted learning (ARL)')
parser.add_argument('--ARL_n',type=pos_int, default=1,
help='Update the adversary n times per main model update')


def get_dummy_state(self, *cmdargs, yaml_file=None, **opt_pairs):
if yaml_file is None:
# Use default Namespace (not UniqueNamespace) because dummy state may
Expand Down Expand Up @@ -515,8 +527,15 @@ def set_state(self, state, dummy=False, silence=False):

# Init discriminator for adversarial training
if state.adv_debiasing:

state.opt.discriminator = networks.adv.Discriminator(state)
# if state.adv_decoupling:
# raise NotImplementedError

if state.adv_gated and (state.adv_gated_type == "Separate"):
# Train a set of discriminators for each class
state.opt.discriminator = [networks.adv.Discriminator(state) for _ in range(state.num_classes)]
else:
# All other adv settings
state.opt.discriminator = networks.adv.Discriminator(state)
logging.info('Discriminator built!')
# adv.utils.print_network(state.opt.discriminator.subdiscriminators[0])

Expand All @@ -530,6 +549,11 @@ def set_state(self, state, dummy=False, silence=False):
if (state.DyBT is not None) and (state.DyBT == "GroupDifference"):
state.opt.group_difference_loss = Group_Difference_Loss(state)

# Init the ARL for unsupervised training
if state.ARL:
assert not state.adv_debiasing, "ARL is unsupervised bias mitigation, which cannot be used together with adversarial training"
state.opt.ARL_loss = ARL(state)

return state


Expand Down
10 changes: 5 additions & 5 deletions fairlib/src/dataloaders/BT.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
method (str, optional): Downsampling | Resampling. Defaults to "Downsampling".

Returns:
list: a list of indices of selected instacnes.
list: a list of indices of selected instances.
"""

# init a dict for storing the index of each group.
Expand Down Expand Up @@ -133,15 +133,15 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
weighting_counter = Counter(y)

# a list of (weights, actual length)
condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()])
candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()])

distinct_y_label = set(y)
distinct_g_label = set(protected_label)

# iterate each main task class
for y in distinct_y_label:
if method == "Downsampling":
selected = int(condidate_selected * weighting_counter[y])
selected = int(candidate_selected * weighting_counter[y])
elif method == "Resampling":
selected = int(weighting_counter[y] / len(distinct_g_label))
for g in distinct_g_label:
Expand All @@ -157,7 +157,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
weighting_counter = Counter(protected_label)
# a list of (weights, actual length)
# Noticing that if stratified_g, the order within the key has been changed.
condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()])
candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()])

distinct_y_label = set(y)
distinct_g_label = set(protected_label)
Expand All @@ -166,7 +166,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
# for y in distinct_y_label:
for g in distinct_g_label:
if method == "Downsampling":
selected = int(condidate_selected * weighting_counter[g])
selected = int(candidate_selected * weighting_counter[g])
elif method == "Resampling":
selected = int(weighting_counter[g] / len(distinct_y_label))
# for g in distinct_g_label:
Expand Down
2 changes: 1 addition & 1 deletion fairlib/src/dataloaders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def __init__(self, args) -> None:

def encoder(self, sample):
encodings = self.tokenizer(sample, truncation=True, padding=True)
return encodings["input_ids"]
return encodings["input_ids"], encodings['token_type_ids'], encodings['attention_mask']
6 changes: 5 additions & 1 deletion fairlib/src/dataloaders/loaders/Adult.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender
elif self.args.protected_task == "race":
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[_r+_s*5 for _r,_s in zip(list(data["race"]), list(data["sex"]))]
).astype(np.int32) # Intersectional
12 changes: 10 additions & 2 deletions fairlib/src/dataloaders/loaders/Bios.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ def load_data(self):

data = pd.read_pickle(Path(self.args.data_dir) / self.filename)

if self.args.protected_task in ["economy", "both"] and self.args.full_label:
# if self.args.protected_task in ["economy", "both"] and self.args.full_label:
if self.args.protected_task in ["gender", "economy", "both", "intersection"] and self.args.full_label:
selected_rows = (data["economy_label"] != "Unknown")
data = data[selected_rows]

if self.args.encoder_architecture == "Fixed":
self.X = list(data[self.embedding_type])
elif self.args.encoder_architecture == "BERT":
self.X = self.args.text_encoder.encoder(list(data[self.text_type]))
_input_ids, _token_type_ids, _attention_mask = self.args.text_encoder.encoder(list(data[self.text_type]))
self.X = _input_ids
self.addition_values["input_ids"] = _input_ids
self.addition_values['attention_mask'] = _attention_mask
else:
raise NotImplementedError

Expand All @@ -28,5 +32,9 @@ def load_data(self):
self.protected_label = data["gender_class"].astype(np.int32) # Gender
elif self.args.protected_task == "economy":
self.protected_label = data["economy_class"].astype(np.int32) # Economy
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[2*_e+_g for _e,_g in zip(list(data["economy_class"]), list(data["gender_class"]))]
).astype(np.int32) # Intersection
else:
self.protected_label = data["intersection_class"].astype(np.int32) # Intersection
6 changes: 5 additions & 1 deletion fairlib/src/dataloaders/loaders/COMPAS.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender
elif self.args.protected_task == "race":
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[_r+_s*3 for _r,_s in zip(list(data["race"]), list(data["sex"]))]
).astype(np.int32) # Intersectional
8 changes: 6 additions & 2 deletions fairlib/src/dataloaders/loaders/Trustpilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label = data["gender_label"].astype(np.int32) # Gender
elif self.args.protected_task == "age":
self.protected_label = data["age_label"].astype(np.int32) # Economy
self.protected_label = data["age_label"].astype(np.int32) # Age
elif self.args.protected_task == "country":
self.protected_label = data["country_label"].astype(np.int32) # Economy
self.protected_label = data["country_label"].astype(np.int32) # Country
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[4*_g+2*_a+_c for _g,_a,_c in zip(list(data["gender_label"]), list(data["age_label"]), data["country_label"])]
).astype(np.int32) # Intersection
else:
raise NotImplementedError
30 changes: 27 additions & 3 deletions fairlib/src/dataloaders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, args, split):
self.instance_weights = []
self.adv_instance_weights = []
self.regression_label = []
self.addition_values = {}

self.load_data()

Expand All @@ -50,7 +51,6 @@ def __init__(self, args, split):
if self.split == "train":
self.adv_decoupling()


print("Loaded data shapes: {}, {}, {}".format(self.X.shape, self.y.shape, self.protected_label.shape))

def __len__(self):
Expand All @@ -59,7 +59,25 @@ def __len__(self):

def __getitem__(self, index):
'Generates one sample of data'
return self.X[index], self.y[index], self.protected_label[index], self.instance_weights[index], self.adv_instance_weights[index], self.regression_label[index]
_X = self.X[index]
_y = self.y[index]
_protected_label = self.protected_label[index]
_instance_weights = self.instance_weights[index]
_adv_instance_weights = self.adv_instance_weights[index]
_regression_label = self.regression_label[index]

data_dict = {
0:_X,
1:_y,
2:_protected_label,
3:_instance_weights,
4:_adv_instance_weights,
5:_regression_label,
}
for _k in self.addition_values.keys():
if _k not in data_dict.keys():
data_dict[_k] = self.addition_values[_k][index]
return data_dict

def load_data(self):
pass
Expand All @@ -79,6 +97,9 @@ def manipulate_data_distribution(self):
self.y = self.y[selected_index]
self.protected_label = self.protected_label[selected_index]

for _k in self.addition_values.keys():
self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index]

def balanced_training(self):
if (self.args.BT is None) or (self.split != "train"):
# Without balanced training
Expand Down Expand Up @@ -112,6 +133,9 @@ def balanced_training(self):
self.protected_label = np.array(_protected_label)
self.instance_weights = np.array([1 for _ in range(len(self.protected_label))])

for _k in self.addition_values.keys():
self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index]

else:
raise NotImplementedError
return None
Expand Down Expand Up @@ -152,7 +176,7 @@ def adv_decoupling(self):
else:
pass
return None

def regression_init(self):
if not self.args.regression:
self.regression_label = np.array([0 for _ in range(len(self.protected_label))])
Expand Down
44 changes: 43 additions & 1 deletion fairlib/src/evaluators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,46 @@ def present_evaluation_scores(
validation_results = ["{}: {:2.2f}\t".format(k, 100.*valid_scores[k]) for k in valid_scores.keys()]
logging.info(('Validation {}').format("".join(validation_results)))
Test_results = ["{}: {:2.2f}\t".format(k, 100.*test_scores[k]) for k in test_scores.keys()]
logging.info(('Test {}').format("".join(Test_results)))
logging.info(('Test {}').format("".join(Test_results)))


def validation_is_best(
valid_preds, valid_labels, valid_private_labels,
model, epoch_valid_loss, selection_criterion = "DTO",
performance_metric = "accuracy", fairness_metric="TPR_GAP"
):
"""
Check is the current model is the best so far.
"""

is_best = False

valid_scores, _ = gap_eval_scores(
y_pred=valid_preds,
y_true=valid_labels,
protected_attribute=valid_private_labels,
args = model.args,
)

if selection_criterion == "DTO":
valid_dto_score = ((1-valid_scores[performance_metric])**2 + valid_scores[fairness_metric]**2)**0.5
if valid_dto_score < model.best_valid_loss:
model.best_valid_loss = valid_dto_score
is_best = True
elif selection_criterion == "Loss":
if epoch_valid_loss < model.best_valid_loss:
model.best_valid_loss = epoch_valid_loss
is_best = True
elif selection_criterion == "Performance":
if (1-valid_scores[performance_metric]) < model.best_valid_loss:
model.best_valid_loss = 1-valid_scores[performance_metric]
is_best = True
elif selection_criterion == "Fairness":
if valid_scores[fairness_metric] < model.best_valid_loss:
model.best_valid_loss = valid_scores[fairness_metric]
is_best = True
else:
raise NotImplementedError


return is_best
Loading