In [None]:
!pip install pytorch-transformers > /dev/null
import torch
import numpy as np
from pytorch_transformers import BertTokenizer, BertForMaskedLM
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', output_attentions=True)
model.eval()

In [None]:
def predict():
    sentence_orig = input('Enter text:')[:513]
    if '____' not in sentence_orig:
        return sentence_orig
            
    sentence = sentence_orig.replace('____', '[MASK]')
    tokens = nltk.word_tokenize(sentence)
    sentences = nltk.sent_tokenize(sentence)
    if len(sentences)>2:
        concat = sentences[1:]
        concat = ' '.join([x[:-1] for x in concat])
        sentences = [sentences[0]] + [concat + '.']
    sentence = " [SEP] ".join(sentences)
    sentence = "[CLS] " + sentence + " [SEP]"


    while '[MASK]' in sentence:
      tokenized_text = tokenizer.tokenize(sentence)
      masked_index = tokenized_text.index('[MASK]')
      indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

      segments_ids = []

      segment=0
      for token in tokenized_text:
        segments_ids.append(segment)
        if token == '[SEP]':
          segment += 1
      

      tokens_tensor = torch.tensor([indexed_tokens])
      segments_tensors = torch.tensor([segments_ids])



      with torch.no_grad():
              outputs = model(tokens_tensor, token_type_ids=segments_tensors)
              predictions = outputs[0]
              attention = outputs[-1]

      dim = attention[2][0].shape[-1]*attention[2][0].shape[-1]
      a = attention[2][0].reshape(12, dim)
      b = a.mean(axis=0)
      c = b.reshape(attention[2][0].shape[-1],attention[2][0].shape[-1])
      avg_wgts = c[masked_index]

      predicted_index = torch.argmax(predictions[0, masked_index]).item()
      predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
      sentence = sentence.replace('[MASK]', predicted_token, 1)
      sentence_orig = sentence_orig.replace('____', predicted_token,1)
    
    return sentence_orig

In [None]:
predict()

Enter text:Today I went to the ____ and bought some milk and eggs. I knew it was going to rain but I forgot to take my ____ and ended up getting wet on the way.


'Today I went to the store and bought some milk and eggs. I knew it was going to rain but I forgot to take my shower and ended up getting wet on the way.'

In [None]:
predict()

Enter text:Today is sunny ____ we may go to the picnic.


'Today is sunny . we may go to the picnic.'

In [None]:
predict()

Enter text:Today I went to the ____ and bought some milk and eggs. I knew it was going to rain but I forgot to take my ____ and ended up getting wet on the way.


'Today I went to the store and bought some milk and eggs. I knew it was going to rain but I forgot to take my shower and ended up getting wet on the way.'

In [None]:
predict()

Enter text:____ mother is upset with ____ because I got caught in the rain.


'my mother is upset with me because I got caught in the rain.'

In [None]:
predict()

Enter text:The balloon is filled with helium. So it is ____ than air.


'The balloon is filled with helium. So it is lighter than air.'

In [None]:
predict()

Enter text:Animals hibernate during ____ when it is very cold.


'Animals hibernate during winter when it is very cold.'

In [None]:
predict()

Enter text:Imagine what it ____ be like if you ____ in your bedroom during an earthquake. Books and stuffed animals tumble ____ shelves. Your computer monitor skitters ____ your desk and crashes to the floor. The walls creak and groan as they flex. Your whole house could ____ in an earthquake.


'Imagine what it would be like if you were in your bedroom during an earthquake. Books and stuffed animals tumble off shelves. Your computer monitor skitters off your desk and crashes to the floor. The walls creak and groan as they flex. Your whole house could collapse in an earthquake.'

In [None]:
predict()

Enter text:Imagine what it ____ be like if you ____ in your bedroom during an earthquake. Books and stuffed animals tumble ____ shelves. Your computer monitor skitters ____ your desk and crashes to the floor. The walls creak and groan as they flex. Your whole house could ____.


'Imagine what it would be like if you were in your bedroom during an earthquake. Books and stuffed animals tumble off shelves. Your computer monitor skitters off your desk and crashes to the floor. The walls creak and groan as they flex. Your whole house could collapse.'

