# Debiasing a Language Model

In [None]:
# Computing libraries
import numpy as np
import torch
import torch.nn as nn

# huggging face
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from datasets import load_dataset, Dataset, DatasetDict
import evaluate

# Custom imports
from FairLangProc.datasets import BiasDataLoader
from FairLangProc.metrics import WEAT

from FairLangProc.algorithms.preprocessors import CDA, BLINDModelForClassification, SentDebiasForSequenceClassification
from FairLangProc.algorithms.inprocessors import EARModel, DebiasAdapter, selective_unfreezing 
from FairLangProc.algorithms.intraprocessors import add_EAT_hook, DiffPrunningBERT

## Configuration

In [None]:
# Configuration
MODELS = [
    'bert-base-uncased',
    'deepseek-ai/deepseek-llm-7b-base',
    'huggyllama/llama-7b'
]
TASKS = [
    "cola",
    "sst2",
    "mrpc",
    "stsb",
    "qqp",
    "mnli",
    "qnli",
    "rte",
    "wnli"
]
TASK_LABELS = {
    "cola": 2,
    "sst2": 2,
    "mrpc": 2,
    "qqp": 2,
    "stsb": 1,
    "mnli": 3,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
}
DEBIAS_METHODS = [
    "none",
    "cda",
    "blind",
    "embedding",
    "ear",
    "adele",
    "selective",
    "eat",
    "diff"
]


MODEL_NAME = MODELS[0]
TASK = "cola"
BATCH_SIZE = 16
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEBIAS = "none"

In [None]:
counterfactual_pairs = [
    ("gods", "goddesses"), ("manager", "manageress"), ("barons", "baronesses"),
    ("nephew", "niece"), ("prince", "princess"), ("boars", "sows"),
    ("baron", "baroness"), ("stepfathers", "stepmothers"), ("wizard", "witch"),
    ("father", "mother"), ("stepsons", "stepdaughters"), ("sons-in-law", "daughters-in-law"),
    ("dukes", "duchesses"), ("boyfriend", "girlfriend"), ("fiances", "fiancees"),
    ("dad", "mom"), ("shepherd", "shepherdess"), ("uncles", "aunts"),
    ("beau", "belle"), ("males", "females"), ("hunter", "huntress"),
    ("beaus", "belles"), ("grandfathers", "grandmothers"), ("lads", "lasses"),
    ("daddies", "mummies"), ("step-son", "step-daughter"), ("masters", "mistresses"),
    ("policeman", "policewoman"), ("nephews", "nieces"), ("brother", "sister"),
    ("grandfather", "grandmother"), ("priest", "priestess"), ("hosts", "hostesses"),
    ("landlord", "landlady"), ("husband", "wife"), ("poet", "poetess"),
    ("landlords", "landladies"), ("fathers", "mothers"), ("masseur", "masseuse"),
    ("monks", "nuns"), ("usher", "usherette"), ("hero", "heroine"),
    ("stepson", "stepdaughter"), ("postman", "postwoman"), ("god", "goddess"),
    ("milkmen", "milkmaids"), ("stags", "hinds"), ("grandpa", "grandma"),
    ("chairmen", "chairwomen"), ("husbands", "wives"), ("grandpas", "grandmas"),
    ("stewards", "stewardesses"), ("murderer", "murderess"), ("manservant", "maidservant"),
    ("men", "women"), ("host", "hostess"), ("heirs", "heiresses"),
    ("masseurs", "masseuses"), ("boy", "girl"), ("male", "female"),
    ("son-in-law", "daughter-in-law"), ("waiter", "waitress"), ("tutors", "governesses"),
    ("priests", "priestesses"), ("bachelor", "spinster"), ("millionaire", "millionairess"),
    ("steward", "stewardess"), ("businessmen", "businesswomen"), ("congressman", "congresswoman"),
    ("emperor", "empress"), ("duke", "duchess"), ("sire", "dam"),
    ("son", "daughter"), ("sirs", "madams"), ("widower", "widow"),
    ("kings", "queens"), ("papas", "mamas"), ("grandsons", "granddaughters"),
    ("proprietor", "proprietress"), ("monk", "nun"), ("headmasters", "headmistresses"),
    ("grooms", "brides"), ("heir", "heiress"), ("boys", "girls"),
    ("gentleman", "lady"), ("uncle", "aunt"), ("he", "she"),
    ("king", "queen"), ("princes", "princesses"), ("policemen", "policewomen"),
    ("governor", "matron"), ("fiance", "fiancee"), ("step-father", "step-mother"),
    ("waiters", "waitresses"), ("mr", "mrs"), ("stepfather", "stepmother"),
    ("daddy", "mummy"), ("lords", "ladies"), ("widowers", "widows"),
    ("emperors", "empresses"), ("father-in-law", "mother-in-law"), ("abbot", "abbess"),
    ("sir", "madam"), ("actor", "actress"), ("mr.", "mrs."),
    ("wizards", "witches"), ("actors", "actresses"), ("chairman", "chairwoman"),
    ("sorcerer", "sorceress"), ("postmaster", "postmistress"), ("brothers", "sisters"),
    ("lad", "lass"), ("headmaster", "headmistress"), ("papa", "mama"),
    ("milkman", "milkmaid"), ("heroes", "heroines"), ("man", "woman"),
    ("grandson", "granddaughter"), ("groom", "bride"), ("sons", "daughters"),
    ("congressmen", "congresswomen"), ("businessman", "businesswoman"), ("boyfriends", "girlfriends"),
    ("dads", "moms"),
]

