In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
import pickle
import json
import tensorflow as tf
import numpy as np
from unidecode import unidecode

In [3]:
import bert
from bert import run_classifier
from bert import optimization
from bert import tokenization
from bert import modeling

In [4]:
import unicodedata
import six
from functools import partial

SPIECE_UNDERLINE = '▁'

def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
  if remove_space:
    outputs = ' '.join(inputs.strip().split())
  else:
    outputs = inputs
  outputs = outputs.replace("``", '"').replace("''", '"')

  if six.PY2 and isinstance(outputs, str):
    outputs = outputs.decode('utf-8')

  if not keep_accents:
    outputs = unicodedata.normalize('NFKD', outputs)
    outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
  if lower:
    outputs = outputs.lower()

  return outputs


def encode_pieces(sp_model, text, return_unicode=True, sample=False):
  # return_unicode is used only for py2

  # note(zhiliny): in some systems, sentencepiece only accepts str for py2
  if six.PY2 and isinstance(text, unicode):
    text = text.encode('utf-8')

  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  new_pieces = []
  for piece in pieces:
    if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(
          piece[:-1].replace(SPIECE_UNDERLINE, ''))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  # note(zhiliny): convert back to unicode for py2
  if six.PY2 and return_unicode:
    ret_pieces = []
    for piece in new_pieces:
      if isinstance(piece, str):
        piece = piece.decode('utf-8')
      ret_pieces.append(piece)
    new_pieces = ret_pieces

  return new_pieces


def encode_ids(sp_model, text, sample=False):
  pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
  ids = [sp_model.PieceToId(piece) for piece in pieces]
  return ids

In [8]:
# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/bert-base-2020-03-19.tar.gz
# !tar -zxf bert-base-2020-03-19.tar.gz

In [9]:
!ls bert-base-2020-03-19

bert_config.json			model.ckpt-2000002.meta
model.ckpt-2000002.data-00000-of-00001	sp10m.cased.v9.model
model.ckpt-2000002.index		sp10m.cased.v9.vocab


In [7]:
# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.bert.model
# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.bert.vocab

In [10]:
import sentencepiece as spm

sp_model = spm.SentencePieceProcessor()
sp_model.Load('sp10m.cased.bert.model')

with open('sp10m.cased.bert.vocab') as fopen:
    v = fopen.read().split('\n')[:-1]
v = [i.split('\t') for i in v]
v = {i[0]: i[1] for i in v}

class Tokenizer:
    def __init__(self, v):
        self.vocab = v
        pass
    
    def tokenize(self, string):
        return encode_pieces(sp_model, string, return_unicode=False, sample=False)
    
    def convert_tokens_to_ids(self, tokens):
        return [sp_model.PieceToId(piece) for piece in tokens]
    
    def convert_ids_to_tokens(self, ids):
        return [sp_model.IdToPiece(i) for i in ids]
    
tokenizer = Tokenizer(v)

In [16]:
import json

with open('ontonotes5-train-test.json') as fopen:
    data = json.load(fopen)
data.keys()

dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])

In [17]:
train_X = data['train_X']
test_X = data['test_X']
train_Y = data['train_Y']
test_Y = data['test_Y']

In [18]:
from glob import glob

augmented = glob('augmentation-*-ontonotes5.json')

for f in augmented:
    print(f)
    with open(f) as fopen:
        data = json.load(fopen)
        
    print(len(data.get('train_X', [])), len(data.get('train_Y', [])))
    print(len(data.get('test_X', [])), len(data.get('test_Y', [])))
    
    train_X.extend(data.get('train_X', []))
    train_Y.extend(data.get('train_Y', []))
    test_X.extend(data.get('test_X', []))
    test_Y.extend(data.get('test_Y', []))

augmentation-org-ontonotes5.json
35150 35150
6300 6300
augmentation-fac-ontonotes5.json
17040 17040
4260 4260
augmentation-loc-ontonotes5.json
13040 13040
2460 2460
augmentation-gpe-ontonotes5.json
21060 21060
0 0
augmentation-work-of-art-ontonotes5.json
4020 4020
1020 1020
augmentation-event-ontonotes5.json
1665 1665
0 0
augmentation-person-ontonotes5.json
37883 37883
2530 2530
augmentation-product-ontonotes5.json
7040 7040
1760 1760
augmentation-law-ontonotes5.json
12584 12584
3161 3161
augmentation-language-ontonotes5.json
6860 6860
0 0
augmentation-norp-ontonotes5.json
30660 30660
5940 5940
augmentation-address-ontonotes5.json
57502 57502
15106 15106


