In [156]:
%load_ext autoreload
%autoreload 2

import torch
import tensorflow as tf
from mobnetv1 import MobileNetV1
import numpy as np
import cv2
from torch import nn


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def load_graph_def(frozen_graph_filename):
    graph_def=None
    with tf.io.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')


model_files = {
    'model_weights/agre.pb': 'agr'
}

load_graph_def('model_weights/agre.pb')

constant_values = {}

with tf.compat.v1.Session() as sess:
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
    for constant_op in constant_ops:
        constant_values[constant_op.name] = sess.run(constant_op.outputs[0])




In [3]:
with tf.Graph().as_default() as full_graph:
    load_graph_def('model_weights/agre.pb')

In [4]:
def get_image(file):
    x = cv2.imread(file)
    x = x.astype(float)
    x[..., 0] -= 103.939
    x[..., 1] -= 116.779
    x[..., 2] -= 123.68
    x = np.expand_dims(x, axis=0)
    return x

In [5]:
def get_layer(layer_name):
    sess = tf.compat.v1.Session(graph=full_graph)
    f_out = full_graph.get_tensor_by_name(layer_name)
    in_img = full_graph.get_tensor_by_name('input_1:0')
    x = get_image('data/0.jpg')
    f = sess.run([f_out], feed_dict={in_img: x})[0]
    sess.close()
    return f

In [6]:
sess = tf.compat.v1.Session(graph=full_graph)
features_out = full_graph.get_tensor_by_name('global_pooling/Mean:0')
age_out = full_graph.get_tensor_by_name('age_pred/Softmax:0')
gender_out = full_graph.get_tensor_by_name('gender_pred/Sigmoid:0')
in_img = full_graph.get_tensor_by_name('input_1:0')
x = get_image('data/0.jpg')
features, age, gender = sess.run([features_out, age_out, gender_out], feed_dict={in_img: x})
sess.close()


In [161]:
def get_age(age_score: np.ndarray) -> int:
    age_score = age_score[0]
    idx_sorted = age_score.argsort()[::-1]
    age_score_sorted = age_score[idx_sorted]
    return int(((idx_sorted + 0.5) * age_score_sorted).sum())

In [48]:
constant_values['conv_dw_1/depthwise_kernel'].shape

(3, 3, 32, 1)

In [47]:
state_dict['model.1.0.weight'].shape

torch.Size([32, 1, 3, 3])

In [123]:
def load_bn(model, state_dict, sd_keys, constant_values, cv_keys):
    torch_tf_params = {
        'weight': 'gamma',
        'bias': 'beta',
        'running_mean': 'moving_mean',
        'running_var': 'moving_variance',
    }

    for sd_k in sd_keys:
        _, idx1, idx2, label = sd_k.split('.')
        idx1, idx2 = int(idx1), int(idx2)
        if type(model.model[idx1][idx2]) == nn.BatchNorm2d:
            if label == 'num_batches_tracked':
                continue
            if idx1 == 0:
                tf_label = torch_tf_params[label]
                state_dict[sd_k] = torch.from_numpy(constant_values[f'conv1_bn/{tf_label}'])
            elif idx2 == 2 or idx2 == 5:
                tf_label = torch_tf_params[label]
                conv_label = 'dw' if idx2 == 2 else 'pw'
                state_dict[sd_k] = torch.from_numpy(constant_values[f'conv_{conv_label}_{idx1}_bn/{tf_label}'])
            else:
                print('Error BN')
                break

def load_conv(model, state_dict, sd_keys, constant_values, cv_keys):
    for sd_k in sd_keys:
        _, idx1, idx2, label = sd_k.split('.')
        idx1, idx2 = int(idx1), int(idx2)
        if type(model.model[idx1][idx2]) == nn.Conv2d:
            if idx1 == 0:
                state_dict[sd_k] = torch.from_numpy(np.transpose(
                        constant_values['conv1/kernel'], (3, 2, 0, 1)
                    ))
            elif idx2 == 1 or idx2 == 4:
                if idx2 == 1:
                    cv_key = f'conv_dw_{idx1}/depthwise_kernel'
                    state_dict[sd_k] = torch.from_numpy(np.transpose(
                        constant_values[cv_key], (2, 3, 0, 1)
                    ))
                else:
                    cv_key = f'conv_pw_{idx1}/kernel'
                    state_dict[sd_k] = torch.from_numpy(np.transpose(
                        constant_values[cv_key], (3, 2, 0, 1)
                    ))

            else:
                print('Error CV')
                break

