In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as chkp
from tensorflow.python import pywrap_tensorflow

imgnet_ckp = "../data/resnet_imagenet_v2_fp32_20181001/model.ckpt-225207"
save_path = '../data/resnet_v2_imagenet_transformed/resnet50_v2.ckpt'

In [None]:
#print('Original resnet-50-v2 weights: ')
#chkp.print_tensors_in_checkpoint_file(imgnet_ckp, tensor_name='', all_tensors=False)

In [None]:

def get_filters(reader, tensor_key):
    filters = np.array(reader.get_tensor(tensor_key))
    filters_shape = np.shape(filters)
    print('Read filter {0}, {1}'.format(tensor_key, filters_shape))
    return filters

In [None]:
reader = pywrap_tensorflow.NewCheckpointReader(imgnet_ckp)

################## Initial conv ##################
tmp_f = get_filters(reader, 'resnet_model/conv2d/kernel')
_ = tf.Variable(tmp_f, name='backbone/kernel')

################## C2 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_1/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/shortcut/kernel')
# 1st block, conv1
tmp_f = get_filters(reader, 'resnet_model/conv2d_2/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/kernel')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/beta')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/beta')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/gamma')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/gamma')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/moving_mean')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/moving_mean')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization/moving_variance')
_ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv1/moving_variance')
# conv2-3
for i in range(3,5):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i-2) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block1/conv' + str(i-1) + '/moving_variance')
# 2n block, conv1-3
for i in range(1,4):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i+4) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+2) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block2/conv' + str(i) + '/moving_variance')
# 3n block, conv1-3
for i in range(1,4):
    tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(i+7) + '/kernel')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/kernel')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/beta')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/beta')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/gamma')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/gamma')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/moving_mean')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/moving_mean')
    tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(i+5) + '/moving_variance')
    _ = tf.Variable(tmp_f,  name='backbone/C2/block3/conv' + str(i) + '/moving_variance')

################## C3 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_11/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C3/block1/shortcut/kernel')
# block 1-4
resconv_i = 12
resbn_i = 9
for block_i in range(1, 5):
    # conv1-3
    for conv_i in range(1,4):
        tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(resconv_i) + '/kernel')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/kernel')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/beta')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/beta')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/gamma')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/gamma')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_mean')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_mean')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_variance')
        _ = tf.Variable(tmp_f,  name='backbone/C3/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_variance')
        resconv_i += 1
        resbn_i += 1

################## C4 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_24/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C4/block1/shortcut/kernel')
# block 1-6
resconv_i = 25
resbn_i = 21
for block_i in range(1, 7):
    # conv1-3
    for conv_i in range(1,4):
        tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(resconv_i) + '/kernel')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/kernel')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/beta')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/beta')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/gamma')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/gamma')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_mean')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_mean')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_variance')
        _ = tf.Variable(tmp_f,  name='backbone/C4/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_variance')
        resconv_i += 1
        resbn_i += 1

################## C5 ###########################
# 1st block, shortcut conv
tmp_f = get_filters(reader, 'resnet_model/conv2d_43/kernel')
_ = tf.Variable(tmp_f,  name='backbone/C5/block1/shortcut/kernel')
# block 1-6
resconv_i = 44
resbn_i = 39
for block_i in range(1, 4):
    # conv1-3
    for conv_i in range(1,4):
        tmp_f = get_filters(reader, 'resnet_model/conv2d_' +  str(resconv_i) + '/kernel')
        _ = tf.Variable(tmp_f,  name='backbone/C5/block' + str(block_i) + '/conv' + str(conv_i) + '/kernel')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/beta')
        _ = tf.Variable(tmp_f,  name='backbone/C5/block' + str(block_i) + '/conv' + str(conv_i) + '/beta')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/gamma')
        _ = tf.Variable(tmp_f,  name='backbone/C5/block' + str(block_i) + '/conv' + str(conv_i) + '/gamma')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_mean')
        _ = tf.Variable(tmp_f,  name='backbone/C5/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_mean')
        tmp_f = get_filters(reader, 'resnet_model/batch_normalization_' + str(resbn_i) + '/moving_variance')
        _ = tf.Variable(tmp_f,  name='backbone/C5/block' + str(block_i) + '/conv' + str(conv_i) + '/moving_variance')
        resconv_i += 1
        resbn_i += 1

################## dense ###########################
tmp_f = get_filters(reader, 'resnet_model/batch_normalization_48/beta')
_ = tf.Variable(tmp_f,  name='backbone/tail/beta')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization_48/gamma')
_ = tf.Variable(tmp_f,  name='backbone/tail/gamma')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization_48/moving_mean')
_ = tf.Variable(tmp_f,  name='backbone/tail/moving_mean')
tmp_f = get_filters(reader, 'resnet_model/batch_normalization_48/moving_variance')
_ = tf.Variable(tmp_f,  name='backbone/tail/moving_variance')
tmp_f = get_filters(reader, 'resnet_model/dense/bias')
_ = tf.Variable(tmp_f,  name='backbone/tail/dense/bias')
tmp_f = get_filters(reader, 'resnet_model/dense/kernel')
_ = tf.Variable(tmp_f,  name='backbone/tail/dense/kernel')
        
print('Renamed kernels created.')
# save to .ckpt file
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess=sess, save_path=save_path, write_meta_graph=False)
    
print('Saved ckpt to {}'.format(save_path))

In [None]:
print('Renamed weights: ')
chkp.print_tensors_in_checkpoint_file('../data/resnet_v2_imagenet_transformed/resnet50_v2.ckpt', tensor_name='', all_tensors=False)
reader = pywrap_tensorflow.NewCheckpointReader('../data/resnet_v2_imagenet_transformed/resnet50_v2.ckpt')
nu_a = np.array(reader.get_tensor('backbone/C2/block1/conv1/kernel'))