In [1]:
!pip install tensorflow_hub
!pip install bert-for-tf2
!pip install sentencepiece

Collecting tensorflow_hub
[?25l  Downloading https://files.pythonhosted.org/packages/00/0e/a91780d07592b1abf9c91344ce459472cc19db3b67fdf3a61dca6ebb2f5c/tensorflow_hub-0.7.0-py2.py3-none-any.whl (89kB)
[K     |████████████████████████████████| 92kB 1.6MB/s eta 0:00:01
Installing collected packages: tensorflow-hub
Successfully installed tensorflow-hub-0.7.0
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Collecting bert-for-tf2
[?25l  Downloading https://files.pythonhosted.org/packages/60/b4/1a3da73498960866ad0510ead86b133569ff012bf1c77d82ce95203779fc/bert-for-tf2-0.13.2.tar.gz (40kB)
[K     |████████████████████████████████| 40kB 1.5MB/s eta 0:00:01
[?25hCollecting py-params>=0.7.3 (from bert-for-tf2)
  Downloading https://files.pythonhosted.org/packages/ec/17/71c5f3c0ab511de96059358bcc5e00891a804cd4049021e5fa80540f201a/py-params-0.8.2.tar.gz
Collecting params-flow>=0.7.1 (from bert-for-tf2)
  Downloading https://files.pythonhosted.org/packages/0d/12/

In [2]:
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

In [3]:
with open('data/counts_1grams.txt') as fopen:
    f = fopen.read().split('\n')[:-1]
    
words = {}
for l in f:
    w, c = l.split('\t')
    c = int(c)
    words[w] = c + words.get(w, 0)

In [4]:
import re
from collections import Counter

class SpellCorrector:
    """
    The SpellCorrector extends the functionality of the Peter Norvig's
    spell-corrector in http://norvig.com/spell-correct.html
    """
    REGEX_TOKEN = re.compile(r'\b[a-z]{2,}\b')

    def __init__(self, words):
        """
        :param corpus: the statistics from which corpus to use for the spell correction.
        """
        super().__init__()
        self.WORDS = words
        self.N = sum(self.WORDS.values())
              
    @staticmethod
    def tokens(text):
        return REGEX_TOKEN.findall(text.lower())

    def P(self, word):
        """
        Probability of `word`.
        """
        return self.WORDS[word] / self.N

    def most_probable(self, words):
        _known = self.known(words)
        if _known:
            return max(_known, key=self.P)
        else:
            return []

    @staticmethod
    def edit_step(word):
        """
        All edits that are one edit away from `word`.
        """
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        """
        All edits that are two edits away from `word`.
        """
        return (e2 for e1 in self.edit_step(word)
                for e2 in self.edit_step(e1))

    def known(self, words):
        """
        The subset of `words` that appear in the dictionary of WORDS.
        """
        return set(w for w in words if w in self.WORDS)

    def edit_candidates(self, word, assume_wrong=False, fast=True):
        """
        Generate possible spelling corrections for word.
        """

        if fast:
            ttt = self.known(self.edit_step(word)) or {word}
        else:
            ttt = self.known(self.edit_step(word)) or self.known(self.edits2(word)) or {word}
        
        ttt = self.known([word]) | ttt
        return list(ttt)

In [5]:
corrector = SpellCorrector(words)

In [6]:
#possible_states = corrector.edit_candidates('eting')
possible_states = corrector.edit_candidates('gife')
possible_states

['wife', 'rife', 'life', 'gift', 'give', 'gibe']

In [7]:
#text = 'scientist suggests eting burger can lead to obesity'
text = 'gife me something to eat'
#text_mask = text.replace('eting', '**mask**')
text_mask = text.replace('gife', '**mask**')
text_mask

'**mask** me something to eat'

In [9]:
import tensorflow_hub as hub
import bert
FullTokenizer = bert.bert_tokenization.FullTokenizer
from tensorflow.keras.models import Model

In [10]:
max_seq_length = 128  # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                    name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
                            trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])

In [11]:
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])

In [12]:
# See BERT paper: https://arxiv.org/pdf/1810.04805.pdf
# And BERT implementation convert_single_example() at https://github.com/google-research/bert/blob/master/run_classifier.py

def get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    segments = []
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids

In [13]:
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = FullTokenizer(vocab_file, do_lower_case)

In [14]:
replaced_masks = [text_mask.replace('**mask**', state) for state in possible_states]
replaced_masks

['wife me something to eat',
 'rife me something to eat',
 'life me something to eat',
 'gift me something to eat',
 'give me something to eat',
 'gibe me something to eat']

In [19]:
for sent in replaced_masks:
    candidat_inputs = list()
    tokens = tokenizer.tokenize(sent)
    stokens = ["[CLS]"] + tokens + ["[SEP]"]
    input_ids = get_ids(stokens, tokenizer, max_seq_length)
    input_masks = get_masks(stokens, max_seq_length)
    input_segments = get_segments(stokens, max_seq_length)
    candidat_inputs.append([[input_ids], [input_masks], [input_segments]])
    
    

In [22]:
for inp in candidat_inputs:
    pool_embs, all_embs = model.predict([[input_ids],[input_masks],[input_segments]])
    print(all_embs)

[[[-0.01106459  0.3070191  -0.00853465 ... -0.3068895   0.00639752
    0.27061567]
  [ 0.18352532 -0.5772168   0.7146587  ... -0.11176492  0.6690078
   -0.02929941]
  [-0.12705982 -0.5298172   1.2385045  ...  0.22001466 -0.09279799
    0.24640633]
  ...
  [ 0.27750766  0.28496078  0.53381324 ...  0.21475478 -0.30043748
    0.00225102]
  [ 0.26203728  0.22347955  0.64350784 ...  0.17977902 -0.2661268
   -0.08173329]
  [ 0.28714928  0.14969665  0.68946457 ...  0.14365995 -0.27989298
   -0.14007063]]]
