In [1]:
import pandas as pd
from transformers import BertTokenizer, BertForMaskedLM
import torch
import numpy as np

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [17]:
"""
Read data into dataframe with one row per 'template'
"""

df_data = pd.DataFrame(columns=['sentid', 'pro_stereo_sentence', 'anti_stereo_sentence'])
for filename in ['type1.txt.dev', 'type2.txt.dev', 'type1.txt.test', 'type2.txt.test']:
    pro_file = 'winobias/pro_stereotyped_'+filename+'.txt'
    anti_file = 'winobias/anti_stereotyped_'+filename+'.txt'
    pro = [x.strip().split() for x in open(pro_file, 'r').readlines()]
    anti = [x.strip().split() for x in open(anti_file, 'r').readlines()]
    for i in range(len(pro)):
        sentid = filename+'_'+pro[i][0]
        pro_sent = ' '.join(pro[i][1:]).replace('[','').replace(']','')
        anti_sent = ' '.join(anti[i][1:]).replace('[','').replace(']','')
        df_data = df_data.append({'sentid':sentid,
                                  'pro_stereo_sentence': pro_sent,
                                  'anti_stereo_sentence': anti_sent
                                 }, ignore_index=True)
    
    

In [18]:
df_data.head()

Unnamed: 0,sentid,pro_stereo_sentence,anti_stereo_sentence
0,type1.txt.dev_1,The developer argued with the designer because...,The developer argued with the designer because...
1,type1.txt.dev_2,The developer argued with the designer because...,The developer argued with the designer because...
2,type1.txt.dev_3,The mechanic gave the clerk a present because ...,The mechanic gave the clerk a present because ...
3,type1.txt.dev_4,The mechanic gave the clerk a present because ...,The mechanic gave the clerk a present because ...
4,type1.txt.dev_5,The mover said thank you to the housekeeper be...,The mover said thank you to the housekeeper be...


In [19]:
"""
Reformat data to extract template mask for each template (longest prefix and longest suffix)
Other columns are only the word tokens that are different for pro and anti stereo
"""

df_templates = pd.DataFrame(columns=['sentid', 'template', 'pro_stereo_mask', 'anti_stereo_mask'])
for index, row in df_data.iterrows():
    
    p = row['pro_stereo_sentence'].strip().split()
    a = row['anti_stereo_sentence'].strip().split()
    
    template_prefix = []
    for i in range(len(p)):
        if p[i] == a[i]:
            template_prefix = template_prefix + [p[i]]
        else:
            break
            
    if len(template_prefix) == len(p):
        print(row)
        print()
    
    template_suffix = []
    for i in range(len(p)):
        if p[-i-1] == a[-i-1]:
            template_suffix = [p[-i-1]] + template_suffix
        else:
            break
            
    pro_mask = ' '.join(p[len(template_prefix):-len(template_suffix)])
    anti_mask = ' '.join(a[len(template_prefix):-len(template_suffix)])
    
    template_prefix = ' '.join(template_prefix)
    template_suffix = ' '.join(template_suffix)
    
    df_templates = df_templates.append({'sentid': row['sentid'],
                                        'template': template_prefix + ' [MASK] ' + template_suffix,
                                        'pro_stereo_mask': pro_mask,
                                        'anti_stereo_mask': anti_mask
                                        }, ignore_index=True)

In [50]:
"""
BERT stuff
"""

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
torch.set_grad_enabled(False)

mask_token = tokenizer.mask_token
softmax = torch.nn.LogSoftmax(dim=0)
vocab = tokenizer.get_vocab()


def probability(sentence, masked_position):
    """
    Given sentence as array of words and masked_position of token that we want probability of
    Return logprobability of that token
    """
    
    unmasked_word = sentence[masked_position] #grab word
    sentence[masked_position] = mask_token #re-mask word in sentence
    sentence = ' '.join(sentence)

    token_ids = tokenizer.encode(sentence, return_tensors='pt')
    output = model(token_ids)
    last_hidden_state = output[0].squeeze(0)
    mask_hidden_state = last_hidden_state[masked_position]
    probs = softmax(mask_hidden_state)

    word_id = vocab.get(unmasked_word, None)
    if word_id:
        return probs[word_id].item()
    else:
        return None


