This Notebook is intended to test relations that are symmetric or inverse.
So to test this, we will ask "a relation []?" and then take the prediction p and ask "p relation []?". If the model has learnt that the relation is symmetric (or inverse), it should then say a, regardless of whether that is now correct or not.
Unfortunately this approach only works properly for 1-to-1 relations.

In [0]:
!pip install transformers
!pip install torch

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/22/97/7db72a0beef1825f82188a4b923e62a146271ac2ced7928baa4d47ef2467/transformers-2.9.1-py3-none-any.whl (641kB)
[K     |████████████████████████████████| 645kB 2.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 13.0MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 17.5MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/3b/88/49e772d686088e1278766ad68a463513642a2a877487decbd691dec02955/sentencepiece-0.1.90-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |██████████

In [0]:
# subjects = ['Austria', 'Denmark', 'Switzerland', 'Ukraine', 'Belarus', 'Estonia', 'Afghanistan', 'Mexico', 'Egypt', 'Angola', 'Honduras', 'Panama', 'Turkey', 'Belgium', 'Mongolia', 'Hungary', 'Niger']
subjects = ['to disagree', 'to admit', 'new', 'fresh', 'novel', 'job', 'occupation', 'price', 'cost', 'peak', 'summit', 'state', 'nation', 'country', 'land', 'earth', 'world', 'humanity', 'mankind', 'good', 'competent', 'situated', 'entire', 'path', 'big', 'great', 'large', 'real', 'actual', 'center', 'middle', 'travel', 'journey', 'trip', 'correct', 'proper', 'result', 'outcome', 'purpose', 'intention', 'region', 'subject', 'topic', 'field', 'discipline', 'home', 'category', 'type', 'class', 'black', 'dark', 'particular', 'movie', 'film', 'photo', 'picture']
# subjects = ['absent', 'mountain', 'valley', 'present', 'dirt', 'house', 'careless', 'expensive', 'indefinite']
lines = open('/content/drive/My Drive/tmp/antonyms').readlines()
ant = []
ant_verbs = []
relation = 'is the same as'
verb = 'is the same as'

for line in lines:
  if len(line) > 2 and ',' not in line:
    line = line.replace('–','-')
    first, second = line.strip().split(' - ')
    # firsts = first.split(',')
    # seconds = second.replace(' v','').split(',')
    if ' v' in second:
      second = second.replace(' v','')
      first, second = 'to '+first, 'to '+second
      first, second = f'[CLS] {first} {relation} to [MASK] . [SEP]', f'[CLS] {second} {relation} to [MASK] . [SEP]'
      ant_verbs.append((first, second))
    else:
      first, second = f'[CLS] {first} {relation} [MASK] . [SEP]', f'[CLS] {second} {relation} [MASK] . [SEP]'
      ant.append((first, second))

# probes = [f'{subj} {relation}' for subj in subjects]


In [0]:

cased_model = True
numb_predictions_displayed = 5
ignore_self_reference_output = True # BERT tends to predict the subject again in many cases. This can be ignored.

In [0]:
lines = open('capital_templates').readlines()
probes = []
for i in range(0,len(lines),2):
  country = lines[i]
  city = lines[i+1]
  country = f'[CLS] {country.strip()} [MASK] . [SEP]'
  city = f'[CLS] {city.strip()} [MASK] . [SEP]'
  probes.append((country, city))
probes[:5]

In [0]:
def predict(model, tokenized_text):
  masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]

  indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
  segments_ids = [0]*len(tokenized_text)

  tokens_tensor = torch.tensor([indexed_tokens])
  segments_tensors = torch.tensor([segments_ids])

  # Predict all tokens
  with torch.no_grad():
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
      predictions = outputs[0][0][masked_index]
  predicted_ids = torch.argsort(predictions, descending=True)[:numb_predictions_displayed]
  predicted_tokens = tokenizer.convert_ids_to_tokens(list(predicted_ids))
  return predicted_tokens


In [0]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer
import numpy as np

# from: https://huggingface.co/transformers/quickstart.html#bert-example

bert_model = 'bert-large-cased' if cased_model else 'bert-large-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model)
model.eval()

In [0]:
output = {'symmetric_correct': [], 'symmetric_incorrect': [],'asymmetric_correct':[], 'asymmetric_incorrect':[]}
subj_idx = 1
for probe in probes:
  probe1, probe2 = probe[0], probe[1]
  subj = probe1.split()[subj_idx]
  obj = probe2.split()[subj_idx]
  if len(tokenizer.tokenize(subj)) > 1 or len(tokenizer.tokenize(obj)) > 1:
    print('INVALID: '+str(tokenizer.tokenize(subj)) + str(tokenizer.tokenize(obj)))
    continue

  tokenized_probe1 = tokenizer.tokenize(probe1)
  tokenized_probe2 = tokenizer.tokenize(probe2)
  print(tokenized_probe1)
  print(tokenized_probe2)
  predicted_tokens = predict(model, tokenized_probe1)

  # determine first pred
  first_pred = predicted_tokens[0] if predicted_tokens[0] != tokenized_probe1[subj_idx] else predicted_tokens[1]

  reverse_probe_str = probe2.replace(obj, first_pred)
  print(reverse_probe_str)
  reverse_probe = tokenizer.tokenize(reverse_probe_str)
  rev_predicted_tokens = predict(model, reverse_probe)
  
  rev_first_pred = rev_predicted_tokens[0] if rev_predicted_tokens[0] != reverse_probe[subj_idx] else rev_predicted_tokens[1]
  
  correct1 = first_pred == tokenized_probe2[subj_idx]
  result = ((probe1, predicted_tokens), (reverse_probe_str, rev_predicted_tokens))
  if rev_first_pred == tokenized_probe1[subj_idx]:
    if correct1:
      output['symmetric_correct'].append(result)
    else:
      output['symmetric_incorrect'].append(result)
  else:
    if correct1:
      output['asymmetric_correct'].append(result)
    else:
      output['asymmetric_incorrect'].append(result)
  if not correct1 and not rev_first_pred == tokenized_probe1[1]:
    predicted_tokens2 = predict(model, tokenized_probe2)
    first_pred2 = predicted_tokens2[0] if predicted_tokens2[0] != tokenized_probe2[subj_idx] else predicted_tokens2[1]
    reverse_probe_str2 = probe1.replace(subj, first_pred2)
    reverse_probe2 = tokenizer.tokenize(reverse_probe_str2)
    rev_predicted_tokens2 = predict(model, reverse_probe2)
    rev_first_pred2 = rev_predicted_tokens2[0] if rev_predicted_tokens2[0] != reverse_probe2[subj_idx] else rev_predicted_tokens2[1]
    correct2 = first_pred2 == tokenized_probe1[subj_idx]
    symmetric2 = rev_first_pred2 == tokenized_probe2[subj_idx]
    result = ((probe2, predicted_tokens2), (reverse_probe_str2, rev_predicted_tokens2))
    if correct2 and not symmetric2:
      output['asymmetric_correct'].append(result)
    if correct2 and symmetric2:
      print("IMPOSSIBLE"+str(result))
    if not correct2 and symmetric2:
      output['symmetric_incorrect'].append(result)
    if not correct2 and not symmetric2:
      output['asymmetric_incorrect'].append(result)


In [0]:
for key, val in output.items():
  print(key + ' ' + str(len(val)))
  for (probe, pred), (rev_probe, rev_pred) in val:
    print(f'{probe} → {pred}')
    print(f'{rev_probe} → {rev_pred}\n')

  print('\n')