In [1]:
import tensorflow as tf
import numpy as np
%load_ext autoreload
%autoreload 2

## Create the Resnet

In [2]:
from kerod.model.backbone.resnet import ResNet50PytorchStyle
test = np.zeros((1, 800, 800, 3))
model = ResNet50PytorchStyle(input_shape=[None, None, 3], weights=None)


## Download the tensorpack weights

In [3]:
!wget http://models.tensorpack.com/ResNet/ImageNet-ResNet50.npz

--2020-05-22 12:30:49--  http://models.tensorpack.com/ResNet/ImageNet-ResNet50.npz
Resolving models.tensorpack.com (models.tensorpack.com)... 185.207.105.29
Connecting to models.tensorpack.com (models.tensorpack.com)|185.207.105.29|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 95179469 (91M) [application/octet-stream]
Saving to: ‘ImageNet-ResNet50.npz.2’


2020-05-22 12:30:54 (18.6 MB/s) - ‘ImageNet-ResNet50.npz.2’ saved [95179469/95179469]



## Load the weights

In [3]:
tensorpack_weights = {}
with np.load('ImageNet-ResNet50.npz') as data:
    for key in data.keys():
        tensorpack_weights[key] = data[key]

## Create the mapping dictionary

In [4]:
convert_to = {
    'conv1_conv/kernel:0': 'conv0/W:0',
    'conv1_bn/gamma:0': 'conv0/bn/gamma:0',
    'conv1_bn/beta:0': 'conv0/bn/beta:0',
    'conv1_bn/moving_mean:0': 'conv0/bn/mean/EMA:0',
    'conv1_bn/moving_variance:0': 'conv0/bn/variance/EMA:0',
    'moving_variance:0': 'variance/EMA:0',
    'moving_mean:0':'mean/EMA:0',
    'kernel:0':'W:0',
    'resnet50/': '', 
}

In [5]:
for var in model.variables:
    print(var.name)

conv1_conv/kernel:0
conv1_bn/gamma:0
conv1_bn/beta:0
conv1_bn/moving_mean:0
conv1_bn/moving_variance:0
resnet50/group0/block0/conv1/kernel:0
resnet50/group0/block0/conv1/bn/gamma:0
resnet50/group0/block0/conv1/bn/beta:0
resnet50/group0/block0/conv1/bn/moving_mean:0
resnet50/group0/block0/conv1/bn/moving_variance:0
resnet50/group0/block0/conv2/kernel:0
resnet50/group0/block0/conv2/bn/gamma:0
resnet50/group0/block0/conv2/bn/beta:0
resnet50/group0/block0/conv2/bn/moving_mean:0
resnet50/group0/block0/conv2/bn/moving_variance:0
resnet50/group0/block0/convshortcut/kernel:0
resnet50/group0/block0/conv3/kernel:0
resnet50/group0/block0/convshortcut/bn/gamma:0
resnet50/group0/block0/convshortcut/bn/beta:0
resnet50/group0/block0/convshortcut/bn/moving_mean:0
resnet50/group0/block0/convshortcut/bn/moving_variance:0
resnet50/group0/block0/conv3/bn/gamma:0
resnet50/group0/block0/conv3/bn/beta:0
resnet50/group0/block0/conv3/bn/moving_mean:0
resnet50/group0/block0/conv3/bn/moving_variance:0
resnet50/g

In [5]:
used = set()
for var in model.variables:
    name_var = var.name
    for key, value in convert_to.items():
        name_var = name_var.replace(key, value)
    if name_var in used:
        print('what')
    weights = tensorpack_weights[name_var]
    used.add(name_var)
    var.assign(weights)
    np.testing.assert_almost_equal(weights, var.numpy())


In [6]:
len(used)

265

In [12]:
model.save_weights('resnet50_tensorpack_conversion.h5')

In [13]:
model.load_weights('resnet50_tensorpack_conversion.h5')

In [14]:
import hashlib

def hash_file(fpath, algorithm='sha256', chunk_size=65535):
    """[Copyright Keras] Calculates a file sha256 or md5 hash.

    # Example
    ```python
        >>> from keras.utils.data_utils import _hash_file
        >>> _hash_file('/path/to/file.zip')
        'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
    ```

    Arguments:

    - *fpath*: path to the file being validated
    - *algorithm*: hash algorithm, one of 'auto', 'sha256', or 'md5'.
           The default 'auto' detects the hash algorithm in use.
    - *chunk_size*: Bytes to read at a time, important for large files.

    Returns:
    The file hash
    """
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
        hasher = hashlib.sha256()
    else:
        hasher = hashlib.md5()

    with open(fpath, 'rb') as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
            hasher.update(chunk)

    return hasher.hexdigest()



In [15]:
hash_file('resnet50_tensorpack_conversion.h5', algorithm='md5')

'3ffd584081cc56435a3689d12afd7cf9'