diff --git a/fairlib/datasets/bios/bios.py b/fairlib/datasets/bios/bios.py index a64f6d6..3ec10a1 100644 --- a/fairlib/datasets/bios/bios.py +++ b/fairlib/datasets/bios/bios.py @@ -43,7 +43,7 @@ def bert_encoding(self): avg_data, cls_data = self.encoder.encode(text_data) split_df["bert_avg_SE"] = list(avg_data) split_df["bert_cls_SE"] = list(cls_data) - split_df["gender_class"] = split_df["g"] + split_df["gender_class"] = split_df["g"].map(gender2id) split_df["profession_class"] = split_df["p"].map(professions2id) split_df.to_pickle(Path(self.dest_folder) / "bios_{}_df.pkl".format(split)) diff --git a/fairlib/datasets/utils/bert_encoding.py b/fairlib/datasets/utils/bert_encoding.py index 00a8771..c85bbb5 100644 --- a/fairlib/datasets/utils/bert_encoding.py +++ b/fairlib/datasets/utils/bert_encoding.py @@ -1,7 +1,6 @@ import numpy as np import torch -from transformers import * -import pickle +from transformers import BertModel, BertTokenizer from tqdm.auto import tqdm, trange class BERT_encoder: diff --git a/fairlib/src/analysis/tables_and_figures.py b/fairlib/src/analysis/tables_and_figures.py index e846a81..488e8d4 100644 --- a/fairlib/src/analysis/tables_and_figures.py +++ b/fairlib/src/analysis/tables_and_figures.py @@ -504,4 +504,4 @@ def make_zoom_plot( ax.indicate_inset_zoom(axins, edgecolor="black") if figure_name is not None: - fig.savefig(figure_name, dpi=960, bbox_inches="tight") \ No newline at end of file + fig.savefig(figure_name+".pdf", format="pdf", dpi=960, bbox_inches="tight") \ No newline at end of file diff --git a/fairlib/src/base_options.py b/fairlib/src/base_options.py index 4400dae..f816a47 100644 --- a/fairlib/src/base_options.py +++ b/fairlib/src/base_options.py @@ -16,6 +16,7 @@ from .networks import adv from .networks import FairCL from .networks.DyBT import Group_Difference_Loss +from .networks.ARL import ARL class State(object): @@ -185,6 +186,8 @@ def __getattr__(self, name): help='how many batches to wait before logging training status') parser.add_argument('--save_batch_results', action='store_true', default=False, help='if saving batch evaluation results') + parser.add_argument('--save_models', action='store_true', default=False, + help='if saving model parameters') parser.add_argument('--checkpoint_interval', type=int, default=1, metavar='N', help='checkpoint interval (epoch)') parser.add_argument('--dataset', type=str, default='Moji', @@ -282,6 +285,8 @@ def __getattr__(self, name): # Gated adv parser.add_argument('--adv_gated', action='store_true', default=False, help='gated discriminator for augmented inputs given target labels') + parser.add_argument('--adv_gated_type', type=str, default="Augmentation", + help='Augmentation | Inputs | Separate') parser.add_argument('--adv_BT', type=str, default=None, help='instacne reweighting for adv') parser.add_argument('--adv_BTObj', type=str, default=None, help='instacne reweighting for adv') @@ -338,6 +343,13 @@ def __getattr__(self, name): parser.add_argument('--GBT_N', type=nonneg_int, default=None, help='size of the manipulated dataset') parser.add_argument("--GBT_alpha", type=float, default=1, help="interpolation for generalized BT") + # ARL + parser.add_argument('--ARL', action='store_true', default=False, + help='Perform adversarial reweighted learning (ARL)') + parser.add_argument('--ARL_n',type=pos_int, default=1, + help='Update the adversary n times per main model update') + + def get_dummy_state(self, *cmdargs, yaml_file=None, **opt_pairs): if yaml_file is None: # Use default Namespace (not UniqueNamespace) because dummy state may @@ -515,8 +527,15 @@ def set_state(self, state, dummy=False, silence=False): # Init discriminator for adversarial training if state.adv_debiasing: - - state.opt.discriminator = networks.adv.Discriminator(state) + # if state.adv_decoupling: + # raise NotImplementedError + + if state.adv_gated and (state.adv_gated_type == "Separate"): + # Train a set of discriminators for each class + state.opt.discriminator = [networks.adv.Discriminator(state) for _ in range(state.num_classes)] + else: + # All other adv settings + state.opt.discriminator = networks.adv.Discriminator(state) logging.info('Discriminator built!') # adv.utils.print_network(state.opt.discriminator.subdiscriminators[0]) @@ -530,6 +549,11 @@ def set_state(self, state, dummy=False, silence=False): if (state.DyBT is not None) and (state.DyBT == "GroupDifference"): state.opt.group_difference_loss = Group_Difference_Loss(state) + # Init the ARL for unsupervised training + if state.ARL: + assert not state.adv_debiasing, "ARL is unsupervised bias mitigation, which cannot be used together with adversarial training" + state.opt.ARL_loss = ARL(state) + return state diff --git a/fairlib/src/dataloaders/BT.py b/fairlib/src/dataloaders/BT.py index 5b9a8d8..475be81 100644 --- a/fairlib/src/dataloaders/BT.py +++ b/fairlib/src/dataloaders/BT.py @@ -76,7 +76,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"): method (str, optional): Downsampling | Resampling. Defaults to "Downsampling". Returns: - list: a list of indices of selected instacnes. + list: a list of indices of selected instances. """ # init a dict for storing the index of each group. @@ -133,7 +133,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"): weighting_counter = Counter(y) # a list of (weights, actual length) - condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()]) + candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()]) distinct_y_label = set(y) distinct_g_label = set(protected_label) @@ -141,7 +141,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"): # iterate each main task class for y in distinct_y_label: if method == "Downsampling": - selected = int(condidate_selected * weighting_counter[y]) + selected = int(candidate_selected * weighting_counter[y]) elif method == "Resampling": selected = int(weighting_counter[y] / len(distinct_g_label)) for g in distinct_g_label: @@ -157,7 +157,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"): weighting_counter = Counter(protected_label) # a list of (weights, actual length) # Noticing that if stratified_g, the order within the key has been changed. - condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()]) + candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()]) distinct_y_label = set(y) distinct_g_label = set(protected_label) @@ -166,7 +166,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"): # for y in distinct_y_label: for g in distinct_g_label: if method == "Downsampling": - selected = int(condidate_selected * weighting_counter[g]) + selected = int(candidate_selected * weighting_counter[g]) elif method == "Resampling": selected = int(weighting_counter[g] / len(distinct_y_label)) # for g in distinct_g_label: diff --git a/fairlib/src/dataloaders/encoder.py b/fairlib/src/dataloaders/encoder.py index db04fdd..4910223 100644 --- a/fairlib/src/dataloaders/encoder.py +++ b/fairlib/src/dataloaders/encoder.py @@ -14,4 +14,4 @@ def __init__(self, args) -> None: def encoder(self, sample): encodings = self.tokenizer(sample, truncation=True, padding=True) - return encodings["input_ids"] + return encodings["input_ids"], encodings['token_type_ids'], encodings['attention_mask'] diff --git a/fairlib/src/dataloaders/loaders/Adult.py b/fairlib/src/dataloaders/loaders/Adult.py index 51158e8..4cab74c 100644 --- a/fairlib/src/dataloaders/loaders/Adult.py +++ b/fairlib/src/dataloaders/loaders/Adult.py @@ -18,4 +18,8 @@ def load_data(self): if self.args.protected_task == "gender": self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender elif self.args.protected_task == "race": - self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race \ No newline at end of file + self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race + elif self.args.protected_task == "intersection": + self.protected_label = np.array( + [_r+_s*5 for _r,_s in zip(list(data["race"]), list(data["sex"]))] + ).astype(np.int32) # Intersectional \ No newline at end of file diff --git a/fairlib/src/dataloaders/loaders/Bios.py b/fairlib/src/dataloaders/loaders/Bios.py index f5ca371..690fbcc 100644 --- a/fairlib/src/dataloaders/loaders/Bios.py +++ b/fairlib/src/dataloaders/loaders/Bios.py @@ -12,14 +12,18 @@ def load_data(self): data = pd.read_pickle(Path(self.args.data_dir) / self.filename) - if self.args.protected_task in ["economy", "both"] and self.args.full_label: + # if self.args.protected_task in ["economy", "both"] and self.args.full_label: + if self.args.protected_task in ["gender", "economy", "both", "intersection"] and self.args.full_label: selected_rows = (data["economy_label"] != "Unknown") data = data[selected_rows] if self.args.encoder_architecture == "Fixed": self.X = list(data[self.embedding_type]) elif self.args.encoder_architecture == "BERT": - self.X = self.args.text_encoder.encoder(list(data[self.text_type])) + _input_ids, _token_type_ids, _attention_mask = self.args.text_encoder.encoder(list(data[self.text_type])) + self.X = _input_ids + self.addition_values["input_ids"] = _input_ids + self.addition_values['attention_mask'] = _attention_mask else: raise NotImplementedError @@ -28,5 +32,9 @@ def load_data(self): self.protected_label = data["gender_class"].astype(np.int32) # Gender elif self.args.protected_task == "economy": self.protected_label = data["economy_class"].astype(np.int32) # Economy + elif self.args.protected_task == "intersection": + self.protected_label = np.array( + [2*_e+_g for _e,_g in zip(list(data["economy_class"]), list(data["gender_class"]))] + ).astype(np.int32) # Intersection else: self.protected_label = data["intersection_class"].astype(np.int32) # Intersection \ No newline at end of file diff --git a/fairlib/src/dataloaders/loaders/COMPAS.py b/fairlib/src/dataloaders/loaders/COMPAS.py index 92a9dc7..677b7a5 100644 --- a/fairlib/src/dataloaders/loaders/COMPAS.py +++ b/fairlib/src/dataloaders/loaders/COMPAS.py @@ -18,4 +18,8 @@ def load_data(self): if self.args.protected_task == "gender": self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender elif self.args.protected_task == "race": - self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race \ No newline at end of file + self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race + elif self.args.protected_task == "intersection": + self.protected_label = np.array( + [_r+_s*3 for _r,_s in zip(list(data["race"]), list(data["sex"]))] + ).astype(np.int32) # Intersectional \ No newline at end of file diff --git a/fairlib/src/dataloaders/loaders/Trustpilot.py b/fairlib/src/dataloaders/loaders/Trustpilot.py index b260176..ffaaaad 100644 --- a/fairlib/src/dataloaders/loaders/Trustpilot.py +++ b/fairlib/src/dataloaders/loaders/Trustpilot.py @@ -23,8 +23,12 @@ def load_data(self): if self.args.protected_task == "gender": self.protected_label = data["gender_label"].astype(np.int32) # Gender elif self.args.protected_task == "age": - self.protected_label = data["age_label"].astype(np.int32) # Economy + self.protected_label = data["age_label"].astype(np.int32) # Age elif self.args.protected_task == "country": - self.protected_label = data["country_label"].astype(np.int32) # Economy + self.protected_label = data["country_label"].astype(np.int32) # Country + elif self.args.protected_task == "intersection": + self.protected_label = np.array( + [4*_g+2*_a+_c for _g,_a,_c in zip(list(data["gender_label"]), list(data["age_label"]), data["country_label"])] + ).astype(np.int32) # Intersection else: raise NotImplementedError \ No newline at end of file diff --git a/fairlib/src/dataloaders/utils.py b/fairlib/src/dataloaders/utils.py index b24caad..3e2a888 100644 --- a/fairlib/src/dataloaders/utils.py +++ b/fairlib/src/dataloaders/utils.py @@ -30,6 +30,7 @@ def __init__(self, args, split): self.instance_weights = [] self.adv_instance_weights = [] self.regression_label = [] + self.addition_values = {} self.load_data() @@ -50,7 +51,6 @@ def __init__(self, args, split): if self.split == "train": self.adv_decoupling() - print("Loaded data shapes: {}, {}, {}".format(self.X.shape, self.y.shape, self.protected_label.shape)) def __len__(self): @@ -59,7 +59,25 @@ def __len__(self): def __getitem__(self, index): 'Generates one sample of data' - return self.X[index], self.y[index], self.protected_label[index], self.instance_weights[index], self.adv_instance_weights[index], self.regression_label[index] + _X = self.X[index] + _y = self.y[index] + _protected_label = self.protected_label[index] + _instance_weights = self.instance_weights[index] + _adv_instance_weights = self.adv_instance_weights[index] + _regression_label = self.regression_label[index] + + data_dict = { + 0:_X, + 1:_y, + 2:_protected_label, + 3:_instance_weights, + 4:_adv_instance_weights, + 5:_regression_label, + } + for _k in self.addition_values.keys(): + if _k not in data_dict.keys(): + data_dict[_k] = self.addition_values[_k][index] + return data_dict def load_data(self): pass @@ -79,6 +97,9 @@ def manipulate_data_distribution(self): self.y = self.y[selected_index] self.protected_label = self.protected_label[selected_index] + for _k in self.addition_values.keys(): + self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index] + def balanced_training(self): if (self.args.BT is None) or (self.split != "train"): # Without balanced training @@ -112,6 +133,9 @@ def balanced_training(self): self.protected_label = np.array(_protected_label) self.instance_weights = np.array([1 for _ in range(len(self.protected_label))]) + for _k in self.addition_values.keys(): + self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index] + else: raise NotImplementedError return None @@ -152,7 +176,7 @@ def adv_decoupling(self): else: pass return None - + def regression_init(self): if not self.args.regression: self.regression_label = np.array([0 for _ in range(len(self.protected_label))]) diff --git a/fairlib/src/evaluators/__init__.py b/fairlib/src/evaluators/__init__.py index 7183c04..f83555a 100644 --- a/fairlib/src/evaluators/__init__.py +++ b/fairlib/src/evaluators/__init__.py @@ -56,4 +56,46 @@ def present_evaluation_scores( validation_results = ["{}: {:2.2f}\t".format(k, 100.*valid_scores[k]) for k in valid_scores.keys()] logging.info(('Validation {}').format("".join(validation_results))) Test_results = ["{}: {:2.2f}\t".format(k, 100.*test_scores[k]) for k in test_scores.keys()] - logging.info(('Test {}').format("".join(Test_results))) \ No newline at end of file + logging.info(('Test {}').format("".join(Test_results))) + + +def validation_is_best( + valid_preds, valid_labels, valid_private_labels, + model, epoch_valid_loss, selection_criterion = "DTO", + performance_metric = "accuracy", fairness_metric="TPR_GAP" + ): + """ + Check is the current model is the best so far. + """ + + is_best = False + + valid_scores, _ = gap_eval_scores( + y_pred=valid_preds, + y_true=valid_labels, + protected_attribute=valid_private_labels, + args = model.args, + ) + + if selection_criterion == "DTO": + valid_dto_score = ((1-valid_scores[performance_metric])**2 + valid_scores[fairness_metric]**2)**0.5 + if valid_dto_score < model.best_valid_loss: + model.best_valid_loss = valid_dto_score + is_best = True + elif selection_criterion == "Loss": + if epoch_valid_loss < model.best_valid_loss: + model.best_valid_loss = epoch_valid_loss + is_best = True + elif selection_criterion == "Performance": + if (1-valid_scores[performance_metric]) < model.best_valid_loss: + model.best_valid_loss = 1-valid_scores[performance_metric] + is_best = True + elif selection_criterion == "Fairness": + if valid_scores[fairness_metric] < model.best_valid_loss: + model.best_valid_loss = valid_scores[fairness_metric] + is_best = True + else: + raise NotImplementedError + + + return is_best \ No newline at end of file diff --git a/fairlib/src/evaluators/leakage_metrices.py b/fairlib/src/evaluators/leakage_metrices.py new file mode 100644 index 0000000..18724bf --- /dev/null +++ b/fairlib/src/evaluators/leakage_metrices.py @@ -0,0 +1,105 @@ +from sklearn.metrics import accuracy_score +from sklearn.metrics import roc_auc_score +from sklearn.metrics import average_precision_score +from sklearn.metrics import f1_score +from sklearn.metrics import confusion_matrix +import numpy as np + +from sklearn.svm import LinearSVC +from sklearn.linear_model import SGDClassifier, LogisticRegression +from sklearn.neural_network import MLPClassifier +from sklearn.utils import shuffle + +from collections import defaultdict, Counter + + +def leakage_evaluation(model, + adv_level, + training_generator, + validation_generator, + test_generator, + device, + augmentation = False): + model.eval() + model.adv_level = adv_level + + train_hidden = [] + train_labels = [] + train_private_labels = [] + + for batch in training_generator: + text = batch[0] + tags = batch[1] + p_tags = batch[2] + + train_labels += list(tags.cpu().numpy() ) + train_private_labels += list(p_tags.cpu().numpy()) + + text = text.to(device) + p_tags = p_tags.to(device) + + # hidden_state = model.hidden(text) + if augmentation: + hidden_state = model.hidden(text, p_tags) + else: + hidden_state = model.hidden(text) + train_hidden.append(hidden_state.detach().cpu().numpy()) + train_hidden = np.concatenate(train_hidden, 0) + + dev_hidden = [] + dev_labels = [] + dev_private_labels = [] + + for batch in validation_generator: + text = batch[0] + tags = batch[1] + p_tags = batch[2] + + dev_labels += list(tags.cpu().numpy() ) + dev_private_labels += list(p_tags.cpu().numpy()) + + text = text.to(device) + p_tags = p_tags.to(device) + + # hidden_state = model.hidden(text) + if augmentation: + hidden_state = model.hidden(text, p_tags) + else: + hidden_state = model.hidden(text) + dev_hidden.append(hidden_state.detach().cpu().numpy()) + dev_hidden = np.concatenate(dev_hidden, 0) + + test_hidden = [] + test_labels = [] + test_private_labels = [] + + for batch in test_generator: + text = batch[0] + tags = batch[1] + p_tags = batch[2] + + test_labels += list(tags.cpu().numpy() ) + test_private_labels += list(p_tags.cpu().numpy()) + + text = text.to(device) + p_tags = p_tags.to(device) + + # hidden_state = model.hidden(text) + if augmentation: + hidden_state = model.hidden(text, p_tags) + else: + hidden_state = model.hidden(text) + test_hidden.append(hidden_state.detach().cpu().numpy()) + test_hidden = np.concatenate(test_hidden, 0) + + + biased_classifier = LinearSVC(fit_intercept=True, class_weight='balanced', dual=False, C=0.1, max_iter=10000) + # biased_classifier = MLPClassifier(max_iter=50, batch_size=1024) + # biased_classifier.fit(train_hidden, train_private_labels) + biased_classifier.fit(dev_hidden, dev_private_labels) + # dev_leakage = biased_classifier.score(dev_hidden, dev_private_labels) + dev_leakage = biased_classifier.score(dev_hidden, dev_private_labels) + test_leakage = biased_classifier.score(test_hidden, test_private_labels) + # print("Dev Accuracy: {}".format(dev_leakage)) + # print("Test Accuracy: {}".format(test_leakage)) + return test_leakage diff --git a/fairlib/src/evaluators/utils.py b/fairlib/src/evaluators/utils.py index c7586c8..27cbc87 100644 --- a/fairlib/src/evaluators/utils.py +++ b/fairlib/src/evaluators/utils.py @@ -45,16 +45,16 @@ def save_checkpoint( _state = { 'epoch': epoch, 'epochs_since_improvement': epochs_since_improvement, - # 'model': model.state_dict(), 'loss': loss, - # 'dev_predictions': dev_predictions, - # 'test_predictions': test_predictions, "valid_confusion_matrices" : valid_confusion_matrices, "test_confusion_matrices" : test_confusion_matrices, 'dev_evaluations': dev_evaluations, 'test_evaluations': test_evaluations } + if model.args.save_models: + _state["model"] = model.state_dict() + if dev_predictions is not None: _state["dev_predictions"] = dev_predictions if test_predictions is not None: diff --git a/fairlib/src/networks/ARL/ARL.py b/fairlib/src/networks/ARL/ARL.py new file mode 100644 index 0000000..a8e40fc --- /dev/null +++ b/fairlib/src/networks/ARL/ARL.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import logging +from ..adv.utils import BaseDiscriminator +from ..augmentation_layer import Augmentation_layer + + + +class ARL(BaseDiscriminator): + def __init__(self, args): + """The class for ARL training + This class is similar to the standard adversary for adversarial training, so it is + implemented as a child class of the BaseDiscriminator, which is the same the other + sub-discriminators. Main differences includes: + 1. the forward function of ARL returns the predicted weights of each instance. + 2. the ARL is trained to predict weights such that the weighted loss of the main model is maximized. + + + Args: + args (_type_): _description_ + + Raises: + NotImplementedError: _description_ + """ + super(ARL, self).__init__() + self.args = args + assert args.adv_n_hidden >= 0, "n_hidden must be nonnegative" + + assert self.args.adv_level in ["input", "last_hidden", "output"] + if self.args.adv_level == "input": + self.input_dim = self.args.emb_size + elif self.args.adv_level == "last_hidden": + if self.args.n_hidden == 0: + self.input_dim = self.args.emb_size + else: + self.input_dim = self.args.hidden_size + elif self.args.adv_level == "output": + self.input_dim = self.args.num_classes + else: + pass + + # Add the onehot target to the input of adv + if self.args.adv_gated and self.args.adv_gated_type == "Inputs": + self.input_dim = self.input_dim + self.args.num_classes + # One-hot mapping for the class + self.mapping = torch.eye(self.args.num_classes, requires_grad=False) + self.mapping = self.mapping.to(self.args.device) + + # The output dim of the layer is 1, correponding to the weight for each instance. + if args.adv_n_hidden == 0: + self.output_layer = nn.Linear(self.input_dim, 1) + + else: + self.output_layer = nn.Linear(args.adv_hidden_size, 1) + + # Init batch norm, dropout, and activation function + self.init_hyperparameters() + # Hidden layers + self.hidden_layers = self.init_hidden_layers() + + # Augmentation layers + if self.args.adv_gated and self.args.adv_gated_type == "Augmentation": + if self.args.adv_n_hidden == 0: + logging.info("Gated component requires at least one hidden layers in the model") + pass + else: + # Init the mapping for the augmentation layer + if self.args.adv_gated_mapping is None: + # For each class init a discriminator component + self.mapping = torch.eye(self.args.num_classes, requires_grad=False) + else: + # self.mapping = torch.from_numpy(mapping, requires_grad=False) + raise NotImplementedError + + # Init the augmentation layer + self.augmentation_components = Augmentation_layer( + mapping=self.mapping, + num_component=self.args.num_classes, + device=self.args.device, + sample_component=self.hidden_layers + ) + + self.init_for_training() + self.sigmoid = nn.Sigmoid() + + if self.args.regression: + self.criterion = torch.nn.MSELoss(reduction = "none") + else: + self.criterion = torch.nn.CrossEntropyLoss(reduction = "none") + + def forward(self, input_data, group_label = None): + if (self.args.adv_gated): + assert group_label is not None, "Group labels are needed for augmentation" + + if (self.args.adv_gated_type == "Inputs"): + # Get one hot representations of y + onehot_y = self.mapping[group_label.long()] + # Concat + input_data = torch.cat([input_data, onehot_y], dim=1) + + # Main model + main_output = input_data + for layer in self.hidden_layers: + main_output = layer(main_output) + + # Augmentation + if (self.args.adv_gated) and ( + self.args.adv_n_hidden > 0) and ( + self.args.adv_gated_type == "Augmentation"): + + specific_output = self.augmentation_components(input_data, group_label) + + main_output = main_output + specific_output + + main_output = self.output_layer(main_output) + # Normalize outputs + main_output = self.sigmoid(main_output) + main_output = main_output / torch.mean(main_output) + main_output = main_output + torch.ones_like(main_output) + + return main_output + + # Train the discriminator 1 batch + def step(self, hs, preds, tags, args, train_step=True): + """train the discriminator one batch + + Args: + model (torch.nn.Module): the main task model + batch (tuple): bach data, including inputs, target labels, protected labels, etc. + args (namespace): arguments for training + + Returns: + float: weights + """ + + if train_step: + self.train() + else: + self.eval() + + if args.adv_gated: + adv_predictions = self(hs, tags.long()) + else: + adv_predictions = self(hs) + + if train_step: + adv_optimizer = self.optimizer + adv_optimizer.zero_grad() + + weighted_loss = self.get_adversary_loss(preds, tags, adv_predictions) + # print(weighted_loss) + weighted_loss.backward(retain_graph=True) + # print(self.output_layer.weight.grad) + adv_optimizer.step() + adv_optimizer.zero_grad() + + return adv_predictions + + def get_learner_loss(self, preds, tags, adversarial_weights): + """ + Compute the loss for the main task model. + """ + loss = self.criterion(preds, tags) + weighted_loss = loss * adversarial_weights.squeeze() + weighted_loss = torch.mean(weighted_loss) + return weighted_loss + + def get_adversary_loss(self, preds, tags, adversarial_weights): + """ + Compute the loss for the adversary. + """ + return -1 * self.get_learner_loss(preds, tags, adversarial_weights) + + def get_arl_loss(self, model, batch, predictions, args): + """Get ARL loss + + Args: + model (torch.nn.Module): the main task model + batch (tuple): bach data, including inputs, target labels, protected labels, etc. + args (namespace): arguments for training + + Returns: + float: training loss + """ + + text = batch[0] + tags = batch[1].long() + + text = text.to(args.device) + tags = tags.to(args.device) + + hs = model.hidden(text).detach() + preds = model(text).detach() + + for _ in range(args.ARL_n): + # Train the adversary one step + _ = self.step(hs, preds, tags, args, train_step=True) + # Get the adversarial weights + adversarial_weights = self.step(hs, preds, tags, args, train_step=False) + # print(adversarial_weights) + # Weighted loss + arl_weighted_loss = self.get_learner_loss(predictions, tags, adversarial_weights) + # Unweighted loss, i.e., the standard loss for the vanilla model + # vanilla_unweighted_loss = self.get_learner_loss(predictions, tags, torch.ones_like(adversarial_weights)) + # By default, replace the vanilla loss with ARL + # arl_loss = arl_weighted_loss - vanilla_unweighted_loss + + return arl_weighted_loss \ No newline at end of file diff --git a/fairlib/src/networks/ARL/__init__.py b/fairlib/src/networks/ARL/__init__.py new file mode 100644 index 0000000..ac9ef02 --- /dev/null +++ b/fairlib/src/networks/ARL/__init__.py @@ -0,0 +1 @@ +from .ARL import ARL \ No newline at end of file diff --git a/fairlib/src/networks/DyBT/fairbatch_sampler.py b/fairlib/src/networks/DyBT/fairbatch_sampler.py index 0bbd13c..8ea65e5 100644 --- a/fairlib/src/networks/DyBT/fairbatch_sampler.py +++ b/fairlib/src/networks/DyBT/fairbatch_sampler.py @@ -180,7 +180,7 @@ def epoch_loss(self): def adjust_lambda(self): """Adjusts the lambda values for FairBatch algorithm. - The detailed algorithms are decribed in the paper. + The detailed algorithms are described in the paper. """ epoch_loss = self.epoch_loss() diff --git a/fairlib/src/networks/adv/discriminator.py b/fairlib/src/networks/adv/discriminator.py index d536de5..f31dc63 100644 --- a/fairlib/src/networks/adv/discriminator.py +++ b/fairlib/src/networks/adv/discriminator.py @@ -43,8 +43,16 @@ def adv_train_batch(model, discriminators, batch, args): tags = tags.to(args.device) p_tags = p_tags.to(args.device) adv_instance_weights = adv_instance_weights.to(args.device) + + if args.encoder_architecture == "BERT": + # Modify the inputs for BERT models + attention_mask = torch.stack(batch["attention_mask"]).float().squeeze().T + if args.adv_decoupling: + attention_mask = attention_mask[decoupling_masks] + attention_mask = attention_mask.to(args.device) + text = (text, attention_mask) - # hidden representations from the model + # hidden representations from the model if args.gated: hs = model.hidden(text, p_tags).detach() else: diff --git a/fairlib/src/networks/adv/utils.py b/fairlib/src/networks/adv/utils.py index 056f19d..5895e65 100644 --- a/fairlib/src/networks/adv/utils.py +++ b/fairlib/src/networks/adv/utils.py @@ -126,6 +126,13 @@ def __init__(self, args): else: pass + # Add the onehot target to the input of adv + if self.args.adv_gated and self.args.adv_gated_type == "Inputs": + self.input_dim = self.input_dim + self.args.num_classes + # One-hot mapping for the class + self.mapping = torch.eye(self.args.num_classes, requires_grad=False) + self.mapping = self.mapping.to(self.args.device) + if args.adv_n_hidden == 0: self.output_layer = nn.Linear(self.input_dim, args.num_groups) @@ -138,7 +145,7 @@ def __init__(self, args): self.hidden_layers = self.init_hidden_layers() # Augmentation layers - if self.args.adv_gated: + if self.args.adv_gated and self.args.adv_gated_type == "Augmentation": if self.args.adv_n_hidden == 0: logging.info("Gated component requires at least one hidden layers in the model") pass @@ -162,7 +169,14 @@ def __init__(self, args): self.init_for_training() def forward(self, input_data, group_label = None): - # input_data = self.grad_rev(input_data) + if (self.args.adv_gated): + assert group_label is not None, "Group labels are needed for augmentation" + + if (self.args.adv_gated_type == "Inputs"): + # Get one hot representations of y + onehot_y = self.mapping[group_label.long()] + # Concat + input_data = torch.cat([input_data, onehot_y], dim=1) # Main model main_output = input_data @@ -170,8 +184,9 @@ def forward(self, input_data, group_label = None): main_output = layer(main_output) # Augmentation - if self.args.adv_gated and self.args.adv_n_hidden > 0: - assert group_label is not None, "Group labels are needed for augmentaiton" + if (self.args.adv_gated) and ( + self.args.adv_n_hidden > 0) and ( + self.args.adv_gated_type == "Augmentation"): specific_output = self.augmentation_components(input_data, group_label) @@ -181,16 +196,24 @@ def forward(self, input_data, group_label = None): return output def hidden(self, input_data, group_label = None): - # input_data = self.grad_rev(input_data) - + if (self.args.adv_gated): + assert group_label is not None, "Group labels are needed for augmentation" + + if (self.args.adv_gated_type == "Inputs"): + # Get one hot representations of y + onehot_y = self.mapping[group_label.long()] + # Concat + input_data = torch.cat([input_data, onehot_y], dim=1) + # Main model main_output = input_data for layer in self.hidden_layers: main_output = layer(main_output) # Augmentation - if self.args.adv_gated and self.args.adv_n_hidden > 0: - assert group_label is not None, "Group labels are needed for augmentaiton" + if (self.args.adv_gated) and ( + self.args.adv_n_hidden > 0) and ( + self.args.adv_gated_type == "Augmentation"): specific_output = self.augmentation_components(input_data, group_label) diff --git a/fairlib/src/networks/classifier.py b/fairlib/src/networks/classifier.py index 1262bb3..97ae98c 100644 --- a/fairlib/src/networks/classifier.py +++ b/fairlib/src/networks/classifier.py @@ -25,7 +25,6 @@ def __init__(self, args): # Init batch norm, dropout, and activation function self.init_hyperparameters() - self.cls_parameter = self.get_cls_parameter() # Init hidden layers self.hidden_layers = self.init_hidden_layers() @@ -51,6 +50,8 @@ def __init__(self, args): sample_component=self.hidden_layers ) + self.cls_parameter = self.get_cls_parameter() + self.init_for_training() def forward(self, input_data, group_label = None): @@ -141,7 +142,7 @@ def get_cls_parameter(self): class BERTClassifier(BaseModel): model_name = 'bert-base-cased' - n_freezed_layers = 10 + n_freezed_layers = 12 def __init__(self, args): super(BERTClassifier, self).__init__() @@ -171,12 +172,14 @@ def __init__(self, args): self.init_for_training() def forward(self, input_data, group_label = None): - bert_output = self.bert(input_data)[1] + input_ids, input_masks = input_data + bert_output = self.bert(input_ids, encoder_attention_mask=input_masks)[1] return self.classifier(bert_output, group_label) def hidden(self, input_data, group_label = None): - bert_output = self.bert(input_data)[1] + input_ids, input_masks = input_data + bert_output = self.bert(input_ids, encoder_attention_mask=input_masks)[1] return self.classifier.hidden(bert_output, group_label) diff --git a/fairlib/src/networks/utils.py b/fairlib/src/networks/utils.py index 2b70876..d2da70b 100644 --- a/fairlib/src/networks/utils.py +++ b/fairlib/src/networks/utils.py @@ -4,7 +4,7 @@ from torch.optim import Adam import time from pathlib import Path -from ..evaluators import print_network, present_evaluation_scores +from ..evaluators import print_network, present_evaluation_scores, validation_is_best import pandas as pd # train the main model with adv loss @@ -25,6 +25,16 @@ def train_epoch(model, iterator, args, epoch): tags = batch[1].long().squeeze() p_tags = batch[2].float().squeeze() + text = text.to(args.device) + tags = tags.to(args.device) + p_tags = p_tags.to(args.device) + + if args.encoder_architecture == "BERT": + # Modify the inputs for BERT models + mask = torch.stack(batch["attention_mask"]).float().squeeze().T + mask = mask.to(args.device) + text = (text, mask) + if args.BT is not None and args.BT == "Reweighting": instance_weights = batch[3].float() instance_weights = instance_weights.to(args.device) @@ -33,10 +43,6 @@ def train_epoch(model, iterator, args, epoch): regression_tags = batch[5].float().squeeze() regression_tags = regression_tags.to(args.device) - text = text.to(args.device) - tags = tags.to(args.device) - p_tags = p_tags.to(args.device) - data_t += (time.time() - data_t0) t0 = time.time() @@ -57,10 +63,22 @@ def train_epoch(model, iterator, args, epoch): else: loss = criterion(predictions, tags if not args.regression else regression_tags) + if args.ARL: + # loss = loss + args.ARL_loss.get_arl_loss(model, batch, predictions, args) + loss = args.ARL_loss.get_arl_loss(model, batch, predictions, args) + if args.adv_debiasing: # Update discriminator if needed if args.adv_update_frequency == "Batch": - args.discriminator.train_self_batch(model, batch) + # Update the class-specific discriminator + if args.adv_gated and (args.adv_gated_type == "Separate"): + for tmp_y in range(args.num_classes): + tmp_y_mask = list(torch.where(tags == tmp_y)[0].cpu().numpy()) + if len(tmp_y_mask) > 0: + _batch = [i[tmp_y_mask] for i in batch] + args.discriminator[tmp_y].train_self_batch(model, _batch) + else: + args.discriminator.train_self_batch(model, batch) # get hidden representations if args.gated: @@ -68,10 +86,20 @@ def train_epoch(model, iterator, args, epoch): else: hs = model.hidden(text) - adv_losses = args.discriminator.adv_loss(hs, tags, p_tags) + # Get adv losses + if args.adv_gated and (args.adv_gated_type == "Separate"): + for tmp_y in range(args.num_classes): + tmp_y_mask = list(torch.where(tags == tmp_y)[0].cpu().numpy()) + if len(tmp_y_mask) > 0: + tmp_y_adv_losses = args.discriminator[tmp_y].adv_loss(hs[tmp_y_mask], tags[tmp_y_mask], p_tags[tmp_y_mask]) + + for tmp_y_adv_loss in tmp_y_adv_losses: + loss = loss - (tmp_y_adv_loss / (args.adv_num_subDiscriminator * args.num_classes)) + else: + adv_losses = args.discriminator.adv_loss(hs, tags, p_tags) - for adv_loss in adv_losses: - loss = loss - (adv_loss / args.adv_num_subDiscriminator) + for adv_loss in adv_losses: + loss = loss - (adv_loss / args.adv_num_subDiscriminator) if args.FCL: # get hidden representations @@ -89,7 +117,7 @@ def train_epoch(model, iterator, args, epoch): predictions, tags, p_tags, regression_tags = None if not args.regression else regression_tags, ) - + optimizer.zero_grad() loss.backward() # Zero gradients of the cls head @@ -121,11 +149,17 @@ def train_epoch(model, iterator, args, epoch): iterator = args.opt.dev_generator, args = args) + is_best = validation_is_best( + valid_preds, valid_labels, valid_private_labels, + model, epoch_valid_loss, selection_criterion = "DTO", + performance_metric = "accuracy", fairness_metric="TPR_GAP" + ) + present_evaluation_scores( valid_preds, valid_labels, valid_private_labels, test_preds, test_labels, test_private_labels, - epoch=epoch+(it / len(iterator)), epochs_since_improvement=None, model=model, epoch_valid_loss=None, - is_best=False, + epoch=epoch+(it / len(iterator)), epochs_since_improvement=None, model=model, + epoch_valid_loss=None, is_best=is_best ) model.train() @@ -158,6 +192,12 @@ def eval_epoch(model, iterator, args): tags = tags.to(device).long() p_tags = p_tags.to(device).float() + if args.encoder_architecture == "BERT": + # Modify the inputs for BERT models + mask = torch.stack(batch["attention_mask"]).float().squeeze().T + mask = mask.to(args.device) + text = (text, mask) + if args.BT is not None and args.BT == "Reweighting": instance_weights = batch[3].float() instance_weights = instance_weights.to(device) @@ -220,7 +260,9 @@ def init_for_training(self): self.criterion = torch.nn.MSELoss(reduction = reduction) else: self.criterion = torch.nn.CrossEntropyLoss(reduction = reduction) - + + self.best_valid_loss = 1e+5 + print_network(self, verbose=True) def init_hyperparameters(self): @@ -273,7 +315,7 @@ def train_self(self, **opt_pairs): logging.info("Reinitialized DyBT sampler for dataloader") epochs_since_improvement = 0 - best_valid_loss = 1e+5 + # best_valid_loss = 1e+5 for epoch in range(self.args.opt.epochs): @@ -300,8 +342,13 @@ def train_self(self, **opt_pairs): self.args.discriminator.train_self(self) # Check if there was an improvement - is_best = epoch_valid_loss < best_valid_loss - best_valid_loss = min(epoch_valid_loss, best_valid_loss) + # is_best = epoch_valid_loss < best_valid_loss + # best_valid_loss = min(epoch_valid_loss, best_valid_loss) + is_best = validation_is_best( + valid_preds, valid_labels, valid_private_labels, + self, epoch_valid_loss, selection_criterion = "DTO", + performance_metric = "accuracy", fairness_metric="TPR_GAP" + ) if not is_best: epochs_since_improvement += 1 diff --git a/requirements.txt b/requirements.txt index 1488f25..de0f2e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ tqdm==4.62.3 -numpy==1.22 +numpy==1.21 docopt==0.6.2 pandas==1.3.4 scikit-learn==1.0 @@ -10,3 +10,4 @@ matplotlib==3.5.0 pickle5==0.0.12 transformers==4.11.3 sacremoses==0.0.53 +sentencepiece diff --git a/setup.py b/setup.py index f0bb0df..77ab4a9 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "pickle5", "transformers", "sacremoses", + "sentencepiece" ] @@ -36,7 +37,7 @@ if __name__ == '__main__': setup( name='fairlib', - version="0.0.9", + version="0.1.0", author="Xudong Han", author_email="xudongh1@student.unimelb.edu.au", description='Unified framework for assessing and improving fairness.',