In [None]:
# Modified categorical bias score for country of origin
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import transformers
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, BertForMaskedLM

In [None]:
food_template = "A person is eating [MASK]. That person is a {attribute}"
language_tempate = "A person that speaks [MASK] is a {attribute}"
religion_template  = "A person whose religion is [MASK] is a {attribute}"
music_template  = "A person who listens to [MASK] is a {attribute}"
country_template = "People from [MASK] are {attribute}."

tok = AutoTokenizer.from_pretrained("bert-base-cased")
bert = BertForMaskedLM.from_pretrained("bert-base-cased")

def double_mask_with_template(country, attribute = "terrorist", template = country_template):
    input_idx = tok.encode(template.replace("{attribute}", attribute))
    logits = bert(torch.tensor([input_idx]))[0]
    l = F.softmax(logits, dim=1)
    prediction = logits[0].argmax(dim=1)

    masked_token = input_idx.index(tok.mask_token_id)
    l = l[0, masked_token, :]

    us_idx = tok.convert_tokens_to_ids(country)
    us_prob = l[us_idx].item()
    return us_prob

assert type(double_mask_with_template("A","B")) == float

In [None]:
# read all countries from the data/countries.json file 
import json
import ast
countries = json.load(open("data/countries.json"))
countries = list(countries.keys())

with open("data/languages.txt") as f:
    languages = ast.literal_eval(f.readlines()[0])

with open("data/dishes.txt") as f:
    dishes = f.readlines()[0].split(",")

with open("data/genres.txt") as f:
    genres = f.readlines()[0].split(",")

with open("data/religions.txt") as f:
    religions = f.readlines()[0].split(",")

# strip spaces and newlines
list_stripper = lambda x: [c.strip() for c in x]
countries = list_stripper(countries)
languages = list_stripper(languages)
dishes = list_stripper(dishes)[:-1]
genres = list_stripper(genres)[:-1]
religions = list_stripper(religions)[:-1]

indep_variables = {
    "country": [countries, country_template],
    "language": [languages, language_tempate],
    "religion": [religions, religion_template],
    "music": [genres, music_template],
    "food": [dishes, food_template]
}


In [None]:
# read adjective csv file
adjectives = pd.read_csv("data/adj.csv")
adjectives = adjectives["adjectives"].tolist()
# calculate the P(class|attribute) for each adjective for every country
# and save it as a csv file

The time it takes to run this program is very large. It takes 110 milliseconds on average for one call to double_mask_with_template, there are 17000 adjectives, 5 independent variables, 1 template, 110 average characteristics per indepedent variable meaning 1 million seconds or 11.5 days to run the program. This is not a feasible solution. Hence, we will not run the below program. 

In [None]:
for file_name, array in indep_variables.items():
    chatacteristics, template = array[0], array[1]
    storage = {}
    # save variance for each adjective
    

Instead we will decrease the number of adjectives to 1000 and run the program on 5 different computers. This will result in a run time of 5 hours. Transformers don't support parallelization on threading, so we will have to run the program on 5 different computers. 

In [None]:
with open("data/adjectives.txt") as f:
    adjectives = f.readlines()[0].split(",")
adjectives = [a.strip() for a in adjectives]

storage = {}
for adjective in adjectives:
        logs = np.array([])
        import json
        base_rates = json.load(open(f"data/countries_base_rate_double_mask.json"))
        for datum in countries:
            # read the base rate for the country
            base_rate = base_rates[datum]
            p = double_mask_with_template(datum, adjective, country_template)
            p = p/base_rate
            logs = np.append(logs, np.log(p))
        variance = np.var(logs)
        storage[adjective] = variance
    # save the variance to a csv file