In [5]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from collections import OrderedDict
from DeepVOG_torch import DeepVOG

In [60]:
def tf_conv2d_to_torch(h5_group, prefix=''):
    return OrderedDict({
        prefix+'weight': torch.Tensor(h5_group['kernel:0'][()].transpose((3, 2, 0, 1))),
        prefix+'bias': torch.Tensor(h5_group['bias:0'][()]),
    })

def tf_bn_to_torch(h5_group, prefix=''):
    return OrderedDict({
        prefix+'weight' : torch.Tensor(h5_group['gamma:0'][()]),
        prefix+'bias' : torch.Tensor(h5_group['beta:0'][()]),
        prefix+'running_mean' : torch.Tensor(h5_group['moving_mean:0'][()]),
        prefix+'running_var' : torch.Tensor(h5_group['moving_variance:0'][()]),
        prefix+'num_batches_tracked' : torch.Tensor(np.zeros(())).type(torch.int64),
    })

In [2]:
f = h5py.File('DeepVOG/deepvog/model/DeepVOG_weights.h5', 'r')

In [97]:
def print_name_shape(name, obj):
    if isinstance(obj, h5py._hl.dataset.Dataset):
        print(f'{name:60} {str(tuple(obj.shape)):15}')

In [98]:
f.visititems(print_name_shape)

bn_down1_down/bn_down1_down/beta:0                           (32,)          
bn_down1_down/bn_down1_down/gamma:0                          (32,)          
bn_down1_down/bn_down1_down/moving_mean:0                    (32,)          
bn_down1_down/bn_down1_down/moving_variance:0                (32,)          
bn_down1_main_1/bn_down1_main_1/beta:0                       (16,)          
bn_down1_main_1/bn_down1_main_1/gamma:0                      (16,)          
bn_down1_main_1/bn_down1_main_1/moving_mean:0                (16,)          
bn_down1_main_1/bn_down1_main_1/moving_variance:0            (16,)          
bn_down2_down/bn_down2_down/beta:0                           (64,)          
bn_down2_down/bn_down2_down/gamma:0                          (64,)          
bn_down2_down/bn_down2_down/moving_mean:0                    (64,)          
bn_down2_down/bn_down2_down/moving_variance:0                (64,)          
bn_down2_main_1/bn_down2_main_1/beta:0                       (32,)          

In [6]:
model = DeepVOG()

In [65]:
d = OrderedDict()
for k in range(model.layers_down):
    k1 = k+1
    d.update(tf_conv2d_to_torch(f[f'conv_down{k1}_main_1/conv_down{k1}_main_1'], f'down_{k1}.conv_main.'))
    d.update(tf_bn_to_torch(f[f'bn_down{k1}_main_1/bn_down{k1}_main_1'], f'down_{k1}.bn_main.'))
    d.update(tf_conv2d_to_torch(f[f'conv_down{k1}_down/conv_down{k1}_down'], f'down_{k1}.conv_down.'))
    d.update(tf_bn_to_torch(f[f'bn_down{k1}_down/bn_down{k1}_down'], f'down_{k1}.bn_down.'))
for k in range(model.layers_up):
    k1 = k+1
    d.update(tf_conv2d_to_torch(f[f'conv_up{k1}_main_1/conv_up{k1}_main_1'], f'up_{k1}.conv_main.'))
    d.update(tf_bn_to_torch(f[f'bn_up{k1}_main_1/bn_up{k1}_main_1'], f'up_{k1}.bn_main.'))
    if k1<5:
        d.update(tf_conv2d_to_torch(f[f'conv_up{k1}_up/conv_up{k1}_up'], f'up_{k1}.conv_up.'))
        d.update(tf_bn_to_torch(f[f'bn_up{k1}_up/bn_up{k1}_up'], f'up_{k1}.bn_up.'))
d.update(tf_conv2d_to_torch(f[f'conv_out/conv_out'], f'conv_out.'))

In [90]:
model.load_state_dict(d)

<All keys matched successfully>

In [91]:
torch.save(model.state_dict(), 'DeepVOG_weights.pt')

In [94]:
for key in d.keys():
    print(f'{key:40} {str(tuple(d[key].shape)):15}')

down_1.conv_main.weight                  (16, 3, 10, 10)
down_1.conv_main.bias                    (16,)          
down_1.bn_main.weight                    (16,)          
down_1.bn_main.bias                      (16,)          
down_1.bn_main.running_mean              (16,)          
down_1.bn_main.running_var               (16,)          
down_1.bn_main.num_batches_tracked       ()             
down_1.conv_down.weight                  (32, 16, 2, 2) 
down_1.conv_down.bias                    (32,)          
down_1.bn_down.weight                    (32,)          
down_1.bn_down.bias                      (32,)          
down_1.bn_down.running_mean              (32,)          
down_1.bn_down.running_var               (32,)          
down_1.bn_down.num_batches_tracked       ()             
down_2.conv_main.weight                  (32, 32, 10, 10)
down_2.conv_main.bias                    (32,)          
down_2.bn_main.weight                    (32,)          
down_2.bn_main.bias           