## 1. Loading and Initializing

### 1.1 Sentence processing functions

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
print("Using Tensorflow version: " + tf.__version__)
print(tf.config.list_physical_devices('GPU'))

BERT_DIR = "/home/aufish/Downloads/bert"

# try with TF2 SavedModel
# The online downloading method does not work, use pre-downloaded module
# bert_module = hub.Module("https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/1")

In [None]:
bert_module = hub.KerasLayer(BERT_DIR, trainable=True)

In [None]:
# tokenizer
from bert import tokenization

def create_tokenizer(vocab_file, do_lower_case=False):
    return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)

tokenizer = create_tokenizer(BERT_DIR + "/assets/vocab.txt")

def convert_sentence_to_features(sentence, tokenizer, max_seq_len=50):
    tokens = ['[CLS]']
    tokens.extend(tokenizer.tokenize(sentence))
    if len(tokens) > max_seq_len-1:
        tokens = tokens[:max_seq_len-1]
    tokens.append('[SEP]')
    
    segment_ids = [0] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)

    #Zero Mask till seq_length
    zero_mask = [0] * (max_seq_len-len(tokens))
    input_ids.extend(zero_mask)
    input_mask.extend(zero_mask)
    segment_ids.extend(zero_mask)
    
    return input_ids, input_mask, segment_ids

def convert_sentences_to_features(sentences, tokenizer, max_seq_len=50):
    all_input_ids = []
    all_input_mask = []
    all_segment_ids = []
    
    for sentence in sentences:
        input_ids, input_mask, segment_ids = convert_sentence_to_features(sentence, tokenizer, max_seq_len)
        all_input_ids.append(input_ids)
        all_input_mask.append(input_mask)
        all_segment_ids.append(segment_ids)
    
    return all_input_ids, all_input_mask, all_segment_ids

import random, copy
import numpy as np
def make_rand_mask(input_ids, input_mask, vocab_size, segment_id_vals=None):
    ''' 
    Only make mask for one sentence
    input_ids: the ids of words in the sentences
    input_mask: initial mask (1 if there is a word; 0 for padding)
    returns
    input_mask: replace one bit of 1 with 0, meaning that the word will be masked
    mask_word_ids: the id of words that are masked
    pure_ids: ids in number instead of one-hot (to generate weights per masked word)
    segment_id_vals: mark the masked word with segment id 1
    sequential_output_filter: the masked word index
    '''
    
    new_input_mask = copy.deepcopy(input_mask)
#     mask_word_ids = np.zeros(vocab_size)
    pure_ids = []

    total_word = sum(input_mask)
    mask_word = random.randint(0, total_word-1)

    pure_ids.append(input_ids[mask_word])
    assert new_input_mask[mask_word] == 1
    new_input_mask[mask_word] = 0
#     mask_word_ids[input_ids[mask_word]] = 1.0
                
    return new_input_mask, segment_id_vals, mask_word

### 1.2 Blank filler model

In [None]:
class WordPredictor(tf.keras.Model):
    # The output means, how possible the given word may fit into the blank
    def __init__(self, class_num, bert=bert_module, dropout=0.1):
        super(WordPredictor, self).__init__()
        self.bert = bert
        self.drop = tf.keras.layers.Dropout(rate=dropout, trainable=True)
        
        self.dense = tf.keras.layers.Dense(
            class_num,
            activation=None,
            kernel_initializer='glorot_uniform',
            name='word_prediction',
            trainable=True)
        
    def call(self, inputs, mask_loc):
        # When passed in, all tensors are stacked in one, split it into a list
        inputs = tf.unstack(tf.cast(inputs, tf.dtypes.int32), axis=1)
        
#         # The last element in the list is a filter that will be exerted on sequential
#         inputs, seq_output_filter = inputs[:-1], inputs[-1]
#         seq_output_filter = tf.cast(seq_output_filter, tf.dtypes.float32)

        pooled, sequential = self.bert(inputs)
        
#         s = tf.tensordot(seq_output_filter,  sequential, axes=(1, 1))

        # select one from each batch
        s = tf.gather_nd(sequential, [(i, mask_loc[i]) for i in range(sequential.shape[0])])
        # s now has shape (batch_size * 768)
        
        x = self.drop(s)
        return self.dense(x)

### 1.2.1 Sanity test

In [None]:
model = WordPredictor(1)

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(opt)

## 2. Prepare data

In [None]:
# One time run, write all sentences in the json file into txt
# import json 

# DATA_FILE = "/home/aufish/Documents/ScratchGan++/scratchgan/emnlp_data/train.json"
# all_sentences = json.load(open(DATA_FILE, "r"))

# SENTENCE_FILE = "./sentences.txt"

# output_file = open(SENTENCE_FILE, "w")
# for sentence in all_sentences:
#     output_file.write(sentence['s'] + '\n')

In [None]:
import numpy as np
# extracted from emnlp
DATA_FILE = "./sentences.txt"

MASK_ID = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]

def parse_line(line):    
    input_ids, input_mask, segment_ids = convert_sentence_to_features(line, tokenizer, max_seq_len=30)
        
    input_mask, segment_ids, mask_position = \
        make_rand_mask(input_ids, input_mask, len(tokenizer.vocab), segment_ids)
    
    label = input_ids[mask_position]
    input_ids[mask_position] = MASK_ID
    
    return [input_ids, input_mask, segment_ids], [mask_position, label]

