Parts of the code were inspired by this [BERT Tutorial](https://towardsdatascience.com/fine-tuning-bert-for-text-classification-54e7df642894#96e0) and adapted to work with XLNet. <br>
The EAR technique (i.e. the function def compute_negative_entropy) has been implemented according to the [GitHub repository](https://github.com/g8a9/ear) hosting the code associated with the original [EAR paper](https://aclanthology.org/2022.findings-acl.88/) by Attanasio et al.

In [None]:
! pip install transformers datasets evaluate
! pip install SentencePiece
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import XLNetTokenizer, XLNetForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tqdm import trange
import random

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from

In [None]:
#@title
from google.colab import drive
drive.mount('/content/drive')
train_df = pd.read_csv('/content/drive/My Drive/DATASETS/wiki_toxic/train.csv')
validation_df = pd.read_csv('/content/drive/My Drive/DATASETS/wiki_toxic/validation.csv')
frac = 0.5
#TRAIN
print(train_df.shape[0]) # get the number of rows in the dataframe
rows_to_delete = train_df.sample(frac=frac, random_state=1) # randomly select half of the rows. Random_state ensures reproducibility
train_df = train_df.drop(rows_to_delete.index)
print(train_df.shape[0])

#VALIDATION
print(validation_df.shape[0]) # get the number of rows in the dataframe
rows_to_delete = validation_df.sample(frac=frac, random_state=1) # randomly select half of the rows. Random_state ensures reproducibility
validation_df = validation_df.drop(rows_to_delete.index)
print(validation_df.shape[0])

train_text = train_df.comment_text.values
train_labels = train_df.label.values
validation_text = validation_df.comment_text.values
validation_labels = validation_df.label.values

In [None]:
from google.colab import drive
drive.mount('/content/drive')
#TEST
test_df = pd.read_csv('/content/drive/My Drive/DATASETS/wiki_toxic/test.csv')
print(test_df.shape[0]) # get the number of rows in the dataframe
test_text = test_df.comment_text.values
test_labels = test_df.label.values

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
63978


In [None]:
def compute_negative_entropy(
    inputs: tuple, attention_mask: torch.Tensor, return_values=False
):
    """Compute the negative entropy across layers of a network for given inputs.

    Args:
        - input: tuple. Tuple of length num_layers. Each item should be in the form: BHSS
        - attention_mask. Tensor with dim: BS
    """
    inputs = torch.stack(inputs)  #  LayersBatchHeadsSeqlenSeqlen
    assert inputs.ndim == 5, "Here we expect 5 dimensions in the form LBHSS"

    #  average over attention heads
    pool_heads = inputs.mean(2)

    batch_size = pool_heads.shape[1]
    samples_entropy = list()
    neg_entropies = list()
    for b in range(batch_size):
        #  get inputs from non-padded tokens of the current sample
        mask = attention_mask[b]
        sample = pool_heads[:, b, mask.bool(), :]
        sample = sample[:, :, mask.bool()]

        #  get the negative entropy for each non-padded token
        neg_entropy = (sample.softmax(-1) * sample.log_softmax(-1)).sum(-1)
        if return_values:
            neg_entropies.append(neg_entropy.detach())

        #  get the "average entropy" that traverses the layer
        mean_entropy = neg_entropy.mean(-1)

        #  store the sum across all the layers
        samples_entropy.append(mean_entropy.sum(0))

    # average over the batch
    final_entropy = torch.stack(samples_entropy).mean()
    if return_values:
        return final_entropy, neg_entropies
    else:
        return final_entropy

In [None]:
#@title
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)
train_token_id = []
train_attention_masks = []
validation_token_id = []
validation_attention_masks = []

def preprocessing(input_text, tokenizer):
  return tokenizer.encode_plus(
                        input_text,
                        add_special_tokens = True,
                        max_length = 250,
                        truncation=True,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt'
                   )


for sample in train_text:
  encoding_dict = preprocessing(sample, tokenizer)
  train_token_id.append(encoding_dict['input_ids'])
  train_attention_masks.append(encoding_dict['attention_mask'])
train_token_id = torch.cat(train_token_id, dim = 0)
train_attention_masks = torch.cat(train_attention_masks, dim = 0)
train_labels = torch.tensor(train_labels)

for sample in validation_text:
  encoding_dict = preprocessing(sample, tokenizer)
  validation_token_id.append(encoding_dict['input_ids'])
  validation_attention_masks.append(encoding_dict['attention_mask'])
validation_token_id = torch.cat(validation_token_id, dim = 0)
validation_attention_masks = torch.cat(validation_attention_masks, dim = 0)
validation_labels = torch.tensor(validation_labels)

In [None]:
#@title
# Recommended batch size: 16, 32
batch_size = 32

train_idx = np.arange(len(train_labels))
val_idx = np.arange(len(validation_labels))

# Train and validation sets
train_set = TensorDataset(train_token_id[train_idx],
                          train_attention_masks[train_idx],
                          train_labels[train_idx])

val_set = TensorDataset(validation_token_id[val_idx],
                        validation_attention_masks[val_idx],
                        validation_labels[val_idx])

# Prepare DataLoader
train_dataloader = DataLoader(
            train_set,
            sampler = RandomSampler(train_set),
            batch_size = batch_size
        )

validation_dataloader = DataLoader(
            val_set,
            sampler = SequentialSampler(val_set),
            batch_size = batch_size
        )

In [None]:
#@title
def b_tp(preds, labels):
  '''Returns True Positives (TP): count of correct predictions of actual class 1'''
  return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
  '''Returns False Positives (FP): count of wrong predictions of actual class 1'''
  return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
  '''Returns True Negatives (TN): count of correct predictions of actual class 0'''
  return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
  '''Returns False Negatives (FN): count of wrong predictions of actual class 0'''
  return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
  '''
  Returns the following metrics:
    - precision   = TP / (TP + FP)
    - recall      = TP / (TP + FN)
  '''
  preds = np.argmax(preds, axis = 1).flatten()
  labels = labels.flatten()
  tp = b_tp(preds, labels)
  tn = b_tn(preds, labels)
  fp = b_fp(preds, labels)
  fn = b_fn(preds, labels)
  b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
  b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
  return b_precision, b_recall

In [None]:
#@title
# Load the XLNetForSequenceClassification model
model = XLNetForSequenceClassification.from_pretrained(
    'xlnet-base-cased',
    num_labels = 2,
    output_attentions = True,
    output_hidden_states = False,
)

model.config.problem_type = "single_label_classification" #in this way Cross Entropy loss is selected

# Recommended learning rates (Adam): 5e-5, 3e-5, 2e-5. See: https://arxiv.org/pdf/1810.04805.pdf
optimizer = torch.optim.AdamW(model.parameters(),
                              lr = 2e-5,
                              weight_decay=0.01,
                              )

# Run on GPU
model.cuda()

In [None]:
#@title
device = torch.device('cuda')

# Recommended number of epochs: 2, 3, 4. See: https://arxiv.org/pdf/1810.04805.pdf
epochs = 2

for _ in trange(epochs, desc = 'Epoch'):

    # ========== Training ==========

    # Set model to training mode
    model.train()

    # Tracking variables
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        optimizer.zero_grad()
        # Forward pass
        train_output = model(b_input_ids,
                             token_type_ids = None,
                             attention_mask = b_input_mask,
                             labels = b_labels)

        reg_strength = 0.01 #tweak this parameter to apply regularisation. reg_strength = 0.0
        neg_entropy = compute_negative_entropy(
            inputs=train_output.attentions,
            attention_mask=b_input_mask
        )
        reg_loss = reg_strength * neg_entropy
        loss = train_output.loss + reg_loss

        # Backward pass
        loss.backward()
        optimizer.step()
        # Update tracking variables
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1

    # ========== Validation ==========

    # Set model to evaluation mode
    model.eval()

    # Tracking variables
    val_precision = []
    val_recall = []

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
          # Forward pass
          eval_output = model(b_input_ids,
                              token_type_ids = None,
                              attention_mask = b_input_mask)
        logits = eval_output.logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        # Calculate validation metrics
        b_precision, b_recall = b_metrics(logits, label_ids)
        # Update precision only when (tp + fp) !=0; ignore nan
        if b_precision != 'nan': val_precision.append(b_precision)
        # Update recall only when (tp + fn) !=0; ignore nan
        if b_recall != 'nan': val_recall.append(b_recall)

    print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
    precision = sum(val_precision)/len(val_precision)
    recall = sum(val_recall)/len(val_recall)
    f1_score = 2*((precision*recall)/(precision+recall))
    print('\t - Validation Precision: {:.4f}'.format(precision) if len(val_precision)>0 else '\t - Validation Precision: NaN')
    print('\t - Validation Recall: {:.4f}'.format(recall) if len(val_recall)>0 else '\t - Validation Recall: NaN')
    print('\t - Validation F1-score: {:.4f}'.format(f1_score) if (precision+recall)!=0 else '\t - Validation F1-score: NaN')

In [None]:
#@title
model_save_name = 'FINAL_xlnet_ear_reg_0_01_.bin'
path = F"drive/My Drive/MODELS/{model_save_name}"
torch.save(model.state_dict(), path)

In [None]:
#LOAD MODEL
from google.colab import drive
import pandas as pd
pd.options.mode.chained_assignment = None
drive.mount('/content/drive')
model_save_name = 'FINAL_xlnet_ear_reg_0_01_.bin'
path = F"drive/My Drive/MODELS/{model_save_name}"
# Load the XLNetForSequenceClassification model
model = XLNetForSequenceClassification.from_pretrained(
    'xlnet-base-cased',
    num_labels = 2,
    output_attentions = True,
    output_hidden_states = False,
)
device = torch.device('cuda')
model.load_state_dict(torch.load(path,map_location=device))
model.to(device)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)