In [19]:
d = [
    {'Tag': 'OTHER', 'Description': 'other'},
    {'Tag': 'ADDRESS', 'Description': 'Address of physical location.'},
    {'Tag': 'PERSON', 'Description': 'People, including fictional.'},
    {
        'Tag': 'NORP',
        'Description': 'Nationalities or religious or political groups.',
    },
    {
        'Tag': 'FAC',
        'Description': 'Buildings, airports, highways, bridges, etc.',
    },
    {
        'Tag': 'ORG',
        'Description': 'Companies, agencies, institutions, etc.',
    },
    {'Tag': 'GPE', 'Description': 'Countries, cities, states.'},
    {
        'Tag': 'LOC',
        'Description': 'Non-GPE locations, mountain ranges, bodies of water.',
    },
    {
        'Tag': 'PRODUCT',
        'Description': 'Objects, vehicles, foods, etc. (Not services.)',
    },
    {
        'Tag': 'EVENT',
        'Description': 'Named hurricanes, battles, wars, sports events, etc.',
    },
    {'Tag': 'WORK_OF_ART', 'Description': 'Titles of books, songs, etc.'},
    {'Tag': 'LAW', 'Description': 'Named documents made into laws.'},
    {'Tag': 'LANGUAGE', 'Description': 'Any named language.'},
    {
        'Tag': 'DATE',
        'Description': 'Absolute or relative dates or periods.',
    },
    {'Tag': 'TIME', 'Description': 'Times smaller than a day.'},
    {'Tag': 'PERCENT', 'Description': 'Percentage, including "%".'},
    {'Tag': 'MONEY', 'Description': 'Monetary values, including unit.'},
    {
        'Tag': 'QUANTITY',
        'Description': 'Measurements, as of weight or distance.',
    },
    {'Tag': 'ORDINAL', 'Description': '"first", "second", etc.'},
    {
        'Tag': 'CARDINAL',
        'Description': 'Numerals that do not fall under another type.',
    },
]
d = [d['Tag'] for d in d]
d = ['PAD', 'X'] + d
tag2idx = {i: no for no, i in enumerate(d)}
idx2tag = {no: i for no, i in enumerate(d)}

In [20]:
from tqdm import tqdm

def XY(strings):
    left_train, right_train = strings[0], strings[1]
    X, Y, MASK = [], [], []
    for i in tqdm(range(len(left_train))):
        left = [d for d in left_train[i]]
        right = [d for d in right_train[i]]
        bert_tokens = ['[CLS]']
        y = ['PAD']
        for no, orig_token in enumerate(left):
            t = tokenizer.tokenize(orig_token)
            bert_tokens.extend(t)
            if len(t):
                y.append(right[no])
            y.extend(['X'] * (len(t) - 1))
        bert_tokens.append('[SEP]')
        y.append('PAD')
        x = tokenizer.convert_tokens_to_ids(bert_tokens)
        y = [tag2idx[i] for i in y]
        input_mask = [1] * len(y)
        if len(x) != len(y):
            print(i)
        X.append(x)
        Y.append(y)
        MASK.append(input_mask)
    return [(X, Y, MASK,)]

In [21]:
import cleaning
t = cleaning.multiprocessing(train_X, train_Y, XY)
train_X, train_Y, train_masks = [], [], []
for x, y, m in t:
    train_X.extend(x)
    train_Y.extend(y)
    train_masks.extend(m)

100%|██████████| 40812/40812 [00:24<00:00, 1650.02it/s]
100%|██████████| 1/1 [00:00<00:00, 1236.89it/s].14it/s]
100%|██████████| 40812/40812 [00:24<00:00, 1637.13it/s]
 40%|████      | 16509/40812 [00:10<00:16, 1515.38it/s]