## Load model and debias method

In [None]:
num_labels = TASK_LABELS[TASK]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
original_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)
hidden_dim = original_model.config.hidden_size

if not hasattr(original_model, 'classifier'):
    original_model.classifier = original_model.score

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at deepseek-ai/deepseek-llm-7b-base and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
if DEBIAS in ("none", "cda", "eat"):
    model = original_model

In [None]:
if DEBIAS == "embedding":

    class SentDebiasBert(SentDebiasForSequenceClassification):        
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.bert(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state[:,0,:]

    class SentDebiasAverageAutoreg(SentDebiasForSequenceClassification):
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.model(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state.mean(dim = 1)
        
    if model == 'bert-base-uncased':
        SentDebias = SentDebiasBert(
            model = model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1
        )
    else:
        SentDebias = SentDebiasAverageAutoreg(
            model = model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1
        )

In [None]:
if DEBIAS == "blind":
    
    class BLINDBERT(BLINDModelForClassification):
        def _get_embedding(self, input_ids = None, attention_mask = None, token_type_ids = None):
            return self.model.bert(
                input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state[:,0,:]
        
    class BLINDAverageAutoreg(BLINDModelForClassification):
        def _get_embedding(self, input_ids = None, attention_mask = None, token_type_ids = None):
            return self.model.model(
                input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state.mean(dim = 1)
        
    if model == 'bert-base-uncased':
        model = BLINDBERT(
            model = model,
            config = None,
            base_loss = nn.CrossEntropyLoss(),
            alpha=1.0,
            gamma=2.0,
            temperature=1.0,
            size_average=True,
            hidden_dim = hidden_dim
        )
    else:
        model = BLINDAverageAutoreg(
            model = model,
            config = None,
            base_loss = nn.CrossEntropyLoss(),
            alpha=1.0,
            gamma=2.0,
            temperature=1.0,
            size_average=True,
            hidden_dim = hidden_dim
        )

In [None]:
if DEBIAS == "adele":
    model = DebiasAdapter(
        model = original_model,
        config = 'lora'
    )

In [None]:
if DEBIAS == "ear":
    model = EARModel(
        model = original_model,
        ear_reg_strength = 0.01
    )

In [None]:
if DEBIAS == "selective":
    model = original_model
    selective_unfreezing(model, ["attention.self", "attention.output"])

In [None]:
if DEBIAS == "diff":
    tokens_male = [words[0] for words in counterfactual_pairs]
    tokens_female = [words[1] for words in counterfactual_pairs]

    inputs_male = tokenizer(tokens_male, padding = True, return_tensors = "pt")
    inputs_female = tokenizer(tokens_female, padding = True, return_tensors = "pt")

    ModularDebiasingBERT = DiffPrunningBERT(
        base_model = original_model,
        input_ids_A = inputs_male,
        input_ids_B = inputs_female
    )

## Auxiliary functions

In [10]:
def preprocess_function(examples, tokenizer, task):
    if task in ["sst2", "cola"]:
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
    elif task in ["mnli", "qnli", "rte", "wnli"]:
        return tokenizer(examples["premise"], examples["hypothesis"], truncation=True, padding="max_length", max_length=128)
    elif task in ["mrpc", "qqp"]:
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    elif task == "stsb":
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    
def get_metric(task_name):
    if task_name == "stsb":
        return evaluate.load("glue", task_name), lambda x: x
    return evaluate.load("glue", task_name), lambda logits: np.argmax(logits, axis=1)

def compute_metrics_fn(eval_pred, task_name):
    logits, labels = eval_pred
    metric, pred_fn = get_metric(task_name)
    if task_name == "stsb":
        predictions = np.squeeze(logits)
    else:
        predictions = pred_fn(logits)
    return metric.compute(predictions=predictions, references=labels)

## Load dataset

In [None]:
# Dataset
dataset = load_dataset("glue", TASK)

if DEBIAS == "cda":
    train_dataset = Dataset.from_dict(
        CDA(dataset['train'][:], pairs = dict(counterfactual_pairs))
        )
    dataset = DatasetDict({
        "train": train_dataset,
        'validation': dataset["validation"],
        "test": dataset["test"]
        
    })
    
tokenized_datasets = dataset.map(lambda x: preprocess_function(x, tokenizer, TASK), batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


## Fine-tuning

In [None]:
training_args = TrainingArguments(
    output_dir=f"./output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy" if TASK != "stsb" else "pearson",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=lambda p: compute_metrics_fn(p, TASK),
)

if DEBIAS == 'eat':
    model = add_EAT_hook(
        model = model,
        beta = 0.7
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda p: compute_metrics_fn(p, TASK),
    )

trainer.train()
eval_results = trainer.evaluate()
print("Results in ", TASK, ":", eval_results)

Linear(in_features=4096, out_features=2, bias=False)

## WEAT

In [33]:
words = load_dataset("fairnlp/weat", split="train")

README.md:   0%|          | 0.00/2.09k [00:00<?, ?B/s]

ValueError: Config name is missing.
Please pick one among the available configs: ['words', 'associations', 'associations_wefat']
Example of usage:
	`load_dataset('fairnlp/weat', 'words')`

In [None]:
class BertWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state[:, 0, :]

class AverageAutoregWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state.mean(dim = 1)
    
if MODEL_NAME == 'bert-base-uncased':
    weat = BertWEAT(model = model, tokenizer = tokenizer)
else:
    weat = AverageAutoregWEAT(model = model, tokenizer = tokenizer)

math = ['math', 'algebra', 'geometry', 'calculus', 'equations', 'computation', 'numbers', 'addition']
arts = ['poetry', 'art', 'dance', 'literature', 'novel', 'symphony', 'drama', 'sculpture']
male = ['male', 'man', 'boy', 'brother', 'he', 'him', 'his', 'son']
female = ['female', 'woman', 'girl', 'sister', 'she', 'her', 'hers', 'daughter']

results = weat.run_test(
    W1_words = math, W2_words = arts,
    A1_words = male, A2_words = female,
    pval = False
    )