In [1]:
import os
import time
from tqdm import tqdm
import numpy as np
from sklearn.metrics import classification_report
import tensorflow.compat.v1 as tf
import utils
import bert_utils

In [2]:
BERTLARGE     = False
MAX_SEQ_LEN   = 128
LEARNING_RATE = 1e-5
TUNE_LAYERS   = -1

In [3]:
if BERTLARGE:
    BERT_PATH = "https://tfhub.dev/google/bert_uncased_L-24_H-1024_A-16/1"
    H_SIZE = 1024
else:
    BERT_PATH = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
    H_SIZE = 768

In [4]:
tokenizer = bert_utils.create_tokenizer_from_hub_module(BERT_PATH, tf.Session())

examples, labels, num_classes = utils.load_ag_news_dataset(max_seq_len=MAX_SEQ_LEN,
                                                           test=True)
labels = np.asarray(labels)
test_examples = bert_utils.convert_text_to_examples(examples, labels)
feat = bert_utils.convert_examples_to_features(tokenizer,
                                               test_examples,
                                               max_seq_length=MAX_SEQ_LEN,
                                               verbose=True)

(test_input_ids, test_input_masks, test_segment_ids, test_labels) = feat

Converting examples to features:   2%|▏         | 124/7600 [00:00<00:06, 1232.90it/s]

Loaded test set from: /home/jovyan/.keras/datasets/ag_news
Examples: 7600 Classes: 4


Converting examples to features: 100%|██████████| 7600/7600 [00:06<00:00, 1247.04it/s]


In [5]:
def reshape_list(input_list, output_shape):
    output = []
    for item in input_list:
        output.append(item.reshape(output_shape))
    return output

In [6]:
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt

config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

trt_session = tf.Session(config=config)

with trt_session as sess:
    with tf.gfile.GFile("frozen_model.pb", "rb") as f:
        frozen_graph = tf.GraphDef()
        frozen_graph.ParseFromString(f.read())
    converter = trt.TrtGraphConverter(input_graph_def=frozen_graph,
                                      nodes_blacklist=["dense/Softmax"],
                                      session_config=config,
                                      precision_mode=trt.TrtPrecisionMode.FP16,
                                      maximum_cached_engines=100,
                                      is_dynamic_op=False)
    trt_graph = converter.convert()
    
    output_node = tf.import_graph_def(trt_graph,
                                      return_elements=["dense/Softmax:0"])
    
    input_shape = [1,MAX_SEQ_LEN]
    
    token_placeholder = tf.placeholder(tf.int32,shape=input_shape)
    mask_placeholder = tf.placeholder(tf.int32,shape=input_shape)
    segment_placeholder = tf.placeholder(tf.int32,shape=input_shape)
    
    test_input_ids = reshape_list(test_input_ids, input_shape)
    test_input_masks = reshape_list(test_input_masks, input_shape)
    test_segment_ids = reshape_list(test_segment_ids, input_shape)
    
    # warm-up inference engine
    
    for i in range(10):
        output = sess.run([output_node], feed_dict={'import/input_ids:0': test_input_ids[i],
                                                    'import/input_masks:0': test_input_masks[i],
                                                    'import/segment_ids:0': test_segment_ids[i]})
    
    # actual benchmark
    
    num_examples = len(test_labels)
    
    start_time = time.time()
    
    preds = []
    
    for i in tqdm(range(num_examples)):
        output = sess.run([output_node], feed_dict={'import/input_ids:0': test_input_ids[i],'import/input_masks:0': test_input_masks[i],'import/segment_ids:0': test_segment_ids[i]})
        preds.append(output)
        
    end_time = time.time()

100%|██████████| 7600/7600 [01:48<00:00, 70.30it/s]


In [7]:
duration = end_time - start_time
print("Duration:", round(duration, 2))
print("Examples/second:", round(len(test_labels)/duration, 2))

Duration: 108.11
Examples/second: 70.3


In [8]:
class_preds = []
for pred in preds:
    class_label = np.argmax(pred[0][0][0])
    class_preds.append(class_label)

In [9]:
y_true = np.asarray(test_labels).flatten().tolist()

In [10]:
print(classification_report(y_true, class_preds))

              precision    recall  f1-score   support

           0       0.97      0.93      0.95      1900
           1       0.98      0.99      0.99      1900
           2       0.92      0.90      0.91      1900
           3       0.89      0.94      0.91      1900

   micro avg       0.94      0.94      0.94      7600
   macro avg       0.94      0.94      0.94      7600
weighted avg       0.94      0.94      0.94      7600

