In [None]:
import sys
import numpy as np
from PIL import Image
import tensorflow as tf
import torch
import torch.nn as nn

# https://github.com/tensorflow/models
sys.path.append('/home/dan/work/models/research/slim/nets/')
from mobilenet_v1 import mobilenet_v1, mobilenet_v1_arg_scope, slim

sys.path.append('../')
from detector.backbone import MobileNet

# Get an image

In [None]:
image = Image.open('dog.jpg').resize((224, 224))
image = np.expand_dims(np.array(image), 0)
image = (image/255.0).astype('float32')

# Extract weights from tensorflow

In [None]:
tf.reset_default_graph()

inputs = tf.constant(image)
scope = mobilenet_v1_arg_scope(is_training=False, weight_decay=0.0)
with slim.arg_scope(scope):
    logits, _ = mobilenet_v1(2.0 * inputs - 1.0, num_classes=1001, is_training=False)
    outputs = tf.nn.softmax(logits, axis=1)[0]
    
weights = {v.name[:-2]: v for v in tf.global_variables()}
saver = tf.train.Saver()


with tf.Session() as sess:
    saver.restore(sess, '../pretrained/mobilenet_v1_1.0_224.ckpt')
    tf_outputs, weights = sess.run([outputs, weights])

weights = {n: torch.FloatTensor(w) for n, w in weights.items()}

# Create mapping between the names of weights

In [None]:
batch_norms = {'beginning.1': 'MobilenetV1/Conv2d_0/BatchNorm/'}
for i in range(1, 14):
    batch_norms.update({
        f'layers.{i - 1}.layers.1': f'MobilenetV1/Conv2d_{i}_depthwise/BatchNorm/',
        f'layers.{i - 1}.layers.4': f'MobilenetV1/Conv2d_{i}_pointwise/BatchNorm/'
    })
        
