In [1]:
import autoreload
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  

In [4]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input , Dense, Lambda, Normalization, Concatenate
from tensorflow.keras.activations import relu, sigmoid, tanh
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD, schedules
from tensorflow_addons.optimizers import AdamW, extend_with_decoupled_weight_decay

from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN, TensorBoard, \
                                        EarlyStopping, ReduceLROnPlateau, CSVLogger, Callback

import numpy as np
import random
import os

import gc

from tqdm import tqdm 

TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


In [5]:
import import_ipynb

from modules import convkxf, GroupGRULayer, GroupFC
from loss import MaskLoss, LocalSnrTarget, DfAlphaLoss, SpectralLoss, SISNR_Loss
from utils import mask_operations, df_operations, synthesis_frame, df_operations_wo_alpha
from params import model_params
from dataloader import read_tfrecod_data

importing Jupyter notebook from modules.ipynb
importing Jupyter notebook from loss.ipynb
importing Jupyter notebook from bandERB.ipynb
importing Jupyter notebook from params.ipynb
importing Jupyter notebook from utils.ipynb
importing Jupyter notebook from dataloader.ipynb


In [6]:
from tensorflow.compat.v1 import ConfigProto, InteractiveSession, enable_eager_execution

config = ConfigProto(allow_soft_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.6
config.gpu_options.allow_growth = True
InteractiveSession(config=config)
enable_eager_execution(config=config)

# tf.compat.v1.disable_eager_execution()

# Read config

In [7]:
p = model_params('config.ini')

# Read data

In [8]:
dataset_path = '/home2/user/myhsueh/dataset/tfrecord/'

training_set_path = [
    dataset_path + 'rir/training', 
    dataset_path + 'no_rir/training',
    dataset_path + 'rir/KB/training', 
    dataset_path + 'no_rir/KB/training',
    dataset_path + 'no_rir/KB2/training'
]
validation_set_path = [
    dataset_path + 'rir/validation', 
    dataset_path + 'no_rir/validation',
    dataset_path + 'rir/KB/validation', 
    dataset_path + 'no_rir/KB/validation',
    dataset_path + 'no_rir/KB2/validation'
]

combine_train, combine_val = False, False
if len(training_set_path) > 1 : combine_train = True
if len(validation_set_path) > 1 : combine_val = True

In [9]:
dataset = read_tfrecod_data(training_set_path)

['/home2/user/myhsueh/dataset/tfrecord/rir/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/training', '/home2/user/myhsueh/dataset/tfrecord/rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB2/training'] , file directory exist: True
384
TD_30-2-of-4.tfrecord, TD_41-2-of-4.tfrecord, TD_96-4-of-4.tfrecord, TD_112-2-of-4.tfrecord
TD_80-2-of-4.tfrecord, TD_57-4-of-4.tfrecord, TD_71-2-of-4.tfrecord, TD_102-4-of-4.tfrecord
TD_40-2-of-4.tfrecord, TD_66-3-of-4.tfrecord, TD_55-4-of-4.tfrecord, TD_111-1-of-4.tfrecord
TD_52-2-of-4.tfrecord, TD_34-2-of-4.tfrecord, TD_114-4-of-4.tfrecord, TD_100-1-of-4.tfrecord
TD_40-1-of-4.tfrecord, TD_118-3-of-4.tfrecord, TD_34-4-of-4.tfrecord, TD_71-1-of-4.tfrecord
TD_105-1-of-4.tfrecord, TD_34-1-of-4.tfrecord, TD_98-2-of-4.tfrecord, TD_32-3-of-4.tfrecord
TD_90-1-of-4.tfrecord, TD_105-4-of-4.tfrecord, TD_116-4-of-4.tfrecord, TD_99-4-of-4.tfrecord
TD_49-4-of-4.tfrecord, TD_73-4-of-4.

done erb
done spec
Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
ERB shape: (300, 32, 1) Tensor("Squeeze_300:0", shape=(300, 32, 1), dtype=float32)
NOISY_SPEC_df shape: (300, 96, 2) Tensor("Reshape_4:0", shape=(300, 96, 2), dtype=float32)
CLEAN_SPEC shape: (300, 481, 2) Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
NOISY_SPEC shape: (300, 481, 2) Tensor("Reshape_5:0", shape=(300, 481, 2), dtype=float32)
['/home2/user/myhsueh/dataset/tfrecord/rir/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/training', '/home2/user/myhsueh/dataset/tfrecord/rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB2/training'] , file directory exist: True
239
TD_30-2-of-4.tfrecord, TD_41-2-of-4.tfrecord, TD_57-4-of-4.tfrecord, TD_40-2-of-4.tfrecord
TD_66-3-of-4.tfrecord, TD_52-2-of-4.tfrecord, TD_1-2-of-4.tfrecord, TD_40-1-of-4.tfrecord
TD_34-4-of-4.tfrecord, TD_34-1-of-4.tfrecord, TD_10-1-of-4.tfrecord, T

done spec
Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
ERB shape: (300, 32, 1) Tensor("Squeeze_300:0", shape=(300, 32, 1), dtype=float32)
NOISY_SPEC_df shape: (300, 96, 2) Tensor("Reshape_4:0", shape=(300, 96, 2), dtype=float32)
CLEAN_SPEC shape: (300, 481, 2) Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
NOISY_SPEC shape: (300, 481, 2) Tensor("Reshape_5:0", shape=(300, 481, 2), dtype=float32)
['/home2/user/myhsueh/dataset/tfrecord/rir/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/training', '/home2/user/myhsueh/dataset/tfrecord/rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB/training', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB2/training'] , file directory exist: True
101
TD_90KB-2-of-4.tfrecord, TD_85KB-4-of-4.tfrecord, TD_80KB-4-of-4.tfrecord, TD_78KB-4-of-4.tfrecord
TD_73KB-2-of-4.tfrecord, TD_77KB-4-of-4.tfrecord, TD_94KB-2-of-4.tfrecord, TD_94KB-1-of-4.tfrecord
TD_70KB-1-of-4.tfrecord, TD_98KB-2-of-4.tfrecord, TD_88KB-4-of

In [10]:
val_dataset = read_tfrecod_data(validation_set_path, training=False)

['/home2/user/myhsueh/dataset/tfrecord/rir/validation', '/home2/user/myhsueh/dataset/tfrecord/no_rir/validation', '/home2/user/myhsueh/dataset/tfrecord/rir/KB/validation', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB/validation', '/home2/user/myhsueh/dataset/tfrecord/no_rir/KB2/validation'] , file directory exist: True
96
TD_87-4-of-4.tfrecord, TD_87-3-of-4.tfrecord, TD_89-1-of-4.tfrecord, TD_1-2-of-4.tfrecord
TD_10-1-of-4.tfrecord, TD_88-3-of-4.tfrecord, TD_15-4-of-4.tfrecord, TD_4-3-of-4.tfrecord
TD_106-3-of-4.tfrecord, TD_106-2-of-4.tfrecord, TD_5-1-of-4.tfrecord, TD_14-2-of-4.tfrecord
TD_12-2-of-4.tfrecord, TD_0-3-of-4.tfrecord, TD_2-2-of-4.tfrecord, TD_1-3-of-4.tfrecord
TD_7-4-of-4.tfrecord, TD_107-2-of-4.tfrecord, TD_86-2-of-4.tfrecord, TD_3-2-of-4.tfrecord
TD_8-4-of-4.tfrecord, TD_8-2-of-4.tfrecord, TD_14-1-of-4.tfrecord, TD_6-3-of-4.tfrecord
TD_16-4-of-4.tfrecord, TD_86-4-of-4.tfrecord, TD_11-3-of-4.tfrecord, TD_108-2-of-4.tfrecord
TD_106-1-of-4.tfrecord, TD_16-1-of-4.tfreco

done erb
done spec
Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
ERB shape: (300, 32, 1) Tensor("Squeeze_300:0", shape=(300, 32, 1), dtype=float32)
NOISY_SPEC_df shape: (300, 96, 2) Tensor("Reshape_4:0", shape=(300, 96, 2), dtype=float32)
CLEAN_SPEC shape: (300, 481, 2) Tensor("Reshape_6:0", shape=(300, 481, 2), dtype=float32)
NOISY_SPEC shape: (300, 481, 2) Tensor("Reshape_5:0", shape=(300, 481, 2), dtype=float32)


In [11]:
each_tfr_batch = 6000 // 4 // p.length_sec
train_count, val_count = 0, 0
if combine_train:
    for i in training_set_path:
        train_count += len(os.listdir(i))
else:
    train_count = len(os.listdir(training_set_path[0]))
    
if combine_val:
    for i in validation_set_path:
        val_count += len(os.listdir(i))
else:
    val_count = len(os.listdir(validation_set_path[0]))

steps_per_epoch = train_count*each_tfr_batch//p.batch_size 
validation_steps = val_count*each_tfr_batch//p.batch_size 
print('training step: {} , validation step: {}'.format(steps_per_epoch,validation_steps))

training step: 13312 , validation step: 2937


# Model construction

In [12]:
erb_inputs = Input(shape=(None, p.nb_erb, 1), name='ERB_input')
spec_inputs = Input(shape=(None, p.nb_df, 2), name='spec_input') # complex
clean_spec = Input(shape=(None, p.fft_size//2+1, 2), name='clean_spec') # complex
noisy_spec = Input(shape=(None, p.fft_size//2+1, 2), name='noisy_spec') # complex

In [13]:
erb_conv0 = convkxf(erb_inputs, p.conv_out_ch, k=3, f=3, fstride=1, bias=False, name='conv0_encoder')
erb_conv1 = convkxf(erb_conv0, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, name='conv1_encoder')
erb_conv2 = convkxf(erb_conv1, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, name='conv2_encoder')
erb_conv3 = convkxf(erb_conv2, p.conv_out_ch, k=1, f=3, fstride=1, bias=False, name='conv3_encoder')

if not p.mask_only: 
    df_conv0 = convkxf(spec_inputs, p.conv_out_ch, k=3, f=3, fstride=1, name='df_conv0_encoder')
    df_conv1 = convkxf(df_conv0, p.conv_out_ch, k=1, f=3, fstride=2, name='df_conv1_encoder')
    
#     shape = df_conv1.get_shape().as_list()
    shape = [tf.shape(df_conv1)[l] for l in range(4)]
    df_conv1 = tf.transpose(df_conv1,(0,1,3,2))
    df_conv1 = tf.reshape(df_conv1, [shape[0], shape[1], shape[2]*shape[3]])
    
    cemb = GroupFC(df_conv1, p.fc_hidden, p.fc_group, name='GFC_encoder')
    
#shape = erb_conv3.get_shape().as_list()    
shape = [tf.shape(erb_conv3)[l] for l in range(4)]
if p.mask_only: 
    emb = tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]])
else:           
    emb = Concatenate()([cemb, tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]])])

GRU_emb = GroupGRULayer(emb, p.gru_hidden, p.gru_group, num_layer=1, name='GGRU0', add_output=True)

In [14]:
GRU_emb2 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU01', add_output=True)
de_emb = GroupFC(GRU_emb2, p.fc_hidden, p.fc_group, activation='relu', name='GFC_decoder')

shape = [tf.shape(erb_conv3)[l] for l in range(4)]
# shape = erb_conv3.get_shape().as_list()
emb_decoder = tf.reshape(de_emb, shape = [shape[0], shape[1], shape[2], shape[3]], name='decoder_reshape')

kwargs = {
    "k": 1,
    "batch_norm": True,
}
tkwargs = {
    "k": 1,
    "batch_norm": True,
    "mode": "transposed"
}
pkwargs = {
    "k": 1,
    "f": 1,
    "batch_norm": True,
}

convp3 = convkxf(erb_conv3, out_ch=p.conv_out_ch, name='convp3', **pkwargs)  # Conv
vt3_in = convp3 + emb_decoder
convt3 = convkxf(vt3_in, out_ch=p.conv_out_ch, fstride=1, name='convt3', **kwargs) #ConvT

convp2 = convkxf(erb_conv2, out_ch=p.conv_out_ch, name='convp2', **pkwargs) # Conv
vt2_in = convp2 + convt3
convt2 = convkxf(vt2_in, out_ch=p.conv_out_ch, name='convt2', **tkwargs)

convp1 = convkxf(erb_conv1, out_ch=p.conv_out_ch, name='convp1', **pkwargs) # Conv
vt1_in = convp1 + convt2
convt1 = convkxf(vt1_in, out_ch=p.conv_out_ch, name='convt1', **tkwargs)

convp0 = convkxf(erb_conv0, out_ch=p.conv_out_ch, training=True, name='convp0', **pkwargs) # Conv
vt0_in = convp0 + convt1
mask_out = convkxf(vt0_in, out_ch=1, k=1, fstride=1, act='sigmoid', training=True, reshape=True, name='mask_out')

In [15]:
if not p.mask_only:
    GRU_emb2 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU1', add_output=True)

    convp = convkxf(df_conv0, 2*p.df_order, k=1, f=1, complex_in=True, batch_norm=True, name='convp_DfDecoder') 
    convp = tf.transpose(convp, (0,1,3,2))

    df_alpha = sigmoid(Dense(1, name='convp_alpha')(GRU_emb2))
    c = tanh(Dense(p.nb_df*p.df_order*2, name='convp_c')(GRU_emb2))

    shape = [tf.shape(c)[k] for k in range(2)]
#     shape = c.get_shape().as_list()
    c = tf.reshape(c,(shape[0], shape[1], p.df_order*2, p.nb_df))
        
    c = tf.reshape(c + convp,(shape[0], shape[1], p.df_order, 2, p.nb_df))

    df_coeff = tf.transpose(c, [0, 1, 2, 4, 3], name='output_c')

# Enhance operation

In [16]:
spec = mask_operations(noisy_spec, mask_out) # mask gain
enhanced = df_operations(spec, df_coeff, df_alpha) # deep filter

# Callbacks

In [19]:
callbacks_nan = TerminateOnNaN()
callbacks_earlystop = EarlyStopping(patience=5, mode='min', restore_best_weights=True)
# callbacks_csvlog = CSVLogger('train.log')

In [20]:
class ModelCheckpoint(Callback):
    def __init__(self, model, path):
        super().__init__()
        
        self.model = model
        self.path = path
        if not os.path.exists(path): os.makedirs(path)
        self.best_loss = np.inf
        
    def on_epoch_end(self, epoch, logs=None):
        loss = logs['val_loss']
        if loss < self.best_loss:
            print('Saving model to {}'.format(self.path.format(epoch=epoch, loss=loss)))
            print(', Validation loss decreased from {} to {}.\n'.format(self.best_loss, loss))
            self.model.save_weights(self.path.format(epoch=epoch, loss=loss), overwrite=True)
            self.best_loss = loss

In [21]:
import resource
from tensorflow.compat.v1 import set_random_seed
class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        print('Memory usage: ', resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
        
    def on_batch_end(self, batch, logs=None):
        if int(K.get_value(self.model.optimizer.iterations))%100 == 0:
            gc.collect()

In [22]:
class BatchCallback(Callback):
    def __init__(self, callbacks_tensorboard, logdir):
        self.tb_callback = callbacks_tensorboard
        self.logdir = logdir
        self.train_writer = tf.summary.create_file_writer(os.path.join(self.logdir,"training"))
        self.val_writer = tf.summary.create_file_writer(os.path.join(self.logdir,"validation"))
        
        self.loss_tag = 'Epoch Summary/'
    
    def on_train_begin(self, logs={}):
        global_step = int(K.get_value(self.model.optimizer.iterations))
        print('[INFO][BatchCallback] global step:{}'.format(global_step))
        
    def on_train_batch_end(self, batch, logs={}):
        logs.update({'global_step': K.get_value(self.model.optimizer.iterations)})
        logs.update({'loss': logs['loss']})
        logs.update({'maskloss': logs['maskloss']})
        if not p.mask_only:
#             logs.update({'dfalphaloss': logs['dfalphaloss']})
            logs.update({'spectralloss': logs['spectralloss']})

        self._write_log(logs)
        
    def on_epoch_end(self, epoch, logs={}):
        train_writer = self.train_writer
        with train_writer.as_default():
            for key, val in logs.items():
                if key in ['loss']: tf.summary.scalar('loss', val, step=epoch)
        train_writer.flush()
        
        val_writer = self.val_writer
        with val_writer.as_default():
            for key, val in logs.items():
                if key in ['loss','val_loss']: tf.summary.scalar('loss', val, step=epoch)
        val_writer.flush()
        
    def _write_log(self, logs):
        writer = self.train_writer
        with writer.as_default():
            for key, val in logs.items():
                if key in ['maskloss', 'dfalphaloss', 'spectralloss']:
                    tag = 'Train(Metric)/' + key
                    tf.summary.scalar(tag, val, step=logs['global_step'])
                if key in ['loss']:
                    tag = 'Train(Total)/' + key.upper()
                    tf.summary.scalar(tag, val, step=logs['global_step'])
                
        writer.flush()

In [23]:
# Tensorboard
logdir = "./logs_2"
callbacks_tensorboard = TensorBoard(log_dir=logdir)
tb_callback = BatchCallback(callbacks_tensorboard, logdir)

# Memory clear
Memory_callback=ClearMemory()

# Loss

In [24]:
maskloss = Lambda(lambda x:MaskLoss(*x, factor=p.mask_factor, r=p.mask_gamma), name='maskloss')([mask_out, clean_spec, noisy_spec])

In [25]:
lsnr_min, lsnr_max = -15, 35
lsnr = LocalSnrTarget(ws=20, target_snr_range=[lsnr_min - 5, lsnr_max + 5])
lsnr_gt = lsnr.forward(clean_spec, noise= noisy_spec - clean_spec, max_bin=p.nb_df)

In [26]:
# if not p.mask_only: 
#     dfalphaloss = Lambda(lambda x:DfAlphaLoss(*x, factor=p.alpha_factor), name='dfalphaloss')([df_alpha, lsnr_gt])
spectralloss = Lambda(lambda x:SpectralLoss(*x, gamma=p.df_gamma, factor_mag=p.df_factor, factor_img=p.df_factor), name='spectralloss')([enhanced, clean_spec])

# Model compile

In [28]:
if p.mask_only: 
    if p.sisnr:
        inputs = [erb_inputs, clean_spec, noisy_spec]
        outputs = [sisnr, maskloss]
    else:
        inputs = [erb_inputs, clean_spec, noisy_spec]
        outputs = [spectralloss, maskloss]
else:
    if p.sisnr: 
        inputs = [erb_inputs, spec_inputs, clean_spec, noisy_spec]
        outputs = [sisnr, maskloss, spectralloss]
#         outputs = [sisnr, maskloss, dfalphaloss, spectralloss]
    else:
        inputs = [erb_inputs, spec_inputs, clean_spec, noisy_spec]
        outputs = [maskloss, spectralloss]
#         outputs = [maskloss, dfalphaloss, spectralloss]

In [29]:
model=Model(inputs=inputs, outputs=outputs, name='DfNet')

In [30]:
model.summary()

Model: "DfNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 ERB_input (InputLayer)         [(None, None, 32, 1  0           []                               
                                )]                                                                
                                                                                                  
 tf.compat.v1.pad (TFOpLambda)  (None, None, 32, 1)  0           ['ERB_input[0][0]']              
                                                                                                  
 spec_input (InputLayer)        [(None, None, 96, 2  0           []                               
                                )]                                                                
                                                                                              

                                                                                                  
 tf.math.less_2 (TFOpLambda)    (None, None, 8, 64)  0           ['conv2_encoder_1x1[0][0]']      
                                                                                                  
 tf.compat.v1.shape (TFOpLambda  (4,)                0           ['tf.where_5[0][0]']             
 )                                                                                                
                                                                                                  
 tf.compat.v1.shape_1 (TFOpLamb  (4,)                0           ['tf.where_5[0][0]']             
 da)                                                                                              
                                                                                                  
 tf.__operators__.getitem_2 (Sl  ()                  0           ['tf.compat.v1.shape_2[0][0]']   
 icingOpLa

                                                                                                  
 tf.compat.v1.transpose_1 (TFOp  (None, None, 8, 64)  0          ['tf.stack[0][0]']               
 Lambda)                                                                                          
                                                                                                  
 tf.compat.v1.shape_7 (TFOpLamb  (4,)                0           ['tf.compat.v1.transpose_1[0][0]'
 da)                                                             ]                                
                                                                                                  
 tf.compat.v1.shape_6 (TFOpLamb  (4,)                0           ['tf.compat.v1.transpose_1[0][0]'
 da)                                                             ]                                
                                                                                                  
 tf.compat

                                ]                                                                 
                                                                                                  
 GGRU00_0 (GRU)                 (None, None, 256)    984576      ['tf.split_1[0][0]']             
                                                                                                  
 tf.stack_1 (TFOpLambda)        (None, None, 256, 1  0           ['GGRU00_0[0][0]']               
                                )                                                                 
                                                                                                  
 tf.compat.v1.shape_15 (TFOpLam  (4,)                0           ['tf.stack_1[0][0]']             
 bda)                                                                                             
                                                                                                  
 tf.compat

                                                                 , 'tf.__operators__.getitem_16[0]
                                                                 [0]',                            
                                                                  'tf.__operators__.getitem_17[0][
                                                                 0]',                             
                                                                  'tf.math.multiply_4[0][0]']     
                                                                                                  
 tf.split_3 (TFOpLambda)        [(None, None, 256)]  0           ['tf.reshape_4[0][0]']           
                                                                                                  
 GGRU011_0 (GRU)                (None, None, 256)    394752      ['tf.split_3[0][0]']             
                                                                                                  
 tf.stack_

                                                                  'GFC_decoder_6[0][0]',          
                                                                  'GFC_decoder_7[0][0]']          
                                                                                                  
 tf.compat.v1.transpose_4 (TFOp  (None, None, 8, 64)  0          ['tf.stack_4[0][0]']             
 Lambda)                                                                                          
                                                                                                  
 tf.compat.v1.shape_27 (TFOpLam  (4,)                0           ['tf.compat.v1.transpose_4[0][0]'
 bda)                                                            ]                                
                                                                                                  
 tf.compat.v1.shape_26 (TFOpLam  (4,)                0           ['tf.compat.v1.transpose_4[0][0]'
 bda)     

 tf.reshape_7 (TFOpLambda)      (None, None, 8, 64)  0           ['tf.where_6[0][0]',             
                                                                  'tf.__operators__.getitem_28[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_29[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_30[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_31[0][
                                                                 0]']                             
                                                                                                  
 tf.__oper

                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
          

                                                                                                  
 conv2d_transpose_9 (Conv2DTran  (None, None, 16, 1)  4          ['tf.split_5[0][9]']             
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_10 (Conv2DTra  (None, None, 16, 1)  4          ['tf.split_5[0][10]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_11 (Conv2DTra  (None, None, 16, 1)  4          ['tf.split_5[0][11]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_tr

 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_37 (Conv2DTra  (None, None, 16, 1)  4          ['tf.split_5[0][37]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_38 (Conv2DTra  (None, None, 16, 1)  4          ['tf.split_5[0][38]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_39 (Conv2DTra  (None, None, 16, 1)  4          ['tf.split_5[0][39]']            
 nspose)                                                                                          
          

 convp1 (Conv2D)                (None, None, 16, 64  128         ['tf.where_1[0][0]']             
                                )                                                                 
                                                                                                  
 tf.concat (TFOpLambda)         (None, None, 16, 64  0           ['conv2d_transpose[0][0]',       
                                )                                 'conv2d_transpose_1[0][0]',     
                                                                  'conv2d_transpose_2[0][0]',     
                                                                  'conv2d_transpose_3[0][0]',     
                                                                  'conv2d_transpose_4[0][0]',     
                                                                  'conv2d_transpose_5[0][0]',     
                                                                  'conv2d_transpose_6[0][0]',     
          

 tf.math.less_10 (TFOpLambda)   (None, None, 16, 64  0           ['batch_normalization_3[0][0]']  
                                )                                                                 
                                                                                                  
 tf.where_11 (TFOpLambda)       (None, None, 16, 64  0           ['tf.math.less_11[0][0]',        
                                )                                 'batch_normalization_4[0][0]']  
                                                                                                  
 tf.where_10 (TFOpLambda)       (None, None, 16, 64  0           ['tf.math.less_10[0][0]',        
                                )                                 'batch_normalization_3[0][0]']  
                                                                                                  
 tf.split_7 (TFOpLambda)        [(None, None, 256)]  0           ['tf.reshape_3[0][0]']           
          

                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
          

 conv2d_transpose_70 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][6]']             
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_71 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][7]']             
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_72 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][8]']             
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_73 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][9]']             
 nspose)  

                                                                                                  
 conv2d_transpose_98 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][34]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_99 (Conv2DTra  (None, None, 32, 1)  4          ['tf.split_6[0][35]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_100 (Conv2DTr  (None, None, 32, 1)  4          ['tf.split_6[0][36]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_tr

 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_126 (Conv2DTr  (None, None, 32, 1)  4          ['tf.split_6[0][62]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_127 (Conv2DTr  (None, None, 32, 1)  4          ['tf.split_6[0][63]']            
 anspose)                                                                                         
                                                                                                  
 tf.compat.v1.transpose_5 (TFOp  (None, None, 1, 256  0          ['tf.stack_5[0][0]']             
 Lambda)                        )                                                                 
          

 bda)                                                            ]                                
                                                                                                  
 convp0_1x1 (Conv2D)            (None, None, 32, 64  4096        ['convp0[0][0]']                 
                                )                                                                 
                                                                                                  
 convt1_1x1 (Conv2D)            (None, None, 32, 64  4096        ['tf.concat_1[0][0]']            
                                )                                                                 
                                                                                                  
 tf.compat.v1.shape_32 (TFOpLam  (4,)                0           ['tf.compat.v1.transpose_5[0][0]'
 bda)                                                            ]                                
          

 bda)                                                                                             
                                                                                                  
 tf.__operators__.getitem_39 (S  ()                  0           ['tf.compat.v1.shape_39[0][0]']  
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_38 (S  ()                  0           ['tf.compat.v1.shape_38[0][0]']  
 licingOpLambda)                                                                                  
                                                                                                  
 tf.compat.v1.squeeze (TFOpLamb  (None, None, 32)    0           ['tf.math.sigmoid[0][0]']        
 da)                                                                                              
          

 Lambda)                        5)                                                                
                                                                                                  
 tf.image.extract_patches_1 (TF  (None, None, None,   0          ['tf.split_9[0][1]']             
 OpLambda)                      5)                                                                
                                                                                                  
 tf.reshape_10 (TFOpLambda)     (None, None, 10, 96  0           ['tf.math.tanh[0][0]',           
                                )                                 'tf.__operators__.getitem_40[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_41[0][
                                                                 0]']                             
          

 tf.math.subtract (TFOpLambda)  (None, None, 5, 96)  0           ['tf.math.multiply_10[0][0]',    
                                                                  'tf.math.multiply_11[0][0]']    
                                                                                                  
 tf.__operators__.add_7 (TFOpLa  (None, None, 5, 96)  0          ['tf.math.multiply_12[0][0]',    
 mbda)                                                            'tf.math.multiply_13[0][0]']    
                                                                                                  
 convp_alpha (Dense)            (None, None, 1)      257         ['tf.__operators__.add_5[0][0]'] 
                                                                                                  
 tf.expand_dims_3 (TFOpLambda)  (None, None, 5, 96,  0           ['tf.math.subtract[0][0]']       
                                 1)                                                               
          

