In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from albert import modeling
from albert import optimization
from albert import tokenization
import tensorflow as tf
import numpy as np




In [3]:
tokenizer = tokenization.FullTokenizer(
      vocab_file='albert-base-2020-04-10/sp10m.cased.v10.vocab', do_lower_case=False,
      spm_model_file='albert-base-2020-04-10/sp10m.cased.v10.model')


INFO:tensorflow:loading sentence piece model


In [4]:
bert_config = modeling.AlbertConfig.from_json_file('albert-base-2020-04-10/config.json')
bert_config




<albert.modeling.AlbertConfig at 0x7ff2f3dddbe0>

In [5]:
import pickle

with open('albert-squad-test.pkl', 'rb') as fopen:
    test_features, test_examples = pickle.load(fopen)

In [6]:
max_seq_length = 384
doc_stride = 128
max_query_length = 64

In [7]:
epoch = 5
batch_size = 22
warmup_proportion = 0.1
n_best_size = 20
num_train_steps = int(len(test_features) / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)

In [8]:
from tensorflow.contrib import layers as contrib_layers

class Model:
    def __init__(self, is_training = True):
        self.X = tf.placeholder(tf.int32, [None, None])
        self.segment_ids = tf.placeholder(tf.int32, [None, None])
        self.input_masks = tf.placeholder(tf.int32, [None, None])
        self.p_mask = tf.placeholder(tf.int32, [None, None])
        
        model = modeling.AlbertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=self.X,
            input_mask=self.input_masks,
            token_type_ids=self.segment_ids,
            use_one_hot_embeddings=False)
        
        final_hidden = model.get_sequence_output()
        self.output = final_hidden
        vectorize = tf.identity(final_hidden, name = 'logits_vectorize')

In [9]:
learning_rate = 2e-5
start_n_top = 5
end_n_top = 5
is_training = False

tf.reset_default_graph()
model = Model(is_training = is_training)





Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.


In [10]:
output = model.output
bsz = tf.shape(output)[0]
return_dict = {}
output = tf.transpose(output, [1, 0, 2])

# invalid position mask such as query and special symbols (PAD, SEP, CLS)
p_mask = tf.cast(model.p_mask, dtype = tf.float32)

# logit of the start position
with tf.variable_scope('start_logits'):
    start_logits = tf.layers.dense(
        output,
        1,
        kernel_initializer = modeling.create_initializer(
            bert_config.initializer_range
        ),
    )
    start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
    start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
    start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)