In [None]:
predict()

Enter text:The girl did not cross the street because ____ was too wide.


'The girl did not cross the street because it was too wide.'

In [None]:
predict()

Enter text:The girl did not cross the street because ____ was too slow.


'The girl did not cross the street because she was too slow.'

In [None]:
# from transformers import BertTokenizer, BertForMaskedLM
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()

top_k = 10

In [None]:
import string
def decode(tokenizer, pred_idx, top_clean):
    ignore_tokens = string.punctuation + '[PAD]'
    tokens = []
    for w in pred_idx:
        token = ''.join(tokenizer.decode(w).split())
        if token not in ignore_tokens:
            tokens.append(token.replace('##', ''))
    return '\n'.join(tokens[:top_clean])


def encode(tokenizer, text_sentence, add_special_tokens=True):
    text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
    if tokenizer.mask_token == text_sentence.split()[-1]:
        text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
    return input_ids, mask_idx


In [None]:
def get_all_predictions(text_sentence, top_clean=5):
    print(text_sentence)
    input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
    # print(input_ids)
    # print(mask_idx)
    with torch.no_grad():
        predict = bert_model(input_ids)[0]
    # print(predict)
    bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
    print(bert)

In [None]:
def get_prediction_eos():
    try:
        input_text = ' '.join(input("Enter text: ").split())
        input_text += ' <mask>'
        res = get_all_predictions(input_text, top_clean=top_k)
        print(res)
    except Exception as error:
        err = str(error)
        print(err)

def get_prediction_mask():
    try:
        input_text = ' '.join(input("Enter text: ").split())
        res = get_all_predictions(input_text, top_clean=top_k)
        print(res)
    except Exception as error:
        err = str(error)
        print(err)

In [None]:
get_prediction_eos()

Enter text: today is a good
today is a good <mask>
day
morning
time
week
night
afternoon
year
month
evening
thing
None


In [None]:
get_prediction_mask()

Enter text: Today I went to the <mask> and bought some eggs and milk.
Today I went to the <mask> and bought some eggs and milk.
store
grocery
supermarket
fridge
kitchen
refrigerator
dairy
market
mall
bathroom
None


In [None]:
get_prediction_mask()

Enter text: The girl did not cross the street because <mask> was too wide.
The girl did not cross the street because <mask> was too wide.
tensor([[[ -6.6597,  -6.6559,  -6.6483,  ...,  -6.0181,  -5.8601,  -3.9669],
         [-18.8028, -18.6821, -18.5608,  ..., -15.9036, -14.8830, -13.1857],
         [-11.7190, -11.8977, -11.6776,  ...,  -9.2338, -10.3024,  -4.3298],
         ...,
         [ -9.8016, -10.0319,  -9.8785,  ...,  -8.0871,  -8.3019,  -5.6903],
         [-12.3961, -12.0771, -12.2415,  ..., -10.5092, -10.7978,  -6.7733],
         [-15.8802, -15.9526, -15.8541,  ..., -14.0989, -13.6420,  -9.7592]]])
it
she
this
one
that
hers
he
there
her
everything
None


In [None]:
get_prediction_mask()

Enter text: Today is <mask>
Today is <mask>
tensor([[[ -6.5143,  -6.4589,  -6.4877,  ...,  -5.8938,  -5.6277,  -3.9238],
         [-11.3656, -11.4199, -11.1213,  ...,  -9.1138,  -9.6923,  -8.7958],
         [-13.9106, -13.7745, -13.7661,  ..., -12.6961, -10.6249,  -8.2464],
         [ -6.1176,  -6.0248,  -6.0567,  ...,  -5.4127,  -6.3371,  -3.1804],
         [-10.7985, -10.4356, -11.0958,  ...,  -8.7714,  -8.7744,  -4.9093],
         [-13.9459, -13.8575, -14.1216,  ..., -11.4664, -10.7277,  -8.8884]]])
different
better
today
good
sunday
nothing
summer
perfect
not
saturday
None
