In [0]:
!pip install transformers
!pip install torch

In [0]:
cased_model = True
model_type = 'roberta'
mask_token = '[MASK]' if model_type == 'bert' else '<mask>'

In [0]:
order = ['The Bible', 'The Divine Comedy', 'The Wealth of Nations', 'Faust', 'Moby Dick', 'Brave New World', 'To Kill a Mockingbird', 'Harry Potter']
# order = ['Jimmy Carter', 'Donald Trump', 'Barack Obama', 'Tiger Woods', 'Serena Williams', 'Ariana Grande', 'Kylie Jenner']
# order = ['falcon', 'cheetah', 'swordfish','antelope', 'lion', 'kangaroo', 'dog', 'pig', 'cow', 'hedgehog', 'snail']
# order = ['Ferrari', 'Porsche', 'Audi', 'VW', 'bike']
# order = ['the Sun', 'Jupiter', 'Saturn', 'Uranus', 'Neptune', 'Earth', 'Venus', 'Mars', 'Mercury' ,'Pluto']
#ordered by population:
# order = ['China', 'Indonesia', 'Brazil', 'Japan', 'Egypt', 'Germany', 'Italy', 'Argentina', 'Australia', 'Chile', 'Belgium', 'Sweden', 'Denmark', 'Ireland', 'Slovenia', 'Malta']
# order = ['New York City', 'Los Angeles', 'Chicago', 'Houston', 'Phoenix', 'Philadelphia', 'San Antonio', 'San Diego', 'Dallas', 'San Jose', 'Austin', 'San Francisco', 'Seattle', 'Boston', 'Detroit', 'Portland', 'Las Vegas', 'Atlanta', 'Miami', 'New Orleans']
# order = ['nail', 'pen', 'laptop', 'table', 'house', 'airplane', 'city', 'sun'] #reproduce olmpics
# order = ['Russia', 'Canada', 'China', 'Brazil','Australia', 'India', 'Argentina', 'Kazakhstan', 'Algeria', 'Saudi Arabia', 'Mexico', 'Indonesia', 'Turkey', 'France', 'Italy', 'Ireland', 'Belgium', 'Monaco']
# bigger = ['faster']
# smaller = ['slower']
# bigger = ['bigger', 'larger', 'greater', 'more', 'older']
# smaller = ['smaller', 'less', 'younger']
# bigger = ['older']
# smaller = ['younger']
bigger = ['before']
smaller = ['after']

In [0]:
pairs = []
for i in range(len(order)):
  for j in range(i+1, len(order)):
    e1, e2 = order[i], order[j]
    prefix = ''
    probe = f'{prefix}{e1} was published {mask_token} {prefix}{e2} .'
    # probe = '[CLS] ' + probe + ' [SEP]'
    rev_probe = f'{prefix}{e2} was published {mask_token} {prefix}{e1} .'
    # rev_probe = '[CLS] ' + rev_probe + ' [SEP]'
    probe, rev_probe = probe[0].upper() + probe[1:], rev_probe[0].upper() + rev_probe[1:]
    pairs.append(((order[i], order[j]), (probe, rev_probe)))
pairs[:5]

[(('The Bible', 'The Divine Comedy'),
  ('The Bible was published <mask> The Divine Comedy .',
   'The Divine Comedy was published <mask> The Bible .')),
 (('The Bible', 'The Wealth of Nations'),
  ('The Bible was published <mask> The Wealth of Nations .',
   'The Wealth of Nations was published <mask> The Bible .')),
 (('The Bible', 'Faust'),
  ('The Bible was published <mask> Faust .',
   'Faust was published <mask> The Bible .')),
 (('The Bible', 'Moby Dick'),
  ('The Bible was published <mask> Moby Dick .',
   'Moby Dick was published <mask> The Bible .')),
 (('The Bible', 'Brave New World'),
  ('The Bible was published <mask> Brave New World .',
   'Brave New World was published <mask> The Bible .'))]

In [0]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer
import numpy as np

# from: https://huggingface.co/transformers/quickstart.html#bert-example
if model_type == 'bert':
  bert_model = 'bert-large-cased' if cased_model else 'bert-large-uncased'
  tokenizer = BertTokenizer.from_pretrained(bert_model)
  model = BertForMaskedLM.from_pretrained(bert_model)
elif model_type == 'roberta':
  tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
  model = RobertaForMaskedLM.from_pretrained('roberta-large')
model.eval()
print(f'{model_type} ready')

roberta ready


In [0]:
all_adj = bigger + smaller
proc_adj = ['Ġ'+adj for adj in all_adj] if model_type == 'roberta' else all_adj
ids = tokenizer.convert_tokens_to_ids(proc_adj)
adj2id = {adj:i for adj, i in zip(all_adj, ids)}
print(adj2id)

{'before': 137, 'after': 71}


