In [None]:
import autoreload
%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input , Dense, Lambda, Concatenate, BatchNormalization
from tensorflow.keras.activations import relu, sigmoid, tanh
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD, schedules
from tensorflow_addons.optimizers import AdamW

from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN, TensorBoard, \
                                        EarlyStopping, ReduceLROnPlateau, CSVLogger, Callback

import numpy as np
import random
import os

import gc

from tqdm import tqdm 

In [None]:
import import_ipynb

from modules import convkxf, GroupGRULayer, GroupFC
from loss import MaskLoss, LocalSnrTarget, DfAlphaLoss, SpectralLoss, SISNR_Loss
from utils import mask_operations, df_operations, synthesis_frame, df_operations_wo_alpha
from params import model_params
from dataloader import read_tfrecod_data

In [None]:
from Callbacks import WarmUpCosineDecayScheduler, BatchCallback, ClearMemory, ModelCheckpoint

In [None]:
# from tensorflow.compat.v1 import ConfigProto, InteractiveSession, enable_eager_execution

# config = ConfigProto(allow_soft_placement=True)
# config.gpu_options.per_process_gpu_memory_fraction = 0.6
# config.gpu_options.allow_growth = True
# InteractiveSession(config=config)
# enable_eager_execution(config=config)

# tf.compat.v1.disable_eager_execution()

# Read config

In [None]:
p = model_params('config.ini')

# Read data

In [None]:
dataset_path = '/home2/user/myhsueh/dataset/tfrecord/'

training_set_path = [
    dataset_path + 'rir/training', 
    dataset_path + 'no_rir/training',
    dataset_path + 'rir/KB/training', 
    dataset_path + 'no_rir/KB/training',
    dataset_path + 'no_rir/KB2/training'
]
validation_set_path = [
    dataset_path + 'rir/validation', 
    dataset_path + 'no_rir/validation',
    dataset_path + 'rir/KB/validation', 
    dataset_path + 'no_rir/KB/validation',
    dataset_path + 'no_rir/KB2/validation'
]

combine_train, combine_val = False, False
if len(training_set_path) > 1 : combine_train = True
if len(validation_set_path) > 1 : combine_val = True

In [None]:
dataset = read_tfrecod_data(training_set_path)

In [None]:
val_dataset = read_tfrecod_data(validation_set_path, training=False)

In [None]:
each_tfr_batch = 6000 // 4 // p.length_sec
train_count, val_count = 0, 0
if combine_train:
    for i in training_set_path:
        train_count += len(os.listdir(i))
else:
    train_count = len(os.listdir(training_set_path[0]))
    
if combine_val:
    for i in validation_set_path:
        val_count += len(os.listdir(i))
else:
    val_count = len(os.listdir(validation_set_path[0]))

steps_per_epoch = train_count*each_tfr_batch//p.batch_size 
validation_steps = val_count*each_tfr_batch//p.batch_size 
print('training step: {} , validation step: {}'.format(steps_per_epoch,validation_steps))

# Model construction

In [None]:
BN_type = "normal" # range normal

