In [34]:
from tokenizers import BertWordPieceTokenizer
from tokenizers.processors import TemplateProcessing
from transformers import BertTokenizer
import random
from collections import namedtuple

In [11]:
tokenizer_path = '/home/americanthinker/notebooks/pytorch/NationalSecurityBERT/Preprocessing/Tokenization/wp-vocab-30500-vocab.txt'
text_data_path = '/home/americanthinker/notebooks/pytorch/NationalSecurityBERT/Data/Text/'
encodings_data_path = '/home/americanthinker/notebooks/pytorch/NationalSecurityBERT/Data/Encodings/encodings_0_395390.pt'
working_data =  'combined_4Gb.txt'

In [35]:
MaskedLmInstance = namedtuple("MaskedLmInstance",["index", "label"])

def create_masked_lm_predictions(tokens: list, vocab_words: list, masked_lm_prob: float=0.15, max_predictions_per_seq: int=77):
  """Creates the predictions for the masked LM objective."""

  cand_indexes = []
  for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      continue
    # Whole Word Masking means that if we mask all of the wordpieces
    # corresponding to an original word. When a word has been split into
    # WordPieces, the first token does not have any marker and any subsequence
    # tokens are prefixed with ##. So whenever we see the ## token, we
    # append it to the previous set of word indexes.
    #
    # Note that Whole Word Masking does *not* change the training code
    # at all -- we still predict each WordPiece independently, softmaxed
    # over the entire vocabulary.
    if (len(cand_indexes) >= 1 and token.startswith("##")):
      cand_indexes[-1].append(i)
    else:
      cand_indexes.append([i])

  random.shuffle(cand_indexes)

  output_tokens = list(tokens)

  num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))

  masked_lms = []
  covered_indexes = set()
  for index_set in cand_indexes:
    if len(masked_lms) >= num_to_predict:
      break
    # If adding a whole-word mask would exceed the maximum number of
    # predictions, then just skip this candidate.
    if len(masked_lms) + len(index_set) > num_to_predict:
      continue
    is_any_index_covered = False
    for index in index_set:
      if index in covered_indexes:
        is_any_index_covered = True
        break
    if is_any_index_covered:
      continue
    for index in index_set:
      covered_indexes.add(index)

      masked_token = None
      # 80% of the time, replace with [MASK]
      if random.random() < 0.8:
        masked_token = "[MASK]"
      else:
        # 10% of the time, keep original
        if random.random() < 0.5:
          masked_token = tokens[index]
        # 10% of the time, replace with random word
        else:
          masked_token = vocab_words[random.randint(0, len(vocab_words) - 1)]

      output_tokens[index] = masked_token

      masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
  assert len(masked_lms) <= num_to_predict
  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  masked_lm_positions = []
  masked_lm_labels = []
  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)

  return (output_tokens, masked_lm_positions, masked_lm_labels)

In [36]:
alternative_tokenizer = BertTokenizer.from_pretrained(tokenizer_path)

tokenizer = BertWordPieceTokenizer(tokenizer_path, strip_accents=True, lowercase=True)
tokenizer.enable_truncation(max_length=512)
tokenizer.enable_padding()
tokenizer.post_processor = TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", tokenizer.token_to_id("[CLS]")),
        ("[SEP]", tokenizer.token_to_id("[SEP]")),
        ("[MASK]", tokenizer.token_to_id("[MASK]"))
    ],
)



In [37]:
text = "This city of Hoboken is known for snuffleupagus, histrionics, missspellled words, such as acetylcholinesterase and dopaminergic effects."

In [38]:
tokens = tokenizer.encode(text)

In [58]:
create_masked_lm_predictions(tokens.tokens, vocab)

(['[CLS]',
  'this',
  'city',
  'of',
  'hob',
  '##oken',
  'is',
  'known',
  '[MASK]',
  'sn',
  '##uff',
  '##le',
  '##up',
  '##agus',
  ',',
  'hist',
  '##rion',
  '##ics',
  ',',
  'miss',
  '##sp',
  '##ell',
  '##led',
  'words',
  ',',
  'such',
  'as',
  '[MASK]',
  '[MASK]',
  '[MASK]',
  '[MASK]',
  'and',
  'dopaminergic',
  'effects',
  '.',
  '[SEP]'],
 [8, 27, 28, 29, 30],
 ['for', 'acetylcholine', '##st', '##era', '##se'])

In [49]:
vocab = sorted(list(tokenizer.get_vocab().keys()))