In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.layers import common_attention
from tensor2tensor.utils import registry
from tensor2tensor import problems
import tensorflow as tf
import os
import logging
import sentencepiece as spm
import transformer_tag
from tensor2tensor.layers import modalities







In [3]:
vocab = 'sp10m.cased.t5.model'
sp = spm.SentencePieceProcessor()
sp.Load(vocab)

class Encoder:
    def __init__(self, sp):
        self.sp = sp
        self.vocab_size = sp.GetPieceSize() + 100

    def encode(self, s):
        return self.sp.EncodeAsIds(s)

    def decode(self, ids, strip_extraneous = False):
        return self.sp.DecodeIds(list(ids))

In [4]:
d = [
    {'class': 0, 'Description': 'PAD', 'salah': '', 'betul': ''},
    {
        'class': 1,
        'Description': 'kesambungan subwords',
        'salah': '',
        'betul': '',
    },
    {
        'class': 2,
        'Description': 'tiada kesalahan',
        'salah': '',
        'betul': '',
    },
    {
        'class': 3,
        'Description': 'kesalahan frasa nama, Perkara yang diterangkan mesti mendahului "penerang"',
        'salah': 'Cili sos',
        'betul': 'sos cili',
    },
    {
        'class': 4,
        'Description': 'kesalahan kata jamak',
        'salah': 'mereka-mereka',
        'betul': 'mereka',
    },
    {
        'class': 5,
        'Description': 'kesalahan kata penguat',
        'salah': 'sangat tinggi sekali',
        'betul': 'sangat tinggi',
    },
    {
        'class': 6,
        'Description': 'kata adjektif dan imbuhan "ter" tanpa penguat.',
        'salah': 'Sani mendapat markah yang tertinggi sekali.',
        'betul': 'Sani mendapat markah yang tertinggi.',
    },
    {
        'class': 7,
        'Description': 'kesalahan kata hubung',
        'salah': 'Sally sedang membaca bila saya tiba di rumahnya.',
        'betul': 'Sally sedang membaca apabila saya tiba di rumahnya.',
    },
    {
        'class': 8,
        'Description': 'kesalahan kata bilangan',
        'salah': 'Beribu peniaga tidak membayar cukai pendapatan.',
        'betul': 'Beribu-ribu peniaga tidak membayar cukai pendapatan',
    },
    {
        'class': 9,
        'Description': 'kesalahan kata sendi',
        'salah': 'Umar telah berpindah daripada sekolah ini bulan lalu.',
        'betul': 'Umar telah berpindah dari sekolah ini bulan lalu.',
    },
    {
        'class': 10,
        'Description': 'kesalahan penjodoh bilangan',
        'salah': 'Setiap orang pelajar',
        'betul': 'Setiap pelajar.',
    },
    {
        'class': 11,
        'Description': 'kesalahan kata ganti diri',
        'salah': 'Pencuri itu telah ditangkap. Beliau dibawa ke balai polis.',
        'betul': 'Pencuri itu telah ditangkap. Dia dibawa ke balai polis.',
    },
    {
        'class': 12,
        'Description': 'kesalahan ayat pasif',
        'salah': 'Cerpen itu telah dikarang oleh saya.',
        'betul': 'Cerpen itu telah saya karang.',
    },
    {
        'class': 13,
        'Description': 'kesalahan kata tanya',
        'salah': 'Kamu berasal dari manakah ?',
        'betul': 'Kamu berasal dari mana ?',
    },
    {
        'class': 14,
        'Description': 'kesalahan tanda baca',
        'salah': 'Kamu berasal dari manakah .',
        'betul': 'Kamu berasal dari mana ?',
    },
    {
        'class': 15,
        'Description': 'kesalahan kata kerja tak transitif',
        'salah': 'Dia kata kepada saya',
        'betul': 'Dia berkata kepada saya',
    },
    {
        'class': 16,
        'Description': 'kesalahan kata kerja transitif',
        'salah': 'Dia suka baca buku',
        'betul': 'Dia suka membaca buku',
    },
    {
        'class': 17,
        'Description': 'penggunaan kata yang tidak tepat',
        'salah': 'Tembuk Besar negeri Cina dibina oleh Shih Huang Ti.',
        'betul': 'Tembok Besar negeri Cina dibina oleh Shih Huang Ti',
    },
]


class Tatabahasa:
    def __init__(self, d):
        self.d = d
        self.kesalahan = {i['Description']: no for no, i in enumerate(self.d)}
        self.reverse_kesalahan = {v: k for k, v in self.kesalahan.items()}
        self.vocab_size = len(self.d)

    def encode(self, s):
        return [self.kesalahan[i] for i in s]

    def decode(self, ids, strip_extraneous = False):
        return [self.reverse_kesalahan[i] for i in ids]