def create_dataset(filename = DATA_FILE, data_size = 268528, batch_size = 10):
    dataset = tf.data.TextLineDataset([filename])
    
    dataset = dataset.map(lambda x : tf.numpy_function(parse_line, [x], [tf.int64, tf.int64]))
    
    # dataset = dataset.shuffle(data_size, reshuffle_each_iteration=True)
    
    dataset = dataset.batch(batch_size)
    
    return dataset

In [None]:
tokenizer.convert_tokens_to_ids(['[MASK]'])

### 2.1 Model calling sanity test

In [None]:
temp_batch_size = 3
dataset = create_dataset(batch_size = temp_batch_size)

In [None]:
# Processing data without model
for (bert_input, label) in dataset.take(1):
    print(type(bert_input))
    print(bert_input)
    print(label)
    inputs = tf.unstack(tf.cast(bert_input, tf.dtypes.int32), axis=1)
    
    pooled, sequential = bert_module(inputs)
    
    print(sequential.shape)
    
#     seq_output_filter = tf.cast(seq_output_filter, tf.dtypes.float32)
#     r = tf.tensordot(seq_output_filter,  sequential, axes=(1, 1))
    
    mask, label = tf.unstack(label, axis=1)
    r = tf.gather_nd(sequential, [(i, mask[i]) for i in range(temp_batch_size)])
    
    print(r)
    print(label)

In [None]:
model = WordPredictor(1)

In [None]:
for (bert_input, label) in dataset.take(1):
    mask, label = tf.unstack(label, axis=1)
    output = model(bert_input, mask)
    
    print("Output")
    print(output)
    print("Weights: ")
    print([weight.name for weight in model.trainable_weights])

## 3. Training

### 3.1 Training Santiy Run

In [None]:
model = WordPredictor(len(tokenizer.vocab))

opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_metric = tf.keras.metrics.Mean()

dataset = create_dataset(batch_size = 10)

In [None]:
for (bert_input, label) in dataset.take(5):
    mask, label = tf.unstack(label, axis=1)
    with tf.GradientTape() as tape:
        output = model(bert_input, mask)

        loss_val = loss(label, output)
        loss_val += sum(model.losses)
        
    grads = tape.gradient(loss_val, model.trainable_weights)
    opt.apply_gradients(zip(grads, model.trainable_weights))
    
    loss_metric(loss_val)
    
    print('mean loss = %s' % (loss_metric.result()))

### 3.2 Actual training setup

In [None]:
model = WordPredictor(len(tokenizer.vocab))

opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_metric = tf.keras.metrics.Mean()

dataset = create_dataset(batch_size = 20)

In [None]:
model.load_weights("./word_predictor_0")

In [None]:
import time

epochs = 10
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    # Iterate over the batches of the dataset.
    total_loss = 0
    for step, (bert_input, target) in enumerate(dataset):
        mask, target = tf.unstack(target, axis=1)
        with tf.GradientTape() as tape:
            output = model(bert_input, mask)
            
            # Compute reconstruction loss
            loss_val = loss(target, output)
            loss_val += sum(model.losses)  # Add KLD regularization loss
        

        grads = tape.gradient(loss_val, model.trainable_weights)
        opt.apply_gradients(zip(grads, model.trainable_weights))

        loss_metric(loss_val)

        if step % 1000 == 0:
          print('step %s: mean loss = %s' % (step, loss_metric.result()))
        
    model.save_weights("./word_predictor_{}".format(epoch))

## 4. Prediction

In [None]:
def blank_word_predict(model, sentence, blank_loc, candidate_num=10):
    # Given a sentence and at which location (0-indexed) it is blank
    # return the predicted word
    ids, masks, seg_ids = convert_sentence_to_features(sentence, tokenizer)
    
    # adjust input_mask, reset the randomly selected mask and set with blank_loc
    masks[blank_loc] = 0
    
    ids[blank_loc] = MASK_ID
    
    bert_input = tf.stack([ids, masks, seg_ids])
    bert_input = tf.reshape(bert_input, (1, bert_input.shape[0], bert_input.shape[1]))
    
#     bert_input = tf.unstack(tf.cast(bert_input, tf.dtypes.int32), axis=1)
#     pooled, sequential = bert_module(bert_input)
        
    output = model(bert_input, [blank_loc])
    return tf.argsort(output, direction='DESCENDING')[0, :candidate_num]

In [None]:
model = WordPredictor(len(tokenizer.vocab))

model.load_weights("./word_predictor_8")

In [None]:
sentence = "the blank has caused panic around the world"

result = blank_word_predict(model, sentence, 2, 5)

result = result.numpy()

for i in range(result.shape[0]):
    print(tokenizer.convert_ids_to_tokens([result[i]]))

In [None]:
sentence = "Mr . president signed the blank to fight pandemic"

result = blank_word_predict(model, sentence, 6, 5)

result = result.numpy()

for i in range(result.shape[0]):
    print(tokenizer.convert_ids_to_tokens([result[i]]))

In [None]:
sentence = "Mr . president blank the act to fight pandemic"

result = blank_word_predict(model, sentence, 4, 5)

result = result.numpy()

for i in range(result.shape[0]):
    print(tokenizer.convert_ids_to_tokens([result[i]]))

In [None]:
sentence = "i love the blank"

result = blank_word_predict(model, sentence, 4, 5)

result = result.numpy()

for i in range(result.shape[0]):
    print(tokenizer.convert_ids_to_tokens([result[i]]))