def preprocessing(input_text, tokenizer):
  return tokenizer.encode_plus(
                        input_text,
                        add_special_tokens = True,
                        max_length = 250,
                        truncation=True,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt'
                   )

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.bias', 'logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

In [None]:
def classifier(comment:str):
  # We need Token IDs and Attention Mask for inference on the new sentence
  test_ids = []
  test_attention_mask = []
  # Apply the tokenizer
  encoding = preprocessing(comment, tokenizer)
  # Extract IDs and Attention Mask
  test_ids.append(encoding['input_ids'])
  test_attention_mask.append(encoding['attention_mask'])
  test_ids = torch.cat(test_ids, dim = 0)
  test_attention_mask = torch.cat(test_attention_mask, dim = 0)
  # Forward pass, calculate logit predictions
  with torch.no_grad():
    output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
  prediction = 'TOXIC' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'NON-TOXIC'
  return prediction

In [None]:
#@title
print(classifier("I am african"))
print(classifier("I love african people"))
print(classifier("I hate african people"))
print('-')
print(classifier("I am blind"))
print(classifier("I love blind people"))
print(classifier("I hate blind people"))

## Compute F1-score on test dataset

In [None]:
from sklearn.metrics import f1_score
predictions = []
for comment_text in test_text:
  result = classifier(comment_text)
  if(result=='TOXIC'):
    predictions.append(1)
  else:
    predictions.append(0)
