In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
import numpy as np
import json
import tensorflow as tf
import itertools
import collections
import re
import random
import sentencepiece as spm
from tqdm import tqdm
import xlnet_utils as squad_utils
import xlnet

In [3]:
from prepro_utils import preprocess_text, encode_ids

sp_model = spm.SentencePieceProcessor()
sp_model.Load('sp10m.cased.v9.model')

True

In [4]:
import pickle

with open('/home/husein/xlnet/xlnet-squad-test.pkl', 'rb') as fopen:
    test_features, test_examples = pickle.load(fopen)

In [5]:
max_seq_length = 512
doc_stride = 128
max_query_length = 64

In [6]:
epoch = 5
batch_size = 6
warmup_proportion = 0.1
n_best_size = 20
num_train_steps = int(len(test_features) / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)
learning_rate = 2e-5

In [7]:
kwargs = dict(
      is_training=False,
      use_tpu=False,
      use_bfloat16=False,
      dropout=0.1,
      dropatt=0.1,
      init='normal',
      init_range=0.1,
      init_std=0.05,
      clamp_len=-1)

xlnet_parameters = xlnet.RunConfig(**kwargs)
xlnet_config = xlnet.XLNetConfig(
    json_path = 'alxlnet-base-2020-04-10/config.json'
)




In [10]:
training_parameters = dict(
      decay_method = 'poly',
      train_steps = num_train_steps,
      learning_rate = learning_rate,
      warmup_steps = num_warmup_steps,
      min_lr_ratio = 0.0,
      weight_decay = 0.00,
      adam_epsilon = 1e-8,
      num_core_per_host = 1,
      lr_layer_decay_rate = 1,
      use_tpu=False,
      use_bfloat16=False,
      dropout=0.0,
      dropatt=0.0,
      init='normal',
      init_range=0.1,
      init_std=0.05,
      clip = 1.0,
      clamp_len=-1,)

In [11]:
class Parameter:
    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, 
                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,
                min_lr_ratio, clip, **kwargs):
        self.decay_method = decay_method
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.adam_epsilon = adam_epsilon
        self.num_core_per_host = num_core_per_host
        self.lr_layer_decay_rate = lr_layer_decay_rate
        self.use_tpu = use_tpu
        self.learning_rate = learning_rate
        self.train_steps = train_steps
        self.min_lr_ratio = min_lr_ratio
        self.clip = clip
        
training_parameters = Parameter(**training_parameters)

In [12]:
from tensorflow.contrib import layers as contrib_layers

class Model:
    def __init__(self, is_training = True):
        self.X = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.segment_ids = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.input_masks = tf.compat.v1.placeholder(tf.compat.v1.float32, [None, None])
        self.p_mask = tf.compat.v1.placeholder(tf.compat.v1.float32, [None, None])
        self.cls_index = tf.compat.v1.placeholder(tf.compat.v1.int32, [None])
        
        xlnet_model = xlnet.XLNetModel(
            xlnet_config=xlnet_config,
            run_config=xlnet_parameters,
            input_ids=tf.compat.v1.transpose(self.X, [1, 0]),
            seg_ids=tf.compat.v1.transpose(self.segment_ids, [1, 0]),
            input_mask=tf.compat.v1.transpose(self.input_masks, [1, 0]))
        
        output = xlnet_model.get_sequence_output()
        self.output = output
        self.vectorize = tf.compat.v1.identity(tf.compat.v1.transpose(output, [1, 0, 2]), name = 'logits_vectorize')
        self.model = xlnet_model

In [13]:
is_training = False

tf.compat.v1.reset_default_graph()
model = Model(is_training = is_training)




INFO:tensorflow:memory input None
INFO:tensorflow:Use float type <dtype: 'float32'>

Instructions for updating:
Use keras.layers.dropout instead.
Instructions for updating:
Please use `layer.__call__` method instead.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.Dense instead.


In [14]:
start_n_top = 5
end_n_top = 5
seq_len = tf.compat.v1.shape(model.X)[1]
initializer = model.model.get_initializer()
return_dict = {}
p_mask = model.p_mask
output = model.output
cls_index = model.cls_index