100%|██████████| 40812/40812 [00:24<00:00, 1644.78it/s]
100%|██████████| 40812/40812 [00:24<00:00, 1655.85it/s]
100%|██████████| 40812/40812 [00:25<00:00, 1606.47it/s]
100%|██████████| 40812/40812 [00:24<00:00, 1673.43it/s]
100%|██████████| 40812/40812 [00:24<00:00, 1674.92it/s]
100%|██████████| 40812/40812 [00:25<00:00, 1595.99it/s]
100%|██████████| 40812/40812 [00:25<00:00, 1612.04it/s]
100%|██████████| 40812/40812 [00:26<00:00, 1538.24it/s]
100%|██████████| 40812/40812 [00:25<00:00, 1575.62it/s]
100%|██████████| 40812/40812 [00:26<00:00, 1560.13it/s]
100%|██████████| 40812/40812 [00:26<00:00, 1561.15it/s]
100%|██████████| 40812/40812 [00:29<00:00, 1405.82it/s]
100%|██████████| 40812/40812 [00:27<00:00, 1475.96it/s]


In [22]:
t = cleaning.multiprocessing(test_X, test_Y, XY)
test_X, test_Y, test_masks = [], [], []
for x, y, m in t:
    test_X.extend(x)
    test_Y.extend(y)
    test_masks.extend(m)

100%|██████████| 5494/5494 [00:03<00:00, 1501.47it/s]
 94%|█████████▍| 5186/5494 [00:03<00:00, 1314.24it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1611.70it/s]
 64%|██████▍   | 3506/5494 [00:02<00:01, 1574.45it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1521.03it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1663.32it/s]
 76%|███████▌  | 4157/5494 [00:02<00:00, 1524.00it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1542.82it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1528.34it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1537.33it/s]
 96%|█████████▌| 5263/5494 [00:03<00:00, 1416.31it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1536.99it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1471.49it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1419.56it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1461.35it/s]
100%|██████████| 5494/5494 [00:04<00:00, 1328.59it/s]
100%|██████████| 5494/5494 [00:03<00:00, 1396.69it/s]


In [23]:
from sklearn.utils import shuffle

train_X, train_Y, train_masks = shuffle(train_X, train_Y, train_masks)

In [24]:
BERT_INIT_CHKPNT = 'bert-base-2020-03-19/model.ckpt-2000002'
BERT_CONFIG = 'bert-base-2020-03-19/bert_config.json'

In [25]:
epoch = 5
batch_size = 32
warmup_proportion = 0.1
num_train_steps = int(1000 / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)

In [26]:
def create_initializer(initializer_range=0.02):
    return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)

class Model:
    def __init__(
        self,
        dimension_output,
        learning_rate = 2e-5,
        training = True
    ):
        self.X = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.MASK = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.Y = tf.compat.v1.placeholder(tf.compat.v1.int32, [None, None])
        self.maxlen = tf.compat.v1.shape(self.X)[1]
        self.lengths = tf.compat.v1.count_nonzero(self.X, 1)
        
        model = modeling.BertModel(
            config=bert_config,
            is_training=training,
            input_ids=self.X,
            input_mask=self.MASK,
            use_one_hot_embeddings=False)
        output_layer = model.get_sequence_output()
        output_layer = tf.compat.v1.layers.dense(
            output_layer,
            bert_config.hidden_size,
            activation=tf.compat.v1.tanh,
            kernel_initializer=create_initializer())
        logits = tf.compat.v1.layers.dense(output_layer, dimension_output,
                                         kernel_initializer=create_initializer())
        y_t = self.Y
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
            logits, y_t, self.lengths
        )
        self.cost = tf.compat.v1.reduce_mean(-log_likelihood)
        self.optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        mask = tf.compat.v1.sequence_mask(self.lengths, maxlen = self.maxlen)
        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(
            logits, transition_params, self.lengths
        )
        self.tags_seq = tf.compat.v1.identity(self.tags_seq, name = 'logits')

        y_t = tf.compat.v1.cast(y_t, tf.compat.v1.int32)
        self.prediction = tf.compat.v1.boolean_mask(self.tags_seq, mask)
        mask_label = tf.compat.v1.boolean_mask(y_t, mask)
        correct_pred = tf.compat.v1.equal(self.prediction, mask_label)
        correct_index = tf.compat.v1.cast(correct_pred, tf.compat.v1.float32)
        self.accuracy = tf.compat.v1.reduce_mean(tf.compat.v1.cast(correct_pred, tf.compat.v1.float32))