def load_model_backbone(model, state_dict, constant_values):
    sd_keys = list(state_dict.keys())[:-8]
    cv_keys = list(constant_values.keys())[:-9]
    load_bn(model, state_dict, sd_keys, constant_values, cv_keys)
    load_conv(model, state_dict, sd_keys, constant_values, cv_keys)

In [157]:
model = MobileNetV1(3, 7)
state_dict = model.state_dict()
sd_keys = list(state_dict.keys())[:-8]
load_model_backbone(model, state_dict, constant_values)
model.load_state_dict(state_dict)
sd_new = model.state_dict()
sd_new['fc.1.weight']     = torch.from_numpy(np.transpose(constant_values['feats/kernel']))
sd_new['fc.1.bias']       = torch.from_numpy(constant_values['feats/bias'])
sd_new['age.0.weight']    = torch.from_numpy(np.transpose(constant_values['age_pred/kernel']))
sd_new['age.0.bias']      = torch.from_numpy(constant_values['age_pred/bias'])
sd_new['gender.0.weight'] = torch.from_numpy(np.transpose(constant_values['gender_pred/kernel']))
sd_new['gender.0.bias']   = torch.from_numpy(constant_values['gender_pred/bias'])
sd_new['race.0.weight']   = torch.from_numpy(np.transpose(constant_values['ethnicity_pred/kernel']))
sd_new['race.0.bias']     = torch.from_numpy(constant_values['ethnicity_pred/bias'])
model.load_state_dict(sd_new)
model.eval()

