In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import tensorflow as tf
from malaya.train.model.bigbird import modeling, utils

In [3]:
bert_config = {
    'attention_probs_dropout_prob': 0.1,
    'hidden_act': 'gelu',
    'hidden_dropout_prob': 0.1,
    'hidden_size': 512,
    'initializer_range': 0.02,
    'intermediate_size': 2048,
    'max_position_embeddings': 2048,
    'max_encoder_length': 1024,
    'max_decoder_length': 1024,
    'num_attention_heads': 8,
    'num_hidden_layers': 6,
    'type_vocab_size': 2,
    'scope': 'bert',
    'use_bias': True,
    'rescale_embedding': False,
    'vocab_model_file': None,
    'attention_type': 'block_sparse',
    'block_size': 16,
    'num_rand_blocks': 3,
    'vocab_size': 32000,
    'couple_encoder_decoder': False,
    'beam_size': 1,
    'alpha': 0.0,
    'label_smoothing': 0.1,
    'norm_type': 'postnorm',
}

In [4]:
import sentencepiece as spm

vocab = 'sp10m.cased.translation.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)

class Encoder:
    def __init__(self, sp):
        self.sp = sp
    
    def encode(self, s):
        return self.sp.EncodeAsIds(s) + [1]
    
    def decode(self, ids, strip_extraneous=False):
        return self.sp.DecodeIds(list(ids))
    
encoder = Encoder(sp)

In [5]:
model = modeling.TransformerModel(bert_config)

In [6]:
X = tf.compat.v1.placeholder(tf.int32, [None, None])

In [7]:
r = model(X, training = False)
r

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
reduction_indices is deprecated, use axis instead


Instructions for updating:
reduction_indices is deprecated, use axis instead


((<tf.Tensor 'bert/log_probs:0' shape=(?, 1024) dtype=float32>,
  <tf.Tensor 'bert/logits:0' shape=(?, 1024, 32000) dtype=float32>,
  <tf.Tensor 'bert/while/Exit_1:0' shape=(?, 1024) dtype=int32>),
 <tf.Tensor 'bert/encoder/layer_5/output/LayerNorm/batchnorm/add_1:0' shape=(?, 1024, 512) dtype=float32>)

In [8]:
logits = tf.identity(r[0][2], name = 'logits')
logits

<tf.Tensor 'logits:0' shape=(?, 1024) dtype=int32>

In [10]:
ckpt_path = tf.train.latest_checkpoint('bigbird-base-en-ms')
ckpt_path

'bigbird-base-en-ms/model.ckpt-375000'

In [11]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [12]:
saver = tf.train.Saver()
saver.restore(sess, ckpt_path)

INFO:tensorflow:Restoring parameters from bigbird-base-en-ms/model.ckpt-375000


INFO:tensorflow:Restoring parameters from bigbird-base-en-ms/model.ckpt-375000


In [13]:
pad_sequences = tf.keras.preprocessing.sequence.pad_sequences

In [14]:
import re
from unidecode import unidecode

def cleaning(string):
    return re.sub(r'[ ]+', ' ', unidecode(string.replace('\n', ' '))).strip()

In [15]:
string = """
Amongst the wide-ranging initiatives proposed are a sustainable food labelling framework, a reformulation of processed foods, and a sustainability chapter in all EU bilateral trade agreements. The EU also plans to publish a proposal for a legislative framework for sustainable food systems by 2023 to ensure all foods on the EU market become increasingly sustainable.
"""
cleaning(string)

'Amongst the wide-ranging initiatives proposed are a sustainable food labelling framework, a reformulation of processed foods, and a sustainability chapter in all EU bilateral trade agreements. The EU also plans to publish a proposal for a legislative framework for sustainable food systems by 2023 to ensure all foods on the EU market become increasingly sustainable.'

In [16]:
encoded = encoder.encode(f'{cleaning(string)}') + [1]
s = pad_sequences([encoded], padding='post', maxlen = 1024)

In [17]:
%%time
l = sess.run(r[0][2], feed_dict = {X: s})

CPU times: user 2.35 s, sys: 111 ms, total: 2.46 s
Wall time: 2.34 s


In [18]:
encoder.decode([i for i in l[0].tolist() if i > 0])

'Antara inisiatif yang dicadangkan secara meluas ialah rangka kerja pelabelan makanan yang mampan, pembaharuan makanan diproses, dan bab kelestarian dalam semua perjanjian perdagangan dua hala EU. EU juga merancang untuk menerbitkan cadangan rangka kerja perundangan untuk sistem makanan yang mampan menjelang 2023 untuk memastikan semua makanan di pasaran EU menjadi semakin mampan.'

