# Next Sentence Prediction (NSP)


Use the **AutoModelForNextSentencePrediction** to predict the next sentence.

In [1]:
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer

import torch

## 1. Create the model for NSP

In [2]:
# change this to try out other models
model_name = "bert-base-uncased"

# create an instanmce of tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# create the model class
model = AutoModelForNextSentencePrediction.from_pretrained(model_name)

# model

## 2. Carry out next sentence prediction



### 2.1 Sentence is NOT a continuation

In [3]:
# example text
# notice the use of model specific special mask.
input_text_1 = "The cat sat on the rug." 
input_text_2 = "Clouds are white" 


# prepare the input
inputs = tokenizer(input_text_1, input_text_2, return_tensors="pt")

# labels parameter
labels=torch.LongTensor([1])

outputs = model(**inputs, labels = labels)

# return indices of element with max probability
prediction = outputs.logits.argmax(dim=1)

if prediction == 0:
    print("2nd sentence is a continuation of the 1st sentence.")
else:
    print("2nd sentence is NOT a continuation of the 1st sentence.")


2nd sentence is NOT a continuation of the 1st sentence.


### 2.2 Sentence is a continuation

In [4]:
# example text
# notice the use of model specific special mask.
input_text_1 = "The cat sat on the rug." 
input_text_2 = "It watched the birds outside." 

# prepare the input
inputs = tokenizer(input_text_1, input_text_2, return_tensors="pt")

outputs = model(**inputs, labels = labels)

# return indices of element with max probability
prediction = outputs.logits.argmax(dim=1)

if prediction == 0:
    print("2nd sentence is a continuation of the 1st sentence.")
else:
    print("2nd sentence is NOT a continuation of the 1st sentence.")

2nd sentence is a continuation of the 1st sentence.


In [13]:
input_sequence = "The cat sat on the rug." 

# Encode the input
encoded = tokenizer.encode([input_text_1, input_text_2])

# Decode 
tokenizer.decode(encoded)

'[CLS] the cat sat on the rug. [SEP] it watched the birds outside. [SEP]'

In [14]:
tokenizer(input_sequence)['input_ids']

[101, 1996, 4937, 2938, 2006, 1996, 20452, 1012, 102]