print("F1-score on test dataset (XLNet 0_01): "+str(f1_score(test_labels,predictions)))



F1-score on test dataset (XLNet 0_01): 0.6067347249733354


## Synthetic dataset pre-processing

In [None]:
from sklearn.metrics import roc_auc_score
synthetic_test_df = pd.read_csv('/content/drive/My Drive/DATASETS/en_sentence_templates.csv')
#All the identities considered in EAR paper
list_identities = ["lesbian", "gay", "bisexual","transgender", "trans", "queer", "lgbt","lgbtq", "homosexual", "straight", "heterosexual","male", "female", "nonbinary", "african","african american", "black", "white", "european","hispanic", "latino", "latina", "latinx", "mexican","canadian", "american", "asian", "indian","middle eastern", "chinese", "japanese", "christian","muslim", "jewish", "buddhist", "catholic","protestant", "sikh", "taoist", "old", "older", "young","younger", "teenage", "millenial", "middle aged","elderly", "blind", "deaf", "paralyzed"]
print(len(list_identities))
#A sub_list of all the identities
gender_sub_list_identities = ["lesbian", "gay", "bisexual","transgender", "trans", "queer", "lgbt","lgbtq", "homosexual", "straight", "heterosexual","male", "female", "nonbinary"]
ethnicity_sub_list_identities = ["african","african american", "black", "white", "european","hispanic", "latino", "latina", "latinx", "mexican","canadian", "american", "asian", "indian","middle eastern", "chinese", "japanese"]
religion_sub_list_identities = ["christian","muslim", "jewish", "buddhist", "catholic","protestant", "sikh", "taoist"]
age_sub_list_identities = ["old", "older", "young","younger", "teenage", "millenial", "middle aged","elderly", "blind", "deaf", "paralyzed"]


