In [1]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools import inspect_checkpoint as chkp
from tensorflow.python import pywrap_tensorflow

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
imgnet_ckp = "/work/wangyu/imagenet/resnet_imagenet_v2_fp32_20181001/model.ckpt-225207"
save_path = '/work/wangyu/imagenet/resnet_v2_imagenet_transformed_mask/resnet50_v2_mask.ckpt'

In [2]:
def get_filters(reader, tensor_key):
    filters = np.array(reader.get_tensor(tensor_key))
    filters_shape = np.shape(filters)
    print('Read filter {0}, {1}'.format(tensor_key, filters_shape))
    return filters

def trans_filters(kernels):
    # from 1x1xinxout to 3x3xinxout
    kernel_shape = np.shape(kernels)
    num_filters = kernel_shape[3]
    filters = []
    for filter_i in range(num_filters):
        ker = kernels[:,:,:,filter_i] # [1, 1, c]
        expand_ker = np.tile(ker, [3,3,1]) # [3,3,c]
        filters.append(expand_ker)
    pack_kernel = np.stack(filters, axis=-1) # [3,3,c,num_filters]
    return pack_kernel

In [3]:
def trans_init_filters(kernels):
    # from [7x7x3, 64] to [7x7x4, 64]
    shape_ker = np.shape(kernels)
    Cout = shape_ker[3]
    final_filter = np.array([0], np.float32)
    # loop for adding input channel filters
    for i in range(Cout):
        filt = kernels[:,:,:,i] # [7,7,3]
        add_filter = np.clip(np.random.normal(loc=0.0, scale=0.01, size=[7,7]), -0.01, 0.01) # [7,7]
        add_filter = add_filter[..., np.newaxis] # [7,7,1]
        
        new_filter = np.concatenate((filt, add_filter), -1) # [7,7,4]
        new_filter = new_filter[..., np.newaxis] # [7,7,4,1]

        if i == 0:
            final_filter = new_filter # [7,7,4,1]
        else:
            final_filter = np.concatenate((final_filter,new_filter),-1) # [7,7,4,Cout]
    # check size
    new_shape = np.shape(final_filter)
    if new_shape[0] != new_shape[1] or new_shape[0] != 7:
        raise ValueError('kernel size {} != 7'.format(new_shape[0]))
    if new_shape[2] != 4:
        raise ValueError('In channel {} != 4'.format(new_shape[2]))
    if new_shape[3] != 64:
        raise ValueError('Out channel {} != 64'.format(new_shape[3]))
    
    return final_filter.astype(np.float32)

In [4]:
reader = pywrap_tensorflow.NewCheckpointReader(imgnet_ckp)

################## Initial conv ##################
tmp_f = get_filters(reader, 'resnet_model/conv2d/kernel')
templar_f = trans_init_filters(tmp_f)
_ = tf.Variable(templar_f, name='temp_begin/kernel')
search_f = trans_init_filters(tmp_f)
_ = tf.Variable(search_f, name='search_begin/kernel')

################## C2 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_1/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/shortcut/kernel')
# 1st block, conv1
tmp_f = get_filters(reader, 'resnet_model/conv2d_2/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/kernel')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/beta')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/beta')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/gamma')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/gamma')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/moving_mean')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/moving_mean')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/moving_variance')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/moving_variance')
# conv2-3
for i in range(3,5):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/moving_variance')
# 2n block, conv1-3
for i in range(1,4):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i+4) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/moving_variance')
# 3n block, conv1-3
for i in range(1,4):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i+7) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/moving_variance')

################## C3 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_11/kernel')
tmp_f = trans_filters(tmp_f)
_ = tf.Variable(tmp_f,  name='backbone/C3/block1/shortcut/kernel')
# block 1-4
resconv_i = 12
resbn_i = 9
for block_i in range(1, 5):
    # conv1-3
    for conv_i in range(1,4):
        tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(resconv_i) + '/kernel')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/kernel')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/beta')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/beta')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/gamma')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/gamma')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_mean')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_mean')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_variance')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_variance')
        resconv_i += 1
        resbn_i += 1

################## C4 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_24/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C4/block1/shortcut/kernel')
# block 1-6
resconv_i = 25
resbn_i = 21
for block_i in range(1, 7):
    # conv1-3
    for conv_i in range(1,4):
        tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(resconv_i) + '/kernel')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/kernel')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/beta')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/beta')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/gamma')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/gamma')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_mean')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_mean')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_variance')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_variance')
        resconv_i += 1
        resbn_i += 1
        
print('Renamed kernels created.')
# save to .ckpt file
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess=sess, save_path=save_path, write_meta_graph=False)
    
print('Saved ckpt to {}'.format(save_path))

Read filter resnet_model/conv2d/kernel, (7, 7, 3, 64)
Read filter resnet_model/conv2d_1/kernel, (1, 1, 64, 256)
Read filter resnet_model/conv2d_2/kernel, (1, 1, 64, 64)
Read filter resnet_model/batch_normalization/beta, (64,)
Read filter resnet_model/batch_normalization/gamma, (64,)
Read filter resnet_model/batch_normalization/moving_mean, (64,)
Read filter resnet_model/batch_normalization/moving_variance, (64,)
Read filter resnet_model/conv2d_3/kernel, (3, 3, 64, 64)
Read filter resnet_model/batch_normalization_1/beta, (64,)
Read filter resnet_model/batch_normalization_1/gamma, (64,)
Read filter resnet_model/batch_normalization_1/moving_mean, (64,)
Read filter resnet_model/batch_normalization_1/moving_variance, (64,)
Read filter resnet_model/conv2d_4/kernel, (1, 1, 64, 256)
Read filter resnet_model/batch_normalization_2/beta, (64,)
Read filter resnet_model/batch_normalization_2/gamma, (64,)
Read filter resnet_model/batch_normalization_2/moving_mean, (64,)
Read filter resnet_model/batc

Read filter resnet_model/batch_normalization_25/gamma, (256,)
Read filter resnet_model/batch_normalization_25/moving_mean, (256,)
Read filter resnet_model/batch_normalization_25/moving_variance, (256,)
Read filter resnet_model/conv2d_30/kernel, (1, 1, 256, 1024)
Read filter resnet_model/batch_normalization_26/beta, (256,)
Read filter resnet_model/batch_normalization_26/gamma, (256,)
Read filter resnet_model/batch_normalization_26/moving_mean, (256,)
Read filter resnet_model/batch_normalization_26/moving_variance, (256,)
Read filter resnet_model/conv2d_31/kernel, (1, 1, 1024, 256)
Read filter resnet_model/batch_normalization_27/beta, (1024,)
Read filter resnet_model/batch_normalization_27/gamma, (1024,)
Read filter resnet_model/batch_normalization_27/moving_mean, (1024,)
Read filter resnet_model/batch_normalization_27/moving_variance, (1024,)
Read filter resnet_model/conv2d_32/kernel, (3, 3, 256, 256)
Read filter resnet_model/batch_normalization_28/beta, (256,)
Read filter resnet_model/