In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import tensorflow as tf
import numpy as np
from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding
from fast_transformer import FastTransformer

In [3]:
from malaya.text.bpe import WordPieceTokenizer

In [4]:
tokenizer = WordPieceTokenizer('BERT.wordpiece', do_lower_case = False)
# tokenizer.tokenize('halo nama sayacomel')

In [5]:
import pickle

with open('ontonotes5-fastformer.pkl', 'rb') as fopen:
    train_X, train_Y, test_X, test_Y = pickle.load(fopen)

In [6]:
epoch = 10
batch_size = 32
warmup_proportion = 0.1
num_train_steps = int(len(train_X) / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)

In [7]:
import optimization




In [10]:
def create_initializer(initializer_range=0.02):
    return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)

class Model:
    def __init__(
        self,
        dimension_output,
        learning_rate = 2e-5,
        training = True,
    ):
        self.X = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        mask = tf.compat.v1.math.not_equal(self.X, 0)
        mask = tf.compat.v1.cast(mask, tf.compat.v1.bool)
        self.Y = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.maxlen = tf.compat.v1.shape(self.X)[1]
        self.lengths = tf.compat.v1.count_nonzero(self.X, 1)
        
        self.model = FastTransformer(
            num_tokens = 32000,
            dim = 336,
            depth = 4,
            heads = 12,
            max_seq_len = 2048,
            absolute_pos_emb = True,
            mask = mask
        )
        self.logits = self.model(self.X)[0]
        logits = tf.compat.v1.layers.dense(self.logits, dimension_output,
                                         kernel_initializer=create_initializer())
        
        y_t = self.Y
        log_likelihood, transition_params = tf.compat.v1.contrib.crf.crf_log_likelihood(
            logits, y_t, self.lengths
        )
        self.cost = tf.compat.v1.reduce_mean(-log_likelihood)
        self.optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        mask = tf.compat.v1.sequence_mask(self.lengths, maxlen = self.maxlen)
        self.tags_seq, tags_score = tf.compat.v1.contrib.crf.crf_decode(
            logits, transition_params, self.lengths
        )
        self.tags_seq = tf.compat.v1.identity(self.tags_seq, name = 'logits')

        y_t = tf.compat.v1.cast(y_t, tf.compat.v1.int32)
        self.prediction = tf.compat.v1.boolean_mask(self.tags_seq, mask)
        mask_label = tf.compat.v1.boolean_mask(y_t, mask)
        correct_pred = tf.compat.v1.equal(self.prediction, mask_label)
        correct_index = tf.compat.v1.cast(correct_pred, tf.compat.v1.float32)
        self.accuracy = tf.compat.v1.reduce_mean(tf.compat.v1.cast(correct_pred, tf.compat.v1.float32))

In [11]:
d = [
    {'Tag': 'OTHER', 'Description': 'other'},
    {'Tag': 'ADDRESS', 'Description': 'Address of physical location.'},
    {'Tag': 'PERSON', 'Description': 'People, including fictional.'},
    {
        'Tag': 'NORP',
        'Description': 'Nationalities or religious or political groups.',
    },
    {
        'Tag': 'FAC',
        'Description': 'Buildings, airports, highways, bridges, etc.',
    },
    {
        'Tag': 'ORG',
        'Description': 'Companies, agencies, institutions, etc.',
    },
    {'Tag': 'GPE', 'Description': 'Countries, cities, states.'},
    {
        'Tag': 'LOC',
        'Description': 'Non-GPE locations, mountain ranges, bodies of water.',
    },
    {
        'Tag': 'PRODUCT',
        'Description': 'Objects, vehicles, foods, etc. (Not services.)',
    },
    {
        'Tag': 'EVENT',
        'Description': 'Named hurricanes, battles, wars, sports events, etc.',
    },
    {'Tag': 'WORK_OF_ART', 'Description': 'Titles of books, songs, etc.'},
    {'Tag': 'LAW', 'Description': 'Named documents made into laws.'},
    {'Tag': 'LANGUAGE', 'Description': 'Any named language.'},
    {
        'Tag': 'DATE',
        'Description': 'Absolute or relative dates or periods.',
    },
    {'Tag': 'TIME', 'Description': 'Times smaller than a day.'},
    {'Tag': 'PERCENT', 'Description': 'Percentage, including "%".'},
    {'Tag': 'MONEY', 'Description': 'Monetary values, including unit.'},
    {
        'Tag': 'QUANTITY',
        'Description': 'Measurements, as of weight or distance.',
    },
    {'Tag': 'ORDINAL', 'Description': '"first", "second", etc.'},
    {
        'Tag': 'CARDINAL',
        'Description': 'Numerals that do not fall under another type.',
    },
]
d = [d['Tag'] for d in d]
d = ['PAD', 'X'] + d
tag2idx = {i: no for no, i in enumerate(d)}
idx2tag = {no: i for no, i in enumerate(d)}