In [None]:
print(synthetic_test_df.shape[0])
#drop some random names
name_drop = ['Êú™Êú™','Zoe','Ziv','Ying','Thel','Sulin','Sooyin','Sohaib','Sara','Santiago','Ram√≥n','Nithum','Myesha','Mukhtar','Morgan','Mohammed','Moeta','Medina','Mateo','Mahir','Lucas','ÏÑúÏú§','Abdul','Agust√≠n','Aishah','Aki','Alva','Anastasiya','Angela','Anita','Baratunde','Binh','Chris','Dlshad','Elizabeth','Emad','Hiro','Hoshi','Ioannis','Javeria','Jim','Joe','L√©','Liam','Ling']
# drop the rows that contain the string name in the 'phrase' column
for name in name_drop:
  synthetic_test_df = synthetic_test_df[~synthetic_test_df['phrase'].str.contains(name)]
print(synthetic_test_df.shape[0])

In [None]:
# drop_identities = list_identities - sub_list_identities
drop_identities = [x for x in list_identities if x not in age_sub_list_identities]
for identity in drop_identities:
  # drop the rows that contain the string identity in the 'phrase' column
  synthetic_test_df = synthetic_test_df[~synthetic_test_df['phrase'].str.contains(identity)]
print(synthetic_test_df.shape[0])

## Compute AUC metrics

In [None]:
def compute_auc(df:pd.DataFrame):
  predictions = []
  list_comments = list(df['phrase'])
  for comment in list_comments:
    result = classifier(comment)
    if(result=='TOXIC'):
      predictions.append(1)
    else:
      predictions.append(0)
  print(predictions)
  df["toxicity"] = df["toxicity"].map({"nontoxic": 0, "toxic": 1})
  return roc_auc_score(df['toxicity'],predictions)

In [None]:
metrics_results = {'subgroup_AUC':[],'BPSN_AUC':[],'BNSP_AUC':[]}
for identity in age_sub_list_identities:
  print(identity)
  subgroup_auc_df = synthetic_test_df.loc[synthetic_test_df['phrase'].str.contains(identity)].copy()
  score = compute_auc(subgroup_auc_df)
  metrics_results['subgroup_AUC'].append((identity,score))
  BPSN_auc_df = synthetic_test_df.loc[((synthetic_test_df['phrase'].str.contains(identity)) & (synthetic_test_df['toxicity'] == 'nontoxic')) | (~(synthetic_test_df['phrase'].str.contains(identity)) & (synthetic_test_df['toxicity'] == 'toxic'))].copy()
  score = compute_auc(BPSN_auc_df)
  metrics_results['BPSN_AUC'].append((identity,score))
  BNSP_auc_df = synthetic_test_df.loc[((synthetic_test_df['phrase'].str.contains(identity)) & (synthetic_test_df['toxicity'] == 'toxic')) | (~(synthetic_test_df['phrase'].str.contains(identity)) & (synthetic_test_df['toxicity'] == 'nontoxic'))].copy()
  score = compute_auc(BNSP_auc_df)
  metrics_results['BNSP_AUC'].append((identity,score))

In [None]:
def compute_avg_scores(AUC_dict:dict):
  #print(AUC_dict)
  for metric in AUC_dict:
    sum=0
    for tup in AUC_dict[metric]:
      sum += tup[1]
    average = sum/len(AUC_dict[metric])
    print('Avg '+metric+': '+str(round(average, 10)))

In [None]:
compute_avg_scores(metrics_results)

## Compute F1-score on synthetic dataset

In [None]:
from sklearn.metrics import f1_score
synthetic_comments = list(synthetic_test_df['phrase'])
synthetic_labels = synthetic_test_df["toxicity"].map({"nontoxic": 0, "toxic": 1})
predictions = []
for comment in synthetic_comments:
  result = classifier(comment)
  if(result=='TOXIC'):
    predictions.append(1)
  else:
    predictions.append(0)
