In [1]:
import argparse
import json
import mxnet as mx
import tensorflow as tf
from tensorflow.python.framework import graph_util

import numpy as np

from converter import Converter
import os
import time

%reload_ext autoreload
%autoreload 2

def main(model_prefix, output_prefix, input_h=128, input_w = 128):
    # Parsing JSON is easier because it contains operator name
    js_model = json.load(open(model_prefix + '-symbol.json', 'r'))
    mx_model, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 0)
    params = arg_params
    params.update(aux_params)
    tf.reset_default_graph()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

    a = 0
    with tf.Session(config=config) as sess:
        tf_nodes = dict()
        # Workaround for input node
        input_data = tf.placeholder('float32', (1, input_h, input_w, 3), name='data')
        tf_nodes['data'] = input_data
        nodes = js_model['nodes']
        conv = Converter(tf_nodes, nodes, params)
        for node_idx, node in enumerate(nodes):
            op = node['op']
            print('Parsing node %s with operator %s and index %d' % (node['name'], op, node_idx))
            # Hack for older versions of MxNet
            if 'param' in node:
                node['attrs'] = node['param']
            if op == 'BatchNorm':
                conv.create_bn(node)
            elif op == 'elemwise_add' or op == '_Plus':
                conv.create_elementwise(node)
            elif op == 'Activation':
                conv.create_activation(node)
            elif op == 'SoftmaxOutput':
                conv.create_softmax(node)
            elif op == 'Convolution':
                conv.create_conv(node)
            elif op == 'Pooling':
                conv.create_pooling(node)
            elif op == 'Flatten':
                conv.create_flatten(node)
            elif op == 'FullyConnected':
                conv.create_fc(node)
            elif op == 'L2Normalization':
                conv.create_norm(node)
            elif op == "Concat":
                conv.create_concat(node)
            elif op == "Reshape":
                conv.create_reshape(node)
            elif op == "Crop":
                conv.create_crop(node)
            elif op == "UpSampling":
                conv.create_upsampling(node)
                #conv.create_upsampling_v2(node)
            elif op == "SoftmaxActivation":
                conv.create_softmaxactivation(node)
            elif op == 'null':
                #print("this is just a param, do not create op")
                a += 1
            else:
                print("------------------unsupported op!!!----------------------")
        
        #print("null nodes:", a)

        #test run
        output_node_names = []
        for i in range(len(js_model['heads'])):
            output_node_names.append(nodes[js_model['heads'][i][0]]['name'])
        print("output node names:", output_node_names)
        output_nodes = []
        print("------------------------------Output nodes:--------------------------")
        for i in range(len(output_node_names)):
            print(tf_nodes[output_node_names[i]])
            output_nodes.append(tf_nodes[output_node_names[i]])
        
        # test define my own conv layer after
        output_nodes_rev = []
        #################### stride 32 ################################
        output_nodes_rev.append(output_nodes[0])
        tmp = output_nodes[3]
        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
            bbox_stride32_w = tf.get_variable("face_rpn_bbox_pred_stride32_weight")
            bbox_stride32_b = tf.get_variable("face_rpn_bbox_pred_stride32_bias")
            landmark_stride32_w = tf.get_variable("face_rpn_landmark_pred_stride32_weight")
            landmark_stride32_b = tf.get_variable("face_rpn_landmark_pred_stride32_bias")
            #print("shape: ", bbox_stride32_w.shape, bbox_stride32_b.shape)
        tmp1 = tf.nn.conv2d(tmp, bbox_stride32_w, [1, 1, 1, 1], padding='VALID')
        tmp1 = tf.add(tmp1, bbox_stride32_b, name = "face_rpn_bbox_pred_stride32_rev")

        tmp2 = tf.nn.conv2d(tmp, landmark_stride32_w, [1, 1, 1, 1], padding='VALID')
        tmp2 = tf.add(tmp2, landmark_stride32_b, name = "face_rpn_landmark_pred_stride32_rev")

        output_nodes_rev.append(tmp1)
        output_nodes_rev.append(tmp2)
        
        #################### stride 16 ################################
        output_nodes_rev.append(output_nodes[1])
        tmp = output_nodes[4]
        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
            bbox_stride16_w = tf.get_variable("face_rpn_bbox_pred_stride16_weight")
            bbox_stride16_b = tf.get_variable("face_rpn_bbox_pred_stride16_bias")
            landmark_stride16_w = tf.get_variable("face_rpn_landmark_pred_stride16_weight")
            landmark_stride16_b = tf.get_variable("face_rpn_landmark_pred_stride16_bias")
            
        tmp1 = tf.nn.conv2d(tmp, bbox_stride16_w, [1, 1, 1, 1], padding='VALID')
        tmp1 = tf.add(tmp1, bbox_stride16_b, name = "face_rpn_bbox_pred_stride16_rev")

        tmp2 = tf.nn.conv2d(tmp, landmark_stride16_w, [1, 1, 1, 1], padding='VALID')
        tmp2 = tf.add(tmp2, landmark_stride16_b, name = "face_rpn_landmark_pred_stride16_rev")

        output_nodes_rev.append(tmp1)
        output_nodes_rev.append(tmp2)
        
        #################### stride 8 ################################
        output_nodes_rev.append(output_nodes[2])
        tmp = output_nodes[5]
        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
            bbox_stride8_w = tf.get_variable("face_rpn_bbox_pred_stride8_weight")
            bbox_stride8_b = tf.get_variable("face_rpn_bbox_pred_stride8_bias")
            landmark_stride8_w = tf.get_variable("face_rpn_landmark_pred_stride8_weight")
            landmark_stride8_b = tf.get_variable("face_rpn_landmark_pred_stride8_bias")
            
        tmp1 = tf.nn.conv2d(tmp, bbox_stride8_w, [1, 1, 1, 1], padding='VALID')
        tmp1 = tf.add(tmp1, bbox_stride8_b, name = "face_rpn_bbox_pred_stride8_rev")

        tmp2 = tf.nn.conv2d(tmp, landmark_stride8_w, [1, 1, 1, 1], padding='VALID')
        tmp2 = tf.add(tmp2, landmark_stride8_b, name = "face_rpn_landmark_pred_stride8_rev")

        output_nodes_rev.append(tmp1)
        output_nodes_rev.append(tmp2)
        
        print("------------------------output nodes rev:----------------------")
        for node in output_nodes_rev:
            print(node)

        print("---------------------test inference-------------------")
        for i in range(5):
            start = time.time()
            out = sess.run(output_nodes_rev, feed_dict = {input_data: np.zeros([1,input_h,input_w,3])})
            end = time.time()
            print("inference time: %.6f" %(end-start))
        for i in range(len(out)):
            print("output node shape:", out[i].transpose(0,3,1,2).shape)
        #'''
        
        g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        print("number of variables (trainable & global):", len(g_vars))
        saver = tf.train.Saver(g_vars)
        saver.save(sess, os.path.join("./checkpoint", "mnet.25-%dx%d"%(input_h, input_w)), write_meta_graph=True)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
