In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import tensorflow as tf
import numpy as np
from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding
from fast_transformer import FastTransformer

In [3]:
from malaya.text.bpe import WordPieceTokenizer

In [4]:
tokenizer = WordPieceTokenizer('BERT.wordpiece', do_lower_case = False)
# tokenizer.tokenize('halo nama sayacomel')

In [5]:
import pickle

with open('ontonotes5-fastformer.pkl', 'rb') as fopen:
    train_X, train_Y, test_X, test_Y = pickle.load(fopen)

In [6]:
epoch = 10
batch_size = 32
warmup_proportion = 0.1
num_train_steps = int(len(train_X) / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)

In [7]:
import optimization




In [8]:
def create_initializer(initializer_range=0.02):
    return tf.truncated_normal_initializer(stddev=initializer_range)

class Model:
    def __init__(
        self,
        dimension_output,
        learning_rate = 2e-5,
        training = True,
    ):
        self.X = tf.compat.v1.placeholder(tf.int32, [None, None])
        mask = tf.math.not_equal(self.X, 0)
        mask = tf.cast(mask, tf.bool)
        self.Y = tf.compat.v1.placeholder(tf.int32, [None, None])
        self.maxlen = tf.shape(self.X)[1]
        self.lengths = tf.count_nonzero(self.X, 1)
        
        self.model = FastTransformer(
            num_tokens = 32000,
            dim = 768,
            depth = 12,
            heads = 12,
            max_seq_len = 2048,
            absolute_pos_emb = True,
            mask = mask
        )
        self.logits = self.model(self.X)[0]
        logits = tf.layers.dense(self.logits, dimension_output,
                                         kernel_initializer=create_initializer())
        
        y_t = self.Y
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
            logits, y_t, self.lengths
        )
        self.cost = tf.reduce_mean(-log_likelihood)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)
        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(
            logits, transition_params, self.lengths
        )
        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')

        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(self.tags_seq, mask)
        mask_label = tf.boolean_mask(y_t, mask)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [9]:
d = [
    {'Tag': 'OTHER', 'Description': 'other'},
    {'Tag': 'ADDRESS', 'Description': 'Address of physical location.'},
    {'Tag': 'PERSON', 'Description': 'People, including fictional.'},
    {
        'Tag': 'NORP',
        'Description': 'Nationalities or religious or political groups.',
    },
    {
        'Tag': 'FAC',
        'Description': 'Buildings, airports, highways, bridges, etc.',
    },
    {
        'Tag': 'ORG',
        'Description': 'Companies, agencies, institutions, etc.',
    },
    {'Tag': 'GPE', 'Description': 'Countries, cities, states.'},
    {
        'Tag': 'LOC',
        'Description': 'Non-GPE locations, mountain ranges, bodies of water.',
    },
    {
        'Tag': 'PRODUCT',
        'Description': 'Objects, vehicles, foods, etc. (Not services.)',
    },
    {
        'Tag': 'EVENT',
        'Description': 'Named hurricanes, battles, wars, sports events, etc.',
    },
    {'Tag': 'WORK_OF_ART', 'Description': 'Titles of books, songs, etc.'},
    {'Tag': 'LAW', 'Description': 'Named documents made into laws.'},
    {'Tag': 'LANGUAGE', 'Description': 'Any named language.'},
    {
        'Tag': 'DATE',
        'Description': 'Absolute or relative dates or periods.',
    },
    {'Tag': 'TIME', 'Description': 'Times smaller than a day.'},
    {'Tag': 'PERCENT', 'Description': 'Percentage, including "%".'},
    {'Tag': 'MONEY', 'Description': 'Monetary values, including unit.'},
    {
        'Tag': 'QUANTITY',
        'Description': 'Measurements, as of weight or distance.',
    },
    {'Tag': 'ORDINAL', 'Description': '"first", "second", etc.'},
    {
        'Tag': 'CARDINAL',
        'Description': 'Numerals that do not fall under another type.',
    },
]
d = [d['Tag'] for d in d]
d = ['PAD', 'X'] + d
tag2idx = {i: no for no, i in enumerate(d)}
idx2tag = {no: i for no, i in enumerate(d)}

In [10]:
dimension_output = len(tag2idx)
learning_rate = 2e-5

tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model(
    dimension_output,
    learning_rate
)

sess.run(tf.global_variables_initializer())
var_lists = tf.trainable_variables()

Instructions for updating:
reduction_indices is deprecated, use axis instead
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Tensor("fast_transformer/pre_norm/fast_attention/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm/fast_attention/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_2/fast_attention_1/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_2/fast_attention_1/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_4/fast_attention_2/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_4/fast_attention_2/Select:0", shape=(?, 12, ?), dtype=float32)
Tensor("fast_transformer/pre_norm_6/fast_attention_3/Sel

In [11]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name_to_variable[name]
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [12]:
tvars = tf.trainable_variables()
checkpoint = 'fastformer-base/model.ckpt-500000'
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 
                                                                                checkpoint)

In [13]:
saver = tf.train.Saver(var_list = assignment_map)
saver.restore(sess, checkpoint)

INFO:tensorflow:Restoring parameters from fastformer-base/model.ckpt-500000


In [14]:
pad_sequences = tf.keras.preprocessing.sequence.pad_sequences

In [15]:
string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'

import re

def entities_textcleaning(string, lowering = False):
    """
    use by entities recognition, pos recognition and dependency parsing
    """
    string = re.sub('[^A-Za-z0-9\-\/() ]+', ' ', string)
    string = re.sub(r'[ ]+', ' ', string).strip()
    original_string = string.split()
    if lowering:
        string = string.lower()
    string = [
        (original_string[no], word.title() if word.isupper() else word)
        for no, word in enumerate(string.split())
        if len(word)
    ]
    return [s[0] for s in string], [s[1] for s in string]

def parse_X(left):
    bert_tokens = ['[CLS]']
    for no, orig_token in enumerate(left):
        t = tokenizer.tokenize(orig_token)
        bert_tokens.extend(t)
    bert_tokens.append("[SEP]")
    input_mask = [1] * len(bert_tokens)
    return tokenizer.convert_tokens_to_ids(bert_tokens), bert_tokens, input_mask

sequence = entities_textcleaning(string)[1]
parsed_sequence, bert_sequence, input_mask = parse_X(sequence)

In [16]:
predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                },
        )[0]

In [17]:
def merge_wordpiece_tokens_tagging(x, y):
    new_paired_tokens = []
    n_tokens = len(x)
    rejected = ['[CLS]', '[SEP]', '[PAD]']

    i = 0

    while i < n_tokens:
        current_token, current_label = x[i], y[i]
        if current_token.startswith('##'):
            previous_token, previous_label = new_paired_tokens.pop()
            merged_token = previous_token
            merged_label = [previous_label]
            while current_token.startswith('##'):
                merged_token = merged_token + current_token.replace('##', '')
                merged_label.append(current_label)
                i = i + 1
                current_token, current_label = x[i], y[i]
            merged_label = merged_label[0]
            new_paired_tokens.append((merged_token, merged_label))

        else:
            new_paired_tokens.append((current_token, current_label))
            i = i + 1

    words = [
        i[0]
        for i in new_paired_tokens
        if i[0] not in rejected
    ]
    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]
    return words, labels

In [18]:
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
list(zip(merged[0], merged[1]))

