In [86]:
from model import UNet
import mindspore
from keras.models import load_model
import h5py
import numpy as np
from mindspore import Tensor, save_checkpoint,load_checkpoint,Model,load_param_into_net
import os
import cv2
from MyAcc import PixelAcc
from mindspore import nn
from dataloader import load_train_data

##### 通过MindSpore的Cell，打印Cell里所有参数的参数名和shape，返回参数字典

In [13]:
def mindspore_params(network):
    """Get MindSpore parameter and shape"""
    ms_params = {}
    name_set = []
    for param in network.get_parameters():
        name = param.name
        name_set.append(name)
        value = param.data.asnumpy()
        print(name, value.shape)
        ms_params[name] = value
    return name_set

##### 对于从hdf5文件导入的模型中，仅保留layer.weights不为空的层，过滤掉无学习参数的层

In [3]:
def load_weights_from_hdf5_group(f, layers, reshape=False):

    filtered_layers = []
    for layer in layers:
        weights = model.get_layer(layer).get_weights()
        if weights:
            filtered_layers.append(layer)
    return filtered_layers

##### 对卷积权重进行转置
MindSpore的卷积层中weight的shape为[out_channel, in_channel, kernel_height, kernel_weight] 而TensorFlow卷积层的weight 的shape为[kernel_height, kernel_weight, in_channel, out_channel] 因此需要进行转置

In [55]:
def hdf5_2_mindspore(h5_model,h5_name_list,ms_name_list, ms_ckpt_path):

    new_params_list = []
    for i in range (0,len(h5_name_list)):
        param_dict = {}
        weight,bias  = h5_model.get_layer(h5_name_list[i]).get_weights()
        parameter = np.transpose(weight, axes=[3, 2, 0, 1])
        param_dict['name'] = ms_name_list[i]
        param_dict['data'] = Tensor(parameter)
        new_params_list.append(param_dict)
    save_checkpoint(new_params_list, os.path.join(ms_ckpt_path, 'hdf2mindspore.ckpt'))

##### 导入模型查看文件

In [None]:
model = load_model('unet.hdf5')
print(model.summary())

In [None]:
f = h5py.File('unet.hdf5')					#打开h5文件
#for key in f.keys():					#查看内部的键
#     print(key)
#f['model_weights'].attrs.keys()			#查看键的属性
#f['model_weights'].attrs['layer_names']	#查看层的名称
layer_names = [n.decode('utf8') for n in f['model_weights'].attrs['layer_names']]
# print(layer_names)
layer_with_weight = load_weights_from_hdf5_group(f,layer_names)
for name in layer_with_weight:
    print(name)
    weight,bias_=model.get_layer(name).get_weights()
    print(weight.shape)

In [None]:
network = UNet()
network_param_ms = mindspore_params(network)
print(network_param_ms)

##### 由于网络名称较多，这里需要手动对应

In [None]:
keras_name_list = ['conv2d_49', 'conv2d_50', 'conv2d_51', 'conv2d_52', 'conv2d_53', 'conv2d_54', 'conv2d_55', 'conv2d_56', 'conv2d_57', 'conv2d_58', 'conv2d_59', 'conv2d_60', 'conv2d_61', 'conv2d_62', 'conv2d_63', 'conv2d_64', 'conv2d_65', 'conv2d_66', 'conv2d_67', 'conv2d_68', 'conv2d_69', 'conv2d_70', 'conv2d_71', 'conv2d_72']
mindspore_name_list = ['downsample1.0.weight', 'downsample1.2.weight', 'downsample2.0.weight', 'downsample2.2.weight', 'downsample3.0.weight', 'downsample3.2.weight', 'downsample4.0.weight', 'downsample4.2.weight', 'downsample5.0.weight', 'downsample5.2.weight', 'upconv1.0.weight', 'upsample1.0.weight', 'upsample1.2.weight', 'upconv2.0.weight', 'upsample2.0.weight', 'upsample2.2.weight','upconv3.0.weight', 'upsample3.0.weight', 'upsample3.2.weight', 'upconv4.0.weight', 'upsample4.0.weight', 'upsample4.2.weight', 'outconv.weight', 'sigmoid_conv.0.weight']
# print(len(keras_name_list))
# print(len(mindspore_name_list))

In [62]:
hdf5_2_mindspore(model,keras_name_list,mindspore_name_list,'./')

##### 导入权重，查看是否正确保存，能否正确导入

In [64]:
param_dict = mindspore.load_checkpoint("tf2mindspore.ckpt",net=network)
param_not_load, _ = mindspore.load_param_into_net(network, param_dict)
print(param_not_load)


[]


##### 推理并评价

In [89]:
network = UNet()
loss_fn = nn.BCELoss()
model = Model(network, loss_fn, metrics={"Accuracy": PixelAcc()})

print("============== Starting Testing ==============")
param_dict = load_checkpoint("./tf2mindspore.ckpt")
load_param_into_net(network, param_dict)
dataset = load_train_data('./train/',512,512,1,shuffle=True)
acc = model.eval(dataset)
print("============== {} ==============".format(acc))