In [5]:
@registry.register_problem
class Grammar(text_problems.Text2TextProblem):
    """grammatical error correction."""

    def feature_encoders(self, data_dir):
        encoder = Encoder(sp)
        t = Tatabahasa(d)
        return {'inputs': encoder, 'targets': encoder, 'targets_error_tag': t}

    def hparams(self, defaults, model_hparams):
        super(Grammar, self).hparams(defaults, model_hparams)
        if 'use_error_tags' not in model_hparams:
            model_hparams.add_hparam('use_error_tags', True)
        if 'middle_prediction' not in model_hparams:
            model_hparams.add_hparam('middle_prediction', False)
        if 'middle_prediction_layer_factor' not in model_hparams:
            model_hparams.add_hparam('middle_prediction_layer_factor', 2)
        if 'ffn_in_prediction_cascade' not in model_hparams:
            model_hparams.add_hparam('ffn_in_prediction_cascade', 1)
        if 'error_tag_embed_size' not in model_hparams:
            model_hparams.add_hparam('error_tag_embed_size', 12)
        if model_hparams.use_error_tags:
            defaults.modality[
                'targets_error_tag'
            ] = modalities.ModalityType.SYMBOL
            error_tag_vocab_size = self._encoders[
                'targets_error_tag'
            ].vocab_size
            defaults.vocab_size['targets_error_tag'] = error_tag_vocab_size

    def example_reading_spec(self):
        data_fields, _ = super(Grammar, self).example_reading_spec()
        data_fields['targets_error_tag'] = tf.VarLenFeature(tf.int64)
        return data_fields, None

    @property
    def approx_vocab_size(self):
        return 32100

    @property
    def is_generate_per_split(self):
        return False

    @property
    def dataset_splits(self):
        return [
            {'split': problem.DatasetSplit.TRAIN, 'shards': 200},
            {'split': problem.DatasetSplit.EVAL, 'shards': 1},
        ]

In [6]:
DATA_DIR = os.path.expanduser('t2t-tatabahasa/data')
TMP_DIR = os.path.expanduser('t2t-tatabahasa/tmp')
TRAIN_DIR = os.path.expanduser('t2t-tatabahasa/train-small')

In [7]:
PROBLEM = 'grammar'
t2t_problem = problems.problem(PROBLEM)

In [8]:
MODEL = 'transformer_tag'
HPARAMS = 'transformer_base'

In [9]:
from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment
from tensor2tensor.utils.trainer_lib import create_hparams
from tensor2tensor.utils import registry
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.utils import trainer_lib

In [10]:
X = tf.placeholder(tf.int32, [None, None], name = 'x_placeholder')
Y = tf.placeholder(tf.int32, [None, None], name = 'y_placeholder')
targets_error_tag = tf.placeholder(tf.int32, [None, None], 'error_placeholder')
X_seq_len = tf.count_nonzero(X, 1, dtype=tf.int32)
maxlen_decode = tf.reduce_max(X_seq_len)

x = tf.expand_dims(tf.expand_dims(X, -1), -1)
y = tf.expand_dims(tf.expand_dims(Y, -1), -1)
targets_error_tag_ = tf.expand_dims(tf.expand_dims(targets_error_tag, -1), -1)

features = {
    "inputs": x,
    "targets": y,
    "target_space_id": tf.constant(1, dtype=tf.int32),
    'targets_error_tag': targets_error_tag,
}
Modes = tf.estimator.ModeKeys
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)

Instructions for updating:
reduction_indices is deprecated, use axis instead


Instructions for updating:
reduction_indices is deprecated, use axis instead


In [11]:
hparams.filter_size = 2048
hparams.hidden_size = 512
hparams.num_heads = 8
hparams.num_hidden_layers = 6
hparams.vocab_divisor = 128
hparams.dropout = 0.1
hparams.max_length = 256

# LM
hparams.label_smoothing = 0.0
hparams.shared_embedding_and_softmax_weights = False
hparams.eval_drop_long_sequences = True
hparams.max_length = 256
hparams.multiproblem_mixing_schedule = 'pretrain'

# tpu
hparams.symbol_modality_num_shards = 1
hparams.attention_dropout_broadcast_dims = '0,1'
hparams.relu_dropout_broadcast_dims = '1'
hparams.layer_prepostprocess_dropout_broadcast_dims = '1'

In [12]:
model = registry.model(MODEL)(hparams, Modes.PREDICT)

INFO:tensorflow:Setting T2TModel mode to 'infer'