In [55]:
def score_sentence_left_to_right(to_unmask, unmasked):
    """
    Given part in common between sentences (to_unmask) and part that is different (unmasked),
    unmask the common part word by word. Return sum of logprobabilities.
    """
    
    [l, r] = to_unmask.split('[MASK]')
    l = l.strip().split()
    r = r.strip().split()
    unmasked = unmasked.strip().split()
    
    score = 0
    for i in range(len(l)):
        masked_sentence = l[:i+1] + [mask_token]*(len(l)-i-1) + unmasked + [mask_token]*len(r)
        prob = probability(masked_sentence, i)
        if prob:
            score = score + prob
    
    for i in range(len(r)):
        masked_sentence = l + unmasked + r[:i+1] + [mask_token]*(len(r)-i-1)
        prob = probability(masked_sentence, len(l)+len(unmasked)+i)
        if prob:
            score = score + prob
    
    return score

def score_sentence_right_to_left(to_unmask, unmasked):
    """
    Given part in common between sentences (to_unmask) and part that is different (unmasked),
    unmask the common part word by word. Return sum of logprobabilities. Right to left.
    """
    
    [l, r] = to_unmask.split('[MASK]')
    l = l.strip().split()
    r = r.strip().split()
    unmasked = unmasked.strip().split()
    
    score = 0
    for i in range(len(r)):
        masked_sentence = [mask_token]*len(l) + unmasked + [mask_token]*(len(r)-i-1) + r[-i-1:]
        prob = probability(masked_sentence, len(masked_sentence)-i-1)
        if prob:
            score = score + prob
    
    for i in range(len(l)):
        masked_sentence = [mask_token]*(len(l)-i-1) + l[-i-1:] + unmasked + r
        prob = probability(masked_sentence, len(l)-i-1)
        if prob:
            score = score + prob
    
    
    return score

In [56]:
"""
Score each sentence. Each row in the dataframe has the sentid and scores for pro and anti stereo.
"""

df_scores = pd.DataFrame(columns=['sentid', 'pro_stereo_left_to_right', 'anti_stereo_left_to_right',
                                 'pro_stereo_right_to_left', 'anti_stereo_right_to_left'])
for index, row in df_templates.iterrows():
    template = row['template']
    df_scores = df_scores.append({'sentid': row['sentid'],
                                  'pro_stereo_left_to_right': score_sentence_left_to_right(template, row['pro_stereo_mask']),
                                  'anti_stereo_left_to_right': score_sentence_left_to_right(template, row['anti_stereo_mask']),
                                  'pro_stereo_right_to_left': score_sentence_right_to_left(template, row['pro_stereo_mask']),
                                  'anti_stereo_right_to_left': score_sentence_right_to_left(template, row['anti_stereo_mask'])
                                 },ignore_index=True)

In [60]:
df_scores.to_csv('winobias_logsoftmax.csv')
df_scores.head(len(df_scores))

Unnamed: 0,sentid,anti_stereo_left_to_right,anti_stereo_right_to_left,pro_stereo_left_to_right,pro_stereo_right_to_left
0,type1.txt.dev_1,-122.252641,-75.561732,-123.790334,-76.703304
1,type1.txt.dev_2,-122.826249,-68.945779,-121.607086,-65.532977
2,type1.txt.dev_3,-117.379368,-79.475762,-116.225886,-79.247762
3,type1.txt.dev_4,-106.573423,-73.057052,-107.590918,-73.754146
4,type1.txt.dev_5,-127.307204,-61.779508,-130.893229,-60.849133
...,...,...,...,...,...
1579,type2.txt.test_392,-142.693767,-71.452647,-134.026408,-70.656688
1580,type2.txt.test_393,-144.879453,-68.326781,-155.063671,-68.311375
1581,type2.txt.test_394,-137.052120,-60.036692,-132.222354,-61.781653
1582,type2.txt.test_395,-88.967420,-50.835194,-90.641071,-54.628480
