<a href="https://colab.research.google.com/github/BennoKrojer/Probe-Masked-LMs/blob/master/probeROBERTA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install transformers
!pip install torch

Let's first set up our probes and some detailed settings.

In [0]:
subjects = ['France', 'Paris', 'Berlin','Haidhausen', 'Mount Everest', 'My heart']
relation = 'is located in'
answers = ['Europe', 'France', 'Germany', ['Munich', 'Bavaria', 'Germany'], ['Nepal', 'China'], ''] # If there is no correct answer, simply put '' in the list
probes = [f'{subj} {relation}' for subj in subjects]
cased_model = True
numb_predictions_displayed = 5
ignore_self_reference_output = True # RoBERTa tends to predict the subject again in many cases. This can be ignored.

In [6]:
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM

tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
model = RobertaForMaskedLM.from_pretrained('roberta-large')
model.eval()

output = {'correct': [], 'false': [], 'undefined': []}
for probe, answer in zip(probes, answers):
  text = f'{probe} <mask>.'
  tokenized_text = tokenizer.tokenize(text)
  masked_index = tokenized_text.index('<mask>')+1
  input_ids = torch.tensor(tokenizer.encode(tokenized_text,add_special_tokens=True)).unsqueeze(0)
  with torch.no_grad():
    outputs = model(input_ids, masked_lm_labels=input_ids)
    loss, prediction_scores = outputs[:2]
  predicted_k_indexes = torch.topk(prediction_scores[0, masked_index],k=numb_predictions_displayed)
  predicted_indexes_list = predicted_k_indexes[1]
  predicted_tokens = [tokenizer.decode(i).strip() for i in predicted_indexes_list.tolist()]
  if answer:
    if isinstance(answer, str):
      answer = [answer]
    first_pred = predicted_tokens[0]
    if first_pred == tokenized_text[1]:
      first_pred = predicted_tokens[1]
    if first_pred in answer:
      output['correct'].append((probe, predicted_tokens))
    else:
      output['false'].append((probe, predicted_tokens))
  else:
    output['undefined'].append((probe, predicted_tokens))

for key, val in output.items():
  print(key)
  for probe, predicted_tokens in val:
    print(f'{probe} [MASK] → {predicted_tokens}')
  print('\n')

correct
France is located in [MASK] → ['Europe', 'Asia', 'Africa', 'France', 'Spain']
Paris is located in [MASK] → ['France', 'Paris', 'Europe', 'Belgium', 'Spain']
Berlin is located in [MASK] → ['Germany', 'Berlin', 'Austria', 'Europe', 'Poland']
Haidhausen is located in [MASK] → ['Germany', 'Austria', 'Switzerland', 'Belgium', 'Poland']
Mount Everest is located in [MASK] → ['Nepal', 'Tibet', 'India', 'China', 'Asia']


false


undefined
My heart is located in [MASK] → ['Texas', 'California', 'Chicago', 'Ohio', 'Boston']


