In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
# !wget https://f000.backblazeb2.com/file/malaya-model/v38/translation/en-ms/base-translation.pb

In [3]:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
from glob import glob
tf.compat.v1.set_random_seed(0)

In [4]:
import tensorflow_text
import tf_sentencepiece

In [5]:
pbs = glob('*.pb')
pbs

['base-translation.pb']

In [6]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics)',
             'fold_constants(ignore_errors=true)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'strip_unused_nodes',
             'sort_by_execution_order']

for pb in pbs:
    input_graph_def = tf.compat.v1.GraphDef()
    with tf.compat.v1.gfile.FastGFile(pb, 'rb') as f:
        input_graph_def.ParseFromString(f.read())
        
    print(pb)
    
    transformed_graph_def = TransformGraph(input_graph_def, 
                                           ['Placeholder'],
                                           ['greedy', 'beam'], transforms)
    
    with tf.compat.v1.gfile.GFile(f'{pb}.optimized', 'wb') as f:
        f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.compat.v1.gfile.GFile.
base-translation.pb


In [7]:
import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np

In [10]:
def load_graph_with_sess(frozen_graph_filename):
    with tf.compat.v1.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)
    return graph

def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'):
    """
    Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for 
    gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm)
    """
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    new_node = graph_def.node.add()
    new_node.op = "FusedBatchNormV2"
    new_node.name = node.name
    new_node.input.extend(node.input)
    new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
    for attr in list(node.attr.keys()):
        if attr == "T":
            node.attr[attr].type = dtype
        new_node.attr[attr].CopyFrom(node.attr[attr])
    print("rewrite fused_batch_norm done!")

def convert_graph_to_fp16(model_path, as_text=False, target_type='fp16', 
                          input_name=None, output_names=None):
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    graph = load_graph_with_sess(model_path)
    source_graph_def = graph.as_graph_def()
    # return source_graph_def
    for node in source_graph_def.node:
        # fused batch norm node
        if node.op == "FusedBatchNorm":
            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
            continue
        if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
            continue
        attrs = list(node.attr.keys())
        # keep batch norm params node
        # replace dtype in node attr with target dtype
        if node.op == 'convert_gradient_to_tensor_HBc3xYw22Mw':
            node.op = 'Identity'
            node.attr.setdefault('T')
            node.attr['T'].type = types_pb2.DT_HALF
            del node.attr['_disable_call_shape_inference']
            
        for attr in attrs:
            # keep special node in fp32
            if node.name in keep_fp32_node_name:
                node.attr[attr].CopyFrom(node.attr[attr])
                continue
            if node.attr[attr].type == types_pb2.DT_FLOAT:
                # modify node dtype
                node.attr[attr].type = dtype
            if attr == "value":
                tensor = node.attr[attr].tensor
                if tensor.dtype == types_pb2.DT_FLOAT:
                    # if float_val exists
                    if tensor.float_val:
                        float_val = tf.compat.v1.make_ndarray(node.attr[attr].tensor)
                        node.attr[attr].tensor.CopyFrom(tf.compat.v1.make_tensor_proto(float_val, dtype=dtype))
                        continue
                    # if tensor content exists
                    if tensor.tensor_content:
                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                        tensor_weights = tf.compat.v1.make_ndarray(tensor)
                        # reshape tensor
                        tensor_weights = np.reshape(tensor_weights, tensor_shape)
                        tensor_proto = tf.compat.v1.make_tensor_proto(tensor_weights, dtype=dtype)
                        node.attr[attr].tensor.CopyFrom(tensor_proto)
                        continue
            
    with tf.compat.v1.gfile.GFile(f'{model_path}.optimized', 'wb') as f:
        f.write(source_graph_def.SerializeToString())

In [11]:
input_name = ['Placeholder']
output_names = ['greedy', 'beam']
keep_fp32_node_name = []

model_path = "base-translation.pb.optimized"
as_text = False
target_type = 'fp16'
g = convert_graph_to_fp16(model_path,
                      as_text=as_text, 
                      target_type=target_type, 
                      input_name=input_name, output_names=output_names)

In [12]:
def load_graph(frozen_graph_filename, **kwargs):
    with tf.compat.v1.io.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = @@#GraphDef()
        graph_def.ParseFromString(f.read())

    # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091
    # to fix import T5
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
            if 'validate_shape' in node.attr:
                del node.attr['validate_shape']
            if len(node.input) == 2:
                node.input[0] = node.input[1]
                del node.input[1]

    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)
    return graph

# g_optimized = load_graph('base-translation.pb.optimized.fp16/base-translation.pb.optimized.fp16')

In [13]:
g_optimized = load_graph('base-translation.pb.optimized.optimized')

In [14]:
from malaya.text.t2t import text_encoder

encoder = text_encoder.SubwordTextEncoder('/home/husein/Malaya/translation-en-ms/base/en-ms.subwords')

In [15]:
g = load_graph('base-translation.pb')

In [16]:
x_optimized = g_optimized.get_tensor_by_name('import/import/Placeholder:0')
greedy_optimized = g_optimized.get_tensor_by_name('import/import/greedy:0')

In [17]:
x = g.get_tensor_by_name('import/Placeholder:0')
greedy = g.get_tensor_by_name('import/greedy:0')

In [19]:
sess = tf.compat.v1.InteractiveSession(graph = g)
sess_optimized = tf.compat.v1.InteractiveSession(graph = g_optimized)

In [20]:
e = encoder.encode('Palestine, recognized officially as the State of Palestine by the United Nations and other entities, is a de jure sovereign state in Western Asia claiming the West Bank and Gaza Strip with Jerusalem as the designated capital, although its administrative center is currently located in Ramallah') + [1]

In [21]:
import time
from tqdm import tqdm

In [22]:
r = []

for _ in tqdm(range(10)):
    before = time.time()
    sess.run(greedy, feed_dict = {x: [e]})
    r.append(time.time() - before)

100%|██████████| 10/10 [00:23<00:00,  2.34s/it]


In [23]:
import numpy as np

np.mean(r)

2.3338852405548094

In [24]:
r = []

for _ in tqdm(range(10)):
    before = time.time()
    sess_optimized.run(greedy_optimized, feed_dict = {x_optimized: [e]})
    r.append(time.time() - before)

100%|██████████| 10/10 [00:51<00:00,  5.16s/it]


In [25]:
np.mean(r)

5.159882307052612