This is a simple Notebook to probe BERT for facts in the form "Subject relation [clozed object]".

In [0]:
!pip install transformers
!pip install torch

Let's first set up our probes and some detailed settings.

In [0]:
subjects = ['France', 'Paris', 'Bonn', 'Mount Everest']
relation = 'is located in'
# answers = ['Europe', 'France', 'Germany', ['Nepal', 'China'], '']
answers = ['']*len(subjects)
probes = [f'{subj} {relation}' for subj in subjects]
cased_model = True
numb_predictions_displayed = 5
ignore_self_reference_output = True # BERT tends to predict the subject again in many cases. This can be ignored.

Let's probe!

In [0]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer
import numpy as np

# from: https://huggingface.co/transformers/quickstart.html#bert-example

bert_model = 'bert-large-cased' if cased_model else 'bert-large-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model)
model.eval()

output = {'correct': [], 'false': [], 'undefined': []}
for probe, answer in zip(probes, answers):
  text = f'[CLS] {probe} [MASK] . [SEP]'
  tokenized_text = tokenizer.tokenize(text)
  masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]
  print(f'TOKENIZED TEXT: {tokenized_text}')

  indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
  segments_ids = [0]*len(tokenized_text)

  tokens_tensor = torch.tensor([indexed_tokens])
  segments_tensors = torch.tensor([segments_ids])

  # Predict all tokens
  with torch.no_grad():
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
      predictions = outputs[0][0][masked_index]
  predicted_ids = torch.argsort(predictions, descending=True)[:numb_predictions_displayed]
  predicted_tokens = tokenizer.convert_ids_to_tokens(list(predicted_ids))
  if answer:
    if isinstance(answer, str):
      answer = [answer]
    first_pred = predicted_tokens[0]
    if first_pred == tokenized_text[1]:
      first_pred = predicted_tokens[1]
    if first_pred in answer:
      output['correct'].append((probe, predicted_tokens))
    else:
      output['false'].append((probe, predicted_tokens))
  else:
    output['undefined'].append((probe, predicted_tokens))

for key, val in output.items():
  print(key)
  for probe, predicted_tokens in val:
    print(f'{probe} [MASK] → {predicted_tokens}')
  print('\n')