# logit of the end position
with tf.variable_scope('end_logits'):
    if is_training:
        # during training, compute the end logits based on the
        # ground truth of the start position
        start_positions = tf.reshape(model.start_positions, [-1])
        start_index = tf.one_hot(
            start_positions,
            depth = max_seq_length,
            axis = -1,
            dtype = tf.float32,
        )
        start_features = tf.einsum('lbh,bl->bh', output, start_index)
        start_features = tf.tile(
            start_features[None], [max_seq_length, 1, 1]
        )
        end_logits = tf.layers.dense(
            tf.concat([output, start_features], axis = -1),
            bert_config.hidden_size,
            kernel_initializer = modeling.create_initializer(
                bert_config.initializer_range
            ),
            activation = tf.tanh,
            name = 'dense_0',
        )
        end_logits = contrib_layers.layer_norm(
            end_logits, begin_norm_axis = -1
        )

        end_logits = tf.layers.dense(
            end_logits,
            1,
            kernel_initializer = modeling.create_initializer(
                bert_config.initializer_range
            ),
            name = 'dense_1',
        )
        end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
    else:
        # during inference, compute the end logits based on beam search

        start_top_log_probs, start_top_index = tf.nn.top_k(
            start_log_probs, k = start_n_top
        )
        start_index = tf.one_hot(
            start_top_index,
            depth = max_seq_length,
            axis = -1,
            dtype = tf.float32,
        )
        start_features = tf.einsum('lbh,bkl->bkh', output, start_index)
        end_input = tf.tile(output[:, :, None], [1, 1, start_n_top, 1])
        start_features = tf.tile(
            start_features[None], [max_seq_length, 1, 1, 1]
        )
        end_input = tf.concat([end_input, start_features], axis = -1)
        end_logits = tf.layers.dense(
            end_input,
            bert_config.hidden_size,
            kernel_initializer = modeling.create_initializer(
                bert_config.initializer_range
            ),
            activation = tf.tanh,
            name = 'dense_0',
        )
        end_logits = contrib_layers.layer_norm(
            end_logits, begin_norm_axis = -1
        )
        end_logits = tf.layers.dense(
            end_logits,
            1,
            kernel_initializer = modeling.create_initializer(
                bert_config.initializer_range
            ),
            name = 'dense_1',
        )
        end_logits = tf.reshape(
            end_logits, [max_seq_length, -1, start_n_top]
        )
        end_logits = tf.transpose(end_logits, [1, 2, 0])
        end_logits_masked = (
            end_logits * (1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
        )
        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
        end_top_log_probs, end_top_index = tf.nn.top_k(
            end_log_probs, k = end_n_top
        )
        end_top_log_probs = tf.reshape(
            end_top_log_probs, [-1, start_n_top * end_n_top]
        )
        end_top_index = tf.reshape(
            end_top_index, [-1, start_n_top * end_n_top]
        )
        
if is_training:
    return_dict['start_log_probs'] = start_log_probs
    return_dict['end_log_probs'] = end_log_probs
else:
    return_dict['start_top_log_probs'] = start_top_log_probs
    return_dict['start_top_index'] = start_top_index
    return_dict['end_top_log_probs'] = end_top_log_probs
    return_dict['end_top_index'] = end_top_index

# an additional layer to predict answerability
with tf.variable_scope('answer_class'):
    # get the representation of CLS
    cls_index = tf.one_hot(
        tf.zeros([bsz], dtype = tf.int32),
        max_seq_length,
        axis = -1,
        dtype = tf.float32,
    )
    cls_feature = tf.einsum('lbh,bl->bh', output, cls_index)

    # get the representation of START
    start_p = tf.nn.softmax(
        start_logits_masked, axis = -1, name = 'softmax_start'
    )
    start_feature = tf.einsum('lbh,bl->bh', output, start_p)

    # note(zhiliny): no dependency on end_feature so that we can obtain
    # one single `cls_logits` for each sample
    ans_feature = tf.concat([start_feature, cls_feature], -1)
    ans_feature = tf.layers.dense(
        ans_feature,
        bert_config.hidden_size,
        activation = tf.tanh,
        kernel_initializer = modeling.create_initializer(
            bert_config.initializer_range
        ),
        name = 'dense_0',
    )
    ans_feature = tf.layers.dropout(
        ans_feature, bert_config.hidden_dropout_prob, training = is_training
    )
    cls_logits = tf.layers.dense(
        ans_feature,
        1,
        kernel_initializer = modeling.create_initializer(
            bert_config.initializer_range
        ),
        name = 'dense_1',
        use_bias = False,
    )
    cls_logits = tf.squeeze(cls_logits, -1)
    
return_dict['cls_logits'] = cls_logits

Instructions for updating:
Use keras.layers.dropout instead.


In [11]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_list = tf.trainable_variables())
saver.restore(sess, 'albert-base-squad/model.ckpt')

INFO:tensorflow:Restoring parameters from albert-base-squad/model.ckpt


In [12]:
start_top_log_probs = tf.identity(start_top_log_probs, name = 'start_top_log_probs')
start_top_index = tf.identity(start_top_index, name = 'start_top_index')
end_top_log_probs = tf.identity(end_top_log_probs, name = 'end_top_log_probs')
end_top_index = tf.identity(end_top_index, name = 'end_top_index')
cls_logits = tf.identity(cls_logits, name = 'cls_logits')

In [13]:
import bert_utils as squad_utils

In [14]:
i = 0
batch_size = 2
batch = test_features[i: i + batch_size]
batch_ids = [b.input_ids for b in batch]
batch_masks = [b.input_mask for b in batch]
batch_segment = [b.segment_ids for b in batch]
batch_start = [b.start_position for b in batch]
batch_end = [b.end_position for b in batch]
is_impossible = [b.is_impossible for b in batch]
p_mask = [b.p_mask for b in batch]
o = sess.run(
    [start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits],
    feed_dict = {
        model.X: batch_ids,
        model.segment_ids: batch_segment,
        model.input_masks: batch_masks,
        model.p_mask: p_mask
    },
)

