# Debiasing a Language Model

In [2]:
import json

# Computing libraries
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW

# Check if CUDA is available and set device
device = torch.device("cuda")

# huggging face
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from datasets import load_dataset, Dataset, DatasetDict
import evaluate

# Custom imports
import os
import sys
ruta_raiz = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) \
    if "__file__" in globals() else os.path.abspath("..")
sys.path.insert(0, ruta_raiz)

from FairLangProc.datasets import BiasDataLoader
from FairLangProc.metrics import WEAT

from FairLangProc.algorithms.preprocessors import CDA, BLINDModelForClassification, SentDebiasForSequenceClassification
from FairLangProc.algorithms.inprocessors import EARModel, DebiasAdapter, selective_unfreezing 
from FairLangProc.algorithms.intraprocessors import add_EAT_hook, DiffPrunningBERT, DiffPrunedDebiasing

## Configuration

In [3]:
# Configuration
MODELS = [
    'bert-base-uncased',
    'deepseek-ai/deepseek-llm-7b-base',
    'huggyllama/llama-7b'
]
TASKS = [
    "cola",
    "sst2",
    "mrpc",
    "stsb",
    "qqp",
    "mnli",
    "qnli",
    "rte",
    "wnli"
]
TASK_LABELS = {
    "cola": 2,
    "sst2": 2,
    "mrpc": 2,
    "qqp": 2,
    "stsb": 1,
    "mnli": 3,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
}
DEBIAS_METHODS = [
    "none",
    "cda",
    "blind",
    "embedding",
    "ear",
    "adele",
    "selective",
    "eat",
    "diff"
]
TASK_METRICS = {
    "cola": "eval_matthews_correlation",
    "sst2": "eval_accuracy",
    "mrpc": "eval_accuracy",
    "stsb": "eval_pearson",
    "mnli": "eval_accuracy",
    "qnli": "eval_accuracy",
    "rte": "eval_accuracy",
    "wnli": "eval_accuracy",
}
CDA_METHOD = {
    "none": False,
    "cda": True,
    "blind": False,
    "embedding": False,
    "ear": False,
    "adele": True,
    "selective": True,
    "eat": False,
    "diff": False
}

In [5]:
MODEL_NAME = MODELS[0]
TASK = "mnli"
DEBIAS = "none"

In [6]:
METRIC_FOR_BEST = TASK_METRICS.get(TASK, "eval_accuracy")
BATCH_SIZE = 16
WEIGHT_DECAY = 0.1
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
counterfactual_pairs = [
    ("gods", "goddesses"), ("manager", "manageress"), ("barons", "baronesses"),
    ("nephew", "niece"), ("prince", "princess"), ("boars", "sows"),
    ("baron", "baroness"), ("stepfathers", "stepmothers"), ("wizard", "witch"),
    ("father", "mother"), ("stepsons", "stepdaughters"), ("sons-in-law", "daughters-in-law"),
    ("dukes", "duchesses"), ("boyfriend", "girlfriend"), ("fiances", "fiancees"),
    ("dad", "mom"), ("shepherd", "shepherdess"), ("uncles", "aunts"),
    ("beau", "belle"), ("males", "females"), ("hunter", "huntress"),
    ("beaus", "belles"), ("grandfathers", "grandmothers"), ("lads", "lasses"),
    ("daddies", "mummies"), ("step-son", "step-daughter"), ("masters", "mistresses"),
    ("policeman", "policewoman"), ("nephews", "nieces"), ("brother", "sister"),
    ("grandfather", "grandmother"), ("priest", "priestess"), ("hosts", "hostesses"),
    ("landlord", "landlady"), ("husband", "wife"), ("poet", "poetess"),
    ("landlords", "landladies"), ("fathers", "mothers"), ("masseur", "masseuse"),
    ("monks", "nuns"), ("usher", "usherette"), ("hero", "heroine"),
    ("stepson", "stepdaughter"), ("postman", "postwoman"), ("god", "goddess"),
    ("milkmen", "milkmaids"), ("stags", "hinds"), ("grandpa", "grandma"),
    ("chairmen", "chairwomen"), ("husbands", "wives"), ("grandpas", "grandmas"),
    ("stewards", "stewardesses"), ("murderer", "murderess"), ("manservant", "maidservant"),
    ("men", "women"), ("host", "hostess"), ("heirs", "heiresses"),
    ("masseurs", "masseuses"), ("boy", "girl"), ("male", "female"),
    ("son-in-law", "daughter-in-law"), ("waiter", "waitress"), ("tutors", "governesses"),
    ("priests", "priestesses"), ("bachelor", "spinster"), ("millionaire", "millionairess"),
    ("steward", "stewardess"), ("businessmen", "businesswomen"), ("congressman", "congresswoman"),
    ("emperor", "empress"), ("duke", "duchess"), ("sire", "dam"),
    ("son", "daughter"), ("sirs", "madams"), ("widower", "widow"),
    ("kings", "queens"), ("papas", "mamas"), ("grandsons", "granddaughters"),
    ("proprietor", "proprietress"), ("monk", "nun"), ("headmasters", "headmistresses"),
    ("grooms", "brides"), ("heir", "heiress"), ("boys", "girls"),
    ("gentleman", "lady"), ("uncle", "aunt"), ("he", "she"),
    ("king", "queen"), ("princes", "princesses"), ("policemen", "policewomen"),
    ("governor", "matron"), ("fiance", "fiancee"), ("step-father", "step-mother"),
    ("waiters", "waitresses"), ("mr", "mrs"), ("stepfather", "stepmother"),
    ("daddy", "mummy"), ("lords", "ladies"), ("widowers", "widows"),
    ("emperors", "empresses"), ("father-in-law", "mother-in-law"), ("abbot", "abbess"),
    ("sir", "madam"), ("actor", "actress"), ("mr.", "mrs."),
    ("wizards", "witches"), ("actors", "actresses"), ("chairman", "chairwoman"),
    ("sorcerer", "sorceress"), ("postmaster", "postmistress"), ("brothers", "sisters"),
    ("lad", "lass"), ("headmaster", "headmistress"), ("papa", "mama"),
    ("milkman", "milkmaid"), ("heroes", "heroines"), ("man", "woman"),
    ("grandson", "granddaughter"), ("groom", "bride"), ("sons", "daughters"),
    ("congressmen", "congresswomen"), ("businessman", "businesswoman"), ("boyfriends", "girlfriends"),
    ("dads", "moms"),
]

