# Debiasing a Language Model

In [1]:
# Standard imports
import json

# Computing libraries
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW

# Check if CUDA is available and set device
device = torch.device("cuda")

# huggging face
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from adapters import AdapterTrainer
from datasets import load_dataset, Dataset, DatasetDict
import evaluate

# Custom imports
LOCAL = True
if LOCAL:
    import os
    import sys
    ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) \
        if "__file__" in globals() else os.path.abspath("..")
    sys.path.insert(0, ROOT_PATH)

from FairLangProc.datasets import BiasDataLoader
from FairLangProc.metrics import WEAT

from FairLangProc.algorithms.preprocessors import CDA, BLINDTrainer, SentDebiasForSequenceClassification
from FairLangProc.algorithms.inprocessors import EARModel, DebiasAdapter, selective_unfreezing 
from FairLangProc.algorithms.intraprocessors import add_EAT_hook, DiffPrunBERT, DiffPrunDebiasing

## Configuration

In [2]:
# Configuration
MODELS = [
    'bert-base-uncased',
    'deepseek-ai/deepseek-llm-7b-base',
    'huggyllama/llama-7b'
]
TASKS = [
    "cola",
    "sst2",
    "mrpc",
    "stsb",
    "qqp",
    "mnli",
    "qnli",
    "rte",
    "wnli"
]
TASK_LABELS = {
    "cola": 2,
    "sst2": 2,
    "mrpc": 2,
    "qqp": 2,
    "stsb": 1,
    "mnli": 3,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
}
DEBIAS_METHODS = [
    "none",
    "cda",
    "blind",
    "embedding",
    "ear",
    "adele",
    "selective",
    "eat",
    "diff"
]
TASK_METRICS = {
    "cola": "eval_matthews_correlation",
    "sst2": "eval_accuracy",
    "mrpc": "eval_accuracy",
    "stsb": "eval_pearson",
    "mnli": "eval_accuracy",
    "qnli": "eval_accuracy",
    "rte": "eval_accuracy",
    "wnli": "eval_accuracy",
}
CDA_METHOD = {
    "none": False,
    "cda": True,
    "blind": False,
    "embedding": False,
    "ear": False,
    "adele": True,
    "selective": True,
    "eat": False,
    "diff": False
}

In [3]:
MODEL_NAME = MODELS[0]
TASK = "cola"
DEBIAS = "diff"

