A notebook for masked language modeling (MLM) and next sentence prediction (NSP).

We'll use a BERT model from huggingface: https://huggingface.co/bert-base-uncased

## Masked Language Modeling: predict the missing word

In [None]:
from transformers import pipeline

In [None]:
# bert-base-uncased is a relatively small model with 110M parameters and can therefore easily fit into memory
model = pipeline('fill-mask', model='bert-base-uncased')

In [None]:
model("Hello world! What a [MASK] day it is!")

## Next sentence prediction

In [None]:
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# create BertTokenizer and BertForNextSentencePrediction and move the model to GPU if possible
# BertForNextSentencePrediction is a Bert Model with a next sentence prediction (classification) head on top
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
model = model.eval()
model = model.to(device)

In [None]:
# create a positive and negative text pair
prompt = "I will play some tennis today with a friend."

next_sentence_unlogical = "Heartbreak pain can be explained through hormonal changes."
label_unlogical = torch.LongTensor([1]).to(device)

next_sentence_logical = "It's a beautiful day today."
label_logical = torch.LongTensor([0]).to(device)

In [None]:
# encode the text to tokens and view what they look like
tokens_unlogical = tokenizer(prompt, next_sentence_unlogical, return_tensors='pt').to(device)
tokens_logical = tokenizer(prompt, next_sentence_logical, return_tensors='pt').to(device)
print(tokens_logical)

In [None]:
# run model inference
with torch.no_grad():
    output_unlogical = model(**tokens_unlogical, labels=label_unlogical)
    output_logical = model(**tokens_logical, labels=label_logical)

print(output_unlogical)
print(output_logical)

In [None]:
def check_if_logical(model_output):
    logits = model_output.logits
    if (logits[0, 0] < logits[0, 1]).item():
        print(f"The next sentence doesn't make sense.")
    else:
        print(f"The next sentence is logical.")


check_if_logical(output_unlogical)
check_if_logical(output_logical)

## Question Answering

In [None]:
import torch
from transformers import BertTokenizer, BertForQuestionAnswering

In [None]:
tokenizer = BertTokenizer.from_pretrained("deepset/bert-base-cased-squad2")
model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")

In [None]:
# define the text and the question about the text
text = "Joe Carrol is a man who loves to ride the waves whereas Hank, who is thirty years old, prefers to chill in his chair. There will never be any man like Hank."
question = "What is Hank his age?"

In [None]:
# model inference
inputs = tokenizer(question, text, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

In [None]:
# get the start and end index of the text segment that answers the question
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
print(f"Start/End index answer: {answer_start_index} / {answer_end_index}")

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
answer = tokenizer.decode(predict_answer_tokens)
print(f"{question} -> {answer}")