In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [175]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,       # picture size
                               out_channels=5,      # KernelSize in tf.js
                               kernel_size=3,       # filters in tf.js
                               bias=True)
        self.pool = nn.MaxPool2d(kernel_size=(2,2), 
                                 stride=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=5,
                               out_channels=5,
                               kernel_size=5,
                               bias=True)
        
        # self.fc1 = nn.Linear(5 * 5 * 3, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # x = F.softmax(self.fc1(x))
        return x


net = Net()

In [176]:
import json
import numpy as np
import pandas as pd

In [177]:
with open('test.json', 'r') as f:
    data = json.load(f)

## First layer

In [178]:
data[0]['model']

{'kept': False,
 'isDisposedInternal': False,
 'shape': [3, 3, 1, 5],
 'dtype': 'float32',
 'size': 45,
 'strides': [15, 5, 5],
 'dataId': {'id': 0},
 'id': 1,
 'rankType': '4',
 'trainable': True,
 'name': 'conv2d_Conv2D1/kernel'}

In [179]:
data[1]['model']

{'kept': False,
 'isDisposedInternal': False,
 'shape': [5],
 'dtype': 'float32',
 'size': 5,
 'strides': [],
 'dataId': {'id': 1},
 'id': 3,
 'rankType': '1',
 'trainable': True,
 'name': 'conv2d_Conv2D1/bias'}

In [180]:
len(data)

4

In [181]:
torch.tensor(pd.Series(data[2]['params']))