In [19]:
# !wget https://f000.backblazeb2.com/file/malay-dataset/test-en-ms.tar.gz
# !tar -zxf test-en-ms.tar.gz

In [20]:
batch_size = 24

path = 'test-en'

with open(os.path.join(path, 'left.txt')) as fopen:
    left = fopen.read().split('\n')
    
with open(os.path.join(path, 'right.txt')) as fopen:
    right = fopen.read().split('\n')
    
len(left), len(right)

(77707, 77707)

In [21]:
%%time

encoded = encoder.encode(left[0]) + [1]
s = pad_sequences([encoded], padding='post', maxlen = 1024)

CPU times: user 0 ns, sys: 1.65 ms, total: 1.65 ms
Wall time: 1.65 ms


In [22]:
%%time

p = sess.run(logits, feed_dict = {X: s}).tolist()
results = []
for row in p:
    results.append([i for i in row if i not in [0, 1]])
results

CPU times: user 4.27 s, sys: 168 ms, total: 4.43 s
Wall time: 3.95 s


[[2439,
  1676,
  1663,
  3568,
  47,
  9,
  3045,
  103,
  4202,
  3853,
  19928,
  15084,
  20,
  1596,
  1136,
  10954,
  33,
  614,
  81,
  9,
  863,
  3925,
  107,
  194,
  133,
  20,
  12195,
  20,
  22720,
  20,
  12195,
  207,
  27,
  516,
  14236,
  27,
  2190,
  26,
  2067,
  31,
  163,
  107,
  194,
  133,
  9,
  7391,
  20,
  81,
  130,
  1058,
  3459,
  879,
  3356,
  2010,
  4711,
  13,
  791,
  28,
  130,
  115,
  10856,
  26,
  22694,
  20,
  262,
  36,
  65,
  5350,
  26,
  2345,
  3754,
  25676,
  20,
  26,
  2345,
  753,
  33,
  6957,
  20,
  2345,
  2588,
  50,
  30941,
  2929,
  27,
  2345,
  9790,
  4603,
  633,
  8431,
  5662,
  26,
  12265,
  39,
  614,
  344,
  43,
  233,
  85,
  1093,
  31,
  427,
  118,
  15767,
  25946,
  1431,
  154,
  19,
  6892,
  13,
  791,
  28,
  168,
  1196,
  3203,
  162,
  891,
  20,
  1185,
  192,
  13434,
  1679,
  5662,
  26,
  54,
  127,
  4184,
  9,
  206,
  442,
  1290,
  1616,
  367,
  177,
  17474,
  158,
  27253,
  20,
  26

In [23]:
from tensor2tensor.utils import bleu_hook
bleu_hook.compute_bleu(reference_corpus = [encoder.encode(right[0])], 
                       translation_corpus = results)

0.6417256

In [24]:
from tqdm import tqdm

results = []
for i in tqdm(range(0, len(left), batch_size)):
    index = min(i + batch_size, len(left))
    x = left[i: index]
    encoded = [encoder.encode(l) + [1] for l in x]
    batch_x = pad_sequences(encoded, padding='post', maxlen = 1024)
    
    p = sess.run(logits, feed_dict = {X: batch_x}).tolist()
    result = []
    for row in p:
        result.append([i for i in row if i not in [0, 1]])
    results.extend(result)

100%|██████████| 3238/3238 [6:08:41<00:00,  6.83s/it]   


In [25]:
rights = [encoder.encode(r) for r in right[:len(results)]]
bleu_hook.compute_bleu(reference_corpus = rights,
                       translation_corpus = results)

0.16509123

In [26]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'output/model.ckpt')

'output/model.ckpt'

In [27]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name
        or 'self/Softmax' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'gradients' not in n.name
    ]
)
strings.split(',')