print("F1-score on synthetic dataset (XLNet 0_00): "+str(f1_score(synthetic_labels,predictions)))

## Compute Avg AUC scores for all identities

In [None]:
FINAL_xlnet_ear_reg_0_00 = {'subgroup_AUC': [('lesbian', 0.6616541353383458), ('gay', 0.5075187969924813), ('bisexual', 0.5526315789473684), ('transgender', 0.6992481203007519), ('trans', 0.8233082706766917), ('queer', 0.5864661654135338), ('lgbt', 0.9774436090225564), ('lgbtq', 0.9774436090225564), ('homosexual', 0.5), ('straight', 0.9586466165413533), ('heterosexual', 0.6015037593984962), ('male', 0.9812030075187971), ('female', 0.9849624060150376), ('nonbinary', 0.9661654135338346), ('african', 0.9774436090225563), ('african american', 0.981203007518797), ('black', 0.7969924812030076), ('white', 0.9699248120300752), ('european', 0.9736842105263157), ('hispanic', 0.9699248120300752), ('latino', 0.9473684210526316), ('latina', 0.9473684210526316), ('latinx', 0.9661654135338344), ('mexican', 0.9699248120300752), ('canadian', 0.9736842105263157), ('american', 0.981203007518797), ('asian', 0.9135338345864661), ('indian', 0.9661654135338347), ('middle eastern', 0.9548872180451127), ('chinese', 0.9774436090225564), ('japanese', 0.9736842105263157), ('christian', 0.9887218045112782), ('muslim', 0.9736842105263158), ('jewish', 0.8571428571428572), ('buddhist', 0.9774436090225564), ('catholic', 0.93609022556391), ('protestant', 0.969924812030075), ('sikh', 0.9774436090225564), ('taoist', 0.906015037593985), ('old', 0.9699248120300752), ('older', 0.9624060150375939), ('young', 0.9718045112781954), ('younger', 0.9586466165413534), ('teenage', 0.9548872180451128), ('millenial', 0.9736842105263157), ('middle aged', 0.9736842105263158), ('elderly', 0.981203007518797), ('blind', 0.7819548872180451), ('deaf', 0.612781954887218), ('paralyzed', 0.6278195488721805)], 'BPSN_AUC': [('lesbian', 0.6474840948525158), ('gay', 0.4933487565066512), ('bisexual', 0.5384615384615384), ('transgender', 0.6931752458068248), ('trans', 0.8342731829573934), ('queer', 0.5803932909196068), ('lgbt', 0.9884085213032581), ('lgbtq', 0.9875650665124349), ('homosexual', 0.48582995951417), ('straight', 0.9728166570271833), ('heterosexual', 0.5954308849045691), ('male', 0.9790100250626567), ('female', 0.9829381145170618), ('nonbinary', 0.9843840370156159), ('african', 0.9765664160401002), ('african american', 0.9783834586466165), ('black', 0.7861842105263159), ('white', 0.9671052631578947), ('european', 0.9788533834586466), ('hispanic', 0.9631109022556391), ('latino', 0.9485432330827068), ('latina', 0.9405545112781956), ('latinx', 0.9593515037593985), ('mexican', 0.9790883458646616), ('canadian', 0.9788533834586466), (
    'american', 0.9781954887218045), ('asian', 0.918703007518797), ('indian', 0.9673402255639099), ('middle eastern', 0.9600563909774436), ('chinese', 0.9786184210526316), ('japanese', 0.9788533834586466), ('christian', 0.976906552094522), ('muslim', 0.9747583243823845), ('jewish', 0.8582169709989259), ('buddhist', 0.9785177228786252), ('catholic', 0.9371643394199786), ('protestant', 0.9752953813104188), ('sikh', 0.9785177228786252), ('taoist', 0.9070891514500539), ('old', 0.9809941520467836), ('older', 0.9827067669172933), ('young', 0.9828738512949039), ('younger', 0.9830827067669172), ('teenage', 0.9462406015037593), ('millenial', 0.9815789473684211), ('middle aged', 0.9774436090225563), ('elderly', 0.9808270676691729), ('blind', 0.7691729323308271), ('deaf', 0.5958646616541354), ('paralyzed', 0.6150375939849624)], 'BNSP_AUC': [('lesbian', 0.8016194331983806), ('gay', 0.813475997686524), ('bisexual', 0.8100057836899942), ('transgender', 0.7906304222093695), ('trans', 0.7600250626566416), ('queer', 0.799305957200694), ('lgbt', 0.7343358395989975), ('lgbtq', 0.7530364372469636), ('homosexual', 0.8140543666859457), ('straight', 0.7504337767495661), ('heterosexual', 0.7981492192018507), ('male', 0.7468671679197996), ('female', 0.7646038172353962), ('nonbinary', 0.7458068247541932), ('african', 0.9530075187969925), ('african american', 0.9562969924812029), ('black', 0.9757988721804511), ('white', 0.9570018796992481), ('european', 0.9487781954887217), ('hispanic', 0.9609962406015037), ('latino', 0.9544172932330828), ('latina', 0.962406015037594), ('latinx', 0.9612312030075187), ('mexican', 0.9450187969924813), ('canadian', 0.9487781954887217), ('american', 0.9546365914786967), ('asian', 0.9525375939849624), ('indian', 0.9532424812030076), ('middle eastern', 0.949953007518797), ('chinese', 0.9525375939849625), ('japanese', 0.9487781954887217), ('christian', 0.9543501611170784), ('muslim', 0.943609022556391), ('jewish', 0.9602577873254566), ('buddhist', 0.9430719656283566), ('catholic', 0.9489795918367347), ('protestant', 0.9398496240601503), ('sikh', 0.9430719656283566), ('taoist', 0.9532760472610098), ('old', 0.8611111111111109), ('older', 0.8624060150375941), ('young', 0.8606934001670843), ('younger', 0.8586466165413534), ('teenage', 0.8921052631578947), ('millenial', 0.8736842105263157), ('middle aged', 0.8778195488721805), ('elderly', 0.8812030075187971), ('blind', 0.9135338345864661), ('deaf', 0.9345864661654135), ('paralyzed', 0.9289473684210525)]}

