In [33]:
import collections
import random
from transformers import *

In [53]:
import six
import numpy as np

In [44]:
class DummyObject:
    def __init__(self,**kwargs):
        self.__dict__.update(kwargs)

FLAGS=DummyObject(do_whole_word_mask=True,
                 ngram=3,
                 favor_shorter_ngram=False,
                 do_permutation=False,
                 random_seed=3107,
                 max_predictions_per_seq=5,
                 masked_lm_prob=0.15)

In [45]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [46]:
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])

In [47]:
rng = random.Random(FLAGS.random_seed)
vocab_words = list(tokenizer.vocab.keys())
tokens = tokenizer.tokenize('hi how are you doing, i am fine and how about you. you say. i am also good')
max_predictions_per_seq = FLAGS.max_predictions_per_seq
masked_lm_prob = FLAGS.masked_lm_prob

In [51]:
def _is_start_piece_bert(piece):
  """Check if the current word piece is the starting piece (BERT)."""
  # When a word has been split into
  # WordPieces, the first token does not have any marker and any subsequence
  # tokens are prefixed with ##. So whenever we see the ## token, we
  # append it to the previous set of word indexes.
  return not six.ensure_str(piece).startswith("##")


def is_start_piece(piece):
    return _is_start_piece_bert(piece)

def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, vocab_words, rng):
  """Creates the predictions for the masked LM objective."""

  cand_indexes = []
  # Note(mingdachen): We create a list for recording if the piece is
  # the starting piece of current token, where 1 means true, so that
  # on-the-fly whole word masking is possible.
  token_boundary = [0] * len(tokens)

  for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      token_boundary[i] = 1
      continue
    # Whole Word Masking means that if we mask all of the wordpieces
    # corresponding to an original word.
    #
    # Note that Whole Word Masking does *not* change the training code
    # at all -- we still predict each WordPiece independently, softmaxed
    # over the entire vocabulary.
    if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
        not is_start_piece(token)):
      cand_indexes[-1].append(i)
    else:
      cand_indexes.append([i])
      if is_start_piece(token):
        token_boundary[i] = 1

  output_tokens = list(tokens)

  masked_lm_positions = []
  masked_lm_labels = []

  if masked_lm_prob == 0:
    return (output_tokens, masked_lm_positions,
            masked_lm_labels, token_boundary)

  num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))

  # Note(mingdachen):
  # By default, we set the probilities to favor longer ngram sequences.
  ngrams = np.arange(1, FLAGS.ngram + 1, dtype=np.int64)
  pvals = 1. / np.arange(1, FLAGS.ngram + 1)
  pvals /= pvals.sum(keepdims=True)

  if FLAGS.favor_shorter_ngram:
    pvals = pvals[::-1]

  ngram_indexes = []
  for idx in range(len(cand_indexes)):
    ngram_index = []
    for n in ngrams:
      ngram_index.append(cand_indexes[idx:idx+n])
    ngram_indexes.append(ngram_index)

  rng.shuffle(ngram_indexes)

  masked_lms = []
  covered_indexes = set()
  for cand_index_set in ngram_indexes:
    if len(masked_lms) >= num_to_predict:
      break
    if not cand_index_set:
      continue
    # Note(mingdachen):
    # Skip current piece if they are covered in lm masking or previous ngrams.
    for index_set in cand_index_set[0]:
      for index in index_set:
        if index in covered_indexes:
          continue

    n = np.random.choice(ngrams[:len(cand_index_set)],
                         p=pvals[:len(cand_index_set)] /
                         pvals[:len(cand_index_set)].sum(keepdims=True))
    index_set = sum(cand_index_set[n - 1], [])
    n -= 1
    # Note(mingdachen):
    # Repeatedly looking for a candidate that does not exceed the
    # maximum number of predictions by trying shorter ngrams.
    while len(masked_lms) + len(index_set) > num_to_predict:
      if n == 0:
        break
      index_set = sum(cand_index_set[n - 1], [])
      n -= 1
    # If adding a whole-word mask would exceed the maximum number of
    # predictions, then just skip this candidate.
    if len(masked_lms) + len(index_set) > num_to_predict:
      continue
    is_any_index_covered = False
    for index in index_set:
      if index in covered_indexes:
        is_any_index_covered = True
        break
    if is_any_index_covered:
      continue
    for index in index_set:
      covered_indexes.add(index)

      masked_token = None
      # 80% of the time, replace with [MASK]
      if rng.random() < 0.8:
        masked_token = "[MASK]"
      else:
        # 10% of the time, keep original
        if rng.random() < 0.5:
          masked_token = tokens[index]
        # 10% of the time, replace with random word
        else:
          masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

      output_tokens[index] = masked_token

      masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
  assert len(masked_lms) <= num_to_predict

  rng.shuffle(ngram_indexes)

  select_indexes = set()
  if FLAGS.do_permutation:
    for cand_index_set in ngram_indexes:
      if len(select_indexes) >= num_to_predict:
        break
      if not cand_index_set:
        continue
      # Note(mingdachen):
      # Skip current piece if they are covered in lm masking or previous ngrams.
      for index_set in cand_index_set[0]:
        for index in index_set:
          if index in covered_indexes or index in select_indexes:
            continue

      n = np.random.choice(ngrams[:len(cand_index_set)],
                           p=pvals[:len(cand_index_set)] /
                           pvals[:len(cand_index_set)].sum(keepdims=True))
      index_set = sum(cand_index_set[n - 1], [])
      n -= 1

      while len(select_indexes) + len(index_set) > num_to_predict:
        if n == 0:
          break
        index_set = sum(cand_index_set[n - 1], [])
        n -= 1
      # If adding a whole-word mask would exceed the maximum number of
      # predictions, then just skip this candidate.
      if len(select_indexes) + len(index_set) > num_to_predict:
        continue
      is_any_index_covered = False
      for index in index_set:
        if index in covered_indexes or index in select_indexes:
          is_any_index_covered = True
          break
      if is_any_index_covered:
        continue
      for index in index_set:
        select_indexes.add(index)
    assert len(select_indexes) <= num_to_predict

    select_indexes = sorted(select_indexes)
    permute_indexes = list(select_indexes)
    rng.shuffle(permute_indexes)
    orig_token = list(output_tokens)

    for src_i, tgt_i in zip(select_indexes, permute_indexes):
      output_tokens[src_i] = orig_token[tgt_i]
      masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))

  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)
  return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)

In [54]:
create_masked_lm_predictions(tokens, masked_lm_prob,max_predictions_per_seq, vocab_words, rng)

(['hi',
  'how',
  'are',
  'you',
  'doing',
  ',',
  'i',
  '[MASK]',
  'fine',
  'and',
  'how',
  'about',
  'you',
  '.',
  'you',
  'say',
  '.',
  'i',
  'am',
  'also',
  '[MASK]'],
 [3, 7, 20],
 ['you', 'am', 'good'],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])