In [2]:
# -*- coding: utf-8 -*-
'''
@Author: yichao.li
@Date:   2020-04-27
@Description: merge two pbs into one savedmodel
'''
import tensorflow as tf
import argparse

def denoiser(melgan_two_outputs, param_denoise):
    with tf.variable_scope('denoiser'):
        bias_audio = tf.identity(tf.squeeze(melgan_two_outputs[1], 1))
        bias_audio = tf.expand_dims(bias_audio, 0)

        real_audio = tf.identity(tf.squeeze(melgan_two_outputs[0],1))
        real_audio = tf.expand_dims(real_audio, 0)

        bias_spec = tf.abs(tf.contrib.signal.stft(bias_audio, 800, 200, 2048))

        audio_stft = tf.contrib.signal.stft(real_audio, 800, 200, 2048)
        audio_spec = tf.abs(audio_stft)
        audio_angles = audio_stft / tf.cast(tf.maximum(1e-8, audio_spec), tf.complex64)

        param = tf.constant([param_denoise], dtype=tf.float32) 
        audio_spec_denoised = tf.subtract(audio_spec, tf.multiply(bias_spec, param))
        audio_spec_denoised = tf.clip_by_value(audio_spec_denoised, 0.0, 999999999.0)
        S_complex = tf.cast(audio_spec_denoised, dtype=tf.complex64)

        denoiser_audio = tf.contrib.signal.inverse_stft(S_complex * audio_angles, 800, 200, 2048)
        return tf.squeeze(denoiser_audio, 0)

def merge(args):
    # load tacotron pb
    tf.reset_default_graph()
    graph_A = tf.GraphDef()
    with tf.gfile.GFile(args.tacotron, 'rb') as fid:
        serialized_graph = fid.read()
        graph_A.ParseFromString(serialized_graph)


    # load melgan pb
    tf.reset_default_graph()
    graph_B = tf.GraphDef()
    with tf.gfile.GFile(args.melgan, 'rb') as fid:
        serialized_graph = fid.read()
        graph_B.ParseFromString(serialized_graph)

    # some nodes in melgan is locked, you must free it at first
    for node in graph_B.node:
        if node.op == 'Assign':
            node.op = 'Identity'
        if 'use_locking' in node.attr: del node.attr['use_locking']
        if 'validate_shape' in node.attr: del node.attr['validate_shape']
        # if len(node.input) == 2:
        #     # input0: ref: Should be from a Variable node. May be uninitialized.
        #     # input1: value: The value to be assigned to the variable.
        #     node.input[0] = node.input[1]
        #     del node.input[1]

    # merge two graphs into one
    tf.reset_default_graph()
    with tf.Graph().as_default() as graphs_merged:
        with tf.Session(graph=graphs_merged) as sess:
            inputs = tf.placeholder(tf.int32, [1, None], name='inputs')
            input_lengths = tf.placeholder(tf.int32, [1], name='input_lengths')
            
            if args.split_infos:
                split_infos = tf.placeholder(tf.int32, shape=[1, None], name='split_infos')
                mel_out, = tf.import_graph_def(
                    graph_A, 
                    input_map={"inputs:0": inputs, "input_lengths:0": input_lengths, "split_infos:0": split_infos},
                    return_elements=["Tacotron_model/inference/add:0"], 
                    name="")
            else:
                mel_out, = tf.import_graph_def(
                    graph_A, 
                    input_map={"inputs:0": inputs, "input_lengths:0": input_lengths},
                    return_elements=["Tacotron-2/inference/Minimum_1:0"], 
                    name="")

            if args.denoise:
                mel_out = tf.pad(mel_out,[[0,1],[0,0],[0,0]], mode="CONSTANT", constant_values=0)

            print("------successed done build input and melspec!-----")

            audio, = tf.import_graph_def(
                graph_B, 
                input_map={"mel_G:0": mel_out},
                return_elements=["MelGAN/Generator/Tanh:0"], 
                name="")

            # make sure node is identity
            tf.identity(audio, "MelGAN/Generator/Tanh")

            print("------successed done build output audio!-----")
            

            # export to saved model
            if args.split_infos and args.denoise:
                tf.saved_model.simple_save(
                    sess,
                    args.output_dir,
                    inputs={"input_lengths": input_lengths,"inputs": inputs,"split_infos":split_infos},
                    outputs={"audio": denoiser(tf.get_default_graph().get_tensor_by_name("MelGAN/Generator/Tanh:0"), args.param_denoise)})
            elif args.split_infos and not args.denoise:
                tf.saved_model.simple_save(
                    sess,
                    args.output_dir,
                    inputs={"input_lengths": input_lengths,"inputs": inputs,"split_infos":split_infos},
                    outputs={"audio": tf.get_default_graph().get_tensor_by_name("MelGAN/Generator/Tanh:0")})
            elif not args.split_infos and args.denoise:
                tf.saved_model.simple_save(
                    sess,
                    args.output_dir,
                    inputs={"input_lengths": input_lengths,"inputs": inputs},
                    outputs={"audio": denoiser(tf.get_default_graph().get_tensor_by_name("MelGAN/Generator/Tanh:0"), args.param_denoise)})
            else:
                tf.saved_model.simple_save(
                    sess,
                    args.output_dir,
                    inputs={"input_lengths": input_lengths,"inputs": inputs},
                    outputs={"audio": tf.get_default_graph().get_tensor_by_name("MelGAN/Generator/Tanh:0")})

            print("------your denoise is {}.------".format(args.denoise))
            print("------your split_infos is {}.------".format(args.split_infos))
            print("------successed merge graph to {} !------".format(args.output_dir))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-t",
        "--tacotron",
        type=str,
        default='/home/yichao.li/pb/meltron/t2origin.pb')
    parser.add_argument(
        "-m",
        "--melgan",
        type=str,
        default= '/home/yichao.li/pb/melgan/melgan.pb')
    parser.add_argument(
        "-d",
        "--denoise",
        type=bool,
        default=False)
    parser.add_argument(
        "-s",
        "--split_infos",
        type=bool,
        default=False)
    parser.add_argument(
        "-o",
        "--output_dir",
        type=str,
        default='/home/yichao.li/pb/12')
    parser.add_argument(
        "-p",
        "--param_denoise",
        type=int,
        default=3000)
    args = parser.parse_args()
    merge(args)



usage: ipykernel_launcher.py [-h] [-t TACOTRON] [-m MELGAN] [-d DENOISE]
                             [-s SPLIT_INFOS] [-o OUTPUT_DIR]
                             [-p PARAM_DENOISE]
ipykernel_launcher.py: error: unrecognized arguments: -f /run/user/1002/jupyter/kernel-c4dfb8e3-e64f-415a-b569-19128ef827d9.json


SystemExit: 2