In [12]:
dimension_output = len(tag2idx)
learning_rate = 2e-5

tf.compat.v1.reset_default_graph()
sess = tf.compat.v1.InteractiveSession()
model = Model(
    dimension_output,
    learning_rate
)

sess.run(tf.compat.v1.global_variables_initializer())
var_lists = tf.compat.v1.trainable_variables()

Instructions for updating:
reduction_indices is deprecated, use axis instead
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.compat.v1.where in 2.0, which has the same broadcast rule as np.where
Tensor("fast_transformer/pre_norm/fast_attention/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm/fast_attention/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_2/fast_attention_1/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_2/fast_attention_1/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_4/fast_attention_2/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_4/fast_attention_2/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_6/fast_atten

In [13]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.compat.v1.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name_to_variable[name]
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [14]:
tvars = tf.compat.v1.trainable_variables()
checkpoint = 'fastformer-tiny/model.ckpt-500000'
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 
                                                                                checkpoint)

In [15]:
saver = tf.compat.v1.train.Saver(var_list = assignment_map)
saver.restore(sess, checkpoint)

INFO:tensorflow:Restoring parameters from fastformer-tiny/model.ckpt-500000


In [16]:
pad_sequences = tf.compat.v1.keras.preprocessing.sequence.pad_sequences

In [17]:
string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'

import re

def entities_textcleaning(string, lowering = False):
    """
    use by entities recognition, pos recognition and dependency parsing
    """
    string = re.sub('[^A-Za-z0-9\-\/() ]+', ' ', string)
    string = re.sub(r'[ ]+', ' ', string).strip()
    original_string = string.split()
    if lowering:
        string = string.lower()
    string = [
        (original_string[no], word.title() if word.isupper() else word)
        for no, word in enumerate(string.split())
        if len(word)
    ]
    return [s[0] for s in string], [s[1] for s in string]

def parse_X(left):
    bert_tokens = ['[CLS]']
    for no, orig_token in enumerate(left):
        t = tokenizer.tokenize(orig_token)
        bert_tokens.extend(t)
    bert_tokens.append("[SEP]")
    input_mask = [1] * len(bert_tokens)
    return tokenizer.convert_tokens_to_ids(bert_tokens), bert_tokens, input_mask

sequence = entities_textcleaning(string)[1]
parsed_sequence, bert_sequence, input_mask = parse_X(sequence)

In [18]:
predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                },
        )[0]

In [19]:
def merge_wordpiece_tokens_tagging(x, y):
    new_paired_tokens = []
    n_tokens = len(x)
    rejected = ['[CLS]', '[SEP]', '[PAD]']

    i = 0

    while i < n_tokens:
        current_token, current_label = x[i], y[i]
        if current_token.startswith('##'):
            previous_token, previous_label = new_paired_tokens.pop()
            merged_token = previous_token
            merged_label = [previous_label]
            while current_token.startswith('##'):
                merged_token = merged_token + current_token.replace('##', '')
                merged_label.append(current_label)
                i = i + 1
                current_token, current_label = x[i], y[i]
            merged_label = merged_label[0]
            new_paired_tokens.append((merged_token, merged_label))

        else:
            new_paired_tokens.append((current_token, current_label))
            i = i + 1

    words = [
        i[0]
        for i in new_paired_tokens
        if i[0] not in rejected
    ]
    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]
    return words, labels

In [20]:
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
list(zip(merged[0], merged[1]))