In [None]:
erb_inputs = Input(shape=(None, p.nb_erb, 1), name='ERB_input')
spec_inputs = Input(shape=(None, p.nb_df, 2), name='spec_input') # complex
clean_spec = Input(shape=(None, p.fft_size//2+1, 2), name='clean_spec') # complex
noisy_spec = Input(shape=(None, p.fft_size//2+1, 2), name='noisy_spec') # complex

In [None]:
erb_conv0 = convkxf(erb_inputs, p.conv_out_ch, k=3, f=3, fstride=1, bias=False, batch_norm=True, 
                    training=True, infer=True, BN_type = BN_type, name='conv0_encoder')
erb_conv1 = convkxf(erb_conv0, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True,
                    infer=True, BN_type = BN_type, name='conv1_encoder')
erb_conv2 = convkxf(erb_conv1, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True,
                    infer=True, BN_type = BN_type, name='conv2_encoder')
erb_conv3 = convkxf(erb_conv2, p.conv_out_ch, k=1, f=3, fstride=1, bias=False, batch_norm=True,
                    infer=True, BN_type = BN_type, name='conv3_encoder')

df_conv0 = convkxf(spec_inputs, p.conv_out_ch, k=3, f=3, fstride=1, batch_norm=True, training=True,
                   infer=True, BN_type = BN_type, name='df_conv0_encoder')
df_conv1 = convkxf(df_conv0, p.conv_out_ch, k=1, f=3, fstride=2, batch_norm=True,
                   infer=True, BN_type = BN_type, name='df_conv1_encoder')

shape = [tf.shape(df_conv1)[l] for l in range(4)]
df_conv1 = tf.transpose(df_conv1,(0,1,3,2))
df_conv1 = tf.reshape(df_conv1, [shape[0], shape[1], shape[2]*shape[3]])

cemb = GroupFC(df_conv1, p.fc_hidden, p.fc_group, infer=True, name='GFC_encoder')
    
shape = [tf.shape(erb_conv3)[l] for l in range(4)]
emb = Concatenate()([cemb, tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]])])

GRU_emb = GroupGRULayer(emb, p.gru_hidden, p.gru_group, num_layer=1, name='GGRU0', add_output=True, norm=True)

In [None]:
GRU_emb1 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU01', add_output=True, norm=True)
# GRU_emb1 = BatchNormalization()(GRU_emb1)
de_emb = GroupFC(GRU_emb1, p.fc_hidden, p.fc_group, activation='relu', infer=True, name='GFC_decoder', norm=True)

shape = [tf.shape(erb_conv3)[l] for l in range(4)]
# shape = erb_conv3.get_shape().as_list()
emb_decoder = tf.reshape(de_emb, shape = [shape[0], shape[1], shape[2], shape[3]], name='decoder_reshape')

kwargs = {
    "k": 1,
    "batch_norm": True,
    "BN_type": BN_type,
    "infer": True
}
tkwargs = {
    "k": 1,
    "batch_norm": True,
    "mode": "transposed",
    "BN_type": BN_type,
    "infer": True
}
pkwargs = {
    "k": 1,
    "f": 1,
    "batch_norm": True,
    "BN_type": BN_type,
    "infer": True
}

convp3 = convkxf(erb_conv3, out_ch=p.conv_out_ch, name='convp3', **pkwargs)  # Conv
vt3_in = convp3 + emb_decoder
convt3 = convkxf(vt3_in, out_ch=p.conv_out_ch, fstride=1, name='convt3', **kwargs) #ConvT

convp2 = convkxf(erb_conv2, out_ch=p.conv_out_ch, name='convp2', **pkwargs) # Conv
vt2_in = convp2 + convt3
convt2 = convkxf(vt2_in, out_ch=p.conv_out_ch, name='convt2', **tkwargs)

convp1 = convkxf(erb_conv1, out_ch=p.conv_out_ch, name='convp1', **pkwargs) # Conv
vt1_in = convp1 + convt2
convt1 = convkxf(vt1_in, out_ch=p.conv_out_ch, name='convt1', **tkwargs)

convp0 = convkxf(erb_conv0, out_ch=p.conv_out_ch, name='convp0', **pkwargs) # Conv
vt0_in = convp0 + convt1
mask_out = convkxf(vt0_in, out_ch=1, k=1, fstride=1, act='sigmoid', batch_norm=False, 
                   reshape=True, name='mask_out')

In [None]:
from modules import constraint

In [None]:
GRU_emb2 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU1', add_output=True, norm=True)
# GRU_emb2 = BatchNormalization()(GRU_emb2)
convp = convkxf(df_conv0, 2*p.df_order, k=1, f=1, complex_in=True, batch_norm=True, 
                BN_type = BN_type, infer=True, name='convp_DfDecoder') 
convp = tf.transpose(convp, (0,1,3,2))

df_alpha = Dense(1, name='convp_alpha', activation='sigmoid', #)(GRU_emb2)
                        kernel_constraint= constraint, 
                        bias_constraint= constraint)(GRU_emb2)
# c = Dense(p.nb_df*p.df_order*2, name='convp_c', activation='tanh')(GRU_emb2)
#                         kernel_constraint= constraint, 
#                         bias_constraint= constraint)(GRU_emb2)
c = GroupFC(GRU_emb2, p.nb_df*p.df_order*2, p.fc_group, activation='tanh', name='GFC_c')

shape = [tf.shape(c)[k] for k in range(2)]
c = tf.reshape(c,(shape[0], shape[1], p.df_order*2, p.nb_df))

c = tf.reshape(c + convp,(shape[0], shape[1], p.df_order, 2, p.nb_df))

df_coeff = tf.transpose(c, [0, 1, 2, 4, 3], name='output_c')

# Enhance operation

In [None]:
spec = mask_operations(noisy_spec, mask_out) # mask gain
enhanced = df_operations(spec, df_coeff, df_alpha) # deep filter

# Callbacks

In [None]:
callbacks_nan = TerminateOnNaN()
callbacks_earlystop = EarlyStopping(patience=5, mode='min', restore_best_weights=True)

In [None]:
# Tensorboard
logdir = "./logs_6"
callbacks_tensorboard = TensorBoard(log_dir=logdir)
tb_callback = BatchCallback(callbacks_tensorboard, logdir)

# Memory clear
Memory_callback=ClearMemory()

# Loss

In [None]:
maskloss = Lambda(lambda x:MaskLoss(*x, factor=p.mask_factor, r=p.mask_gamma), 
                  name='maskloss')([mask_out, clean_spec, noisy_spec])

spectralloss = Lambda(lambda x:SpectralLoss(*x, gamma=p.df_gamma, factor_mag=p.df_factor, factor_img=p.df_factor), 
                      name='spectralloss')([enhanced, clean_spec])

# Model compile

In [None]:
inputs = [erb_inputs, spec_inputs, clean_spec, noisy_spec]
outputs = [maskloss, spectralloss]

In [None]:
model=Model(inputs=inputs, outputs=outputs, name='DfNet')

In [None]:
model.summary()

In [None]:
warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=p.lr,
                                        total_steps=p.epochs*steps_per_epoch,
                                        warmup_learning_rate=0.0,
                                        warmup_steps=2*steps_per_epoch,
                                        hold_base_rate_steps=0,
                                        mini_lr=1e-7,
                                        steps_per_epoch=steps_per_epoch,
                                        global_step_init=0*steps_per_epoch)