In [15]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'output-albert-base-squad/model.ckpt')

'output-albert-base-squad/model.ckpt'

In [16]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'start_' in n.name
        or 'end_' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'Placeholder_2',
 'Placeholder_3',
 'bert/embeddings/word_embeddings',
 'bert/embeddings/token_type_embeddings',
 'bert/embeddings/position_embeddings',
 'bert/embeddings/LayerNorm/gamma',
 'bert/encoder/embedding_hidden_mapping_in/kernel',
 'bert/encoder/embedding_hidden_mapping_in/bias',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel',
 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias',
 'bert/encoder/transformer/group_0

In [17]:
def freeze_graph(model_dir, output_node_names):

    if not tf.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [18]:
freeze_graph('output-albert-base-squad', strings)

INFO:tensorflow:Restoring parameters from output-albert-base-squad/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 36 variables.
INFO:tensorflow:Converted 36 variables to const ops.
3482 ops in the final graph.


In [19]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
            if 'validate_shape' in node.attr:
                del node.attr['validate_shape']
            if len(node.input) == 2:
                node.input[0] = node.input[1]
                del node.input[1]
                
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph

In [20]:
g = load_graph('output-albert-base-squad/frozen_model.pb')

In [21]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

In [22]:
from tensorflow.tools.graph_transforms import TransformGraph
tf.set_random_seed(0)

In [23]:
pb = 'output-albert-base-squad/frozen_model.pb'

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())
    
inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2', 'Placeholder_3']
outputs = ['start_top_log_probs',
 'start_top_index',
 'end_top_log_probs',
 'end_top_index',
 'cls_logits',
 'logits_vectorize']

transformed_graph_def = TransformGraph(input_graph_def, 
                                           inputs,
                                           outputs, transforms)

with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


In [24]:
g = load_graph('output-albert-base-squad/frozen_model.pb.quantized')

In [25]:
input_nodes = {i: g.get_tensor_by_name(f'import/{i}:0') for i in inputs}
output_nodes = {i: g.get_tensor_by_name(f'import/{i}:0') for i in outputs}

In [26]:
test_sess = tf.InteractiveSession(graph = g)



In [27]:
b = [batch_ids, batch_segment, batch_masks, p_mask]
b = {input_nodes[i]: b[no] for no, i in enumerate(inputs)}

In [28]:
o = test_sess.run(
    output_nodes, feed_dict = b,
)

In [29]:
o

{'start_top_log_probs': array([[-0.046217  , -3.8763359 , -4.272926  , -6.079081  , -6.1347513 ],
        [-0.17415078, -2.2447011 , -3.8741577 , -4.4495482 , -5.3810267 ]],
       dtype=float32), 'start_top_index': array([[ 56,  84,   0, 153,  89],
        [ 39,  38,   0,  40,  33]], dtype=int32), 'end_top_log_probs': array([[-9.7745610e-03, -4.6820168e+00, -8.3515291e+00, -1.0275996e+01,
         -1.0657001e+01, -6.4443070e-01, -8.9851624e-01, -3.4607339e+00,
         -3.9201808e+00, -4.3888478e+00, -3.3040240e-03, -6.1882215e+00,
         -7.3882465e+00, -8.8793802e+00, -9.5075378e+00, -5.5482339e-02,
         -3.6105275e+00, -3.6759305e+00, -7.3098426e+00, -8.7379532e+00,
         -1.2261421e-01, -3.0495484e+00, -3.4539359e+00, -3.9068253e+00,
         -4.5645785e+00],
        [-1.7485154e-01, -1.8497354e+00, -6.4667540e+00, -6.9058638e+00,
         -8.5502615e+00, -2.0349765e-01, -1.7505805e+00, -5.5180950e+00,
         -5.8632727e+00, -7.1287708e+00, -4.5225713e-03, -5.6678548e+0