tensor([ 1.5882e-01,  8.1480e-02,  1.7400e-01,  9.5683e-03, -1.2598e-01,
        -5.7142e-02,  6.9495e-03,  7.7989e-02,  7.2675e-02, -4.6343e-02,
        -7.6225e-02,  6.8786e-02,  3.0958e-02,  1.9353e-02, -8.9495e-02,
        -1.9291e-02,  1.6426e-01,  9.9582e-02, -1.2810e-01,  1.8971e-02,
        -4.6431e-02, -5.8556e-02,  8.2508e-02, -5.7107e-02,  1.4571e-01,
         3.8204e-02,  4.8533e-02,  4.4650e-02,  1.8825e-02, -1.3652e-01,
        -9.4810e-02, -8.2480e-02,  7.4900e-02, -2.1169e-02,  3.0716e-02,
         9.6078e-02,  1.5537e-01,  7.8903e-02, -1.3837e-02,  1.4147e-03,
         1.7427e-01,  4.9713e-02, -8.2639e-03,  5.3467e-02,  7.4937e-03,
         3.8252e-02, -5.9227e-02,  6.8422e-02,  3.0127e-02, -5.4855e-02,
        -1.0487e-01, -2.4225e-02, -1.2204e-01,  1.0031e-01, -6.8431e-02,
        -6.5978e-02,  1.2994e-02, -1.9877e-04, -8.2085e-02,  4.1204e-03,
        -3.2709e-02, -1.6142e-01,  6.3192e-02, -1.8026e-02,  1.3425e-01,
        -1.5689e-01,  1.1534e-01,  3.6959e-02, -9.3

## Fill with params from tf model

In [182]:
shape_layer1 = net.state_dict()['conv1.weight'].shape

In [183]:
net.state_dict()['conv1.weight']

tensor([[[[-0.2665, -0.0781, -0.2384],
          [-0.0520, -0.3037,  0.2831],
          [ 0.2031, -0.1981, -0.2665]]],


        [[[-0.2047, -0.0028,  0.2103],
          [-0.0865,  0.1546,  0.2941],
          [ 0.1939,  0.1886, -0.2259]]],


        [[[-0.1681, -0.2650, -0.2748],
          [-0.2677, -0.0758, -0.3130],
          [ 0.0768,  0.1842,  0.1754]]],


        [[[ 0.1313, -0.0620,  0.1487],
          [ 0.0025,  0.0166, -0.0976],
          [-0.3277, -0.0744, -0.0361]]],


        [[[ 0.1045,  0.1454,  0.1887],
          [ 0.1511, -0.2129,  0.0377],
          [ 0.2243,  0.1870, -0.2181]]]])

In [184]:
tf_params_layer1 = torch.tensor(pd.Series(data[0]['params']))

In [185]:
tensor_input = torch.reshape(tf_params_layer1, shape_layer1)

In [186]:
list(net.state_dict().keys())

['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']

In [187]:
state_dict = net.state_dict()
state_dict['conv1.weight'] = tensor_input
net.load_state_dict(state_dict)

<All keys matched successfully>

In [188]:
tensor_input

tensor([[[[-0.0795,  0.3412, -0.0787],
          [-0.2766, -0.1191, -0.3615],
          [-0.0822, -0.5693,  0.5850]]],


        [[[-0.0629,  0.3145,  0.3474],
          [ 0.3322, -0.4905,  0.0166],
          [ 0.1015,  0.0281,  0.2332]]],


        [[[-0.6460, -0.4513,  0.3746],
          [-0.5687,  0.3402, -0.0768],
          [ 0.2446,  0.6551,  0.0488]]],


        [[[ 0.5565,  0.2835,  0.2720],
          [-0.3204,  0.3552,  0.3212],
          [-0.0397,  0.5283,  0.3701]]],


        [[[-0.1996,  0.1597,  0.3600],
          [ 0.2170, -0.1868,  0.1647],
          [ 0.3011, -0.4380, -0.5455]]]], dtype=torch.float64)

In [189]:
net.state_dict()['conv1.weight']

tensor([[[[-0.0795,  0.3412, -0.0787],
          [-0.2766, -0.1191, -0.3615],
          [-0.0822, -0.5693,  0.5850]]],


        [[[-0.0629,  0.3145,  0.3474],
          [ 0.3322, -0.4905,  0.0166],
          [ 0.1015,  0.0281,  0.2332]]],


        [[[-0.6460, -0.4513,  0.3746],
          [-0.5687,  0.3402, -0.0768],
          [ 0.2446,  0.6551,  0.0488]]],


        [[[ 0.5565,  0.2835,  0.2720],
          [-0.3204,  0.3552,  0.3212],
          [-0.0397,  0.5283,  0.3701]]],


        [[[-0.1996,  0.1597,  0.3600],
          [ 0.2170, -0.1868,  0.1647],
          [ 0.3011, -0.4380, -0.5455]]]])

In [199]:
def load_params_from_tf(py_model:nn.Module, tf_model:list):
    """
    Load and update the parameters from tensorflow.js to pytorch nn.Module

    Args:
        py_model: An nn.Moudule network structure from pytorch
        tf_module: A list read from JSON file which stored the meta data of tensorflow.js model 
                   (length is number of layers, and has two keys in each layer, 'model' and 'params' respectively)

    Returns:
        An updated nn.Module network structure

    Raises:
        Exception: Certain layer structure is not aligned
        KeyError: Model layer is not aligned
    """
    state_dict = py_model.state_dict()
    py_layers = list(state_dict.keys())
    tf_layers = [d['model']['name'] for d in tf_model]
    tf_params_dict = {d['model']['name'] : torch.tensor(pd.Series(d['params'])) for d in tf_model}
    py_nlayers = len(py_layers)
    tf_nlayers = len(tf_layers)
    if tf_nlayers == py_nlayers:
        try:
            for py_layer, tf_layer in zip(py_layers, tf_layers):
                layer_shape = state_dict[py_layer].shape
                params_in = tf_params_dict[tf_layer]
                params_in = torch.reshape(params_in, layer_shape)

                state_dict[py_layer] = params_in
            py_model.load_state_dict(state_dict)
            return py_model
        except:
            raise Exception(f"Sorry, model structure did not align in pytorch layer {py_layer}, and tensorflow.js layer {tf_layer}!")
    else:
        raise TypeError("The model structure of pytorch and tensorflow.js is not aligned! Cannot transfer parameters accordingly.")

In [195]:
new_net = load_params_from_tf(net, data)

In [197]:
net.state_dict()['conv1.weight']

tensor([[[[-0.0795,  0.3412, -0.0787],
          [-0.2766, -0.1191, -0.3615],
          [-0.0822, -0.5693,  0.5850]]],


        [[[-0.0629,  0.3145,  0.3474],
          [ 0.3322, -0.4905,  0.0166],
          [ 0.1015,  0.0281,  0.2332]]],


        [[[-0.6460, -0.4513,  0.3746],
          [-0.5687,  0.3402, -0.0768],
          [ 0.2446,  0.6551,  0.0488]]],


        [[[ 0.5565,  0.2835,  0.2720],
          [-0.3204,  0.3552,  0.3212],
          [-0.0397,  0.5283,  0.3701]]],


        [[[-0.1996,  0.1597,  0.3600],
          [ 0.2170, -0.1868,  0.1647],
          [ 0.3011, -0.4380, -0.5455]]]])

In [153]:
type(data)

list

In [158]:
[i['model']['name'] for i in data]

['conv2d_Conv2D1/kernel',
 'conv2d_Conv2D1/bias',
 'conv2d_Conv2D2/kernel',
 'conv2d_Conv2D2/bias']

In [160]:
{d['model']['name'] : torch.tensor(pd.Series(d['params'])) for d in data}

{'conv2d_Conv2D1/kernel': tensor([-0.0795,  0.3412, -0.0787, -0.2766, -0.1191, -0.3615, -0.0822, -0.5693,
          0.5850, -0.0629,  0.3145,  0.3474,  0.3322, -0.4905,  0.0166,  0.1015,
          0.0281,  0.2332, -0.6460, -0.4513,  0.3746, -0.5687,  0.3402, -0.0768,
          0.2446,  0.6551,  0.0488,  0.5565,  0.2835,  0.2720, -0.3204,  0.3552,
          0.3212, -0.0397,  0.5283,  0.3701, -0.1996,  0.1597,  0.3600,  0.2170,
         -0.1868,  0.1647,  0.3011, -0.4380, -0.5455], dtype=torch.float64),
 'conv2d_Conv2D1/bias': tensor([0, 0, 0, 0, 0]),
 'conv2d_Conv2D2/kernel': tensor([ 1.5882e-01,  8.1480e-02,  1.7400e-01,  9.5683e-03, -1.2598e-01,
         -5.7142e-02,  6.9495e-03,  7.7989e-02,  7.2675e-02, -4.6343e-02,
         -7.6225e-02,  6.8786e-02,  3.0958e-02,  1.9353e-02, -8.9495e-02,
         -1.9291e-02,  1.6426e-01,  9.9582e-02, -1.2810e-01,  1.8971e-02,
         -4.6431e-02, -5.8556e-02,  8.2508e-02, -5.7107e-02,  1.4571e-01,
          3.8204e-02,  4.8533e-02,  4.4650e-02,  

# Clean test

In [3]:
import torch
import torch.nn as nn
import json

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,       # picture size
                               out_channels=5,      # KernelSize in tf.js
                               kernel_size=3,       # filters in tf.js
                               bias=True)
        self.pool = nn.MaxPool2d(kernel_size=(2,2), 
                                 stride=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=5,
                               out_channels=5,
                               kernel_size=5,
                               bias=True)
        
        # self.fc1 = nn.Linear(5 * 5 * 3, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # x = F.softmax(self.fc1(x))
        return x


py_net = Net()

In [5]:
with open('test.json', 'r') as f:
    tf_net = json.load(f)

In [6]:
py_net.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.1679,  0.0853,  0.0358],
                        [-0.1297,  0.2871, -0.0509],
                        [ 0.0134, -0.1860,  0.1001]]],
              
              
                      [[[-0.0046,  0.2067, -0.3076],
                        [ 0.1874,  0.0051,  0.2192],
                        [-0.0480, -0.0578, -0.2316]]],
              
              
                      [[[ 0.3015, -0.1672, -0.0485],
                        [-0.0246,  0.2526,  0.1033],
                        [ 0.1496,  0.1407,  0.2045]]],
              
              
                      [[[-0.0268,  0.1738, -0.1870],
                        [-0.2139,  0.2267,  0.2509],
                        [-0.0293, -0.2276,  0.0493]]],
              
              
                      [[[-0.3029, -0.1292, -0.2790],
                        [ 0.3237,  0.0503, -0.0176],
                        [-0.1855,  0.2258, -0.2814]]]])),
             ('conv1.bias',
              

In [9]:
from utils import load_params_from_tf

In [10]:
load_params_from_tf(py_net, tf_net)

Net(
  (conv1): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1))
)

In [12]:
py_net.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.0795,  0.3412, -0.0787],
                        [-0.2766, -0.1191, -0.3615],
                        [-0.0822, -0.5693,  0.5850]]],
              
              
                      [[[-0.0629,  0.3145,  0.3474],
                        [ 0.3322, -0.4905,  0.0166],
                        [ 0.1015,  0.0281,  0.2332]]],
              
              
                      [[[-0.6460, -0.4513,  0.3746],
                        [-0.5687,  0.3402, -0.0768],
                        [ 0.2446,  0.6551,  0.0488]]],
              
              
                      [[[ 0.5565,  0.2835,  0.2720],
                        [-0.3204,  0.3552,  0.3212],
                        [-0.0397,  0.5283,  0.3701]]],
              
              
                      [[[-0.1996,  0.1597,  0.3600],
                        [ 0.2170, -0.1868,  0.1647],
                        [ 0.3011, -0.4380, -0.5455]]]])),
             ('conv1.bias', tensor([0., 0.