In [31]:
def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
    """Cosine decay schedule with warm up period.
    Cosine annealing learning rate as described in:
      Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
      ICLR 2017. https://arxiv.org/abs/1608.03983
    In this schedule, the learning rate grows linearly from warmup_learning_rate
    to learning_rate_base for warmup_steps, then transitions to a cosine decay
    schedule.
    Arguments:
        global_step {int} -- global step.
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.
    Keyword Arguments:
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
    Returns:
      a float representing learning rate.
    Raises:
      ValueError: if warmup_learning_rate is larger than learning_rate_base,
        or if warmup_steps is larger than total_steps.
    """

    if total_steps < warmup_steps:
        raise ValueError('total_steps must be larger or equal to '
                         'warmup_steps.')
#     learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
#         np.pi *
#         (global_step - warmup_steps - hold_base_rate_steps
#          ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
    learning_rate = learning_rate_base * 0.95 ** ((global_step - warmup_steps - hold_base_rate_steps) / (steps_per_epoch/25))
    if hold_base_rate_steps > 0:
        learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
                                 learning_rate, learning_rate_base)
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('learning_rate_base must be larger or equal to '
                             'warmup_learning_rate.')
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        warmup_rate = slope * global_step + warmup_learning_rate
        learning_rate = np.where(global_step < warmup_steps, warmup_rate,
                                 learning_rate)
    return np.where(global_step > total_steps, 0.0, learning_rate)