FINAL_xlnet_ear_reg_0_01 = {'subgroup_AUC': [('lesbian', 0.8533834586466166), ('gay', 0.6992481203007519), ('bisexual', 0.8984962406015038), ('transgender', 0.8984962406015038), ('trans', 0.9135338345864661), ('queer', 0.5939849624060151), ('lgbt', 0.9229323308270676), ('lgbtq', 0.9172932330827067), ('homosexual', 0.5037593984962406), ('straight', 0.8759398496240601), ('heterosexual', 0.868421052631579), ('male', 0.981203007518797), ('female', 0.9924812030075187), ('nonbinary', 0.9135338345864661), ('african', 0.9680451127819549), ('african american', 0.9624060150375939), ('black', 0.9360902255639098), ('white', 0.9060150375939849), ('european', 0.9624060150375939), ('hispanic', 0.8609022556390977), ('latino', 0.9210526315789475), ('latina', 0.9022556390977443), ('latinx', 0.9097744360902256), ('mexican', 0.9060150375939849), ('canadian', 0.9398496240601503), ('american', 0.9736842105263158), ('asian', 0.9285714285714285), ('indian', 0.951127819548872), ('middle eastern', 0.87593984962406), ('chinese', 0.9661654135338346), ('japanese', 0.9285714285714285), ('christian', 0.9774436090225564), ('muslim', 0.9511278195488723), ('jewish', 0.9097744360902255), ('buddhist', 0.9360902255639098), ('catholic', 0.8947368421052632), ('protestant', 0.9248120300751879), ('sikh', 0.9661654135338347), ('taoist', 0.9060150375939849), ('old', 0.8759398496240602), ('older', 0.8646616541353384), ('young', 0.8909774436090225), ('younger', 0.8796992481203008), ('teenage', 0.9548872180451128), ('millenial', 0.8721804511278194), ('middle aged', 0.8796992481203008), ('elderly', 0.9436090225563909), ('blind', 0.9285714285714284), ('deaf', 0.9022556390977443), ('paralyzed', 0.9022556390977443)], 'BPSN_AUC': [('lesbian', 0.8241758241758242), ('gay', 0.6700404858299596), ('bisexual', 0.8692886061307115), ('transgender', 0.8692886061307115), ('trans', 0.919172932330827), ('queer', 0.5647773279352226), ('lgbt', 0.9548872180451128), ('lgbtq', 0.9609600925390399), ('homosexual', 0.47455176402544824), ('straight', 0.9762868710237129), ('heterosexual', 0.8392134181607868), ('male', 0.9495614035087718), ('female', 0.9632735685367264), ('nonbinary', 0.9774436090225564), ('african', 0.9106516290726817), ('african american', 0.9046052631578948), ('black', 0.8982612781954887), ('white', 0.9241071428571428), ('european', 0.9405545112781953), ('hispanic', 0.9508928571428572), ('latino', 0.90718984962406), ('latina', 0.9363251879699248), ('latinx', 0.9039003759398496), ('mexican', 0.9480733082706767), (
    'canadian', 0.9419642857142856), ('american', 0.918421052631579), ('asian', 0.9306860902255639), ('indian', 0.9412593984962405), ('middle eastern', 0.9419642857142857), ('chinese', 0.9403195488721803), ('japanese', 0.9426691729323308), ('christian', 0.9215896885069818), ('muslim', 0.9210526315789475), ('jewish', 0.9355531686358753), ('buddhist', 0.9360902255639098), ('catholic', 0.9591836734693877), ('protestant', 0.9334049409237379), ('sikh', 0.9360902255639099), ('taoist', 0.9232008592910849), ('old', 0.9342105263157895), ('older', 0.9315789473684211), ('young', 0.9331662489557226), ('younger', 0.9300751879699248), ('teenage', 0.8977443609022557), ('millenial', 0.9266917293233082), ('middle aged', 0.8969924812030075), ('elderly', 0.9236842105263158), ('blind', 0.88796992481203), ('deaf', 0.8285714285714285), ('paralyzed', 0.8409774436090225)], 'BNSP_AUC': [('lesbian', 0.8744939271255061), ('gay', 0.8863504916136495), ('bisexual', 0.8710237131289763), ('transgender', 0.8710237131289763), ('trans', 0.8289473684210527), ('queer', 0.8944476576055523), ('lgbt', 0.8010651629072681), ('lgbtq', 0.7967032967032966), ('homosexual', 0.9013880855986119), ('straight', 0.7432041642567958), ('heterosexual', 0.8733371891266628), ('male', 0.8549498746867168), ('female', 0.8637941006362059), ('nonbinary', 0.7767495662232503), ('african', 0.9827067669172931), ('african american', 0.9861372180451128), ('black', 0.9678101503759398), ('white', 0.9137687969924813), ('european', 0.950187969924812), ('hispanic', 0.8446898496240601), ('latino', 0.9447838345864662), ('latina', 0.8980263157894736), ('latinx', 0.9375), ('mexican', 0.8898026315789473), ('canadian', 0.9276315789473684), ('american', 0.9798245614035087), ('asian', 0.9283364661654134), ('indian', 0.9389097744360902), ('middle eastern', 0.8677161654135337), ('chinese', 0.9539473684210527), ('japanese', 0.9163533834586467), ('christian', 0.9828141783029002), ('muslim', 0.9607948442534909), ('jewish', 0.9108485499462943), ('buddhist', 0.9328678839957036), ('catholic', 0.8743286788399569), ('protestant', 0.9258861439312566), ('sikh', 0.9586466165413533), ('taoist', 0.9199785177228785), ('old', 0.8489974937343359), ('older', 0.8383458646616542), ('young', 0.8617376775271512), ('younger', 0.8533834586466167), ('teenage', 0.9533834586466166), ('millenial', 0.85), ('middle aged', 0.8864661654135338), ('elderly', 0.9172932330827068), ('blind', 0.9394736842105262), ('deaf', 0.975187969924812), ('paralyzed', 0.9627819548872181)]}

print("FINAL_xlnet_ear_reg_0_00")
compute_avg_scores(FINAL_xlnet_ear_reg_0_00)
print("FINAL_xlnet_ear_reg_0_01")
compute_avg_scores(FINAL_xlnet_ear_reg_0_01)