In [0]:
def predict(model, tokenized_text):
  if model_type == 'bert':
    print(tokenized_text)
    masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [0]*len(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    with torch.no_grad():
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
    predictions = outputs[0][0][masked_index]
    predicted_ids = torch.argsort(predictions, descending=True)[:5]
    top_tokens = tokenizer.convert_ids_to_tokens(predicted_ids)
    
  
  elif model_type == 'roberta':
    print(tokenized_text)
    masked_index = tokenized_text.index('<mask>')+1
    input_ids = torch.tensor(tokenizer.encode(tokenized_text,add_special_tokens=True)).unsqueeze(0)
    with torch.no_grad():
      outputs = model(input_ids, masked_lm_labels=input_ids)
    loss, prediction_scores = outputs[:2]
    predictions = prediction_scores[0, masked_index]
    predicted_k_indexes = torch.topk(prediction_scores[0, masked_index],k=5)
    predicted_indexes_list = predicted_k_indexes[1]
    top_tokens = [tokenizer.decode(i) for i in predicted_indexes_list.tolist()]
  return predictions, top_tokens

def top_adj(preds):
  top_val = 0
  top_adj = ''
  for adj, i in adj2id.items():
    if preds[i] > top_val:
      top_val = preds[i]
      top_adj = adj
  return top_adj

In [0]:
output = {'consistent_correct': [], 'consistent_incorrect': [],'inconsistent':[]}
for i, ((e1, e2), (probe, rev_probe)) in enumerate(pairs):
  print(f'pair {i}')
  mask_preds1, top_token1 = predict(model, tokenizer.tokenize(probe))
  mask_preds2, top_token2 = predict(model, tokenizer.tokenize(rev_probe))
  token_pred1 = bigger[0] if top_adj(mask_preds1) in bigger else smaller[0]
  token_pred2 = bigger[0] if top_adj(mask_preds2) in bigger else smaller[0]
  result = ((probe.replace(mask_token, token_pred1.upper()), top_token1), (rev_probe.replace(mask_token, token_pred2.upper()), top_token2))
  if (token_pred1 == token_pred2):
    output['inconsistent'].append(result)
  elif token_pred1 == bigger[0]:
    output['consistent_correct'].append(result)
  else:
    output['consistent_incorrect'].append(result)

pair 0
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠDivine', 'ĠComedy', 'Ġ.']
['The', 'ĠDivine', 'ĠComedy', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠBible', 'Ġ.']
pair 1
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠWealth', 'Ġof', 'ĠNations', 'Ġ.']
['The', 'ĠWealth', 'Ġof', 'ĠNations', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠBible', 'Ġ.']
pair 2
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠFaust', 'Ġ.']
['Fa', 'ust', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠBible', 'Ġ.']
pair 3
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠMob', 'y', 'ĠDick', 'Ġ.']
['M', 'oby', 'ĠDick', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠBible', 'Ġ.']
pair 4
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠBrave', 'ĠNew', 'ĠWorld', 'Ġ.']
['Brave', 'ĠNew', 'ĠWorld', 'Ġwas', 'Ġpublished', '<mask>', 'ĠThe', 'ĠBible', 'Ġ.']
pair 5
['The', 'ĠBible', 'Ġwas', 'Ġpublished', '<mask>', 'ĠTo', 'ĠKill', 'Ġa', 'ĠM', 'ocking', 'bird', 'Ġ.']
['To', 'ĠKill', 'Ġa', 'ĠM', 'ocking'

In [0]:
print([(key, len(output[key]))for key in output])
for key, val in output.items():
  print(key + ' ' + str(len(val)))
  for (probe, pred), (rev_probe, rev_pred) in val:
    print(f'{probe} → {pred}')
    print(f'{rev_probe} → {rev_pred}\n')

  print('\n')

[('consistent_correct', 5), ('consistent_incorrect', 3), ('inconsistent', 20)]
consistent_correct 5
The Divine Comedy was published BEFORE The Wealth of Nations . → [' in', ' as', ' by', ' with', ' alongside']
The Wealth of Nations was published AFTER The Divine Comedy . → [' in', ' as', ' with', ' by', ' alongside']

The Divine Comedy was published BEFORE Faust . → [' as', ' in', ' before', ' by', ' after']
Faust was published AFTER The Divine Comedy . → [' in', ' as', ' by', ' with', ' after']

The Divine Comedy was published BEFORE Brave New World . → [' in', ' by', ' at', ' as', ' on']
Brave New World was published AFTER The Divine Comedy . → [' in', ' by', ' as', ' at', ' with']

The Wealth of Nations was published BEFORE Brave New World . → [' in', ' by', ' at', ' on', ' as']
Brave New World was published AFTER The Wealth of Nations . → [' in', ' by', ' as', ' at', ' on']

The Wealth of Nations was published BEFORE Harry Potter . → [' in', ' as', ' before', ' after', ' with']
Har