MobileNetV1(
  (model): Sequential(
    (0): Sequential(
      (0): ZeroPad2d((0, 1, 0, 1))
      (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (2): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (3): ReLU6(inplace=True)
    )
    (1): Sequential(
      (0): ZeroPad2d((1, 1, 1, 1))
      (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      (2): BatchNorm2d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (3): ReLU6(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (5): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (6): ReLU6(inplace=True)
    )
    (2): Sequential(
      (0): ZeroPad2d((0, 1, 0, 1))
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), groups=64, bias=False)
      (2): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (3):

In [143]:
constant_values['feats/bias'].shape

(256,)

In [145]:
model.fc[1].bias.shape

torch.Size([256])

In [60]:
x_tensor = torch.from_numpy(x).permute(0, 3, 1, 2).float()

In [61]:
constant_values.keys()

dict_keys(['conv1_pad/Pad/paddings', 'conv1/kernel', 'conv1_bn/gamma', 'conv1_bn/beta', 'conv1_bn/moving_mean', 'conv1_bn/moving_variance', 'conv1_bn/moving_mean/biased', 'conv1_bn/moving_mean/local_step', 'conv1_bn/moving_variance/biased', 'conv1_bn/moving_variance/local_step', 'conv_dw_1/depthwise_kernel', 'conv_dw_1_bn/gamma', 'conv_dw_1_bn/beta', 'conv_dw_1_bn/moving_mean', 'conv_dw_1_bn/moving_variance', 'conv_dw_1_bn/moving_mean/biased', 'conv_dw_1_bn/moving_mean/local_step', 'conv_dw_1_bn/moving_variance/biased', 'conv_dw_1_bn/moving_variance/local_step', 'conv_pw_1/kernel', 'conv_pw_1_bn/gamma', 'conv_pw_1_bn/beta', 'conv_pw_1_bn/moving_mean', 'conv_pw_1_bn/moving_variance', 'conv_pw_1_bn/moving_mean/biased', 'conv_pw_1_bn/moving_mean/local_step', 'conv_pw_1_bn/moving_variance/biased', 'conv_pw_1_bn/moving_variance/local_step', 'conv_pad_2/Pad/paddings', 'conv_dw_2/depthwise_kernel', 'conv_dw_2_bn/gamma', 'conv_dw_2_bn/beta', 'conv_dw_2_bn/moving_mean', 'conv_dw_2_bn/moving_var

In [107]:
f1_tf = get_layer('conv_dw_2/depthwise:0')
model = model.eval()
with torch.no_grad():
    jj = 2
    ii = 1
    f1_pt = x_tensor
    for j in range(jj + 1):
        for i in range(len(model.model[j]) if j < jj else ii + 1):
            f1_pt = model.model[j][i](f1_pt)
            print(f1_pt.shape)
    t = np.allclose(
        f1_pt.numpy(), 
        np.transpose(f1_tf, (0, 3, 1, 2)),
        rtol=0, 
        atol=1e-3
    )
    if not t:
        print('ERROR')
    else:
        print('EQUAL')

torch.Size([1, 3, 225, 225])
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 114, 114])
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 64, 113, 113])
torch.Size([1, 64, 56, 56])
EQUAL


In [94]:
model.model[jj][ii]#.training

Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)

In [54]:
np.allclose(f1_pt.numpy(), np.transpose(f1_tf, (0, 3, 1, 2)), rtol=0, atol=1e-6)

True

In [105]:
f1_pt.numpy()

array([[[[-4.69197303e-01, -4.19004774e+00, -2.99021244e+00, ...,
          -4.14173088e+01, -9.72289562e+00, -4.47562456e+00],
         [ 0.00000000e+00, -4.45627403e+00, -8.65003014e+00, ...,
          -3.83823776e+01, -2.18846111e+01,  0.00000000e+00],
         [-1.94530821e+00, -1.76159871e+00, -3.08182240e+00, ...,
          -2.31840973e+01, -3.61757317e+01, -7.18732452e+00],
         ...,
         [-3.50750709e+00, -1.38862028e+01, -1.38362713e+01, ...,
          -1.58023643e+00, -1.01443839e+00, -8.83172393e-01],
         [-2.56322384e+01, -1.93676567e+01, -1.13890524e+01, ...,
          -2.15329337e+00, -9.83235240e-01, -2.08148861e+00],
         [-1.03327265e+01, -3.81133842e+01, -1.33302546e+01, ...,
          -9.58585203e-01, -9.91503358e-01,  0.00000000e+00]],

        [[-2.92701721e+01, -2.65249634e+01, -2.60478077e+01, ...,
          -2.61965237e+01, -2.61496181e+01, -2.22151489e+01],
         [-2.65257416e+01, -2.42798977e+01, -2.37607441e+01, ...,
          -2.36768036e

In [106]:
np.transpose(f1_tf, (0, 3, 1, 2))

array([[[[-4.6919739e-01, -4.1900697e+00, -2.9902081e+00, ...,
          -4.1417309e+01, -9.7228985e+00, -4.4756184e+00],
         [ 0.0000000e+00, -4.4562550e+00, -8.6500664e+00, ...,
          -3.8382397e+01, -2.1884626e+01,  0.0000000e+00],
         [-1.9453217e+00, -1.7616110e+00, -3.0818253e+00, ...,
          -2.3184137e+01, -3.6175713e+01, -7.1873274e+00],
         ...,
         [-3.5075128e+00, -1.3886192e+01, -1.3836237e+01, ...,
          -1.5802236e+00, -1.0144274e+00, -8.8317722e-01],
         [-2.5632265e+01, -1.9367691e+01, -1.1389025e+01, ...,
          -2.1533000e+00, -9.8324615e-01, -2.0815136e+00],
         [-1.0332725e+01, -3.8113403e+01, -1.3330258e+01, ...,
          -9.5858854e-01, -9.9151397e-01,  0.0000000e+00]],

        [[-2.9270166e+01, -2.6524956e+01, -2.6047808e+01, ...,
          -2.6196518e+01, -2.6149616e+01, -2.2215147e+01],
         [-2.6525738e+01, -2.4279896e+01, -2.3760742e+01, ...,
          -2.3676804e+01, -2.3374460e+01, -2.0582050e+01],
        

In [36]:
f1_np = f1_pt.permute(0, 2, 3, 1).numpy()
f1_np

array([[[[ 0.00000000e+00,  3.46460985e-03,  1.18973923e+00, ...,
          -1.70606792e+00,  2.94960767e-01,  4.65372950e-01],
         [ 0.00000000e+00,  1.11057386e-01,  9.13528085e-01, ...,
          -1.12143290e+00, -1.05854440e+00, -1.12338027e-03],
         [ 0.00000000e+00,  1.05382450e-01,  9.27623808e-01, ...,
          -1.12258232e+00, -7.29063094e-01, -7.53029808e-02],
         ...,
         [ 0.00000000e+00,  1.05851233e-01,  9.25442338e-01, ...,
          -1.11579859e+00, -7.29926705e-01, -7.81948864e-02],
         [ 0.00000000e+00,  1.20268419e-01,  9.38564062e-01, ...,
          -1.11409938e+00, -7.19701111e-01, -9.41056013e-02],
         [ 0.00000000e+00,  1.24741822e-01,  4.72871274e-01, ...,
          -9.58056688e-01, -5.03755212e-01, -7.70508170e-01]],

        [[ 0.00000000e+00,  2.74902750e-02,  1.47244239e+00, ...,
          -1.02367294e+00,  5.78671873e-01,  3.13187271e-01],
         [ 0.00000000e+00,  2.20121872e-02,  9.41969931e-01, ...,
          -9.17371035e

In [17]:
x.shape

(1, 224, 224, 3)

In [158]:
with torch.no_grad():
    y = model(x_tensor)

In [162]:
get_age(y[0].numpy())

14

In [160]:
y

(tensor([[2.9300e-02, 2.5816e-04, 2.2002e-04, 1.2720e-03, 9.9148e-02, 9.0066e-03,
          1.5967e-02, 3.2498e-02, 3.1650e-02, 3.7275e-01, 3.4325e-02, 2.0501e-02,
          1.8305e-02, 2.3174e-02, 1.4434e-02, 8.9328e-03, 6.0240e-02, 9.7695e-03,
          7.1496e-03, 9.1742e-03, 1.0052e-02, 1.1716e-02, 9.7114e-03, 1.0830e-02,
          7.9529e-03, 7.1245e-03, 7.8654e-03, 5.0926e-03, 7.0012e-03, 4.2080e-02,
          5.6640e-03, 5.3062e-03, 4.5887e-03, 5.7502e-03, 3.5740e-03, 2.8284e-03,
          3.5926e-03, 2.1025e-03, 2.7359e-03, 1.5336e-02, 2.5231e-03, 2.0922e-03,
          1.1886e-03, 1.5709e-03, 1.4799e-03, 1.4977e-03, 1.1616e-03, 1.2737e-03,
          7.8395e-04, 3.2415e-03, 7.6875e-04, 8.2468e-04, 6.2688e-04, 1.0098e-03,
          5.9829e-04, 5.6284e-04, 4.9292e-04, 6.9037e-04, 2.8905e-04, 3.9989e-04,
          5.2712e-04, 3.9842e-04, 4.0445e-04, 3.7703e-04, 7.3918e-04, 2.7997e-04,
          1.7572e-04, 2.9022e-04, 2.5982e-04, 9.9037e-04, 1.9045e-04, 3.7280e-04,
          1.5591

In [126]:
y.squeeze()#.mean()

tensor([0.0534, 0.0000, 0.6324,  ..., 0.2799, 0.0330, 0.0000])

In [111]:
torch.save(model.state_dict(), 'model_weights/backbone_state_dict.pth')

In [112]:
model.load_state_dict(torch.load('model_weights/backbone_state_dict.pth'))

<All keys matched successfully>

In [163]:
torch.save(model.state_dict(), 'model_weights/agre.pth')

In [110]:
np.array(features)

array([[0.05342963, 0.        , 0.6324378 , ..., 0.27987874, 0.03298753,
        0.        ]], dtype=float32)

In [22]:
sd_keys

['model.0.0.weight',
 'model.0.1.weight',
 'model.0.1.bias',
 'model.0.1.running_mean',
 'model.0.1.running_var',
 'model.0.1.num_batches_tracked',
 'model.1.0.weight',
 'model.1.1.weight',
 'model.1.1.bias',
 'model.1.1.running_mean',
 'model.1.1.running_var',
 'model.1.1.num_batches_tracked',
 'model.1.3.weight',
 'model.1.4.weight',
 'model.1.4.bias',
 'model.1.4.running_mean',
 'model.1.4.running_var',
 'model.1.4.num_batches_tracked',
 'model.2.0.weight',
 'model.2.1.weight',
 'model.2.1.bias',
 'model.2.1.running_mean',
 'model.2.1.running_var',
 'model.2.1.num_batches_tracked',
 'model.2.3.weight',
 'model.2.4.weight',
 'model.2.4.bias',
 'model.2.4.running_mean',
 'model.2.4.running_var',
 'model.2.4.num_batches_tracked',
 'model.3.0.weight',
 'model.3.1.weight',
 'model.3.1.bias',
 'model.3.1.running_mean',
 'model.3.1.running_var',
 'model.3.1.num_batches_tracked',
 'model.3.3.weight',
 'model.3.4.weight',
 'model.3.4.bias',
 'model.3.4.running_mean',
 'model.3.4.running_var'