In [4]:
METRIC_FOR_BEST = TASK_METRICS.get(TASK, "eval_accuracy")
BATCH_SIZE = 16
WEIGHT_DECAY = 0.1
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
counterfactual_pairs = [
    ("gods", "goddesses"), ("manager", "manageress"), ("barons", "baronesses"),
    ("nephew", "niece"), ("prince", "princess"), ("boars", "sows"),
    ("baron", "baroness"), ("stepfathers", "stepmothers"), ("wizard", "witch"),
    ("father", "mother"), ("stepsons", "stepdaughters"), ("sons-in-law", "daughters-in-law"),
    ("dukes", "duchesses"), ("boyfriend", "girlfriend"), ("fiances", "fiancees"),
    ("dad", "mom"), ("shepherd", "shepherdess"), ("uncles", "aunts"),
    ("beau", "belle"), ("males", "females"), ("hunter", "huntress"),
    ("beaus", "belles"), ("grandfathers", "grandmothers"), ("lads", "lasses"),
    ("daddies", "mummies"), ("step-son", "step-daughter"), ("masters", "mistresses"),
    ("policeman", "policewoman"), ("nephews", "nieces"), ("brother", "sister"),
    ("grandfather", "grandmother"), ("priest", "priestess"), ("hosts", "hostesses"),
    ("landlord", "landlady"), ("husband", "wife"), ("poet", "poetess"),
    ("landlords", "landladies"), ("fathers", "mothers"), ("masseur", "masseuse"),
    ("monks", "nuns"), ("usher", "usherette"), ("hero", "heroine"),
    ("stepson", "stepdaughter"), ("postman", "postwoman"), ("god", "goddess"),
    ("milkmen", "milkmaids"), ("stags", "hinds"), ("grandpa", "grandma"),
    ("chairmen", "chairwomen"), ("husbands", "wives"), ("grandpas", "grandmas"),
    ("stewards", "stewardesses"), ("murderer", "murderess"), ("manservant", "maidservant"),
    ("men", "women"), ("host", "hostess"), ("heirs", "heiresses"),
    ("masseurs", "masseuses"), ("boy", "girl"), ("male", "female"),
    ("son-in-law", "daughter-in-law"), ("waiter", "waitress"), ("tutors", "governesses"),
    ("priests", "priestesses"), ("bachelor", "spinster"), ("millionaire", "millionairess"),
    ("steward", "stewardess"), ("businessmen", "businesswomen"), ("congressman", "congresswoman"),
    ("emperor", "empress"), ("duke", "duchess"), ("sire", "dam"),
    ("son", "daughter"), ("sirs", "madams"), ("widower", "widow"),
    ("kings", "queens"), ("papas", "mamas"), ("grandsons", "granddaughters"),
    ("proprietor", "proprietress"), ("monk", "nun"), ("headmasters", "headmistresses"),
    ("grooms", "brides"), ("heir", "heiress"), ("boys", "girls"),
    ("gentleman", "lady"), ("uncle", "aunt"), ("he", "she"),
    ("king", "queen"), ("princes", "princesses"), ("policemen", "policewomen"),
    ("governor", "matron"), ("fiance", "fiancee"), ("step-father", "step-mother"),
    ("waiters", "waitresses"), ("mr", "mrs"), ("stepfather", "stepmother"),
    ("daddy", "mummy"), ("lords", "ladies"), ("widowers", "widows"),
    ("emperors", "empresses"), ("father-in-law", "mother-in-law"), ("abbot", "abbess"),
    ("sir", "madam"), ("actor", "actress"), ("mr.", "mrs."),
    ("wizards", "witches"), ("actors", "actresses"), ("chairman", "chairwoman"),
    ("sorcerer", "sorceress"), ("postmaster", "postmistress"), ("brothers", "sisters"),
    ("lad", "lass"), ("headmaster", "headmistress"), ("papa", "mama"),
    ("milkman", "milkmaid"), ("heroes", "heroines"), ("man", "woman"),
    ("grandson", "granddaughter"), ("groom", "bride"), ("sons", "daughters"),
    ("congressmen", "congresswomen"), ("businessman", "businesswoman"), ("boyfriends", "girlfriends"),
    ("dads", "moms"),
]

## Load model and debias method

In [6]:
num_labels = TASK_LABELS[TASK]
if TASK == 'stsb':
    problem_type='regression'
else:
    problem_type='single_label_classification'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if DEBIAS in ("adele", "eat", "diff"):
    try:
        RESULTS_PATH = f'../output/{TASK}-none-{MODEL_NAME}/'
        CHECKPOINTS = [direction for direction in os.listdir(RESULTS_PATH) if direction.startswith('checkpoint')]
        LAST_CHECKPOINT_PATH = RESULTS_PATH + CHECKPOINTS[-1]
        original_model = AutoModelForSequenceClassification.from_pretrained(LAST_CHECKPOINT_PATH)
    except:
        RESULTS_PATH = f'output/{TASK}-none-{MODEL_NAME}/'
        CHECKPOINTS = [direction for direction in os.listdir(RESULTS_PATH) if direction.startswith('checkpoint')]
        LAST_CHECKPOINT_PATH = RESULTS_PATH + CHECKPOINTS[-1]
        original_model = AutoModelForSequenceClassification.from_pretrained(LAST_CHECKPOINT_PATH)
else:
    original_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels, problem_type=problem_type)


hidden_dim = original_model.config.hidden_size
if not hasattr(original_model, 'classifier'):
    original_model.classifier = original_model.score