['bert/embeddings/word_embeddings',
 'bert/embeddings/position_embeddings',
 'Placeholder',
 'bert/encoder/LayerNorm/gamma',
 'bert/encoder/layer_0/attention/self/query/kernel',
 'bert/encoder/layer_0/attention/self/query/bias',
 'bert/encoder/layer_0/attention/self/key/kernel',
 'bert/encoder/layer_0/attention/self/key/bias',
 'bert/encoder/layer_0/attention/self/value/kernel',
 'bert/encoder/layer_0/attention/self/value/bias',
 'bert/encoder/layer_0/attention/self/Softmax',
 'bert/encoder/layer_0/attention/self/Softmax_1',
 'bert/encoder/layer_0/attention/self/Softmax_2',
 'bert/encoder/layer_0/attention/self/Softmax_3',
 'bert/encoder/layer_0/attention/self/Softmax_4',
 'bert/encoder/layer_0/attention/output/dense/kernel',
 'bert/encoder/layer_0/attention/output/dense/bias',
 'bert/encoder/layer_0/attention/output/LayerNorm/gamma',
 'bert/encoder/layer_0/intermediate/dense/kernel',
 'bert/encoder/layer_0/intermediate/dense/bias',
 'bert/encoder/layer_0/output/dense/kernel',
 'bert/e

In [28]:
def freeze_graph(model_dir, output_node_names):

    if not tf.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [29]:
freeze_graph('output', strings)

INFO:tensorflow:Restoring parameters from output/model.ckpt


INFO:tensorflow:Restoring parameters from output/model.ckpt


Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`


Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


INFO:tensorflow:Froze 258 variables.


INFO:tensorflow:Froze 258 variables.


INFO:tensorflow:Converted 258 variables to const ops.


INFO:tensorflow:Converted 258 variables to const ops.


13154 ops in the final graph.


In [30]:
from tensorflow.tools.graph_transforms import TransformGraph

In [31]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

In [32]:
pb = 'output/frozen_model.pb'

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())
        
inputs = ['Placeholder']
transformed_graph_def = TransformGraph(input_graph_def, 
                                       inputs,
                                       ['logits'], transforms)

with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


Instructions for updating:
Use tf.gfile.GFile.


In [33]:
def load_graph(frozen_graph_filename, **kwargs):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091
    # to fix import T5
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
            if 'validate_shape' in node.attr:
                del node.attr['validate_shape']
            if len(node.input) == 2:
                node.input[0] = node.input[1]
                del node.input[1]

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph


In [34]:
g = load_graph('output/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.InteractiveSession(graph = g)

In [35]:
%%time
l = test_sess.run(logits, feed_dict = {x: s})
encoder.decode([i for i in l[0].tolist() if i > 0])

CPU times: user 4.73 s, sys: 310 ms, total: 5.03 s
Wall time: 4.56 s


'Anda tahu bagaimana ceritanya. Dua orang kulit putih heteroseksual, hampir pasti berjuang dalam hidup mereka. Mereka bertemu satu sama lain, berkumpul, berpisah, berkumpul lagi dan mendapat kebahagiaan dan menyelesaikan yang sebenar di antara satu sama lain. Akhirnya, mereka boleh mula menjalani kehidupan normal! Rom-coms boleh menjadi pelarian yang menyeronokkan, tetapi ini adalah kisah yang terlalu kerap diceritakan, yang terlalu kurang dalam kepelbagaian, terlalu bergantung pada stereotaip gender dan terlalu mementingkan menjual kami jenama cinta yang mustahil untuk hidup sehingga: kajian tahun 2008 di Universiti Heriot Watt mendapati bahawa rom-coms mempunyai kesan negatif terhadap hubungan, menjadikan kita mengejar standard cinta yang tidak dapat dicapai. Dalam proses menulis drama baru saya Ross & Rachel, yang menghadapi mitos cinta moden dan dibuka di Edinburgh Festival Fringe pada Ogos ini, saya harus berfikir banyak tentang percintaan dalam fiksyen dari Romeo dan Juliet hingg

In [36]:
g = load_graph('output/frozen_model.pb.quantized')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.InteractiveSession(graph = g)

In [37]:
%%time
l = test_sess.run(logits, feed_dict = {x: s})
encoder.decode([i for i in l[0].tolist() if i > 0])

CPU times: user 8.25 s, sys: 316 ms, total: 8.57 s
Wall time: 8.55 s


'Anda tahu bagaimana ceritanya. Dua orang kulit putih heteroseksual, hampir-hampir-tentu bergelut dalam kehidupan mereka. Mereka bertemu satu sama lain, berkumpul, berpisah, berkumpul lagi dan mendapat kebahagiaan dan menyelesaikan yang sebenar di antara satu sama lain. Akhirnya, mereka boleh mula menjalani kehidupan normal! Rom-coms boleh menjadi pelarian yang menyeronokkan, tetapi ini adalah kisah yang terlalu kerap diceritakan, yang terlalu kurang dalam kepelbagaian, terlalu bergantung pada stereotaip gender dan terlalu mementingkan penjualan kita jenama cinta yang mustahil untuk hidup sehingga: kajian tahun 2008 di Universiti Heriot Watt mendapati bahawa rom-coms mempunyai kesan negatif terhadap hubungan, menjadikan kita mengejar standard cinta yang tidak dapat dicapai. Dalam proses menulis drama baru saya Ross & Rachel, yang menghadapi mitos cinta moden dan terbuka di Edinburgh Festival Fringe pada Ogos ini, saya harus berfikir banyak tentang percintaan dalam fiksyen dari Romeo da