INFO:tensorflow:Setting T2TModel mode to 'infer'


INFO:tensorflow:Setting hparams.dropout to 0.0


INFO:tensorflow:Setting hparams.dropout to 0.0


INFO:tensorflow:Setting hparams.label_smoothing to 0.0


INFO:tensorflow:Setting hparams.label_smoothing to 0.0


INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0


INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0


INFO:tensorflow:Setting hparams.symbol_dropout to 0.0


INFO:tensorflow:Setting hparams.symbol_dropout to 0.0


INFO:tensorflow:Setting hparams.attention_dropout to 0.0


INFO:tensorflow:Setting hparams.attention_dropout to 0.0


INFO:tensorflow:Setting hparams.relu_dropout to 0.0


INFO:tensorflow:Setting hparams.relu_dropout to 0.0


In [13]:
# logits = model(features)
# logits

# sess = tf.InteractiveSession()
# sess.run(tf.global_variables_initializer())
# l = sess.run(logits, feed_dict = {X: [[10,10, 10, 10,10,1],[10,10, 10, 10,10,1]],
#                              Y: [[10,10, 10, 10,10,1],[10,10, 10, 10,10,1]],
#                              targets_error_tag: [[10,10, 10, 10,10,1],
#                                                 [10,10, 10, 10,10,1]]})

In [14]:
features = {
    "inputs": x,
    "target_space_id": tf.constant(1, dtype=tf.int32),
}

with tf.variable_scope(tf.get_variable_scope(), reuse = False):
    fast_result = model._greedy_infer(features, maxlen_decode)

Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Use `tf.cast` instead.


In [15]:
result_seq = tf.identity(fast_result['outputs'], name = 'greedy')
result_tag = tf.identity(fast_result['outputs_tag'], name = 'tag_greedy')

In [16]:
from tensor2tensor.layers import common_layers

def accuracy_per_sequence(predictions, targets, weights_fn = common_layers.weights_nonzero):
    padded_predictions, padded_labels = common_layers.pad_with_zeros(predictions, targets)
    weights = weights_fn(padded_labels)
    padded_labels = tf.to_int32(padded_labels)
    padded_predictions = tf.to_int32(padded_predictions)
    not_correct = tf.to_float(tf.not_equal(padded_predictions, padded_labels)) * weights
    axis = list(range(1, len(padded_predictions.get_shape())))
    correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
    return tf.reduce_mean(correct_seq)

def padded_accuracy(predictions, targets, weights_fn = common_layers.weights_nonzero):
    padded_predictions, padded_labels = common_layers.pad_with_zeros(predictions, targets)
    weights = weights_fn(padded_labels)
    padded_labels = tf.to_int32(padded_labels)
    padded_predictions = tf.to_int32(padded_predictions)
    n = tf.to_float(tf.equal(padded_predictions, padded_labels)) * weights
    d = tf.reduce_sum(weights)
    return tf.reduce_sum(n) / d

In [17]:
acc_seq = padded_accuracy(result_seq, Y)
acc_tag = padded_accuracy(result_tag, targets_error_tag)

In [18]:
ckpt_path = tf.train.latest_checkpoint(os.path.join(TRAIN_DIR))
ckpt_path

't2t-tatabahasa/train-small/model.ckpt-140000'

In [19]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [20]:
var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
saver = tf.train.Saver(var_list = var_lists)
saver.restore(sess, ckpt_path)

INFO:tensorflow:Restoring parameters from t2t-tatabahasa/train-small/model.ckpt-140000


INFO:tensorflow:Restoring parameters from t2t-tatabahasa/train-small/model.ckpt-140000


In [21]:
import pickle

with open('../pure-text/dataset-tatabahasa.pkl', 'rb') as fopen:
    data = pickle.load(fopen)

encoder = Encoder(sp)

In [22]:
def get_xy(row, encoder):
    x, y, tag = [], [], []

    for i in range(len(row[0])):
        t = encoder.encode(row[0][i][0])
        y.extend(t)
        t = encoder.encode(row[1][i][0])
        x.extend(t)
        tag.extend([row[1][i][1]] * len(t))

    # EOS
    x.append(1)
    y.append(1)
    tag.append(0)

    return x, y, tag

In [23]:
import numpy as np

In [24]:
x, y, tag = get_xy(data[10], encoder)

In [25]:
e = encoder.encode('Pilih mana jurusan yang sesuai dengan kebolehan anda dalam peperiksaan Sijil Pelajaran Malaysia semasa memohon kemasukan ke institusi pengajian tinggi.') + [1]

In [26]:
r = sess.run(fast_result, 
         feed_dict = {X: [e]})