with tf.compat.v1.variable_scope('start_logits'):
    start_logits = tf.compat.v1.layers.dense(
        output, 1, kernel_initializer = initializer
    )
    start_logits = tf.compat.v1.transpose(tf.compat.v1.squeeze(start_logits, -1), [1, 0])
    start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
    start_log_probs = tf.compat.v1.nn.log_softmax(start_logits_masked, -1)
    
with tf.compat.v1.variable_scope('end_logits'):
    if is_training:
        # during training, compute the end logits based on the
        # ground truth of the start position

        start_positions = tf.compat.v1.reshape(model.start_positions, [-1])
        start_index = tf.compat.v1.one_hot(
            start_positions, depth = seq_len, axis = -1, dtype = tf.compat.v1.float32
        )
        start_features = tf.compat.v1.einsum('lbh,bl->bh', output, start_index)
        start_features = tf.compat.v1.tile(start_features[None], [seq_len, 1, 1])
        end_logits = tf.compat.v1.layers.dense(
            tf.compat.v1.concat([output, start_features], axis = -1),
            xlnet_config.d_model,
            kernel_initializer = initializer,
            activation = tf.compat.v1.tanh,
            name = 'dense_0',
        )
        end_logits = tf.contrib.layers.layer_norm(
            end_logits, begin_norm_axis = -1
        )

        end_logits = tf.compat.v1.layers.dense(
            end_logits,
            1,
            kernel_initializer = initializer,
            name = 'dense_1',
        )
        end_logits = tf.compat.v1.transpose(tf.compat.v1.squeeze(end_logits, -1), [1, 0])
        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
        end_log_probs = tf.compat.v1.nn.log_softmax(end_logits_masked, -1)
    else:
        # during inference, compute the end logits based on beam search

        start_top_log_probs, start_top_index = tf.compat.v1.nn.top_k(
            start_log_probs, k = start_n_top
        )
        start_index = tf.compat.v1.one_hot(
            start_top_index, depth = seq_len, axis = -1, dtype = tf.compat.v1.float32
        )
        start_features = tf.compat.v1.einsum('lbh,bkl->bkh', output, start_index)
        end_input = tf.compat.v1.tile(
            output[:, :, None], [1, 1, start_n_top, 1]
        )
        start_features = tf.compat.v1.tile(start_features[None], [seq_len, 1, 1, 1])
        end_input = tf.compat.v1.concat([end_input, start_features], axis = -1)
        end_logits = tf.compat.v1.layers.dense(
            end_input,
            xlnet_config.d_model,
            kernel_initializer = initializer,
            activation = tf.compat.v1.tanh,
            name = 'dense_0',
        )
        end_logits = tf.contrib.layers.layer_norm(
            end_logits, begin_norm_axis = -1
        )
        end_logits = tf.compat.v1.layers.dense(
            end_logits,
            1,
            kernel_initializer = initializer,
            name = 'dense_1',
        )
        end_logits = tf.compat.v1.reshape(
            end_logits, [seq_len, -1, start_n_top]
        )
        end_logits = tf.compat.v1.transpose(end_logits, [1, 2, 0])
        end_logits_masked = (
            end_logits * (1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
        )
        end_log_probs = tf.compat.v1.nn.log_softmax(end_logits_masked, -1)
        end_top_log_probs, end_top_index = tf.compat.v1.nn.top_k(
            end_log_probs, k = end_n_top
        )
        end_top_log_probs = tf.compat.v1.reshape(
            end_top_log_probs, [-1, start_n_top * end_n_top]
        )
        end_top_index = tf.compat.v1.reshape(
            end_top_index, [-1, start_n_top * end_n_top]
        )

if is_training:
    return_dict['start_log_probs'] = start_log_probs
    return_dict['end_log_probs'] = end_log_probs
else:
    return_dict['start_top_log_probs'] = start_top_log_probs
    return_dict['start_top_index'] = start_top_index
    return_dict['end_top_log_probs'] = end_top_log_probs
    return_dict['end_top_index'] = end_top_index

# an additional layer to predict answerability
with tf.compat.v1.variable_scope('answer_class'):
    # get the representation of CLS
    cls_index = tf.compat.v1.one_hot(
        cls_index, seq_len, axis = -1, dtype = tf.compat.v1.float32
    )
    cls_feature = tf.compat.v1.einsum('lbh,bl->bh', output, cls_index)

    # get the representation of START
    start_p = tf.compat.v1.nn.softmax(
        start_logits_masked, axis = -1, name = 'softmax_start'
    )
    start_feature = tf.compat.v1.einsum('lbh,bl->bh', output, start_p)

    # note(zhiliny): no dependency on end_feature so that we can obtain
    # one single `cls_logits` for each sample
    ans_feature = tf.compat.v1.concat([start_feature, cls_feature], -1)
    ans_feature = tf.compat.v1.layers.dense(
        ans_feature,
        xlnet_config.d_model,
        activation = tf.compat.v1.tanh,
        kernel_initializer = initializer,
        name = 'dense_0',
    )
    ans_feature = tf.compat.v1.layers.dropout(
        ans_feature, 0.1, training = is_training
    )
    cls_logits = tf.compat.v1.layers.dense(
        ans_feature,
        1,
        kernel_initializer = initializer,
        name = 'dense_1',
        use_bias = False,
    )
    cls_logits = tf.compat.v1.squeeze(cls_logits, -1)

    return_dict['cls_logits'] = cls_logits

In [15]:
sess = tf.compat.v1.InteractiveSession()
sess.run(tf.compat.v1.global_variables_initializer())
saver = tf.compat.v1.train.Saver(var_list = tf.compat.v1.trainable_variables())
saver.restore(sess, 'alxlnet-base-squad/model.ckpt')

INFO:tensorflow:Restoring parameters from alxlnet-base-squad/model.ckpt


In [16]:
start_top_log_probs = tf.compat.v1.identity(start_top_log_probs, name = 'start_top_log_probs')
start_top_index = tf.compat.v1.identity(start_top_index, name = 'start_top_index')
end_top_log_probs = tf.compat.v1.identity(end_top_log_probs, name = 'end_top_log_probs')
end_top_index = tf.compat.v1.identity(end_top_index, name = 'end_top_index')
cls_logits = tf.compat.v1.identity(cls_logits, name = 'cls_logits')

In [17]:
i = 0
batch_size = 2
batch = test_features[i: i + batch_size]
batch_ids = [b.input_ids for b in batch]
batch_masks = [b.input_mask for b in batch]
batch_segment = [b.segment_ids for b in batch]
batch_start = [b.start_position for b in batch]
batch_end = [b.end_position for b in batch]
is_impossible = [b.is_impossible for b in batch]
p_mask = [b.p_mask for b in batch]
cls_index = [b.cls_index for b in batch]
o = sess.run(
    [start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits],
    feed_dict = {
        model.X: batch_ids,
        model.segment_ids: batch_segment,
        model.input_masks: batch_masks,
        model.p_mask: p_mask,
        model.cls_index: cls_index
    },
)

In [18]:
saver = tf.compat.v1.train.Saver(tf.compat.v1.trainable_variables())
saver.save(sess, 'output-alxlnet-base-squad/model.ckpt')

'output-alxlnet-base-squad/model.ckpt'

In [19]:
strings = ','.join(
    [
        n.name
        for n in tf.compat.v1.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'start_' in n.name
        or 'end_' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'Placeholder_2',
 'Placeholder_3',
 'Placeholder_4',
 'model/transformer/r_w_bias',
 'model/transformer/r_r_bias',
 'model/transformer/word_embedding/lookup_table',
 'model/transformer/word_embedding/lookup_table_2',
 'model/transformer/r_s_bias',
 'model/transformer/seg_embed',
 'model/transformer/layer_shared/rel_attn/q/kernel',
 'model/transformer/layer_shared/rel_attn/k/kernel',
 'model/transformer/layer_shared/rel_attn/v/kernel',
 'model/transformer/layer_shared/rel_attn/r/kernel',
 'model/transformer/layer_shared/rel_attn/o/kernel',
 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma',
 'model/transformer/layer_shared/ff/layer_1/kernel',
 'model/transformer/layer_shared/ff/layer_1/bias',
 'model/transformer/layer_shared/ff/layer_2/kernel',
 'model/transformer/layer_shared/ff/layer_2/bias',
 'model/transformer/layer_shared/ff/LayerNorm/gamma',
 'logits_vectorize',
 'start_logits/dense/kernel/Initializer/random_normal/shape',
 'start_logits/

In [20]:
def freeze_graph(model_dir, output_node_names):

    if not tf.compat.v1.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.compat.v1.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.compat.v1.Session(graph = tf.compat.v1.Graph()) as sess:
        saver = tf.compat.v1.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess,
            tf.compat.v1.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.compat.v1.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [21]:
freeze_graph('output-alxlnet-base-squad', strings)

INFO:tensorflow:Restoring parameters from output-alxlnet-base-squad/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 30 variables.
INFO:tensorflow:Converted 30 variables to const ops.
6909 ops in the final graph.


In [22]:
def load_graph(frozen_graph_filename):
    with tf.compat.v1.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
            if 'validate_shape' in node.attr:
                del node.attr['validate_shape']
            if len(node.input) == 2:
                node.input[0] = node.input[1]
                del node.input[1]
                
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)
    return graph

In [24]:
g = load_graph('output-alxlnet-base-squad/frozen_model.pb')

In [25]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

In [26]:
from tensorflow.tools.graph_transforms import TransformGraph
tf.compat.v1.set_random_seed(0)

In [27]:
pb = 'output-alxlnet-base-squad/frozen_model.pb'

input_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())
    
inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2', 'Placeholder_3', 'Placeholder_4']
outputs = ['start_top_log_probs',
 'start_top_index',
 'end_top_log_probs',
 'end_top_index',
 'cls_logits',
 'logits_vectorize']

transformed_graph_def = TransformGraph(input_graph_def, 
                                           inputs,
                                           outputs, transforms)

with tf.compat.v1.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.compat.v1.gfile.GFile.


In [28]:
g = load_graph('output-alxlnet-base-squad/frozen_model.pb.quantized')

In [29]:
input_nodes = {i: g.get_tensor_by_name(f'import/{i}:0') for i in inputs}
output_nodes = {i: g.get_tensor_by_name(f'import/{i}:0') for i in outputs}

In [30]:
test_sess = tf.compat.v1.InteractiveSession(graph = g)



In [31]:
b = [batch_ids, batch_segment, batch_masks, p_mask, cls_index]
b = {input_nodes[i]: b[no] for no, i in enumerate(inputs)}

In [32]:
o = test_sess.run(
    output_nodes, feed_dict = b,
)

In [33]:
o

{'start_top_log_probs': array([[-0.01017361, -5.3508844 , -6.230524  , -7.0358524 , -7.1334867 ],
        [-0.12347065, -2.545479  , -4.830178  , -5.2084336 , -5.41917   ]],
       dtype=float32), 'start_top_index': array([[ 47, 190,  45,  46,  44],
        [ 29,  28, 191,  24,  26]], dtype=int32), 'end_top_log_probs': array([[-1.0673070e-03, -6.8515882e+00, -1.2371474e+01, -1.3416997e+01,
         -1.4230107e+01, -5.7099620e-05, -1.0555126e+01, -1.2802329e+01,
         -1.2937642e+01, -1.3404304e+01, -4.9938989e-01, -1.3587554e+00,
         -2.5299537e+00, -3.4406898e+00, -4.0150442e+00, -1.4348902e-01,
         -2.6492026e+00, -2.8838952e+00, -5.4946475e+00, -7.2217665e+00,
         -9.6618307e-01, -1.1195062e+00, -1.6656684e+00, -2.9643059e+00,
         -3.5911660e+00],
        [-1.7521942e-02, -4.3683023e+00, -5.9846005e+00, -6.4855165e+00,
         -8.3603144e+00, -4.2344693e-02, -4.3113942e+00, -4.5560584e+00,
         -4.6392651e+00, -6.5254221e+00, -2.0668755e-04, -1.0992020e+0