<a href="https://colab.research.google.com/github/BennoKrojer/Probe-Masked-LMs/blob/master/probeBERT_reverse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This Notebook is intended to test relations that are symmetric or inverse.
So to test this, we will ask "a relation []?" and then take the prediction p and ask "p relation []?". If the model has learnt that the relation is symmetric (or inverse), it should then say a, regardless of whether that is now correct or not.
Unfortunately this approach only works properly for 1-to-1 relations.

In [0]:
!pip install transformers
!pip install torch

In [0]:
subjects = ['Austria', 'Denmark', 'Switzerland', 'Ukraine', 'Belarus', 'Estonia', 'Afghanistan', 'Mexico', 'Egypt', 'Angola', 'Honduras', 'Panama', 'Turkey', 'Belgium', 'Mongolia', 'Hungary', 'Niger']
relation = 'is north of'
reverse = 'is south of'
probes = [f'{subj} {relation}' for subj in subjects]
cased_model = True
numb_predictions_displayed = 5
ignore_self_reference_output = True # BERT tends to predict the subject again in many cases. This can be ignored.

In [0]:
def predict(model, tokenized_text):
  masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]
  print(f'TOKENIZED TEXT: {tokenized_text}')

  indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
  segments_ids = [0]*len(tokenized_text)

  tokens_tensor = torch.tensor([indexed_tokens])
  segments_tensors = torch.tensor([segments_ids])

  # Predict all tokens
  with torch.no_grad():
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
      predictions = outputs[0][0][masked_index]
  predicted_ids = torch.argsort(predictions, descending=True)[:numb_predictions_displayed]
  predicted_tokens = tokenizer.convert_ids_to_tokens(list(predicted_ids))
  return predicted_tokens

In [0]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer
import numpy as np

# from: https://huggingface.co/transformers/quickstart.html#bert-example

bert_model = 'bert-large-cased' if cased_model else 'bert-large-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model)
model.eval()

output = {'symmetric': [], 'asymmetric': []}
for probe in probes:
  text = f'[CLS] {probe} [MASK] . [SEP]'
  tokenized_text = tokenizer.tokenize(text)
  predicted_tokens = predict(model, tokenized_text)

  first_pred = predicted_tokens[0]
  if first_pred == tokenized_text[1]:
    first_pred = predicted_tokens[1]

  reverse_probe = f'[CLS] {first_pred} {reverse} [MASK] . [SEP]'
  reverse_probe = tokenizer.tokenize(reverse_probe)
  rev_predicted_tokens = predict(model, reverse_probe)
  rev_first_pred = rev_predicted_tokens[0]
  if rev_first_pred == reverse_probe[1]:
    rev_first_pred = rev_predicted_tokens[1]
  
  if rev_first_pred == tokenized_text[1]:
    output['symmetric'].append(((probe, predicted_tokens), (reverse_probe, rev_predicted_tokens)))
  else:
    output['asymmetric'].append(((probe, predicted_tokens), (reverse_probe, rev_predicted_tokens)))

for key, val in output.items():
  print(key)
  for (probe, pred), (rev_probe, rev_pred) in val:
    print(f'{probe} [MASK] → {pred}')
    print(f'{rev_probe} [MASK] → {rev_pred}')

  print('\n')

TOKENIZED TEXT: ['[CLS]', 'Lu', '##anda', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Angola', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Jay', '-', 'Z', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'town', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Urban', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Central', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Kid', '##man', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Mt', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Co', '##ba', '##in', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Mt', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'Reynolds', 'is', 'north', 'of', '[MASK]', '.', '[SEP]']
TOKENIZED TEXT: ['[CLS]', 'town', 'is', 'south', 'of', '[MASK]', '.', '[SEP]']
TOKENIZ