In [27]:
r['outputs_tag']

array([[ 2,  4,  4,  4,  2,  2,  2,  2,  2, 11,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  0]])

In [28]:
encoder.decode(r['outputs'][0].tolist())

'Pilih mana-mana jurusan yang sesuai dengan kebolehan ini dalam peperiksaan Sijil Pelajaran Malaysia semasa memohon kemasukan ke institusi pengajian tinggi.'

In [29]:
encoder.decode(x)

'Marta Vieira da Silva ( lahir 19 Februari 1986 ) yang biasanya dikenali sebagai Marta merupakan seorang bola pemain sepak Brazil yang main laksanabagai penyerang posisi hingga kelab Liga Bola Sepak Wanita Nasional , Orlando Pride sungguhpun juga pasukan sepak kebangsaan bola wanita Brazil .'

In [30]:
encoder.decode(y)

'Marta Vieira da Silva ( lahir 19 Februari 1986 ) yang biasanya dikenali sebagai Marta merupakan seorang pemain bola sepak Brazil yang bermain dalam posisi penyerang untuk kelab Liga Bola Sepak Wanita Nasional , Orlando Pride dan juga pasukan bola sepak kebangsaan wanita Brazil .'

In [31]:
hparams.problem.example_reading_spec()[0]

{'targets': VarLenFeature(dtype=tf.int64),
 'inputs': VarLenFeature(dtype=tf.int64),
 'targets_error_tag': VarLenFeature(dtype=tf.int64)}

In [32]:
def parse(serialized_example):

    data_fields = hparams.problem.example_reading_spec()[0]
    features = tf.parse_single_example(
        serialized_example, features = data_fields
    )
    for k in features.keys():
        features[k] = features[k].values

    return features

In [33]:
dataset = tf.data.TFRecordDataset('t2t-tatabahasa/data/grammar-dev-00000-of-00001')
dataset = dataset.map(parse, num_parallel_calls=32)
dataset = dataset.padded_batch(32, 
    padded_shapes = {
    'inputs': tf.TensorShape([None]),
    'targets': tf.TensorShape([None]),
    'targets_error_tag': tf.TensorShape([None])
    },
    padding_values = {
        'inputs': tf.constant(0, dtype = tf.int64),
        'targets': tf.constant(0, dtype = tf.int64),
        'targets_error_tag': tf.constant(0, dtype = tf.int64),
    })
dataset = dataset.make_one_shot_iterator().get_next()
dataset







Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


{'inputs': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int64>,
 'targets': <tf.Tensor 'IteratorGetNext:1' shape=(?, ?) dtype=int64>,
 'targets_error_tag': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?) dtype=int64>}

In [34]:
seqs, tags = [], []
index = 0
while True:
    try:
        d = sess.run(dataset)
        s, t = sess.run([acc_seq, acc_tag], feed_dict = {X:d['inputs'], 
                                              Y: d['targets'], 
                                              targets_error_tag: d['targets_error_tag']})
        seqs.append(s)
        tags.append(t)
        print(f'done {index}')
        index += 1
    except:
        break

done 0
done 1
done 2
done 3
done 4
done 5
done 6
done 7
done 8
done 9
done 10
done 11
done 12
done 13
done 14
done 15
done 16
done 17
done 18
done 19
done 20
done 21
done 22
done 23
done 24
done 25
done 26
done 27
done 28
done 29
done 30
done 31
done 32
done 33
done 34
done 35
done 36
done 37
done 38
done 39
done 40
done 41
done 42
done 43
done 44
done 45
done 46
done 47
done 48
done 49
done 50
done 51
done 52
done 53
done 54
done 55
done 56
done 57
done 58
done 59
done 60
done 61
done 62
done 63
done 64
done 65
done 66
done 67
done 68
done 69
done 70
done 71
done 72
done 73
done 74
done 75
done 76
done 77
done 78
done 79
done 80
done 81
done 82
done 83
done 84
done 85
done 86
done 87
done 88
done 89
done 90
done 91
done 92
done 93
done 94
done 95
done 96
done 97
done 98
done 99
done 100
done 101
done 102
done 103
done 104
done 105
done 106
done 107
done 108
done 109
done 110
done 111
done 112
done 113
done 114
done 115
done 116
done 117
done 118
done 119
done 120
done 121
done 122
don

In [35]:
np.mean(seqs), np.mean(tags)

(0.860198, 0.96326745)

In [36]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'transformertag-small/model.ckpt')

'transformertag-small/model.ckpt'

In [38]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'greedy' in n.name
        or 'tag_greedy' in n.name
        or 'x_placeholder' in n.name
        or 'self/Softmax' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'modality' not in n.name
        and 'Assign' not in n.name
    ]
)
strings.split(',')