In [None]:
# Checkpoint
cp_callback = ModelCheckpoint(model, './weights2_6/weights.{epoch:02d}-{loss:.2f}.h5')
callbacks_overall = [tb_callback, cp_callback, warm_up_lr,
                     callbacks_nan, callbacks_earlystop, Memory_callback]

In [None]:
model.add_metric(maskloss, name = "maskloss")
model.add_metric(spectralloss, name = "spectralloss")
model.add_loss(spectralloss)

In [None]:
optimizer = AdamW(learning_rate=p.lr, weight_decay=0.0, clipnorm=1)
# optimizer = Adam(learning_rate=p.lr)

In [None]:
model.compile(optimizer=optimizer, run_eagerly=True) #, loss_weights={"dfalphaloss": 1, "spectralloss": 20}) 

# Pre-train weight

In [None]:
# h5_lists = os.listdir("./weights2_6")
# h5_lists.sort(key=lambda fn:os.path.getmtime("./weights2_6/" + fn)
#                 if not os.path.isdir("./weights2_6/" + fn) else 0)
# new_file = os.path.join("./weights2_6/", h5_lists[-1])
# print(new_file)

# model.load_weights(new_file, by_name=True)
# model.save("weights.hdf5")

In [None]:
# import h5py
# keys = []
# with h5py.File(new_file,'r') as f: # open file
#     f.visit(keys.append) # append all keys to list
#     for key in keys:
#         if ':' in key: # contains data if ':' in key
#             print(f[key].name)
    
# f = h5py.File(new_file,'r')
# group = f[key]

In [None]:
# def get_param(folding_layer, weight_name):
#     layer = '/' + folding_layer + '/' + folding_layer + '/'
#     if folding_layer == 'GGRU00': weight_name = 'gru_cell/'+ weight_name
#     if folding_layer == 'GGRU010': weight_name = 'gru_cell_1/'+ weight_name
#     if folding_layer == 'GGRU011': weight_name = 'gru_cell_2/'+ weight_name
#     if folding_layer == 'GGRU10': weight_name = 'gru_cell_3/'+ weight_name
#     if folding_layer == 'GGRU11': weight_name = 'gru_cell_4/'+ weight_name
#     return group[layer+weight_name+':0'][:]

In [None]:
# # store weights before loading pre-trained weights
# preloaded_layers = model.layers.copy()
# preloaded_weights = []
# for pre in preloaded_layers:
#     preloaded_weights.append(pre.get_weights())

In [None]:
# for layer in model.layers:
#     if layer.name[:4] == 'GGRU':
#         kernel_weight = get_param(layer.name, 'kernel')
#         recurrent_weights = get_param(layer.name, 'recurrent_kernel')
#         bias = get_param(layer.name, 'bias')
#         layer.set_weights([kernel_weight, recurrent_weights, bias])
#     else:
#         try:
#             kernel_weight = get_param(layer.name, 'kernel')
#             try:
#                 bias = get_param(layer.name, 'bias')
#                 layer.set_weights([kernel_weight, bias])
#             except:
#                 layer.set_weights([kernel_weight])
#         except:
#              if isinstance(layer, tf.keras.layers.BatchNormalization):
#                 gamma = get_param(layer.name, 'gamma')
#                 beta = get_param(layer.name, 'beta')
#                 moving_mean = get_param(layer.name, 'moving_mean')
#                 moving_variance = get_param(layer.name, 'moving_variance') 
#                 layer.set_weights([gamma, beta, moving_mean, moving_variance])

In [None]:
# # # store weights before loading pre-trained weights
# # preloaded_layers = model.layers.copy()
# # preloaded_weights = []
# # for pre in preloaded_layers:
# #     preloaded_weights.append(pre.get_weights())

# # load pre-trained weights
# # model.load_weights(new_file, by_name=True)

# # # compare previews weights vs loaded weights
# for layer, pre in zip(model.layers, preloaded_weights):
#     weights = layer.get_weights()

#     if weights:
#         if np.array_equal(weights, pre):
#             print('not loaded', layer.name)
#         else:
#             print('loaded', layer.name)
#             pass

# Training

In [None]:
# for layer in model.layers:
#     layer.trainable = True

In [None]:
history = model.fit(dataset,
                    epochs=p.epochs,
                    validation_data=val_dataset,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    callbacks=callbacks_overall,
                    initial_epoch=0)

In [None]:
# 4_1 continue training w/ batch 24
# 4_2 continue training w/ batch 32 w/o constriant

# 5 new training bias=False for convkxf
# 6 new trainging w/ BN for FC and GRU