mx_prefix = "models/mnet.25"
tf_prefix = "xxx"
input_h = 1080
input_w = 1920
main(mx_prefix, tf_prefix, input_h, input_w)

Parsing node data with operator null and index 0
Parsing node mobilenet0_conv0_weight with operator null and index 1
Parsing node mobilenet0_conv0_fwd with operator Convolution and index 2
Instructions for updating:
Colocations handled automatically by placer.
Parsing node mobilenet0_batchnorm0_gamma with operator null and index 3
Parsing node mobilenet0_batchnorm0_beta with operator null and index 4
Parsing node mobilenet0_batchnorm0_running_mean with operator null and index 5
Parsing node mobilenet0_batchnorm0_running_var with operator null and index 6
Parsing node mobilenet0_batchnorm0_fwd with operator BatchNorm and index 7
Parsing node mobilenet0_relu0_fwd with operator Activation and index 8
Parsing node mobilenet0_conv1_weight with operator null and index 9
Parsing node mobilenet0_conv1_fwd with operator Convolution and index 10
Parsing node mobilenet0_batchnorm1_gamma with operator null and index 11
Parsing node mobilenet0_batchnorm1_beta with operator null and index 12
Parsing

Parsing node mobilenet0_relu16_fwd with operator Activation and index 136
Parsing node mobilenet0_conv17_weight with operator null and index 137
Parsing node mobilenet0_conv17_fwd with operator Convolution and index 138
Parsing node mobilenet0_batchnorm17_gamma with operator null and index 139
Parsing node mobilenet0_batchnorm17_beta with operator null and index 140
Parsing node mobilenet0_batchnorm17_running_mean with operator null and index 141
Parsing node mobilenet0_batchnorm17_running_var with operator null and index 142
Parsing node mobilenet0_batchnorm17_fwd with operator BatchNorm and index 143
Parsing node mobilenet0_relu17_fwd with operator Activation and index 144
Parsing node mobilenet0_conv18_weight with operator null and index 145
Parsing node mobilenet0_conv18_fwd with operator Convolution and index 146
Parsing node mobilenet0_batchnorm18_gamma with operator null and index 147
Parsing node mobilenet0_batchnorm18_beta with operator null and index 148
Parsing node mobilene

Parsing node rf_c3_det_context_conv3_1_relu with operator Activation and index 259
Parsing node rf_c3_det_context_conv3_2_weight with operator null and index 260
Parsing node rf_c3_det_context_conv3_2_bias with operator null and index 261
Parsing node rf_c3_det_context_conv3_2 with operator Convolution and index 262
Parsing node rf_c3_det_context_conv3_2_bn_gamma with operator null and index 263
Parsing node rf_c3_det_context_conv3_2_bn_beta with operator null and index 264
Parsing node rf_c3_det_context_conv3_2_bn_moving_mean with operator null and index 265
Parsing node rf_c3_det_context_conv3_2_bn_moving_var with operator null and index 266
Parsing node rf_c3_det_context_conv3_2_bn with operator BatchNorm and index 267
Parsing node rf_c3_det_concat with operator Concat and index 268
Parsing node rf_c3_det_concat_relu with operator Activation and index 269
Parsing node face_rpn_cls_score_stride32_weight with operator null and index 270
Parsing node face_rpn_cls_score_stride32_bias wi

Parsing node rf_c1_red_conv_relu with operator Activation and index 367
Parsing node rf_c2_upsampling with operator UpSampling and index 368
Parsing node crop1 with operator Crop and index 369
Parsing node plus1 with operator elemwise_add and index 370
-------------adding just 2 nodes, replace with tf.add-------------------
Parsing node rf_c1_aggr_weight with operator null and index 371
Parsing node rf_c1_aggr_bias with operator null and index 372
Parsing node rf_c1_aggr with operator Convolution and index 373
Parsing node rf_c1_aggr_bn_gamma with operator null and index 374
Parsing node rf_c1_aggr_bn_beta with operator null and index 375
Parsing node rf_c1_aggr_bn_moving_mean with operator null and index 376
Parsing node rf_c1_aggr_bn_moving_var with operator null and index 377
Parsing node rf_c1_aggr_bn with operator BatchNorm and index 378
Parsing node rf_c1_aggr_relu with operator Activation and index 379
Parsing node rf_c1_det_conv1_weight with operator null and index 380
Parsing 