## Load model and debias method

In [18]:
num_labels = TASK_LABELS[TASK]
if TASK == 'mnli':
    problem_type='regression'
else:
    problem_type='single_label_classification'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
original_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels, problem_type=problem_type)

hidden_dim = original_model.config.hidden_size
if not hasattr(original_model, 'classifier'):
    original_model.classifier = original_model.score

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
if DEBIAS in ("none", "cda", "eat"):
    model = original_model

In [20]:
if DEBIAS == "embedding":

    class SentDebiasBert(SentDebiasForSequenceClassification):        
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.bert(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state[:,0,:]

    class SentDebiasAverageAutoreg(SentDebiasForSequenceClassification):
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.model(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state.mean(dim = 1)
        
    if MODEL_NAME == 'bert-base-uncased':
        model = SentDebiasBert(
            model = original_model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1,
            n_labels = num_labels
        )
    else:
        model = SentDebiasAverageAutoreg(
            model = original_model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1,
            n_labels = num_labels
        )

In [21]:
if DEBIAS == "blind":
    
    class BLINDBERT(BLINDModelForClassification):
        def _get_embedding(self, input_ids = None, attention_mask = None, token_type_ids = None):
            return self.model.bert(
                input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state[:,0,:]
        
    class BLINDAverageAutoreg(BLINDModelForClassification):
        def _get_embedding(self, input_ids = None, attention_mask = None, token_type_ids = None):
            return self.model.model(
                input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state.mean(dim = 1)
        
    if MODEL_NAME == 'bert-base-uncased':
        model = BLINDBERT(
            model = original_model,
            config = None,
            gamma=2.0,
            temperature=1.0,
            hidden_dim = hidden_dim,
            n_labels = num_labels
        )
    else:
        model = BLINDAverageAutoreg(
            model = original_model,
            config = None,
            gamma=2.0,
            temperature=1.0,
            hidden_dim = hidden_dim,
            n_labels = num_labels
        )

In [22]:
if DEBIAS == "adele":
    model = DebiasAdapter(
        model = original_model,
        config = 'lora'
    )

In [23]:
if DEBIAS == "ear":
    model = EARModel(
        model = original_model,
        ear_reg_strength = 0.01
    )

In [24]:
if DEBIAS == "selective":
    model = original_model
    selective_unfreezing(model, ["attention.self", "attention.output"])

In [None]:
if DEBIAS == "diff":

    class DiffPrunningAvgAutoReg(DiffPrunedDebiasing):
        def _get_embedding(self, outputs):
            return outputs.mean(dim = 1)
        def _get_encoder(self):
            self.encoder = self.base_model.model

    tokens_male = [words[0] for words in counterfactual_pairs]
    tokens_female = [words[1] for words in counterfactual_pairs]

    inputs_male = tokenizer(tokens_male, padding = True, return_tensors = "pt")
    inputs_female = tokenizer(tokens_female, padding = True, return_tensors = "pt")

    if MODEL_NAME == 'bert-base-uncased':
        model = DiffPrunningBERT(
            model = original_model,
            input_ids_A = inputs_male,
            input_ids_B = inputs_female
        )

    else:
        model = DiffPrunningAvgAutoReg(
            model = original_model,
            input_ids_A = inputs_male,
            input_ids_B = inputs_female
        )

## Auxiliary functions

In [26]:
def preprocess_function(examples, tokenizer, task):
    if task in ["sst2", "cola"]:
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
    elif task == "mnli":
        return tokenizer(examples["premise"], examples["hypothesis"], truncation=True, padding="max_length", max_length=128)
    elif task == "qnli":
        return tokenizer(examples["question"], examples["sentence"], truncation=True, padding="max_length", max_length=128)
    elif task in ["rte", "wnli"]:
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    elif task == "mrpc":
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    elif task == "qqp":
        return tokenizer(examples["question1"], examples["question2"], truncation=True, padding="max_length", max_length=128)
    elif task == "stsb":
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    
def get_metrics(task_name):
    metric = evaluate.load("glue", task_name)
    if task_name == "stsb":
        return metric, lambda logits: np.squeeze(logits, axis=-1)
    return metric, lambda logits: np.argmax(logits, axis=-1)

def compute_metrics_fn(p, task_name):
    logits = p.predictions
    labels = p.label_ids

    if isinstance(logits, tuple) or isinstance(logits, list):
        logits = logits[0]

    metric, postprocess_fn = get_metrics(task_name)
    predictions = postprocess_fn(logits)

    return metric.compute(predictions=predictions, references=labels)

## Load dataset

In [27]:
# Dataset
dataset = load_dataset("glue", TASK)

if CDA_METHOD[DEBIAS] and TASK != 'mnli':
    train_dataset = Dataset.from_dict(
        CDA(dataset['train'][:], pairs = dict(counterfactual_pairs))
        )
    dataset = DatasetDict({
        "train": train_dataset,
        "validation": dataset["validation"],
        "test": dataset["test"]
    })
elif CDA_METHOD[DEBIAS] and TASK == 'mnli':
    train_dataset = Dataset.from_dict(
        CDA(dataset['train'][:], pairs = dict(counterfactual_pairs))
        )
    dataset = DatasetDict({
        "train": train_dataset,
        "validation_matched": dataset["validation_matched"],
        "validation_mismatched": dataset["validation_mismatched"],
        "test_matched": dataset["test_matched"],
        "test_mismatched": dataset["test_mismatched"]
    })

if TASK == 'mnli':
    dataset["validation"] = dataset["validation_matched"]
    
tokenized_datasets = dataset.map(lambda x: preprocess_function(x, tokenizer, TASK), batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


## Fine-tuning

In [None]:
EVAL_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"
LOAD_BEST_MODEL_AT_END = True
EVAL_STEPS = None

if DEBIAS in ('adele', 'diff', 'eat'):
    SAVE_STRATEGY = "no"
    LOAD_BEST_MODEL_AT_END = False

if TASK in ('qqp', 'mnli'):
    BATCH_SIZE = 32
    FP16 = True
    EVAL_STRATEGY = "steps"
    EVAL_STEPS = 1000
    SAVE_STEPS = 1000

else:
    BATCH_SIZE = 16
    FP16 = False
    EVAL_STRATEGY = "epoch"
    EVAL_STEPS = None
    SAVE_STEPS = None

if LOAD_BEST_MODEL_AT_END:
    SAVE_STRATEGY = EVAL_STRATEGY  


training_args = TrainingArguments(
    output_dir=f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    eval_strategy=EVAL_STRATEGY,
    evaluation_strategy=EVAL_STRATEGY,
    eval_steps=EVAL_STEPS,
    save_strategy=SAVE_STRATEGY,
    save_steps=SAVE_STEPS,
    logging_dir="logs",
    load_best_model_at_end=LOAD_BEST_MODEL_AT_END,
    metric_for_best_model=METRIC_FOR_BEST,
    fp16=FP16,
    greater_is_better = True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=lambda p: compute_metrics_fn(p, TASK),
    optimizers=(AdamW(model.parameters(), lr=1e-5, weight_decay=WEIGHT_DECAY), None),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()

if DEBIAS == 'eat':
    add_EAT_hook(model, beta=0.7)

if TASK == 'mnli':
    eval_results_mismd = trainer.evaluate(tokenized_datasets["validation_mismatched"])
    eval_results_match = trainer.evaluate(tokenized_datasets["validation_matched"])
    test_results_mismd = trainer.evaluate(tokenized_datasets["test_mismatched"])
    test_results_match = trainer.evaluate(tokenized_datasets["test_matched"])
    print("Validation results (matched) in ", TASK, ":", eval_results_match)
    print("Validation results (mismatched) in ", TASK, ":", eval_results_mismd)
    print("Test results (matched) in ", TASK, ":", test_results_match)
    print("Test results (mismatched) in ", TASK, ":", test_results_mismd)
else:
    eval_results = trainer.evaluate()
    test_results = trainer.evaluate(tokenized_datasets["test"])
    print("Validation results in ", TASK, ":", eval_results)
    print("Test results in ", TASK, ":", test_results)

  trainer = Trainer(


Step,Training Loss,Validation Loss,Accuracy,F1
1000,0.3808,0.368326,0.830547,0.798162
2000,0.3371,0.3289,0.85214,0.815903
3000,0.3131,0.31561,0.862083,0.8286
4000,0.3116,0.291948,0.872718,0.82976
5000,0.2946,0.289571,0.87346,0.840165
6000,0.2903,0.281131,0.879718,0.839563
7000,0.2801,0.276153,0.880979,0.839718
8000,0.2693,0.263259,0.884145,0.843835
9000,0.2697,0.260025,0.887831,0.848586
10000,0.2708,0.258583,0.888721,0.855035


KeyboardInterrupt: 

## WEAT

In [None]:
class BertWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state[:, 0, :]

class AverageAutoregWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state.mean(dim = 1)
    
if MODEL_NAME == 'bert-base-uncased':
    try:
        weat = BertWEAT(model = model.model.bert, tokenizer = tokenizer)
    except:
        try:
            weat = BertWEAT(model = model.bert, tokenizer = tokenizer)
        except:
            weat = BertWEAT(model = model.base_model.bert, tokenizer = tokenizer)
else:
    try:
        weat = AverageAutoregWEAT(model = model.model.base_model, tokenizer = tokenizer)
    except:
        try:
            weat = AverageAutoregWEAT(model = model.base_model, tokenizer = tokenizer)
        except:
            weat = AverageAutoregWEAT(model = model.base_model.base_model, tokenizer = tokenizer)

math = ['math', 'algebra', 'geometry', 'calculus', 'equations', 'computation', 'numbers', 'addition']
arts = ['poetry', 'art', 'dance', 'literature', 'novel', 'symphony', 'drama', 'sculpture']
male = ['male', 'man', 'boy', 'brother', 'he', 'him', 'his', 'son']
female = ['female', 'woman', 'girl', 'sister', 'she', 'her', 'hers', 'daughter']

bias_results = weat.run_test(
    W1_words = math, W2_words = arts,
    A1_words = male, A2_words = female,
    pval = False
    )
print(bias_results)

{'X-A_mean_sim': 0.5240193605422974, 'X-B_mean_sim': 0.6104025840759277, 'Y-A_mean_sim': 0.5816237926483154, 'Y-B_mean_sim': 0.6603400707244873, 'W1_size': 8, 'W2_size': 8, 'A1_size': 8, 'A2_size': 8, 'effect_size': -0.07666268199682236}


## Save results

In [None]:
if TASK == 'mnli':
    with open(f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}/results.json", "w") as f:
        json.dump({"eval_matched": eval_results_match, "eval_mismatched": eval_results_mismd, "test_mismatched": test_results_mismd, "test_matched": test_results_match, "bias": bias_results}, f, indent=4)
else:
    with open(f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}/results.json", "w") as f:
        json.dump({"eval": eval_results, "test": test_results, "bias": bias_results}, f, indent=4)