[('Kuala', 'EVENT'),
 ('Lumpur', 'OTHER'),
 ('Sempena', 'OTHER'),
 ('sambutan', 'OTHER'),
 ('Aidilfitri', 'OTHER'),
 ('minggu', 'OTHER'),
 ('depan', 'OTHER'),
 ('Perdana', 'OTHER'),
 ('Menteri', 'EVENT'),
 ('Tun', 'OTHER'),
 ('Dr', 'EVENT'),
 ('Mahathir', 'OTHER'),
 ('Mohamad', 'OTHER'),
 ('dan', 'OTHER'),
 ('Menteri', 'OTHER'),
 ('Pengangkutan', 'OTHER'),
 ('Anthony', 'OTHER'),
 ('Loke', 'OTHER'),
 ('Siew', 'OTHER'),
 ('Fook', 'OTHER'),
 ('menitipkan', 'OTHER'),
 ('pesanan', 'OTHER'),
 ('khas', 'OTHER'),
 ('kepada', 'OTHER'),
 ('orang', 'OTHER'),
 ('ramai', 'EVENT'),
 ('yang', 'OTHER'),
 ('mahu', 'OTHER'),
 ('pulang', 'OTHER'),
 ('ke', 'OTHER'),
 ('kampung', 'OTHER'),
 ('halaman', 'OTHER'),
 ('masing', 'GPE'),
 ('-', 'LANGUAGE'),
 ('masing', 'GPE'),
 ('Dalam', 'PERCENT'),
 ('video', 'OTHER'),
 ('pendek', 'OTHER'),
 ('terbitan', 'OTHER'),
 ('Jabatan', 'OTHER'),
 ('Keselamatan', 'OTHER'),
 ('Jalan', 'EVENT'),
 ('Raya', 'OTHER'),
 ('(', 'OTHER'),
 ('Jkjr', 'EVENT'),
 (')', 'WORK_OF_ART')

In [20]:
from tqdm import tqdm
import time

for e in range(1):
    lasttime = time.time()
    train_acc, train_loss, test_acc, test_loss = [], [], [], []
    pbar = tqdm(
        range(0, len(train_X), batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = train_X[i : index]
        batch_y = train_Y[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
            },
        )
        assert not np.isnan(cost)
        train_loss.append(cost)
        train_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    pbar = tqdm(
        range(0, len(test_X), batch_size), desc = 'test minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = test_X[i : index]
        batch_y = test_Y[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
            },
        )
        assert not np.isnan(cost)
        test_loss.append(cost)
        test_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    train_loss = np.mean(train_loss)
    train_acc = np.mean(train_acc)
    test_loss = np.mean(test_loss)
    test_acc = np.mean(test_acc)

    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (e, train_loss, train_acc, test_loss, test_acc)
    )
    predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                },
        )[0]
    merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
    print(list(zip(merged[0], merged[1])))

train minibatch loop: 100%|██████████| 17853/17853 [1:51:40<00:00,  2.66it/s, accuracy=1, cost=0.28]      
test minibatch loop: 100%|██████████| 2464/2464 [08:39<00:00,  4.74it/s, accuracy=0.964, cost=12.8]

time taken: 7220.066801071167
epoch: 0, training loss: 6.377372, training acc: 0.971428, valid loss: 35.847187, valid acc: 0.938606