In [27]:
dimension_output = len(tag2idx)
learning_rate = 2e-5

tf.compat.v1.reset_default_graph()
sess = tf.compat.v1.InteractiveSession()
model = Model(
    dimension_output,
    learning_rate
)

sess.run(tf.compat.v1.global_variables_initializer())
var_lists = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')
saver = tf.compat.v1.train.Saver(var_list = var_lists)
saver.restore(sess, BERT_INIT_CHKPNT)

Instructions for updating:
reduction_indices is deprecated, use axis instead
Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use tf.compat.v1.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
INFO:tensorflow:Restoring parameters from bert-base-2020-03-19/model.ckpt-2000002


In [28]:
def merge_sentencepiece_tokens_tagging(x, y):
    new_paired_tokens = []
    n_tokens = len(x)
    rejected = ['[CLS]', '[SEP]']

    i = 0

    while i < n_tokens:

        current_token, current_label = x[i], y[i]
        if not current_token.startswith('▁') and current_token not in rejected:
            previous_token, previous_label = new_paired_tokens.pop()
            merged_token = previous_token
            merged_label = [previous_label]
            while (
                not current_token.startswith('▁')
                and current_token not in rejected
            ):
                merged_token = merged_token + current_token.replace('▁', '')
                merged_label.append(current_label)
                i = i + 1
                current_token, current_label = x[i], y[i]
            merged_label = merged_label[0]
            new_paired_tokens.append((merged_token, merged_label))

        else:
            new_paired_tokens.append((current_token, current_label))
            i = i + 1

    words = [
        i[0].replace('▁', '')
        for i in new_paired_tokens
        if i[0] not in rejected
    ]
    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]
    return words, labels

In [31]:
string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'

import re

def entities_textcleaning(string, lowering = False):
    """
    use by entities recognition, pos recognition and dependency parsing
    """
    string = re.sub(r'[ ]+', ' ', string).strip()
    original_string = string.split()
    if lowering:
        string = string.lower()
    string = [
        (original_string[no], word.title() if word.isupper() else word)
        for no, word in enumerate(string.split())
        if len(word)
    ]
    return [s[0] for s in string], [s[1] for s in string]

def parse_X(left):
    bert_tokens = ['[CLS]']
    for no, orig_token in enumerate(left):
        t = tokenizer.tokenize(orig_token)
        bert_tokens.extend(t)
    bert_tokens.append("[SEP]")
    input_mask = [1] * len(bert_tokens)
    return tokenizer.convert_tokens_to_ids(bert_tokens), bert_tokens, input_mask

sequence = entities_textcleaning(string)[1]
parsed_sequence, bert_sequence, input_mask = parse_X(sequence)

In [32]:
predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                    model.MASK: [input_mask]
                },
        )[0]
merged = merge_sentencepiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
list(zip(merged[0], merged[1]))

