- 全ての手順については[こちら](nvc_train_v4.ipynb)
- 上から順に実行してください
- ファイルのアップロードは左メニューのフォルダのアイコンから行えます

## 3. 音源の学習処理

- この処理はTPU専用です。上部メニューの「ランタイム→ランタイムのタイプを変更」からTPUを使用するように設定してください
- この処理は自動では終了しません。また、再度実行すると途中から再開することができます。「4.」で確認して適宜実行や中断を行ってください

In [None]:
# Googleドライブに接続
from google.colab import drive
drive.mount('drive')

In [None]:
#!git clone --depth 1 "https://github.com/NON906/nvc_train_v4.git"
!git clone --depth 1 "https://NON906:****@gitlab.com/NON906/nvc_train_v4.git"

In [None]:
# ---
# 設定パラメータ（※実行前に設定してください）

# 入力の解析済みファイル
targets_zip_file = 'drive/My Drive/targets_UnityChan.zip' # 'drive/My Drive/targets.zip'

# ---

!unzip "{targets_zip_file}" -d targets > /dev/null

In [None]:
# ---
# 設定パラメータ（※実行前に設定してください）

# 出力ディレクトリ
model_dir_path = 'drive/My Drive/nvc'

# ---

try:
    start_file_epoch = int(open(model_dir_path + '/epoch.txt', 'r').read())
    start_file_dir = model_dir_path
except FileNotFoundError:
    start_file_epoch = 0
    start_file_dir = None
load_optim_weights = True
phoneme_model_path = 'nvc_train_v4/phoneme.h5'

%tensorflow_version 2.x

import os
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.layers import Dense, Activation, LSTM, Reshape, Lambda, Layer, Input, Concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.utils import Sequence
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import initializers
from tensorflow.keras.constraints import unit_norm
from tensorflow.keras.utils import CustomObjectScope
from tensorflow.keras.losses import Loss, MeanSquaredError, CategoricalCrossentropy
import glob
import struct
import numpy as np
import sys
import random
import math
import pickle
import tensorflow as tf
from tensorflow.python.framework import tensor_shape


