### Since BERT_Cloze proves that BERT is able to predict missing word pretty decently, let's try BERT with text generation directly

## 1. Loading and Initializing

In [1]:
import tensorflow as tf
import tensorflow_hub as hub
print("Using Tensorflow version: " + tf.__version__)
print(tf.config.list_physical_devices('GPU'))

BERT_DIR = "/home/aufish/Downloads/bert"

Using Tensorflow version: 2.1.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
bert_module = hub.KerasLayer(BERT_DIR, trainable=True)

In [3]:
# tokenizer
from bert import tokenization

def create_tokenizer(vocab_file, do_lower_case=False):
    return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)

tokenizer = create_tokenizer(BERT_DIR + "/assets/vocab.txt")

def convert_sentence_to_features(sentence, tokenizer, max_seq_len=50):
    tokens = ['[CLS]']
    tokens.extend(tokenizer.tokenize(sentence))
    if len(tokens) > max_seq_len-1:
        tokens = tokens[:max_seq_len-1]
    tokens.append('[SEP]')
    
    segment_ids = [0] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)

    #Zero Mask till seq_length
    zero_mask = [0] * (max_seq_len-len(tokens))
    input_ids.extend(zero_mask)
    input_mask.extend(zero_mask)
    segment_ids.extend(zero_mask)
    
    return input_ids, input_mask, segment_ids

def convert_sentences_to_features(sentences, tokenizer, max_seq_len=50):
    all_input_ids = []
    all_input_mask = []
    all_segment_ids = []
    
    for sentence in sentences:
        input_ids, input_mask, segment_ids = convert_sentence_to_features(sentence, tokenizer, max_seq_len)
        all_input_ids.append(input_ids)
        all_input_mask.append(input_mask)
        all_segment_ids.append(segment_ids)
    
    return all_input_ids, all_input_mask, all_segment_ids

In [None]:
# TODO: rewrite this method so that words after a certain index should all be masked
import random, copy
import numpy as np
def make_rand_mask(input_ids, input_mask, vocab_size, segment_id_vals=None):
    ''' 
    Only make mask for one sentence
    input_ids: the ids of words in the sentences
    input_mask: initial mask (1 if there is a word; 0 for padding)
    returns
    input_mask: replace one bit of 1 with 0, meaning that the word will be masked
    segment_id_vals: mark the masked word with segment id 1
    mask_word: the masked word index
    '''
    
    new_input_mask = copy.deepcopy(input_mask)
    pure_ids = []

    total_word = sum(input_mask)
    mask_word = random.randint(0, total_word-1)

    pure_ids.append(input_ids[mask_word])
    assert new_input_mask[mask_word] == 1
    new_input_mask[mask_word] = 0
                
    return new_input_mask, segment_id_vals, mask_word

In [14]:
class TextGenerator(tf.keras.Model):
    # The output means, how possible the given word may fit into the blank
    def __init__(self, class_num, bert=bert_module, dropout=0.1):
        super(TextGenerator, self).__init__()
        self.bert = bert
        self.drop = tf.keras.layers.Dropout(rate=dropout, trainable=True)
        
        self.dense = tf.keras.layers.Dense(
            class_num,
            activation=None,
            kernel_initializer='glorot_uniform',
            name='word_prediction',
            trainable=True)
        
    def call(self, inputs, mask_loc):
        # When passed in, all tensors are stacked in one, split it into a list
        inputs = tf.unstack(tf.cast(inputs, tf.dtypes.int32), axis=1)
        
        pooled, sequential = self.bert(inputs)
        
        # select one from each batch
        s = tf.gather_nd(sequential, [(i, mask_loc[i]) for i in range(sequential.shape[0])])
        # s now has shape (batch_size * 768)
        
        x = self.drop(s)
        return self.dense(x)

## 2. Prepare data

In [6]:
# Preprocess sentences.txt and add mask to end of sentences
# Should be run only once

# DATA_FILE = "./sentences.txt"
# MASKED_SENTENCE_FILE = "./masked_sentences.txt"

# data = open(DATA_FILE, "r")
# masked_data = open(MASKED_SENTENCE_FILE, "w")
# line = data.readline()
# while line != '':
#     line = line.split(" ")
#     new_line = ['[MASK]' for i in range(len(line))]
#     for i in range(len(line)):
#         new_line[i] = line[i]
#         masked_data.write(" ".join(new_line) + "\n")
        
#     line = data.readline()
        
# masked_data.close()
# data.close()

In [17]:
import numpy as np
# extracted from emnlp
DATA_FILE = "./sentences.txt"

MASK_ID = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
SENTENCE_END_ID = tokenizer.convert_tokens_to_ids(['[SEP]'])[0]

# parse_line will return a batch for each sentence
# masking different lengths left in the sentence
def parse_line(line):    
    input_ids, input_mask, segment_ids = convert_sentence_to_features(line, tokenizer, max_seq_len=30)
            
    # Mask the last non-mask word and return the word as target
    word_to_mask_loc = 0
    if MASK_ID in input_ids:
        word_to_mask_loc = input_ids.index(MASK_ID) - 1
    else:
        word_to_mask_loc = input_ids.index(SENTENCE_END_ID) - 1
    
    label = input_ids[word_to_mask_loc]
    input_ids[word_to_mask_loc] = MASK_ID
    input_mask = [0 if input_ids[i]==MASK_ID else input_mask[i] for i in range(len(input_mask))]
    
    return [input_ids, input_mask, segment_ids], [word_to_mask_loc, label]