['x_placeholder',
 'transformer_tag/body/target_space_embedding/kernel',
 'transformer_tag/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale',
 'transformer_tag/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias',
 'transformer_tag/body/encoder/layer_0/self_attention/multihead_attention/q/kernel',
 'transformer_tag/body/encoder/layer_0/self_attention/multihead_attention/k/kernel',
 'transformer_tag/body/encoder/layer_0/self_attention/multihead_attention/v/kernel',
 'transformer_tag/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel',
 'transformer_tag/body/encoder/layer_0/ffn/layer_prepostprocess/layer_norm/layer_norm_scale',
 'transformer_tag/body/encoder/layer_0/ffn/layer_prepostprocess/layer_norm/layer_norm_bias',
 'transformer_tag/body/encoder/layer_0/ffn/conv1/kernel',
 'transformer_tag/body/encoder/layer_0/ffn/conv1/bias',
 'transformer_tag/body/encoder/layer_0/ffn/conv2/kernel',
 'tr

In [39]:
def freeze_graph(model_dir, output_node_names):

    if not tf.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [40]:
freeze_graph('transformertag-small', strings)

INFO:tensorflow:Restoring parameters from transformertag-small/model.ckpt


INFO:tensorflow:Restoring parameters from transformertag-small/model.ckpt


Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`


Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


INFO:tensorflow:Froze 209 variables.


INFO:tensorflow:Froze 209 variables.


INFO:tensorflow:Converted 209 variables to const ops.


INFO:tensorflow:Converted 209 variables to const ops.


5777 ops in the final graph.


In [41]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph

In [43]:
g = load_graph('transformertag-small/frozen_model.pb')
x = g.get_tensor_by_name('import/x_placeholder:0')
greedy = g.get_tensor_by_name('import/greedy:0')
tag_greedy = g.get_tensor_by_name('import/tag_greedy:0')
test_sess = tf.InteractiveSession(graph = g)



In [44]:
test_sess.run([greedy, tag_greedy], feed_dict = {x:d['inputs']})

[array([[ 4881, 27510,   158, ...,  2202,  5767,    15],
        [  720,    17,   130, ..., 13790, 20724,    13],
        [27130,     7, 29076, ...,     3,     1,    15],
        ...,
        [16256,  5222,    36, ...,     3,     1,    15],
        [ 1151,   787,    27, ...,    15,     3,     1],
        [  104,    89,  3502, ...,     1,     1,     1]]),
 array([[ 2,  3,  3, ...,  2,  2, 14],
        [ 2,  2,  2, ...,  2,  2,  2],
        [ 2,  2,  2, ...,  2,  0,  2],
        ...,
        [ 2,  2,  2, ...,  2,  0,  2],
        [ 2,  2,  2, ...,  2,  2,  0],
        [ 2,  2,  2, ...,  0,  0,  0]])]

In [45]:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
from glob import glob
tf.set_random_seed(0)

In [46]:
import tensorflow_text
import tf_sentencepiece

In [47]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_constants(ignore_errors=true)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

pb = 'transformertag-small/frozen_model.pb'
input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())
    
transformed_graph_def = TransformGraph(input_graph_def, 
                                       ['x_placeholder'],
                                       ['greedy', 'tag_greedy'], transforms)

with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


Instructions for updating:
Use tf.gfile.GFile.


In [48]:
g = load_graph('transformertag-small/frozen_model.pb.quantized')
x = g.get_tensor_by_name('import/x_placeholder:0')
greedy = g.get_tensor_by_name('import/greedy:0')
tag_greedy = g.get_tensor_by_name('import/tag_greedy:0')
test_sess = tf.InteractiveSession(graph = g)

In [49]:
test_sess.run([greedy, tag_greedy], feed_dict = {x:d['inputs']})

[array([[ 4881, 27510,   158, ..., 27302,    15,     3],
        [  720,    17,   130, ...,     1,    15,     3],
        [27130,     7, 29076, ...,     3,     1,    15],
        ...,
        [16256,  5222,    36, ...,     3,     1,    15],
        [ 1151,   787,    27, ...,     3,     1,    15],
        [  104,    89,  3502, ...,     1,     1,     1]]),
 array([[ 2,  3,  3, ...,  2, 14, 14],
        [ 2,  2,  2, ...,  0,  2,  2],
        [ 2,  2,  2, ...,  2,  0,  2],
        ...,
        [ 2,  2,  2, ...,  2,  0,  2],
        [ 2,  2,  2, ...,  2,  0,  2],
        [ 2,  2,  2, ...,  0,  0,  0]])]