class WarmUpCosineDecayScheduler(keras.callbacks.Callback):
    """Cosine decay with warmup learning rate scheduler
    """

    def __init__(self,
                 learning_rate_base,
                 total_steps,
                 global_step_init=0,
                 warmup_learning_rate=0.0,
                 warmup_steps=0,
                 hold_base_rate_steps=0,
                 mini_lr=0,
                 verbose=0):
        """Constructor for cosine decay with warmup learning rate scheduler.
    Arguments:
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.
    Keyword Arguments:
        global_step_init {int} -- initial global step, e.g. from previous checkpoint.
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
        verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpCosineDecayScheduler, self).__init__()
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.global_step = global_step_init
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.hold_base_rate_steps = hold_base_rate_steps
        self.verbose = verbose
#         self.learning_rates = []
        self.mini_lr = mini_lr

    def on_train_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
#         lr = K.get_value(self.model.optimizer.lr)
#         self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = cosine_decay_with_warmup(global_step=self.global_step,
                                      learning_rate_base=self.learning_rate_base,
                                      total_steps=self.total_steps,
                                      warmup_learning_rate=self.warmup_learning_rate,
                                      warmup_steps=self.warmup_steps,
                                      hold_base_rate_steps=self.hold_base_rate_steps)
        if lr < self.mini_lr: lr = self.mini_lr
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (self.global_step + 1, lr))

In [32]:
warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=p.lr,
                                        total_steps=p.epochs*steps_per_epoch,
                                        warmup_learning_rate=0.0,
                                        warmup_steps=steps_per_epoch,
                                        hold_base_rate_steps=0,
                                        mini_lr=1e-6)

In [33]:
# Checkpoint
cp_callback = ModelCheckpoint(model, './weights2_2/weights.{epoch:02d}-{loss:.2f}.h5')
callbacks_overall = [warm_up_lr, tb_callback, cp_callback, 
                     callbacks_nan, callbacks_earlystop, Memory_callback]

In [34]:
if p.mask_only: 
    if p.sisnr:
        LOSS = sisnr
        model.add_metric(maskloss, name = "maskloss")
    else:
        LOSS = spectralloss
        model.add_metric(maskloss, name = "maskloss")
        model.add_metric(spectralloss, name = "spectralloss")
else:
    if p.sisnr:
        LOSS = sisnr
        model.add_metric(maskloss, name = "maskloss")
        model.add_metric(spectralloss, name = "spectralloss")
    else:
        LOSS = spectralloss
        model.add_metric(maskloss, name = "maskloss")
        model.add_metric(spectralloss, name = "spectralloss")
model.add_loss(LOSS)

In [35]:
optimizer = AdamW(learning_rate=p.lr, weight_decay=0.0)

In [36]:
model.compile(optimizer=optimizer, run_eagerly=True) #, loss_weights={"dfalphaloss": 1, "spectralloss": 20}) 

# Pre-train weight

In [37]:
h5_lists = os.listdir("./weights2_1")
h5_lists.sort(key=lambda fn:os.path.getmtime("./weights2_1/" + fn)
                if not os.path.isdir("./weights2_1/" + fn) else 0)
new_file = os.path.join("./weights2_1/", h5_lists[-1])
print(new_file)

./weights2_1/weights.07-56.01.h5


In [38]:
# store weights before loading pre-trained weights
preloaded_layers = model.layers.copy()
preloaded_weights = []
for pre in preloaded_layers:
    preloaded_weights.append(pre.get_weights())

# load pre-trained weights
model.load_weights(new_file, by_name=True)

# compare previews weights vs loaded weights
for layer, pre in zip(model.layers, preloaded_weights):
    weights = layer.get_weights()

    if weights:
        if np.array_equal(weights, pre):
            print('not loaded', layer.name)
        else:
            print('loaded', layer.name)
            pass

loaded conv0_encoder
loaded df_conv0_encoder
loaded df_conv0_encoder_1x1
loaded conv1_encoder
loaded conv1_encoder_1x1
loaded df_conv1_encoder
loaded df_conv1_encoder_1x1
loaded conv2_encoder
loaded conv2_encoder_1x1
loaded conv3_encoder
loaded conv3_encoder_1x1
loaded GFC_encoder_0
loaded GFC_encoder_1
loaded GFC_encoder_2
loaded GFC_encoder_3
loaded GFC_encoder_4
loaded GFC_encoder_5
loaded GFC_encoder_6
loaded GFC_encoder_7
loaded GGRU00_0
loaded GGRU010_0
loaded GGRU011_0
loaded GFC_decoder_0
loaded GFC_decoder_1
loaded GFC_decoder_2
loaded GFC_decoder_3
loaded GFC_decoder_4
loaded GFC_decoder_5
loaded GFC_decoder_6
loaded GFC_decoder_7
loaded convp3
loaded convp3_1x1
loaded batch_normalization
loaded convp2
loaded convt3
loaded convp2_1x1
loaded convt3_1x1
loaded batch_normalization_2
loaded batch_normalization_1
loaded conv2d_transpose
loaded conv2d_transpose_1
loaded conv2d_transpose_2
loaded conv2d_transpose_3
loaded conv2d_transpose_4
loaded conv2d_transpose_5
loaded conv2d_tr

  a1, a2 = asarray(a1), asarray(a2)


# Training

In [None]:
history = model.fit(dataset,
                    epochs=p.epochs,
                    validation_data=val_dataset,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    callbacks=callbacks_overall,
                    use_multiprocessing = True,
                    workers=2
                    )

[INFO][BatchCallback] global step:0
Epoch 1/20
, Validation loss decreased from inf to 60.84336853027344.

Memory usage:  33920744
Epoch 2/20
, Validation loss decreased from 60.84336853027344 to 58.52476501464844.

Memory usage:  33995892
Epoch 3/20
  644/13312 [>.............................] - ETA: 2:07:30 - loss: 60.7838 - maskloss: 0.2198 - spectralloss: 60.7838 - global_step: 26946.5000

In [None]:
model.save_weights("weights.hdf5")

In [None]:
h5_lists = os.listdir("./weights2_2")
h5_lists.sort(key=lambda fn:os.path.getmtime("./weights2_2/" + fn)
                if not os.path.isdir("./weights2_2/" + fn) else 0)
new_file = os.path.join("./weights2_2/", h5_lists[-1])
print(new_file)

model.load_weights(new_file)
# # model.save("weights.hdf5")

In [None]:
%run DfNet2_test.ipynb