class VoiceGenerator():
    def __init__(self, dir_path, val_file, batch_size, length=None, train=True, start_cut=0, max_size=None, gender=None):
        self.length = length
        self.batch_size = batch_size
        self.train = train
        self.val_file = val_file
        
        self.index = 0
        self.gender = gender
        self.get_input_files(dir_path)
        
        self.max_size = max_size
        
        self.cut = start_cut

        ORDER = 32
        ORDER_OUTPUT = 128
        CONCAT_MUL = 1
        self.silent_array = []
        file_data = open('nvc_train_v4/silent.mc', 'rb').read()
        for loop in range(len(file_data) // (4 * (ORDER + 1) * CONCAT_MUL)):
            for loop2 in range(CONCAT_MUL):
                self.silent_array.append(list(struct.unpack('<' + str(ORDER + 1) + 'f', file_data[(loop * CONCAT_MUL + loop2) * 4 * (ORDER + 1):(loop * CONCAT_MUL + loop2 + 1) * 4 * (ORDER + 1)])))
                self.silent_array[loop * CONCAT_MUL + loop2].append(0.0)

        self.get_input_voices()

    def get_input_files(self, dir_path):
        self.input_voices = glob.glob(dir_path + '/**/*.pitch', recursive=True)
        if self.val_file == 0:
            self.input_voices = sorted(self.input_voices)
        else:
            if self.train:
                self.input_voices = sorted(self.input_voices)[:-self.val_file]
            else:
                self.input_voices = sorted(self.input_voices)[-self.val_file:]
        if self.train:
            random.shuffle(self.input_voices)
        if self.batch_size is None:
            self.batch_size = len(self.input_voices)

    def get_input_voices(self):
        MAX_SIZE = 512
        CONCAT_MUL = 1
        ORDER = 32

        self.data_array = None

        data_array2 = []
        
        max_array_size = 0
        
        while (self.length is None) or (len(data_array2) < self.length):
        
            if self.index >= len(self.input_voices):
                self.index = 0
                if self.train:
                    random.shuffle(self.input_voices)
                self.cut += 1
                if self.cut >= 1:
                    self.cut = 0
                    if self.length is None:
                        break
            input_voice = self.input_voices[self.index]
        
            name, _ = os.path.splitext(input_voice)
            
            data_array = []
            file_data = open(name + '.mc', 'rb').read()
            for loop in range(len(file_data) // (4 * (ORDER + 1) * CONCAT_MUL)):
                for loop2 in range(CONCAT_MUL):
                    data_array.append(list(struct.unpack('<' + str(ORDER + 1) + 'f', file_data[(loop * CONCAT_MUL + loop2) * 4 * (ORDER + 1):(loop * CONCAT_MUL + loop2 + 1) * 4 * (ORDER + 1)])))
            file_data = open(name + '.pitch', 'rb').read()
            for loop in range(len(file_data) // (4 * CONCAT_MUL)):
                for loop2 in range(CONCAT_MUL):
                    pitch = struct.unpack('<f', file_data[(loop * CONCAT_MUL + loop2) * 4:(loop * CONCAT_MUL + loop2 + 1) * 4])[0]
                    scaled_pitch = pitch / (24000.0 / 71.0) * 4.0
                    data_array[loop * CONCAT_MUL + loop2].append(scaled_pitch)

            if self.max_size is None:
                if len(data_array) > MAX_SIZE:
                    max_array_size = MAX_SIZE
                elif max_array_size < len(data_array):
                    max_array_size = len(data_array)
                for loop in range(len(data_array) // MAX_SIZE + 1):
                    data_array2.append(data_array[loop * MAX_SIZE:(loop + 1) * MAX_SIZE])
            else:
                max_array_size = self.max_size
                for loop in range(len(data_array) // self.max_size + 1):
                    data_array2.append(data_array[loop * self.max_size:(loop + 1) * self.max_size])

            self.index += 1
        
        for loop in range(len(data_array2)):
            extend_data = self.silent_array[:max_array_size - len(data_array2[loop])]
            data_array2[loop].extend(extend_data)

        self.data_array = data_array2
        
        self.max_size = max_array_size

    def __len__(self):
        return len(self.data_array) // self.batch_size
    
    def on_epoch_end(self):
        if self.length is not None:
            self.get_input_voices()

    def get_inputs(self):
        inputs = []
        for idx in range(len(self.data_array) // self.batch_size):
            inputs.extend(self.data_array[idx * self.batch_size:(idx + 1) * self.batch_size])
        batch_inputs = np.array(inputs, dtype='float32')
        return batch_inputs, batch_inputs


class AngleLoss(Loss):
    def call(self, y_true, y_pred):
        y_true = K.reshape(y_true, (-1, K.int_shape(y_true)[1], K.int_shape(y_true)[2] // 2, 2))
        y_pred = K.reshape(y_pred, (-1, K.int_shape(y_pred)[1], K.int_shape(y_pred)[2] // 2, 2))
        sum_val = K.sum(y_true * y_pred, axis=-1)
        acos_val = tf.acos(K.clip(sum_val, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
        return K.mean(acos_val / math.pi, axis=-1)


os.makedirs(model_dir_path, exist_ok=True)

try:
    tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
    tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
    tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
    strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)
except KeyError:
    #strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    print('TPUが設定されていません')
    exit(0)

# 以下から引用
# https://blog.shikoan.com/distributed-train-decorator-in-tf20/
from enum import Enum

class Reduction(Enum):
    NONE = 0
    SUM = 1
    MEAN = 2
    CONCAT = 3

def distrtibuted(*reduction_flags):
    def _decorator(fun):
        def per_replica_reduction(z, flag):
            if flag == Reduction.NONE:
                return z
            elif flag == Reduction.SUM:
                return strategy.reduce(tf.distribute.ReduceOp.SUM, z, axis=None)
            elif flag == Reduction.MEAN:
                return strategy.reduce(tf.distribute.ReduceOp.MEAN, z, axis=None)
            elif flag == Reduction.CONCAT:
                z_list = strategy.experimental_local_results(z)
                return tf.concat(z_list, axis=0)
            else:
                raise NotImplementedError()

        @tf.function
        def _decorated_fun(*args, **kwargs):
            fun_result = strategy.run(fun, args=args, kwargs=kwargs)
            if len(reduction_flags) == 0:
                assert fun_result is None
                return
            elif len(reduction_flags) == 1:
                assert type(fun_result) is not tuple and fun_result is not None
                return per_replica_reduction(fun_result, *reduction_flags)
            else:
                assert type(fun_result) is tuple
                return tuple((per_replica_reduction(fr, rf) for fr, rf in zip(fun_result, reduction_flags)))
        return _decorated_fun
    return _decorator

with strategy.scope():
    length = 3200 #6400
    batch_size = 200
    gen = VoiceGenerator('targets', 0, batch_size, length, train=True, max_size=512)
    shape0 = gen.get_inputs()[0].shape[1]

    optim_gen = Adam(0.0002, 0.5)
    optim_dis = Adam(0.0002, 0.5)

    loss_func_dis_real = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss_func_dis_fake = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss_func_dis_common = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss_func_gen = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss_func_gen_img = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    loss_func_gen_phoneme = AngleLoss(reduction=tf.keras.losses.Reduction.NONE)

    input_layer = Input(shape=(shape0, 128), name='gen_input')

    gen_layers_pitch = LSTM(64, return_sequences=True, name='gen_pitch_lstm0')(input_layer)
    gen_layers_pitch = LSTM(64, return_sequences=True, name='gen_pitch_lstm1')(gen_layers_pitch)
    gen_layers_pitch = Dense(1, name='gen_pitch_dense')(gen_layers_pitch)

    pitch_stopgrad_layers = Lambda(lambda x: K.stop_gradient(x), name='gen_stopgrad')(gen_layers_pitch)
    concat_layers = Concatenate(name='gen_concat_0')([input_layer, pitch_stopgrad_layers])

    gen_layers_power = LSTM(64, return_sequences=True, name='gen_power_lstm0')(input_layer)
    gen_layers_power = LSTM(64, return_sequences=True, name='gen_power_lstm1')(gen_layers_power)
    gen_layers_power = Dense(1, name='gen_power_dense')(gen_layers_power)

    gen_layers = LSTM(128, return_sequences=True, name='gen_lstm0')(concat_layers)
    gen_layers = LSTM(128, return_sequences=True, name='gen_lstm1')(gen_layers)
    gen_layers = Dense(32, name='gen_dense')(gen_layers)

    gen_layers = Concatenate(name='gen_concat_1')([gen_layers_power, gen_layers, gen_layers_pitch])

    gen_model = Model(inputs=input_layer, outputs=gen_layers, name='gen_model')
    gen_model.summary()

    input_layer = Input(shape=(shape0, 34), name='dis_input')

    dis_layers = LSTM(128, return_sequences=True, name='dis_lstm_s')(input_layer)
    dis_layers = LSTM(64, return_sequences=True, name='dis_lstm0')(dis_layers)
    dis_layers = Reshape((shape0 // 2, 128), name='dis_reshape0')(dis_layers)
    dis_layers = LSTM(64, return_sequences=True, name='dis_lstm1')(dis_layers)
    dis_layers = Reshape((shape0 // 4, 128), name='dis_reshape1')(dis_layers)
    dis_layers = LSTM(64, return_sequences=True, name='dis_lstm2')(dis_layers)
    dis_layers = Reshape((shape0 // 8, 128), name='dis_reshape2')(dis_layers)
    dis_layers = Dense(1, name='dis_dense')(dis_layers)
    #dis_layers = Activation('sigmoid', name='dis_sigmoid')(dis_layers)

    dis_model = Model(inputs=input_layer, outputs=dis_layers, name='dis_model')
    dis_model.summary()

    input_layer = Input(shape=(shape0, 34), name='phoneme_input')

    f_layers = LSTM(128, return_sequences=True, name='phoneme_lstm0')(input_layer)
    f_layers = LSTM(128, return_sequences=True, name='phoneme_lstm1')(f_layers)
    loop_layers = []
    for loop in range(64):
        f_layers_loop = Dense(2, name='phoneme_dense_l' + str(loop))(f_layers)
        f_layers_loop = Lambda(lambda x: K.l2_normalize(x, axis=-1), name='phoneme_norm_l' + str(loop))(f_layers_loop)
        loop_layers.append(f_layers_loop)
    phoneme_layers = Concatenate(name='phoneme_concat')(loop_layers)

    phoneme_model = Model(inputs=input_layer, outputs=phoneme_layers, name='phoneme_model')
    phoneme_model.load_weights(phoneme_model_path, by_name=True)
    phoneme_model.trainable = False
    phoneme_model.summary()

    @distrtibuted(Reduction.MEAN, Reduction.MEAN)
    def train_on_batch(real_img, real_img_out):
        with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
            phoneme_true = phoneme_model(real_img)
            fake_img = gen_model(phoneme_true)

        with d_tape:
            real_out = dis_model(real_img_out)
            fake_out = dis_model(fake_img)

            d_real_loss = loss_func_dis_real(K.ones_like(real_out), real_out)
            d_fake_loss = loss_func_dis_fake(K.zeros_like(fake_out), fake_out)
            d_loss = d_real_loss * 0.4 + d_fake_loss * 0.8
            d_loss = tf.reduce_sum(d_loss) * (1.0 / batch_size)
        gradients = d_tape.gradient(d_loss, dis_model.trainable_weights)
        optim_dis.apply_gradients(zip(gradients, dis_model.trainable_weights))

        with g_tape:
            fake_out = dis_model(fake_img)
            #phoneme_pred = phoneme_model(fake_img)

            g_dis_loss = loss_func_gen(K.ones_like(fake_out), fake_out)
            g_dis_loss = tf.reduce_sum(g_dis_loss) * (1.0 / batch_size)
            g_img_loss = loss_func_gen_img(real_img, fake_img)
            g_img_loss = tf.reduce_sum(g_img_loss) * (1.0 / batch_size)
            #g_phoneme_loss = loss_func_gen_phoneme(phoneme_true, phoneme_pred)
            #g_phoneme_loss = tf.reduce_sum(g_phoneme_loss) * (1.0 / batch_size)
            g_loss = g_dis_loss * 0.3 + g_img_loss * 0.5 #g_dis_loss * 0.1 + g_img_loss * 0.5 + g_phoneme_loss * 0.1
        gradients = g_tape.gradient(g_loss, gen_model.trainable_weights)
        optim_gen.apply_gradients(zip(gradients, gen_model.trainable_weights))

        return d_loss, g_loss

    start_train = True
    epoch = start_file_epoch + 1
    while True:
        real_img, real_img_out = gen.get_inputs()
        gen.on_epoch_end()

        trainset = tf.data.Dataset.from_tensor_slices((real_img, real_img_out))
        trainset = trainset.shuffle(buffer_size=length).batch(batch_size)
        trainset = strategy.experimental_distribute_dataset(trainset)

        if not load_optim_weights and start_file_dir is not None:
            gen_model.load_weights('{}/gen_{:09d}.h5'.format(start_file_dir, start_file_epoch), by_name=True)
            dis_model.load_weights('{}/dis_{:09d}.h5'.format(start_file_dir, start_file_epoch), by_name=True)

        for loop_cnt, (X, y) in enumerate(trainset):
            dis_loss_val, gen_loss_val = train_on_batch(X, y)
            if start_train and start_file_dir is not None and load_optim_weights:
                gen_model.load_weights('{}/gen_{:09d}.h5'.format(start_file_dir, start_file_epoch), by_name=True)
                dis_model.load_weights('{}/dis_{:09d}.h5'.format(start_file_dir, start_file_epoch), by_name=True)
                with open('{}/gen_{:09d}.pkl'.format(start_file_dir, start_file_epoch), 'rb') as f:
                    weight_values = pickle.load(f)
                optim_gen.set_weights(weight_values)
                with open('{}/dis_{:09d}.pkl'.format(start_file_dir, start_file_epoch), 'rb') as f:
                    weight_values = pickle.load(f)
                optim_dis.set_weights(weight_values)
                start_train = False

        print('{:09d},{:.6f},{:.6f}'.format(epoch, float(gen_loss_val.numpy()), float(dis_loss_val.numpy())))

        if epoch % 10 == 0:
            symbolic_weights = getattr(optim_gen, 'weights')
            weight_values_combined = K.batch_get_value(symbolic_weights)
            with open('{}/gen_{:09d}.pkl'.format(model_dir_path, epoch), 'wb') as f:
                pickle.dump(weight_values_combined, f)

            symbolic_weights = getattr(optim_dis, 'weights')
            weight_values_dis = K.batch_get_value(symbolic_weights)
            with open('{}/dis_{:09d}.pkl'.format(model_dir_path, epoch), 'wb') as f:
                pickle.dump(weight_values_dis, f)

            gen_model.save('{}/gen_{:09d}.h5'.format(model_dir_path, epoch), include_optimizer=False)
            dis_model.save('{}/dis_{:09d}.h5'.format(model_dir_path, epoch), include_optimizer=False)

            _ = open(model_dir_path + '/epoch.txt', 'w').write(str(epoch))

        epoch += 1