[('Kuala', 'NORP'), ('Lumpur', 'PERSON'), ('Sempena', 'PERSON'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'FAC'), ('Fook', 'OTHER'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ADDRESS'), ('Keselamatan', 'ORG'), ('Jalan




In [21]:
predicted = sess.run(model.tags_seq,
            feed_dict = {
                model.X: [parsed_sequence],
            },
    )[0]
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
print(list(zip(merged[0], merged[1])))

[('Kuala', 'NORP'), ('Lumpur', 'PERSON'), ('Sempena', 'PERSON'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'FAC'), ('Fook', 'OTHER'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ADDRESS'), ('Keselamatan', 'ORG'), ('Jalan', 'ADDRESS'), ('Raya', 'ADDRESS'), ('(', 'ADDRESS'), ('Jkjr', 'PERSON'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'PERSON'), ('Mahathir'

In [22]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'fastformer-base-ontonotes5/model.ckpt')

'fastformer-base-ontonotes5/model.ckpt'

In [23]:
def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            out_i.append(idx2tag[p])
        out.append(out_i)
    return out

In [24]:
real_Y, predict_Y = [], []

pbar = tqdm(
    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'
)
for i in pbar:
    index = min(i + batch_size, len(test_X))
    batch_x = test_X[i : index]
    batch_y = test_Y[i : index]
    batch_x = pad_sequences(batch_x, padding='post')
    batch_y = pad_sequences(batch_y, padding='post')
    predicted = pred2label(sess.run(model.tags_seq,
            feed_dict = {
                model.X: batch_x,
            },
    ))
    real = pred2label(batch_y)
    predict_Y.extend(predicted)
    real_Y.extend(real)

validation minibatch loop: 100%|██████████| 2464/2464 [08:51<00:00,  4.64it/s]


In [25]:
temp_real_Y = []
for r in real_Y:
    temp_real_Y.extend(r)
    
temp_predict_Y = []
for r in predict_Y:
    temp_predict_Y.extend(r)

In [26]:
len(temp_real_Y), len(temp_predict_Y)

(6926122, 6926122)

In [27]:
from sklearn.metrics import classification_report
print(classification_report(temp_real_Y, temp_predict_Y, digits = 5))

              precision    recall  f1-score   support

     ADDRESS    0.40259   0.99782   0.57370     93446
    CARDINAL    0.79275   0.73096   0.76060     43087
        DATE    0.88270   0.81414   0.84704    111322
       EVENT    0.72272   0.66912   0.69489      5434
         FAC    0.67824   0.20194   0.31122     26691
         GPE    0.81544   0.80176   0.80854     92188
    LANGUAGE    0.73010   0.48020   0.57935       783
         LAW    0.76660   0.10181   0.17975     24261
         LOC    0.73178   0.66006   0.69407     33418
       MONEY    0.78543   0.72160   0.75216     25790
        NORP    0.78792   0.43103   0.55723     54553
     ORDINAL    0.67934   0.73757   0.70726      5693
         ORG    0.85413   0.66220   0.74602    194233
       OTHER    0.97412   0.97779   0.97595   3180488
         PAD    1.00000   1.00000   1.00000   1332067
     PERCENT    0.90844   0.85063   0.87859     18150
      PERSON    0.84472   0.79316   0.81813     92888
     PRODUCT    0.55112   0

In [28]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name
        or 'self/Softmax' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'ReadVariableOp' not in n.name
        and 'AssignVariableOp' not in n.name
        and '/Assign' not in n.name
        and '/Adam' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'dense/kernel',
 'dense/bias',
 'transitions',
 'logits']

In [29]:
def freeze_graph(model_dir, output_node_names):

    if not tf.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [30]:
freeze_graph('fastformer-base-ontonotes5', strings)

INFO:tensorflow:Restoring parameters from fastformer-base-ontonotes5/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 163 variables.
INFO:tensorflow:Converted 163 variables to const ops.
7024 ops in the final graph.


In [31]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph

In [32]:
g = load_graph('fastformer-base-ontonotes5/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.InteractiveSession(graph = g)

In [33]:
%%time

predicted = test_sess.run(logits,
            feed_dict = {
                x: [parsed_sequence],
            },
    )[0]
merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
print(list(zip(merged[0], merged[1])))

[('Kuala', 'NORP'), ('Lumpur', 'PERSON'), ('Sempena', 'PERSON'), ('sambutan', 'OTHER'), ('Aidilfitri', 'PERSON'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'FAC'), ('Fook', 'OTHER'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'ADDRESS'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ADDRESS'), ('Keselamatan', 'ORG'), ('Jalan', 'ADDRESS'), ('Raya', 'ADDRESS'), ('(', 'ADDRESS'), ('Jkjr', 'PERSON'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'PERSON'), ('Mahathir'

In [34]:
from tensorflow.tools.graph_transforms import TransformGraph

In [35]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

input_nodes = [
    'Placeholder',
]
output_nodes = [
    'logits',
]

pb = 'fastformer-base-ontonotes5/frozen_model.pb'

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

transformed_graph_def = TransformGraph(input_graph_def, 
                                           input_nodes,
                                           output_nodes, transforms)
    
with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


In [37]:
# g = load_graph('fastformer-tiny-ontonotes5/frozen_model.pb.quantized')
# x = g.get_tensor_by_name('import/Placeholder:0')
# logits = g.get_tensor_by_name('import/logits:0')
# test_sess = tf.InteractiveSession(graph = g)

In [38]:
# %%time

# predicted = test_sess.run(logits,
#             feed_dict = {
#                 x: [parsed_sequence],
#             },
#     )[0]
# merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
# print(list(zip(merged[0], merged[1])))

In [40]:
file = 'fastformer-base-ontonotes5/frozen_model.pb'
outPutname = 'entity-ontonotes5/fastformer/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

<b2sdk.file_version.FileVersionInfo at 0x7feae8176908>

In [41]:
file = 'fastformer-base-ontonotes5/frozen_model.pb.quantized'
outPutname = 'entity-ontonotes5/fastformer-quantized/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

<b2sdk.file_version.FileVersionInfo at 0x7feae81760f0>