<a href="https://colab.research.google.com/github/YuvalPeleg/transformers-workshop/blob/master/BertNSP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
pip install pytorch_transformers

In [0]:
import torch
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM,BertForNextSentencePrediction

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load pre-trained model (weights)
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
model = model.to('cuda')
# This is important as it stops dropouts
model = model.eval()



In [0]:
"""
Calculates the odds that sentence1 follows sentence2

Arguments
  sentnece1 (str): the first sentnece, must be space tokenized
  sentnece2 (str): the first sentnece, must be space tokenized
  model: (BertForNextSentencePrediction)
  tokenizer: The corresponding tokenizer for the model

Returns
  The odds the sentence2 follows sentence1

"""

def predict_NSP(sentence1, sentence2, model, tokenizer):
  text_input = f"[CLS] {sentence1} [SEP] {sentence2} [SEP]"
  tokenized_text = tokenizer.tokenize(text_input)

  # Convert token to vocabulary indices
  indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
  tokens_tensor = torch.tensor([indexed_tokens])
  
  # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
  sep_index = (tokens_tensor[0] == tokenizer.sep_token_id).nonzero()[0][0].item()
  segment_ids = torch.ones(len(indexed_tokens))
  segment_ids[0:sep_index + 1] = 0

  predictions = model(tokens_tensor.to('cuda'), segment_ids.unsqueeze(0).long().to('cuda'))
  return torch.softmax(predictions[0][0], 0)[0].item()
  
predict_NSP("Who was Jim Henson ?", "Paris is in France", model, tokenizer) 





In [0]:
print(predict_NSP("Who was Jim Henson ?", "Paris is in France", model, tokenizer))
print(predict_NSP("Who was Jim Henson ?", "Jim hanson was a puppeteer", model, tokenizer))

In [0]:
"""
Task:
Play with some sentences of your own to get a feel for NSP
Look at that [CLS] token. Why is it there?

"""