def create_dataset(filename = DATA_FILE, data_size = 268528, batch_size = 10):
    dataset = tf.data.TextLineDataset([filename])
    
    dataset = dataset.map(lambda x : tf.numpy_function(parse_line, [x], [tf.int64, tf.int64]))
    
    dataset = dataset.shuffle(data_size, reshuffle_each_iteration=True)
    
    dataset = dataset.batch(batch_size)
    
    return dataset

## 3. Training

### 3.1 Training Sanity Run

In [15]:
model = TextGenerator(len(tokenizer.vocab))

opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_metric = tf.keras.metrics.Mean()

dataset = create_dataset(batch_size = 10)

In [16]:
for (bert_input, label) in dataset.take(5):
    mask, label = tf.unstack(label, axis=1)
    with tf.GradientTape() as tape:
        output = model(bert_input, mask)

        loss_val = loss(label, output)
        loss_val += sum(model.losses)
        
    grads = tape.gradient(loss_val, model.trainable_weights)
    opt.apply_gradients(zip(grads, model.trainable_weights))
    
    loss_metric(loss_val)
    
    print('mean loss = %s' % (loss_metric.result()))

mean loss = tf.Tensor(10.241837, shape=(), dtype=float32)
mean loss = tf.Tensor(10.220797, shape=(), dtype=float32)
mean loss = tf.Tensor(10.223202, shape=(), dtype=float32)
mean loss = tf.Tensor(10.196743, shape=(), dtype=float32)
mean loss = tf.Tensor(10.169739, shape=(), dtype=float32)


### 3.2 Actual Training Setup

In [18]:
model = TextGenerator(len(tokenizer.vocab))

opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_metric = tf.keras.metrics.Mean()

dataset = create_dataset(batch_size = 20)

In [None]:
import time

epochs = 10
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    # Iterate over the batches of the dataset.
    total_loss = 0
    for step, (bert_input, target) in enumerate(dataset):
        mask, target = tf.unstack(target, axis=1)
        with tf.GradientTape() as tape:
            output = model(bert_input, mask)
            
            # Compute reconstruction loss
            loss_val = loss(target, output)
            loss_val += sum(model.losses)  # Add KLD regularization loss
        

        grads = tape.gradient(loss_val, model.trainable_weights)
        opt.apply_gradients(zip(grads, model.trainable_weights))

        loss_metric(loss_val)

        if step % 1000 == 0:
          print('step %s: mean loss = %s' % (step, loss_metric.result()))
        
    model.save_weights("./text_generator_{}".format(epoch))

Start of epoch 0
step 0: mean loss = tf.Tensor(10.303389, shape=(), dtype=float32)
step 1000: mean loss = tf.Tensor(4.598281, shape=(), dtype=float32)
step 2000: mean loss = tf.Tensor(3.7107375, shape=(), dtype=float32)
step 3000: mean loss = tf.Tensor(3.3310423, shape=(), dtype=float32)
step 4000: mean loss = tf.Tensor(3.0930488, shape=(), dtype=float32)
step 5000: mean loss = tf.Tensor(2.9535217, shape=(), dtype=float32)
step 6000: mean loss = tf.Tensor(2.8466659, shape=(), dtype=float32)
step 7000: mean loss = tf.Tensor(2.7577126, shape=(), dtype=float32)
step 8000: mean loss = tf.Tensor(2.6867654, shape=(), dtype=float32)
step 9000: mean loss = tf.Tensor(2.624818, shape=(), dtype=float32)
step 10000: mean loss = tf.Tensor(2.5790024, shape=(), dtype=float32)
step 11000: mean loss = tf.Tensor(2.5331786, shape=(), dtype=float32)
step 12000: mean loss = tf.Tensor(2.4976485, shape=(), dtype=float32)
step 13000: mean loss = tf.Tensor(2.4622, shape=(), dtype=float32)
Start of epoch 1
step

## 4. Complete a Sentence

In [None]:
def complete_next_word(model, prompt, candidate_num=10, sentence_length=50):
    # Gives candidate for the next word
    ids, masks, seg_ids = convert_sentence_to_features(sentence, tokenizer, max_seq_len=sentence_length)
    
    # Change ids and masks after prompt finishes
    # The number of 1 in masks is the number of words in prompt plus CLS and SEP
    index = sum(masks) - 1
    
    while index < len(ids):
        ids[index] = MASK_ID
        masks[index] = 0
        index += 1
        
    bert_input = tf.stack([ids, masks, seg_ids])
    bert_input = tf.reshape(bert_input, (1, bert_input.shape[0], bert_input.shape[1]))
    
    output = model(bert_input, [blank_loc])
    return tf.argsort(output, direction='DESCENDING')[0, :candidate_num]

In [None]:
def complete_sentence(model, prompt, sentence_length=50):
    # Growth factor means how many candidates to choose at each word prediction
    # Be careful with setting growth_rate, or the space grows exponentially
    while len(prompt) < sentence_length:
        result = complete_next_word(model, prompt, candidate_num=1, sentence_length=sentence_length)
        
        next_id = result.numpy()[0]
        
        prompt += " " + tokenizer.convert_ids_to_tokens([next_id])
        
    return prompt