# Introduction

This notebook is inteded to experiment with Masked Language Modeling task.

# BERT

In [2]:
# Import Standard Libraries
from transformers import BertTokenizer, BertForPreTraining
import torch

In [3]:
# Initialize the tokenizer for preprocessing
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
# Initialize the model (NOTE: This is a model version for further fine-tuning BERT)
model = BertForPreTraining.from_pretrained('bert-base-uncased')

In [6]:
# Define input text
text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy.")
text2 = ("War broke out in April 1861 when secessionist forces [MASK] Fort "
         "Sumter in South Carolina, just over a month after Lincoln's "
         "inauguration.")

In [7]:
# Compute tokens by passing both the sentences
tokens = tokenizer(text, text2, return_tensors="pt")

In [11]:
tokens.token_type_ids

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

The value in `tokens.token_type_ids` is used to distinguish between Sentence A (0) and Sentence B (1).

In [12]:
tokens.input_ids

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,   103,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,   102,  2162,  3631,  2041,  1999,  2258,  6863,
          2043, 22965,  2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,
          3792,  1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055,
         17331,  1012,   102]])

Also in `tokens.input_ids` the two sentences are separated by a special token `[102]`.

In [13]:
# Compute outputs
outputs = model(**tokens)

In [15]:
outputs.seq_relationship_logits

tensor([[ 6.0843, -5.6813]], grad_fn=<AddmmBackward0>)

In [17]:
# Get argmax
argmax = torch.argmax(outputs.seq_relationship_logits)

In [18]:
'NotNext' if argmax.item() else 'IsNext'

'IsNext'