In [None]:
import tensorflow as tf
import pandas as pd
import tensorflow_hub as hub
import os
import re
import numpy as np
from bert.tokenization import FullTokenizer
from tqdm import tqdm_notebook
from tensorflow.keras import backend as K
import prepare_data
from prepare_data import tokenizer
import read
from keras.utils import to_categorical
# # Initialize session
sess = tf.Session()

# # Params for bert model and tokenization
# # bert_path = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
# # max_seq_length = 256

# Data

First, we load the sample data IMDB data

In [None]:
sentence=True
no_context=False
neeg_dataset=False

In [None]:
train_dataset = read.read_data_iterator('gw_extractions.pickle')
features = list(prepare_data.tokenize_if_small_enough(train_dataset, sentence, no_context, is_neeg=neeg_dataset))

train_features = features[100:]
val_features = features[:100]

In [None]:
train_input_ids = np.array([f.input_ids for f in train_features])
train_input_masks = np.array([f.input_mask for f in train_features])
train_segment_ids = np.array([f.segment_ids for f in train_features])
train_labels = np.array([to_categorical(f.label_id - 1, num_classes=5) for f in train_features])


val_input_ids = np.array([f.input_ids for f in val_features])
val_input_masks = np.array([f.input_mask for f in val_features])
val_segment_ids = np.array([f.segment_ids for f in val_features])
val_labels = np.array([to_categorical(f.label_id - 1, num_classes=5) for f in val_features])

# Tokenize

Next, tokenize our text to create `input_ids`, `input_masks`, and `segment_ids`

In [None]:
class BertLayer(tf.layers.Layer):
    def __init__(self, n_fine_tune_layers=10, **kwargs):
        self.n_fine_tune_layers = n_fine_tune_layers
        self.trainable = True
        self.output_size = 768
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.bert = hub.Module(
            prepare_data.BERT_MODEL_HUB,
            trainable=self.trainable,
            name="{}_module".format(self.name)
        )

        trainable_vars = self.bert.variables

        # Remove unused layers
        trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]

        # Select how many layers to fine tune
        trainable_vars = trainable_vars[-self.n_fine_tune_layers :]

        # Add to trainable weights
        for var in trainable_vars:
            self._trainable_weights.append(var)
            
        for var in self.bert.variables:
            if var not in self._trainable_weights:
                self._non_trainable_weights.append(var)

        super(BertLayer, self).build(input_shape)

    def call(self, inputs):
        inputs = [K.cast(x, dtype="int32") for x in inputs]
        input_ids, input_mask, segment_ids = inputs
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "pooled_output"
        ]
        return result

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_size)

In [None]:
class MultiBertLayer(BertLayer):
    def call(self, inputs):
        return [super(MultiBertLayer, self).call(ip) for ip in zip(*inputs)]

    def compute_output_shape(self, input_shape):
        #(batch size, num_labels, max_seq_size)
        return (input_shape[0], input_shape[1], self.output_size)

In [None]:
# Build model
def build_model(max_seq_length, num_labels, append_vector_len=0): 
    in_ids = tf.keras.layers.Input(shape=(num_labels, max_seq_length), name="input_ids")
    in_masks = tf.keras.layers.Input(shape=(num_labels, max_seq_length), name="input_masks")
    in_segments = tf.keras.layers.Input(shape=(num_labels, max_seq_length), name="segment_ids")
    
    inputs = [in_ids, in_masks, in_segments]
    if append_vector_len:
        in_append = tf.keras.layers.Input(shape=(num_labels, append_vector_len), name="input_append")
        inputs.append(in_append)
    
    split_in_ids = [tf.keras.layers.Lambda(lambda x: x[:, i, :])(in_ids) for i in range(num_labels)]
    split_in_masks = [tf.keras.layers.Lambda(lambda x: x[:, i, :])(in_masks) for i in range(num_labels)]
    split_in_segments = [tf.keras.layers.Lambda(lambda x: x[:, i, :])(in_segments) for i in range(num_labels)]
    
    if append_vector_len:
        split_in_append = [tf.keras.layers.Lambda(lambda x: x[:, i, :])(in_append) for i in range(num_labels)]
       
    bert_outputs = MultiBertLayer(n_fine_tune_layers=0)([split_in_ids, split_in_masks, split_in_segments])
#     bert_outputs = [BertLayer(n_fine_tune_layers=0)([in_id, in_mask, in_segment])
#                     for (in_id, in_mask, in_segment) in zip(split_in_ids, split_in_masks, split_in_segments)]
    
    if append_vector_len:
        augmented_outputs = [tf.keras.layers.Concatenate(axis=1)([bo, ia]) for (bo, ia) in zip(bert_outputs, in_append)]
    else:
        augmented_outputs = bert_outputs
    concat_output = tf.keras.layers.Concatenate(axis=1)(augmented_outputs)
    dense = tf.keras.layers.Dense(256, activation='relu')(concat_output)
    pred = tf.keras.layers.Dense(num_labels, activation='softmax')(dense)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    
    return model

def initialize_vars(sess):
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    K.set_session(sess)


In [None]:
model = build_model(prepare_data.MAX_SEQ_LENGTH, num_labels=5)

# Instantiate variables


In [None]:
initialize_vars(sess)

model.fit(
    [train_input_ids, train_input_masks, train_segment_ids], 
    train_labels,
    validation_data=([val_input_ids, val_input_masks, val_segment_ids], val_labels),
    epochs=1,
    batch_size=1
)

In [None]:
model.save('BertModel.h5')
pre_save_preds = model.predict([test_input_ids[0:100], 
                                test_input_masks[0:100], 
                                test_segment_ids[0:100]]
                              ) # predictions before we clear and reload model

# Clear and load model
model = None
model = build_model(max_seq_length)
initialize_vars(sess)
model.load_weights('BertModel.h5')

post_save_preds = model.predict([test_input_ids[0:100], 
                                test_input_masks[0:100], 
                                test_segment_ids[0:100]]
                              ) # predictions after we clear and reload model
all(pre_save_preds == post_save_preds) # Are they the same?