[('Kuala', 'PAD'),
 ('Lumpur:', 'TIME'),
 ('Sempena', 'CARDINAL'),
 ('sambutan', 'CARDINAL'),
 ('Aidilfitri', 'LOC'),
 ('minggu', 'GPE'),
 ('depan,', 'CARDINAL'),
 ('Perdana', 'PRODUCT'),
 ('Menteri', 'LOC'),
 ('Tun', 'OTHER'),
 ('Dr', 'OTHER'),
 ('Mahathir', 'PERCENT'),
 ('Mohamad', 'GPE'),
 ('dan', 'NORP'),
 ('Menteri', 'OTHER'),
 ('Pengangkutan', 'OTHER'),
 ('Anthony', 'TIME'),
 ('Loke', 'PERSON'),
 ('Siew', 'GPE'),
 ('Fook', 'ADDRESS'),
 ('menitipkan', 'LAW'),
 ('pesanan', 'LAW'),
 ('khas', 'LAW'),
 ('kepada', 'LAW'),
 ('orang', 'FAC'),
 ('ramai', 'MONEY'),
 ('yang', 'LAW'),
 ('mahu', 'LAW'),
 ('pulang', 'LAW'),
 ('ke', 'NORP'),
 ('kampung', 'OTHER'),
 ('halaman', 'OTHER'),
 ('masing-masing.', 'OTHER'),
 ('Dalam', 'PRODUCT'),
 ('video', 'LAW'),
 ('pendek', 'FAC'),
 ('terbitan', 'LOC'),
 ('Jabatan', 'OTHER'),
 ('Keselamatan', 'TIME'),
 ('Jalan', 'PERSON'),
 ('Raya', 'X'),
 ('(Jkjr)', 'PERSON'),
 ('itu,', 'PERSON'),
 ('Dr', 'NORP'),
 ('Mahathir', 'OTHER'),
 ('menasihati', 'OTHER'),
 

In [33]:
pad_sequences = tf.compat.v1.keras.preprocessing.sequence.pad_sequences

In [34]:
import time

for e in range(epoch):
    lasttime = time.time()
    train_acc, train_loss, test_acc, test_loss = [], [], [], []
    pbar = tqdm(
        range(0, len(train_X), batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = train_X[i : index]
        batch_y = train_Y[i : index]
        batch_masks = train_masks[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        batch_masks = pad_sequences(batch_masks, padding='post')
        
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
                model.MASK: batch_masks,
            },
        )
        assert not np.isnan(cost)
        train_loss.append(cost)
        train_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    pbar = tqdm(
        range(0, len(test_X), batch_size), desc = 'test minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = test_X[i : index]
        batch_y = test_Y[i : index]
        batch_masks = test_masks[i : index]
        batch_x = pad_sequences(batch_x, padding='post')
        batch_y = pad_sequences(batch_y, padding='post')
        batch_masks = pad_sequences(batch_masks, padding='post')
        
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.X: batch_x,
                model.Y: batch_y,
                model.MASK: batch_masks,
            },
        )
        assert not np.isnan(cost)
        test_loss.append(cost)
        test_acc.append(acc)
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    train_loss = np.mean(train_loss)
    train_acc = np.mean(train_acc)
    test_loss = np.mean(test_loss)
    test_acc = np.mean(test_acc)

    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (e, train_loss, train_acc, test_loss, test_acc)
    )
    predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.X: [parsed_sequence],
                    model.MASK: [input_mask]
                },
        )[0]
    merged = merge_sentencepiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
    print(list(zip(merged[0], merged[1])))

train minibatch loop: 100%|██████████| 20407/20407 [2:06:31<00:00,  2.69it/s, accuracy=1, cost=0.705]      
test minibatch loop: 100%|██████████| 2748/2748 [07:59<00:00,  5.73it/s, accuracy=0.999, cost=0.736] 
train minibatch loop:   0%|          | 0/20407 [00:00<?, ?it/s]

time taken: 8070.775153160095
epoch: 0, training loss: 2.863052, training acc: 0.986971, valid loss: 4.914965, valid acc: 0.983747