mapping = {
    'beginning.0.weight': 'MobilenetV1/Conv2d_0/weights',
    'beginning.1.weight': 'MobilenetV1/Conv2d_0/BatchNorm/gamma',
    'beginning.1.bias': 'MobilenetV1/Conv2d_0/BatchNorm/beta',
    'layers.0.layers.0.weight': 'MobilenetV1/Conv2d_1_depthwise/depthwise_weights',
    'layers.1.layers.0.weight': 'MobilenetV1/Conv2d_2_depthwise/depthwise_weights',
    'layers.2.layers.0.weight': 'MobilenetV1/Conv2d_3_depthwise/depthwise_weights',
    'layers.3.layers.0.weight': 'MobilenetV1/Conv2d_4_depthwise/depthwise_weights',
    'layers.4.layers.0.weight': 'MobilenetV1/Conv2d_5_depthwise/depthwise_weights',
    'layers.5.layers.0.weight': 'MobilenetV1/Conv2d_6_depthwise/depthwise_weights',
    'layers.6.layers.0.weight': 'MobilenetV1/Conv2d_7_depthwise/depthwise_weights',
    'layers.7.layers.0.weight': 'MobilenetV1/Conv2d_8_depthwise/depthwise_weights',
    'layers.8.layers.0.weight': 'MobilenetV1/Conv2d_9_depthwise/depthwise_weights',
    'layers.9.layers.0.weight': 'MobilenetV1/Conv2d_10_depthwise/depthwise_weights',
    'layers.10.layers.0.weight': 'MobilenetV1/Conv2d_11_depthwise/depthwise_weights',
    'layers.11.layers.0.weight': 'MobilenetV1/Conv2d_12_depthwise/depthwise_weights',
    'layers.12.layers.0.weight': 'MobilenetV1/Conv2d_13_depthwise/depthwise_weights',
    'layers.0.layers.1.weight': 'MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma',
    'layers.1.layers.1.weight': 'MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma',
    'layers.2.layers.1.weight': 'MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma',
    'layers.3.layers.1.weight': 'MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma',
    'layers.4.layers.1.weight': 'MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma',
    'layers.5.layers.1.weight': 'MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma',
    'layers.6.layers.1.weight': 'MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma',
    'layers.7.layers.1.weight': 'MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma',
    'layers.8.layers.1.weight': 'MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma',
    'layers.9.layers.1.weight': 'MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma',
    'layers.10.layers.1.weight': 'MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma',
    'layers.11.layers.1.weight': 'MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma',
    'layers.12.layers.1.weight': 'MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma',
    'layers.0.layers.1.bias': 'MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta',
    'layers.1.layers.1.bias': 'MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta',
    'layers.2.layers.1.bias': 'MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta',
    'layers.3.layers.1.bias': 'MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta',
    'layers.4.layers.1.bias': 'MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta',
    'layers.5.layers.1.bias': 'MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta',
    'layers.6.layers.1.bias': 'MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta',
    'layers.7.layers.1.bias': 'MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta',
    'layers.8.layers.1.bias': 'MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta',
    'layers.9.layers.1.bias': 'MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta',
    'layers.10.layers.1.bias': 'MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta',
    'layers.11.layers.1.bias': 'MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta',
    'layers.12.layers.1.bias': 'MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta',
    'layers.0.layers.3.weight': 'MobilenetV1/Conv2d_1_pointwise/weights',
    'layers.1.layers.3.weight': 'MobilenetV1/Conv2d_2_pointwise/weights',
    'layers.2.layers.3.weight': 'MobilenetV1/Conv2d_3_pointwise/weights',
    'layers.3.layers.3.weight': 'MobilenetV1/Conv2d_4_pointwise/weights',
    'layers.4.layers.3.weight': 'MobilenetV1/Conv2d_5_pointwise/weights',
    'layers.5.layers.3.weight': 'MobilenetV1/Conv2d_6_pointwise/weights',
    'layers.6.layers.3.weight': 'MobilenetV1/Conv2d_7_pointwise/weights',
    'layers.7.layers.3.weight': 'MobilenetV1/Conv2d_8_pointwise/weights',
    'layers.8.layers.3.weight': 'MobilenetV1/Conv2d_9_pointwise/weights',
    'layers.9.layers.3.weight': 'MobilenetV1/Conv2d_10_pointwise/weights',
    'layers.10.layers.3.weight': 'MobilenetV1/Conv2d_11_pointwise/weights',
    'layers.11.layers.3.weight': 'MobilenetV1/Conv2d_12_pointwise/weights',
    'layers.12.layers.3.weight': 'MobilenetV1/Conv2d_13_pointwise/weights',
    'layers.0.layers.4.weight': 'MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma',
    'layers.1.layers.4.weight': 'MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma',
    'layers.2.layers.4.weight': 'MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma',
    'layers.3.layers.4.weight': 'MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma',
    'layers.4.layers.4.weight': 'MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma',
    'layers.5.layers.4.weight': 'MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma',
    'layers.6.layers.4.weight': 'MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma',
    'layers.7.layers.4.weight': 'MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma',
    'layers.8.layers.4.weight': 'MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma',
    'layers.9.layers.4.weight': 'MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma',
    'layers.10.layers.4.weight': 'MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma',
    'layers.11.layers.4.weight': 'MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma',
    'layers.12.layers.4.weight': 'MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma',
    'layers.0.layers.4.bias': 'MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta',
    'layers.1.layers.4.bias': 'MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta',
    'layers.2.layers.4.bias': 'MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta',
    'layers.3.layers.4.bias': 'MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta',
    'layers.4.layers.4.bias': 'MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta',
    'layers.5.layers.4.bias': 'MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta',
    'layers.6.layers.4.bias': 'MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta',
    'layers.7.layers.4.bias': 'MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta',
    'layers.8.layers.4.bias': 'MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta',
    'layers.9.layers.4.bias': 'MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta',
    'layers.10.layers.4.bias': 'MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta',
    'layers.11.layers.4.bias': 'MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta',
    'layers.12.layers.4.bias': 'MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta'
}

# Create a pytorch model and transfer the weights

In [None]:
net = MobileNet()
net.eval()

for n, m in net.named_modules():
    if n in batch_norms:
        assert isinstance(m, nn.BatchNorm2d)
        m.running_mean = weights[batch_norms[n] + 'moving_mean']
        m.running_var = weights[batch_norms[n] + 'moving_variance']

for n, p in net.named_parameters():
    w = weights[mapping[n]]
    if len(w.shape) == 4:
        if n.endswith('.layers.0.weight'):
            w = w.permute(2, 3, 0, 1).contiguous()
        else:
            w = w.permute(3, 2, 0, 1).contiguous()
    p.data = w

In [None]:
classifier = nn.Sequential(
    nn.AvgPool2d(7),
    nn.Conv2d(1024, 1001, 1),
    nn.Softmax(dim=1)
)

classifier[1].weight.data = weights['MobilenetV1/Logits/Conv2d_1c_1x1/weights'].permute(3, 2, 0, 1).contiguous()
classifier[1].bias.data = weights['MobilenetV1/Logits/Conv2d_1c_1x1/biases']

# Compare predictions

In [None]:
features = net(torch.FloatTensor(image).permute(0, 3, 1, 2))['c5']
torch_outputs = classifier(features).squeeze().detach().numpy()

In [None]:
# top prediction
print(torch_outputs.argmax(), tf_outputs.argmax())
print(torch_outputs.max(), tf_outputs.max())

In [None]:
# most probable classes
print(np.argsort(tf_outputs)[-10:])
print(np.argsort(torch_outputs)[-10:])

In [None]:
(np.abs(torch_outputs - tf_outputs) < 1e-1).all()

# Save

In [None]:
torch.save(net.state_dict(), '../pretrained/mobilenet.pth')