In [7]:
if DEBIAS in ("none", "cda", "eat"):
    model = original_model

In [8]:
if DEBIAS == "embedding":

    class SentDebiasBert(SentDebiasForSequenceClassification):        
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.bert(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state[:,0,:]

    class SentDebiasAverageAutoreg(SentDebiasForSequenceClassification):
        def _get_embedding(self, input_ids, attention_mask = None, token_type_ids = None):
            return self.model.model(
                input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids
                ).last_hidden_state.mean(dim = 1)
        
    if MODEL_NAME == 'bert-base-uncased':
        model = SentDebiasBert(
            model = original_model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1,
            n_labels = num_labels
        )
    else:
        model = SentDebiasAverageAutoreg(
            model = original_model,
            config = None,
            tokenizer = tokenizer,
            word_pairs = counterfactual_pairs,
            n_components = 1,
            n_labels = num_labels
        )

In [9]:
if DEBIAS == "blind":

    model = original_model

    class BLINDBERTTrainer(BLINDTrainer):
        def _get_embedding(self, inputs):
            return self.model.bert(
                input_ids = inputs.get("input_ids"), attention_mask = inputs.get("attention_mask"), token_type_ids = inputs.get("token_type_ids")
                ).last_hidden_state[:,0,:]

In [10]:
if DEBIAS == "adele":
    DebiasAdapter = DebiasAdapter(model = original_model)
    model = DebiasAdapter.get_model()

In [11]:
if DEBIAS == "ear":
    model = EARModel(
        model = original_model,
        ear_reg_strength = 0.01
    )

In [12]:
if DEBIAS == "selective":
    model = original_model
    selective_unfreezing(model, ["attention.self", "attention.output"])

In [None]:
if DEBIAS == "diff":

    def normalize_by_column(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        mean = x.mean(dim=0, keepdim=True)
        std = x.std(dim=0, keepdim=True)
        return (x - mean) / (std + eps)

    if TASK == "stsb":
        loss_fn = nn.MSELoss()
    else:
        loss_fn = nn.CrossEntropyLoss()

    class DiffPrunAvgAutoReg(DiffPrunDebiasing):
        def _forward(self, input_ids, attention_mask=None, token_type_ids=None):
            outputs = self.encoder(
                input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids
                )
            return outputs.last_hidden_state.mean(dim = 1)

    tokens_male = [words[0] for words in counterfactual_pairs]
    tokens_female = [words[1] for words in counterfactual_pairs]

    inputs_male = tokenizer(tokens_male, padding = True, return_tensors = "pt")
    inputs_female = tokenizer(tokens_female, padding = True, return_tensors = "pt")

    if MODEL_NAME == 'bert-base-uncased':
        model = DiffPrunBERT(
            head = original_model.classifier,
            encoder = original_model.bert,
            loss_fn = loss_fn,
            input_ids_A = inputs_male,
            input_ids_B = inputs_female,
            bias_kernel = normalize_by_column,
            upper = 10,
            lower = -0.001,
            lambda_bias = 0.5,
            lambda_sparse = 0.00001
        )

    else:
        model = DiffPrunAvgAutoReg(
            head = original_model.classifier,
            encoder = original_model.base_model,
            loss_fn = loss_fn,
            input_ids_A = inputs_male[:50],
            input_ids_B = inputs_female[:50],
            bias_kernel = normalize_by_column,
            upper = 10,
            lower = -0.001,
            lambda_bias = 0.5,
            lambda_sparse = 0.00001
        )

## Auxiliary functions

In [14]:
def preprocess_function(examples, tokenizer, task):
    if task in ["sst2", "cola"]:
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
    elif task == "mnli":
        return tokenizer(examples["premise"], examples["hypothesis"], truncation=True, padding="max_length", max_length=128)
    elif task == "qnli":
        return tokenizer(examples["question"], examples["sentence"], truncation=True, padding="max_length", max_length=128)
    elif task in ["rte", "wnli"]:
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    elif task == "mrpc":
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    elif task == "qqp":
        return tokenizer(examples["question1"], examples["question2"], truncation=True, padding="max_length", max_length=128)
    elif task == "stsb":
        return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length", max_length=128)
    
def get_metrics(task_name):
    metric = evaluate.load("glue", task_name)
    if task_name == "stsb":
        return metric, lambda logits: np.squeeze(logits, axis=-1)
    return metric, lambda logits: np.argmax(logits, axis=-1)

def compute_metrics_fn(p, task_name):
    logits = p.predictions
    labels = p.label_ids

    if isinstance(logits, tuple) or isinstance(logits, list):
        logits = logits[0]

    metric, postprocess_fn = get_metrics(task_name)
    predictions = postprocess_fn(logits)

    return metric.compute(predictions=predictions, references=labels)

## Load dataset

In [15]:
# Dataset
dataset = load_dataset("glue", TASK)

if CDA_METHOD[DEBIAS] and TASK != 'mnli':
    train_dataset = Dataset.from_dict(
        CDA(dataset['train'][:], pairs = dict(counterfactual_pairs))
        )
    dataset = DatasetDict({
        "train": train_dataset,
        "validation": dataset["validation"],
        "test": dataset["test"]
    })
elif CDA_METHOD[DEBIAS] and TASK == 'mnli':
    train_dataset = Dataset.from_dict(
        CDA(dataset['train'][:], pairs = dict(counterfactual_pairs))
        )
    dataset = DatasetDict({
        "train": train_dataset,
        "validation_matched": dataset["validation_matched"],
        "validation_mismatched": dataset["validation_mismatched"],
        "test_matched": dataset["test_matched"],
        "test_mismatched": dataset["test_mismatched"]
    })

if TASK == 'mnli':
    dataset["validation"] = dataset["validation_matched"]
    
tokenized_datasets = dataset.map(lambda x: preprocess_function(x, tokenizer, TASK), batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


## Fine-tuning

In [16]:
EVAL_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"
LOAD_BEST_MODEL_AT_END = True
SAVE_SAFETENSORS = True
FP16 = True

if DEBIAS == 'diff':
    PATIENCE = 5
    SAVE_SAFETENSORS = False
else:
    PATIENCE = 2
    
CALLBACKS = [EarlyStoppingCallback(early_stopping_patience=PATIENCE)]
EVAL_STEPS = None

# if DEBIAS in ('eat', 'diff):
if DEBIAS == 'eat':
    SAVE_STRATEGY = "no"
    LOAD_BEST_MODEL_AT_END = False
    CALLBACKS = None 


if DEBIAS == 'adele':
    trainer = AdapterTrainer
elif DEBIAS == 'blind':
    trainer = BLINDBERTTrainer
else:
    trainer = Trainer


if TASK in ('qqp', 'mnli'):
    BATCH_SIZE = 32
    EVAL_STRATEGY = "steps"
    EVAL_STEPS = 1000
    SAVE_STEPS = 1000
else:
    BATCH_SIZE = 16
    EVAL_STRATEGY = "epoch"
    EVAL_STEPS = None
    SAVE_STEPS = None


if LOAD_BEST_MODEL_AT_END:
    SAVE_STRATEGY = EVAL_STRATEGY 

In [17]:
training_args = TrainingArguments(
    output_dir=f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    eval_strategy=EVAL_STRATEGY,
    eval_steps=EVAL_STEPS,
    save_strategy=SAVE_STRATEGY,
    save_steps=SAVE_STEPS,
    save_safetensors=SAVE_SAFETENSORS, 
    logging_dir="logs",
    load_best_model_at_end=LOAD_BEST_MODEL_AT_END,
    metric_for_best_model=METRIC_FOR_BEST,
    fp16=FP16,
    greater_is_better = True
)

if DEBIAS == 'blind':
    trainer = trainer(
        blind_optimizer= lambda x: AdamW(x, lr=1e-5, weight_decay=WEIGHT_DECAY),
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda p: compute_metrics_fn(p, TASK),
        callbacks=CALLBACKS,
        optimizers=(AdamW(model.parameters(), lr=1e-5, weight_decay=WEIGHT_DECAY), None)
    )
else:
    trainer = trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda p: compute_metrics_fn(p, TASK),
        callbacks=CALLBACKS,
        optimizers=(AdamW(model.parameters(), lr=1e-5, weight_decay=WEIGHT_DECAY), None)
    )

if DEBIAS == 'eat':
    add_EAT_hook(model, beta=0.7)
else:
    trainer.train()

  trainer = trainer(


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
if TASK == 'mnli':
    eval_results_mismd = trainer.evaluate(tokenized_datasets["validation_mismatched"])
    eval_results_match = trainer.evaluate(tokenized_datasets["validation_matched"])
    print("Validation results (matched) in ", TASK, ":", eval_results_match)
    print("Validation results (mismatched) in ", TASK, ":", eval_results_mismd)
else:
    eval_results = trainer.evaluate()
    print("Validation results in ", TASK, ":", eval_results)

Validation results in  cola : {'eval_loss': 0.1315392404794693, 'eval_matthews_correlation': 0.0, 'eval_runtime': 9.9141, 'eval_samples_per_second': 105.204, 'eval_steps_per_second': 3.329, 'epoch': 0.14925373134328357}


## WEAT

In [None]:
class BertWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state[:, 0, :]

class AverageAutoregWEAT(WEAT):
    def _get_embedding(self, outputs):
        return outputs.last_hidden_state.mean(dim = 1)
    
if MODEL_NAME == 'bert-base-uncased':
    if DEBIAS == 'diff':
        weat = BertWEAT(model = model.encoder, tokenizer = tokenizer)
    else:
        try:
            weat = BertWEAT(model = model.model.bert, tokenizer = tokenizer)
        except:
            try:
                weat = BertWEAT(model = model.bert, tokenizer = tokenizer)
            except:
                weat = BertWEAT(model = model.base_model.bert, tokenizer = tokenizer)
else:
    if DEBIAS == 'diff':
        weat = AverageAutoregWEAT(model = model.encoder, tokenizer = tokenizer)
    else:
        try:
            weat = AverageAutoregWEAT(model = model.model.base_model, tokenizer = tokenizer)
        except:
            try:
                weat = AverageAutoregWEAT(model = model.base_model, tokenizer = tokenizer)
            except:
                weat = AverageAutoregWEAT(model = model.base_model.base_model, tokenizer = tokenizer)

math = ['math', 'algebra', 'geometry', 'calculus', 'equations', 'computation', 'numbers', 'addition']
arts = ['poetry', 'art', 'dance', 'literature', 'novel', 'symphony', 'drama', 'sculpture']
male = ['male', 'man', 'boy', 'brother', 'he', 'him', 'his', 'son']
female = ['female', 'woman', 'girl', 'sister', 'she', 'her', 'hers', 'daughter']

bias_results = weat.metric(
    W1_words = math, W2_words = arts,
    A1_words = male, A2_words = female,
    pval = False
    )
print(bias_results)

{'X-A_mean_sim': 0.9595464468002319, 'X-B_mean_sim': 0.9583925604820251, 'Y-A_mean_sim': 0.9550148248672485, 'Y-B_mean_sim': 0.95367032289505, 'W1_size': 8, 'W2_size': 8, 'A1_size': 8, 'A2_size': 8, 'effect_size': -0.02097395434975624}


## Save results

In [None]:
if TASK == 'mnli':
    with open(f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}/results.json", "w") as f:
        json.dump({"eval_matched": eval_results_match, "eval_mismatched": eval_results_mismd, "bias": bias_results}, f, indent=4)
else:
    with open(f"output/{TASK}-{DEBIAS}-{MODEL_NAME.replace('/', '-')}/results.json", "w") as f:
        json.dump({"eval": eval_results, "bias": bias_results}, f, indent=4)