[('Kuala', 'CARDINAL'),
 ('Lumpur', 'CARDINAL'),
 ('Sempena', 'CARDINAL'),
 ('sambutan', 'CARDINAL'),
 ('Aidilfitri', 'NORP'),
 ('minggu', 'FAC'),
 ('depan', 'EVENT'),
 ('Perdana', 'ORG'),
 ('Menteri', 'ORG'),
 ('Tun', 'ORG'),
 ('Dr', 'CARDINAL'),
 ('Mahathir', 'FAC'),
 ('Mohamad', 'PRODUCT'),
 ('dan', 'PRODUCT'),
 ('Menteri', 'ADDRESS'),
 ('Pengangkutan', 'LOC'),
 ('Anthony', 'FAC'),
 ('Loke', 'CARDINAL'),
 ('Siew', 'CARDINAL'),
 ('Fook', 'MONEY'),
 ('menitipkan', 'ORDINAL'),
 ('pesanan', 'PERSON'),
 ('khas', 'FAC'),
 ('kepada', 'CARDINAL'),
 ('orang', 'FAC'),
 ('ramai', 'PERSON'),
 ('yang', 'FAC'),
 ('mahu', 'PERSON'),
 ('pulang', 'FAC'),
 ('ke', 'PERSON'),
 ('kampung', 'FAC'),
 ('halaman', 'PERSON'),
 ('masing', 'FAC'),
 ('-', 'ADDRESS'),
 ('masing', 'FAC'),
 ('Dalam', 'CARDINAL'),
 ('video', 'QUANTITY'),
 ('pendek', 'CARDINAL'),
 ('terbitan', 'QUANTITY'),
 ('Jabatan', 'QUANTITY'),
 ('Keselamatan', 'CARDINAL'),
 ('Jalan', 'CARDINAL'),
 ('Raya', 'CARDINAL'),
 ('(', 'CARDINAL'),
 ('Jk

In [34]:
from tqdm import tqdm
import time

for e in range(2):
    lasttime = time.time()
    train_acc, train_loss, test_acc, test_loss = [], [], [], []
    pbar = tqdm(
        range(0, len(train_X), batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = train_X[i : index]
        batch_y = train_Y[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
            },
        )
        assert not np.isnan(cost)
        train_loss.append(cost)
        train_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    pbar = tqdm(
        range(0, len(test_X), batch_size), desc = 'test minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = test_X[i : index]
        batch_y = test_Y[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
            },
        )
        assert not np.isnan(cost)
        test_loss.append(cost)
        test_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    train_loss = np.mean(train_loss)
    train_acc = np.mean(train_acc)
    test_loss = np.mean(test_loss)
    test_acc = np.mean(test_acc)

    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (e, train_loss, train_acc, test_loss, test_acc)
    )
    predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                },
        )[0]
    merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
    print(list(zip(merged[0], merged[1])))

train minibatch loop: 100%|██████████| 17853/17853 [50:24<00:00,  5.90it/s, accuracy=0.999, cost=0.587]
test minibatch loop: 100%|██████████| 2464/2464 [05:15<00:00,  7.82it/s, accuracy=0.943, cost=21.3]
train minibatch loop:   0%|          | 0/17853 [00:00<?, ?it/s]

time taken: 3340.075399160385
epoch: 0, training loss: 7.648261, training acc: 0.964309, valid loss: 37.886143, valid acc: 0.924723

