In [83]:
import pandas as pd
import os
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import re
from scipy.spatial import distance
import matplotlib.pyplot as plt
import seaborn as sb
from torch.nn import functional as F
from torch.nn.functional import softmax
#import tensorflow as tf

In [84]:
class ComparePredictions:
    
    def __init__(self, data, targets, model):
        self.data = data
        self.targets = targets
        self.model = AutoModelForMaskedLM.from_pretrained(model)
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.dfData = self.process_sentences()
        self.sent_encodings, self.word_encodings, self.mask_idxs = self.make_encodings() #store the encodings
        
    def run_model_and_evaluate(self):
        output = self.make_predictions()
        print(self.dfData)
        #tweet_df = pd.DataFrame(self.dfData, columns=columns=["template", "target_place", "attribute", "sentence", "predictions"])
        self.dfData.to_csv("results.csv", sep=";")
        
        
    def process_sentences(self):
        person = "<person>"
        attribute = "<attribute>"
        dfData = []
        for index,row in self.data.iterrows():
            target_place = row['target_place']
            sentence = str(row['template'])
            attributes = str(row['attributes']).split(',')
            for att in attributes:
                for tar in self.targets:
                    _sentence = ""
                    _sentence = [re.sub(attribute, str(att), sentence)]
                    _sentence = [re.sub(person, str(tar), "".join(_sentence))]
                    #candidate_sentence.append("".join(_sentence))
                    data = [
                        sentence,
                        tar,
                        att,
                        "".join(_sentence)
                    ]
                    dfData.append(data)
        return pd.DataFrame(dfData, columns=["template", "target_place", "attribute", "sentence"])

    #find the mask indices for the encoded sentence.
    def get_sublist_idxs_in_list(self, word, sentence):
        possibles = np.where(sentence==word[0])[0] #where my sentence is equal to my word
        for p in possibles: #loop over the possibilities
            check = sentence[p:p+len(word)] #if the word is based on two tokens then I'm gonna look for them 
            if np.all(check == word):
                return list(range(p,(p+len(word)))) #return back the positions of the tokens
    
    
    def make_encodings(self): 
        sent_encoding = [] 
        word_encoding = [] 
        mask_idxs = [] 
        for index,row in self.dfData.iterrows():
            encoded_word = self.tokenizer.encode(str(" "+ row.loc['attribute']),add_special_tokens=False) 
            encoded_sent = self.tokenizer.encode_plus(row.loc['sentence'], add_special_tokens = True, return_tensors = 'pt', padding='max_length', max_length=128, return_attention_mask=True)
            tokens_to_mask_idx = self.get_sublist_idxs_in_list(np.array(encoded_word),np.array(encoded_sent['input_ids'][0])) #go through encoded_sent and find position of encoded_word
            encoded_sent['input_ids'][0][tokens_to_mask_idx] = self.tokenizer.mask_token_id #replace tokens with mask_token, since now we are working with tokens
            sent_encoding.append(encoded_sent)
            word_encoding.append(encoded_word)
            mask_idxs.append(tokens_to_mask_idx)
        return sent_encoding , word_encoding , mask_idxs
    
    def make_predictions(self):
        predictions =[]
        for q_idx, (w, s, m) in enumerate(zip(self.word_encodings, self.sent_encodings, self.mask_idxs)):
            candidate_logits = self.model(s['input_ids'], attention_mask=s['attention_mask']).logits 
            mask_token_logits = candidate_logits[0, m, w] # here we want to find the raw prediction for the candidate word
            candidate_score = float(torch.mean(mask_token_logits))
            # logits = softmax(candidate_logits, dim= -1)
            # out = logits.detach().numpy()[0][0][0]
            #softmax = F.softmax(candidate_logits.logits[0], dim=-1)
            # Find the words' probabilities in this probability distribution
            #priors = [softmax[m, self.tokenizer.vocab[t]].numpy()[0] ]
            predictions.append(candidate_score)
        self.dfData['prediction'] = predictions
            #mask_token_logits = candidate_logits[0, m, w] # here we want to find the raw prediction for the candidate word
            #candidate_score = float(torch.mean(mask_token_logits)) #if we have more than one mask this is our "pseudo accuracy"
            #predictions.append(candidate_score)
            #print(f"prediction: {q_idx}, values: {predictions}")
        return predictions

    # def get_graph(self, predictions):
    #     x_axis = []
    #     temp = ""
    #     bar1, bar2 = [], []
    #     for idx, row in self.dfData:
    #         if temp == row.loc['template']:
                
    #         else:
                
    #             bar1, bar2 = [], []
    #             temp = row.loc['template']
    #         if row.loc['target'] == self.targets[0]:
    #             bar1.append(predictions[idx])
    #         else:
    #             bar2.append(predictions[idx])
            

In [85]:
model = 'distilroberta-base'
targets = ["male", "female"]
template = pd.read_csv('word_prediction/refugees.csv', sep=";")
evaluator = ComparePredictions(template, targets, model)
evaluator.run_model_and_evaluate()


                                              template target_place attribute  \
0                  <person> refugees are <attribute> .         male  excluded   
1                  <person> refugees are <attribute> .       female  excluded   
2                  <person> refugees are <attribute> .         male  eligible   
3                  <person> refugees are <attribute> .       female  eligible   
4                  <person> refugees are <attribute> .         male   unknown   
..                                                 ...          ...       ...   
435  <person> refugees <attribute> find place in as...       female     might   
436  <person> refugees <attribute> find place in as...         male    rarely   
437  <person> refugees <attribute> find place in as...       female    rarely   
438  <person> refugees <attribute> find place in as...         male       now   
439  <person> refugees <attribute> find place in as...       female       now   

                           