[('Kuala', 'TIME'), ('Lumpur:', 'TIME'), ('Sempena', 'OTHER'), ('sambutan', 'OTHER'), ('Aidilfitri', 'DATE'), ('minggu', 'DATE'), ('depan,', 'DATE'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'ORG'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'PERSON'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing-masing.', 'OTHER'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('Jalan', 'ORG'), ('Raya', 'ORG'), ('(J

train minibatch loop: 100%|██████████| 20407/20407 [2:06:29<00:00,  2.69it/s, accuracy=1, cost=0.00201]    
test minibatch loop: 100%|██████████| 2748/2748 [07:57<00:00,  5.76it/s, accuracy=1, cost=0.0769]    
train minibatch loop:   0%|          | 0/20407 [00:00<?, ?it/s]

time taken: 8066.735553979874
epoch: 1, training loss: 0.540903, training acc: 0.997296, valid loss: 5.464173, valid acc: 0.984979

[('Kuala', 'PERSON'), ('Lumpur:', 'TIME'), ('Sempena', 'OTHER'), ('sambutan', 'DATE'), ('Aidilfitri', 'DATE'), ('minggu', 'DATE'), ('depan,', 'DATE'), ('Perdana', 'PERSON'), ('Menteri', 'PERSON'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'OTHER'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'PERSON'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing-masing.', 'OTHER'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('Jalan', 'ORG'), ('Raya', 'ORG')

train minibatch loop:   4%|▎         | 718/20407 [04:24<2:01:03,  2.71it/s, accuracy=0.996, cost=0.735]


KeyboardInterrupt: 

In [36]:
saver = tf.compat.v1.train.Saver(tf.compat.v1.trainable_variables())
saver.save(sess, 'bert-base-entities/model.ckpt')

'bert-base-entities/model.ckpt'

In [37]:
dimension_output = len(tag2idx)
learning_rate = 2e-5

tf.compat.v1.reset_default_graph()
sess = tf.compat.v1.InteractiveSession()
model = Model(
    dimension_output,
    learning_rate,
    training = False
)

sess.run(tf.compat.v1.global_variables_initializer())
saver = tf.compat.v1.train.Saver(tf.compat.v1.trainable_variables())
saver.restore(sess, 'bert-base-entities/model.ckpt')



INFO:tensorflow:Restoring parameters from bert-base-entities/model.ckpt


In [38]:
def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            out_i.append(idx2tag[p])
        out.append(out_i)
    return out

In [39]:
real_Y, predict_Y = [], []

pbar = tqdm(
    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'
)
for i in pbar:
    index = min(i + batch_size, len(test_X))
    batch_x = test_X[i : index]
    batch_y = test_Y[i : index]
    batch_masks = test_masks[i : index]
    batch_x = pad_sequences(batch_x, padding='post')
    batch_y = pad_sequences(batch_y, padding='post')
    batch_masks = pad_sequences(batch_masks, padding='post')
    predicted = pred2label(sess.run(model.tags_seq,
            feed_dict = {
                model.X: batch_x,
                model.MASK: batch_masks,
            },
    ))
    real = pred2label(batch_y)
    predict_Y.extend(predicted)
    real_Y.extend(real)

validation minibatch loop: 100%|██████████| 2748/2748 [07:45<00:00,  5.91it/s]


In [40]:
temp_real_Y = []
for r in real_Y:
    temp_real_Y.extend(r)
    
temp_predict_Y = []
for r in predict_Y:
    temp_predict_Y.extend(r)

In [41]:
from sklearn.metrics import classification_report
print(classification_report(temp_real_Y, temp_predict_Y, digits = 5))

              precision    recall  f1-score   support

     ADDRESS    0.99858   0.99974   0.99916     93446
    CARDINAL    0.93840   0.90631   0.92207     48255
        DATE    0.95490   0.93656   0.94564    126548
       EVENT    0.92876   0.93591   0.93232      5711
         FAC    0.93271   0.92658   0.92964     27392
         GPE    0.93437   0.94852   0.94139    101357
    LANGUAGE    0.93478   0.96389   0.94911       803
         LAW    0.94824   0.95744   0.95281     24834
         LOC    0.94148   0.93213   0.93678     34538
       MONEY    0.87803   0.87563   0.87683     30032
        NORP    0.95516   0.90446   0.92912     57014
     ORDINAL    0.91510   0.91083   0.91296      6213
         ORG    0.92453   0.95354   0.93881    219533
       OTHER    0.99135   0.99308   0.99221   3553350
         PAD    0.99956   1.00000   0.99978   1292421
     PERCENT    0.96287   0.96814   0.96550     21722
      PERSON    0.97376   0.93891   0.95602    101981
     PRODUCT    0.87537   0

In [42]:
strings = ','.join(
    [
        n.name
        for n in tf.compat.v1.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name
        or 'self/Softmax' in n.name)
        and 'Adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'Placeholder_2',
 'bert/embeddings/word_embeddings',
 'bert/embeddings/token_type_embeddings',
 'bert/embeddings/position_embeddings',
 'bert/embeddings/LayerNorm/gamma',
 'bert/encoder/layer_0/attention/self/query/kernel',
 'bert/encoder/layer_0/attention/self/query/bias',
 'bert/encoder/layer_0/attention/self/key/kernel',
 'bert/encoder/layer_0/attention/self/key/bias',
 'bert/encoder/layer_0/attention/self/value/kernel',
 'bert/encoder/layer_0/attention/self/value/bias',
 'bert/encoder/layer_0/attention/self/Softmax',
 'bert/encoder/layer_0/attention/output/dense/kernel',
 'bert/encoder/layer_0/attention/output/dense/bias',
 'bert/encoder/layer_0/attention/output/LayerNorm/gamma',
 'bert/encoder/layer_0/intermediate/dense/kernel',
 'bert/encoder/layer_0/intermediate/dense/bias',
 'bert/encoder/layer_0/output/dense/kernel',
 'bert/encoder/layer_0/output/dense/bias',
 'bert/encoder/layer_0/output/LayerNorm/gamma',
 'bert/encoder/layer_1/attention/sel

In [43]:
def freeze_graph(model_dir, output_node_names):

    if not tf.compat.v1.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.compat.v1.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.compat.v1.Session(graph = tf.compat.v1.Graph()) as sess:
        saver = tf.compat.v1.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess,
            tf.compat.v1.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.compat.v1.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [45]:
freeze_graph('bert-base-entities', strings)

INFO:tensorflow:Restoring parameters from bert-base-entities/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 204 variables.
INFO:tensorflow:Converted 204 variables to const ops.
13221 ops in the final graph.


In [46]:
def load_graph(frozen_graph_filename):
    with tf.compat.v1.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)
    return graph

g = load_graph('bert-base-entities/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
mask = g.get_tensor_by_name('import/Placeholder_1:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.compat.v1.InteractiveSession(graph = g)



In [47]:
string = 'Kyrgios, 25, membuat pesanan itu kerana menyedaari pelbagai kesukaran menimpa rakyat Australia ekoran perintah kawalan pergerakan yang diumumkan Mac lalu bagi memerangi wabak COVID-19 di negara berkenaan. Pemain tenis ranking ke-40 dunia yang dilahirkan di Canberra itu meminta pengikut dan penyokongnya agar jangan tidur dalam keadaan perut kosong dalam hantaran Instagram yang meraih lebih 92,000 tanda suka.'

ori, sequence = entities_textcleaning(string)
parsed_sequence, bert_sequence, input_mask = parse_X(sequence)

In [48]:
predicted = test_sess.run(logits,
            feed_dict = {
                x: [parsed_sequence],
                mask: [input_mask]
            },
    )[0]
merged = merge_sentencepiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])
print(list(zip(merged[0], merged[1])))

[('Kyrgios,', 'PERSON'), ('25,', 'NORP'), ('membuat', 'OTHER'), ('pesanan', 'OTHER'), ('itu', 'OTHER'), ('kerana', 'OTHER'), ('menyedaari', 'OTHER'), ('pelbagai', 'OTHER'), ('kesukaran', 'OTHER'), ('menimpa', 'OTHER'), ('rakyat', 'OTHER'), ('Australia', 'NORP'), ('ekoran', 'OTHER'), ('perintah', 'OTHER'), ('kawalan', 'OTHER'), ('pergerakan', 'OTHER'), ('yang', 'OTHER'), ('diumumkan', 'OTHER'), ('Mac', 'DATE'), ('lalu', 'DATE'), ('bagi', 'OTHER'), ('memerangi', 'OTHER'), ('wabak', 'OTHER'), ('Covid-19', 'OTHER'), ('di', 'OTHER'), ('negara', 'OTHER'), ('berkenaan.', 'OTHER'), ('Pemain', 'OTHER'), ('tenis', 'OTHER'), ('ranking', 'OTHER'), ('ke-40', 'ORDINAL'), ('dunia', 'OTHER'), ('yang', 'OTHER'), ('dilahirkan', 'OTHER'), ('di', 'OTHER'), ('Canberra', 'GPE'), ('itu', 'OTHER'), ('meminta', 'OTHER'), ('pengikut', 'OTHER'), ('dan', 'OTHER'), ('penyokongnya', 'OTHER'), ('agar', 'OTHER'), ('jangan', 'OTHER'), ('tidur', 'OTHER'), ('dalam', 'OTHER'), ('keadaan', 'OTHER'), ('perut', 'OTHER'), ('

In [49]:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
tf.compat.v1.set_random_seed(0)

In [50]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

pb = 'bert-base-entities/frozen_model.pb'

input_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

inputs = ['Placeholder', 'Placeholder_1']
outputs = ['dense/BiasAdd']

print(pb, inputs)

transformed_graph_def = TransformGraph(input_graph_def, 
                                       inputs,
                                       ['logits'] + outputs, transforms)

with tf.compat.v1.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.compat.v1.gfile.GFile.
bert-base-entities/frozen_model.pb ['Placeholder', 'Placeholder_1']