[('Kuala', 'PERSON'), ('Lumpur', 'PERSON'), ('Sempena', 'PERSON'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'PERSON'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'ADDRESS'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('J

train minibatch loop:   8%|▊         | 1428/17853 [03:43<42:53,  6.38it/s, accuracy=0.936, cost=18.3]


KeyboardInterrupt: 

In [22]:
predicted = sess.run(model.tags_seq,
            feed_dict = {
                model.X: [parsed_sequence],
            },
    )[0]
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
print(list(zip(merged[0], merged[1])))

[('Kuala', 'OTHER'), ('Lumpur', 'PERSON'), ('Sempena', 'OTHER'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'DATE'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'PERSON'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('Jalan', 'X'), ('Raya', 'ORG'), ('(', 'X'), ('Jkjr', 'PERSON'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('me

In [23]:
saver = tf.compat.v1.train.Saver(tf.compat.v1.trainable_variables())
saver.save(sess, 'fastformer-tiny-ontonotes5/model.ckpt')

'fastformer-tiny-ontonotes5/model.ckpt'

In [24]:
def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            out_i.append(idx2tag[p])
        out.append(out_i)
    return out

In [25]:
real_Y, predict_Y = [], []

pbar = tqdm(
    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'
)
for i in pbar:
    index = min(i + batch_size, len(test_X))
    batch_x = test_X[i : index]
    batch_y = test_Y[i : index]
    batch_x = pad_sequences(batch_x, padding='post')
    batch_y = pad_sequences(batch_y, padding='post')
    predicted = pred2label(sess.run(model.tags_seq,
            feed_dict = {
                model.X: batch_x,
            },
    ))
    real = pred2label(batch_y)
    predict_Y.extend(predicted)
    real_Y.extend(real)

validation minibatch loop: 100%|██████████| 2464/2464 [05:22<00:00,  7.65it/s]


In [26]:
temp_real_Y = []
for r in real_Y:
    temp_real_Y.extend(r)
    
temp_predict_Y = []
for r in predict_Y:
    temp_predict_Y.extend(r)

In [27]:
len(temp_real_Y), len(temp_predict_Y)

(6926122, 6926122)

In [28]:
from sklearn.metrics import classification_report
print(classification_report(temp_real_Y, temp_predict_Y, digits = 5))

              precision    recall  f1-score   support

     ADDRESS    0.39247   0.99574   0.56303     93446
    CARDINAL    0.71784   0.63014   0.67114     43087
        DATE    0.81493   0.75469   0.78365    111322
       EVENT    0.65659   0.57140   0.61104      5434
         FAC    0.46067   0.15601   0.23308     26691
         GPE    0.71550   0.67235   0.69326     92188
    LANGUAGE    0.61130   0.46999   0.53141       783
         LAW    0.66443   0.04905   0.09136     24261
         LOC    0.63321   0.56149   0.59520     33418
       MONEY    0.68559   0.65366   0.66925     25790
        NORP    0.63896   0.35582   0.45710     54553
     ORDINAL    0.62671   0.69155   0.65754      5693
         ORG    0.76243   0.57854   0.65788    194233
       OTHER    0.96120   0.96728   0.96423   3180488
         PAD    1.00000   1.00000   1.00000   1332067
     PERCENT    0.79587   0.74474   0.76945     18150
      PERSON    0.75476   0.71702   0.73541     92888
     PRODUCT    0.36864   0

In [29]:
strings = ','.join(
    [
        n.name
        for n in tf.compat.v1.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name
        or 'self/Softmax' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'ReadVariableOp' not in n.name
        and 'AssignVariableOp' not in n.name
        and '/Assign' not in n.name
        and '/Adam' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'dense/kernel',
 'dense/bias',
 'transitions',
 'logits']

In [30]:
def freeze_graph(model_dir, output_node_names):

    if not tf.compat.v1.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.compat.v1.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.compat.v1.Session(graph = tf.compat.v1.Graph()) as sess:
        saver = tf.compat.v1.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess,
            tf.compat.v1.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.compat.v1.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [31]:
freeze_graph('fastformer-tiny-ontonotes5', strings)

INFO:tensorflow:Restoring parameters from fastformer-tiny-ontonotes5/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 59 variables.
INFO:tensorflow:Converted 59 variables to const ops.
2592 ops in the final graph.


In [32]:
def load_graph(frozen_graph_filename):
    with tf.compat.v1.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)
    return graph

In [33]:
g = load_graph('fastformer-tiny-ontonotes5/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.compat.v1.InteractiveSession(graph = g)

In [33]:
%%time

predicted = test_sess.run(logits,
            feed_dict = {
                x: [parsed_sequence],
            },
    )[0]
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
print(list(zip(merged[0], merged[1])))

[('Kuala', 'NORP'), ('Lumpur', 'PERSON'), ('Sempena', 'PERSON'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'FAC'), ('Fook', 'OTHER'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ADDRESS'), ('Keselamatan', 'ORG'), ('Jalan', 'ADDRESS'), ('Raya', 'ADDRESS'), ('(', 'ADDRESS'), ('Jkjr', 'PERSON'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'PERSON'), ('Mahathir'

In [35]:
from tensorflow.tools.graph_transforms import TransformGraph

In [36]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

input_nodes = [
    'Placeholder',
]
output_nodes = [
    'logits',
]

pb = 'fastformer-tiny-ontonotes5/frozen_model.pb'

input_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

transformed_graph_def = TransformGraph(input_graph_def, 
                                           input_nodes,
                                           output_nodes, transforms)
    
with tf.compat.v1.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.compat.v1.gfile.GFile.


In [37]:
# g = load_graph('fastformer-tiny-ontonotes5/frozen_model.pb.quantized')
# x = g.get_tensor_by_name('import/Placeholder:0')
# logits = g.get_tensor_by_name('import/logits:0')
# test_sess = tf.compat.v1.InteractiveSession(graph = g)

In [38]:
# %%time

# predicted = test_sess.run(logits,
#             feed_dict = {
#                 x: [parsed_sequence],
#             },
#     )[0]
# merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
# print(list(zip(merged[0], merged[1])))

In [40]:
file = 'fastformer-tiny-ontonotes5/frozen_model.pb'
outPutname = 'entity-ontonotes5/tiny-fastformer/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

<b2sdk.file_version.FileVersionInfo at 0x7f51941fd7f0>

In [41]:
file = 'fastformer-tiny-ontonotes5/frozen_model.pb.quantized'
outPutname = 'entity-ontonotes5/tiny-fastformer-quantized/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

<b2sdk.file_version.FileVersionInfo at 0x7f51443b2d30>