In [1]:
import autoreload
%load_ext autoreload
# %autoreload 2

In [2]:
%reload_ext autoreload

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  

In [4]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input , Dense, Lambda, Concatenate, BatchNormalization
from tensorflow.keras.activations import relu, sigmoid, tanh
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD, schedules
from tensorflow_addons.optimizers import AdamW

from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN, TensorBoard, \
                                        EarlyStopping, ReduceLROnPlateau, CSVLogger, Callback

import numpy as np
import random
import os

import gc

from tqdm import tqdm 

In [5]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [6]:
import import_ipynb

from modules import convkxf, Conv2DBNAct, GroupGRULayer, GroupFC
from loss import MaskLoss, LocalSnrTarget, DfAlphaLoss, SpectralLoss, SISDRloss, SpectralLoss_weight
from utils import mask_operations, df_operations, synthesis_frame
from params import model_params
from dataloader import read_tfrecod_data

importing Jupyter notebook from modules.ipynb
importing Jupyter notebook from loss.ipynb
importing Jupyter notebook from bandERB.ipynb
importing Jupyter notebook from params.ipynb
importing Jupyter notebook from utils.ipynb
importing Jupyter notebook from dataloader.ipynb


In [7]:
from Callbacks import WarmUpCosineDecayScheduler, BatchCallback, ClearMemory, ModelCheckpoint, TrainCallback, cosine_decay_with_warmup

importing Jupyter notebook from Callbacks.ipynb
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Read config

In [8]:
p = model_params('config.ini')
initial_epoch = 39

# Read data

In [9]:
dataset_path = '/home2/user/myhsueh/dataset/tfrecord/v3/'

training_set_path = [
    dataset_path + 'boost/training',
    dataset_path + 'random/training',
    dataset_path + 'training/random',
    dataset_path + 'training/car',
    dataset_path + 'training/typing',
    dataset_path + 'training/ambience'
    
]
validation_set_path = [
    dataset_path + 'boost/validation',
    dataset_path + 'random/validation',
    dataset_path + 'validation/random',
    dataset_path + 'validation/car',
    dataset_path + 'validation/typing',
    dataset_path + 'validation/ambience'
]
combine_train, combine_val = False, False
if len(training_set_path) > 1 : combine_train = True
if len(validation_set_path) > 1 : combine_val = True

In [10]:
dataset = read_tfrecod_data(training_set_path, training=True, repeat=p.epochs-initial_epoch)

['/home2/user/myhsueh/dataset/tfrecord/v3/boost/training', '/home2/user/myhsueh/dataset/tfrecord/v3/random/training', '/home2/user/myhsueh/dataset/tfrecord/v3/training/random', '/home2/user/myhsueh/dataset/tfrecord/v3/training/car', '/home2/user/myhsueh/dataset/tfrecord/v3/training/typing', '/home2/user/myhsueh/dataset/tfrecord/v3/training/ambience'] , file directory exist: True
160
TD_427-3-of-4.tfrecord, TD_444-3-of-4.tfrecord, TD_427-2-of-4.tfrecord, TD_433-3-of-4.tfrecord
TD_431-4-of-4.tfrecord, TD_421-3-of-4.tfrecord, TD_425-4-of-4.tfrecord, TD_430-2-of-4.tfrecord
TD_418-4-of-4.tfrecord, TD_422-1-of-4.tfrecord, TD_418-1-of-4.tfrecord, TD_448-3-of-4.tfrecord
TD_420-4-of-4.tfrecord, TD_406-4-of-4.tfrecord, TD_426-2-of-4.tfrecord, TD_404-2-of-4.tfrecord
TD_418-3-of-4.tfrecord, TD_415-3-of-4.tfrecord, TD_434-2-of-4.tfrecord, TD_428-3-of-4.tfrecord
TD_401-4-of-4.tfrecord, TD_405-2-of-4.tfrecord, TD_405-3-of-4.tfrecord, TD_437-3-of-4.tfrecord
TD_425-3-of-4.tfrecord, TD_400-3-of-4.tfreco

done erb
done spec
Tensor("Reshape_6:0", shape=(282, 513, 2), dtype=float32)
ERB shape: (282, 32, 1) Tensor("Squeeze_282:0", shape=(282, 32, 1), dtype=float32)
NOISY_SPEC_df shape: (282, 96, 2) Tensor("Reshape_4:0", shape=(282, 96, 2), dtype=float32)
CLEAN_SPEC shape: (282, 513, 2) Tensor("Reshape_6:0", shape=(282, 513, 2), dtype=float32)
NOISY_SPEC shape: (282, 513, 2) Tensor("Reshape_5:0", shape=(282, 513, 2), dtype=float32)
['/home2/user/myhsueh/dataset/tfrecord/v3/boost/training', '/home2/user/myhsueh/dataset/tfrecord/v3/random/training', '/home2/user/myhsueh/dataset/tfrecord/v3/training/random', '/home2/user/myhsueh/dataset/tfrecord/v3/training/car', '/home2/user/myhsueh/dataset/tfrecord/v3/training/typing', '/home2/user/myhsueh/dataset/tfrecord/v3/training/ambience'] , file directory exist: True
32
TD_55-4-of-4.tfrecord, TD_52-2-of-4.tfrecord, TD_57-3-of-4.tfrecord, TD_50-3-of-4.tfrecord
TD_58-4-of-4.tfrecord, TD_51-2-of-4.tfrecord, TD_57-2-of-4.tfrecord, TD_58-1-of-4.tfrecord
TD

In [11]:
val_dataset = read_tfrecod_data(validation_set_path, training=False, repeat=p.epochs-initial_epoch)

['/home2/user/myhsueh/dataset/tfrecord/v3/boost/validation', '/home2/user/myhsueh/dataset/tfrecord/v3/random/validation', '/home2/user/myhsueh/dataset/tfrecord/v3/validation/random', '/home2/user/myhsueh/dataset/tfrecord/v3/validation/car', '/home2/user/myhsueh/dataset/tfrecord/v3/validation/typing', '/home2/user/myhsueh/dataset/tfrecord/v3/validation/ambience'] , file directory exist: True
40
TD_421-1-of-4.tfrecord, TD_423-2-of-4.tfrecord, TD_438-3-of-4.tfrecord, TD_406-3-of-4.tfrecord
TD_403-4-of-4.tfrecord, TD_445-4-of-4.tfrecord, TD_447-1-of-4.tfrecord, TD_404-4-of-4.tfrecord
TD_424-1-of-4.tfrecord, TD_437-1-of-4.tfrecord, TD_449-1-of-4.tfrecord, TD_436-3-of-4.tfrecord
TD_405-4-of-4.tfrecord, TD_406-1-of-4.tfrecord, TD_431-1-of-4.tfrecord, TD_408-1-of-4.tfrecord
TD_403-3-of-4.tfrecord, TD_411-2-of-4.tfrecord, TD_441-2-of-4.tfrecord, TD_438-4-of-4.tfrecord
TD_442-2-of-4.tfrecord, TD_421-4-of-4.tfrecord, TD_418-2-of-4.tfrecord, TD_448-2-of-4.tfrecord
TD_414-3-of-4.tfrecord, TD_431-3-

done erb
done spec
Tensor("Reshape_6:0", shape=(282, 513, 2), dtype=float32)
ERB shape: (282, 32, 1) Tensor("Squeeze_282:0", shape=(282, 32, 1), dtype=float32)
NOISY_SPEC_df shape: (282, 96, 2) Tensor("Reshape_4:0", shape=(282, 96, 2), dtype=float32)
CLEAN_SPEC shape: (282, 513, 2) Tensor("Reshape_6:0", shape=(282, 513, 2), dtype=float32)
NOISY_SPEC shape: (282, 513, 2) Tensor("Reshape_5:0", shape=(282, 513, 2), dtype=float32)


In [12]:
each_tfr_batch = 6000 // 4 // p.length_sec
train_count, val_count = 0, 0
if combine_train:
    for i in training_set_path:
        train_count += len(os.listdir(i))
else:
    train_count = len(os.listdir(training_set_path[0]))
    
if combine_val:
    for i in validation_set_path:
        val_count += len(os.listdir(i))
else:
    val_count = len(os.listdir(validation_set_path[0]))

steps_per_epoch = train_count*each_tfr_batch//p.batch_size 
validation_steps = val_count*each_tfr_batch//p.batch_size 
print('training step: {} , validation step: {}'.format(steps_per_epoch,validation_steps))

training step: 7609 , validation step: 1843


# Model construction

In [13]:
BN_type = "normal" # range normal
p.fc_group = 8
p.gru_group = 1

In [14]:
erb_inputs = Input(shape=(None, p.nb_erb, 1), name='ERB_input')
spec_inputs = Input(shape=(None, p.nb_df, 2), name='spec_input') # complex
clean_spec = Input(shape=(None, p.fft_size//2+1, 2), name='clean_spec') # complex
noisy_spec = Input(shape=(None, p.fft_size//2+1, 2), name='noisy_spec') # complex

In [15]:
erb_conv0 = convkxf(erb_inputs, p.conv_out_ch, k=3, f=3, fstride=1, bias=False, batch_norm=True, training=True, 
                     BN_type = BN_type, name='conv0_encoder')
erb_conv1 = convkxf(erb_conv0, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True,
                     BN_type = BN_type, name='conv1_encoder')
erb_conv2 = convkxf(erb_conv1, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True,
                     BN_type = BN_type, name='conv2_encoder')
erb_conv3 = convkxf(erb_conv2, p.conv_out_ch, k=1, f=3, fstride=1, bias=False, batch_norm=True,
                     BN_type = BN_type, name='conv3_encoder')

df_conv0 = convkxf(spec_inputs, p.conv_out_ch, k=3, f=3, fstride=1, batch_norm=True, training=True,
                    BN_type = BN_type, name='df_conv0_encoder')
df_conv1 = convkxf(df_conv0, p.conv_out_ch, k=1, f=3, fstride=2, batch_norm=True,
                    BN_type = BN_type, name='df_conv1_encoder')

shape = [tf.shape(df_conv1)[l] for l in range(4)]
df_conv1 = tf.transpose(df_conv1,(0,1,3,2))
df_conv1 = tf.reshape(df_conv1, [shape[0], shape[1], shape[2]*shape[3]])

cemb = GroupFC(df_conv1, p.fc_hidden, p.fc_group, name='GFC_encoder')
    
shape = [tf.shape(erb_conv3)[l] for l in range(4)]
emb = Concatenate()([tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]]), cemb])

if p.model_ver == "3.0":
    embfc = GroupFC(emb, p.gru_hidden, p.fc_group, name='GFC_emb')
    GRU_emb = GroupGRULayer(embfc, p.gru_hidden, p.gru_group, num_layer=1, name='GGRU0', add_output=True)
else:
    GRU_emb = GroupGRULayer(emb, p.gru_hidden, p.gru_group, num_layer=1, name='GGRU0', add_output=True)

{'filters': 64, 'groups': 1, 'kernel_size': (3, 3), 'strides': (1, 1), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'conv0_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 2), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'conv1_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 2), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'conv2_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 1), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'conv3_encoder'}
{'filters': 64, 'groups': 2, 'kernel_size': (3, 3), 'strides': (1, 1), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'df_conv0_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 2), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid'

In [16]:
GRU_emb1 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU01', add_output=True)
de_emb = GroupFC(GRU_emb1, p.fc_hidden, p.fc_group, activation='relu', 
                 name='GFC_decoder')

shape = [tf.shape(erb_conv3)[l] for l in range(4)]
emb_decoder = tf.reshape(de_emb, shape = [shape[0], shape[1], shape[2], -1], name='decoder_reshape')

kwargs = {
    "k": 1,
    "batch_norm": True,
    "BN_type": BN_type,
}
tkwargs = {
    "k": 1,
    "batch_norm": True,
    "mode": "transposed",
    "BN_type": BN_type,
}
pkwargs = {
    "k": 1,
    "f": 1,
    "batch_norm": True,
    "BN_type": BN_type,
}

convp3 = convkxf(erb_conv3, out_ch=p.conv_out_ch, name='convp3', **pkwargs)  # Conv
vt3_in = emb_decoder + convp3
convt3 = convkxf(vt3_in, out_ch=p.conv_out_ch, fstride=1, name='convt3', **kwargs) #ConvT

convp2 = convkxf(erb_conv2, out_ch=p.conv_out_ch, name='convp2', **pkwargs) # Conv
vt2_in = convt3 + convp2
convt2 = convkxf(vt2_in, out_ch=p.conv_out_ch, name='convt2', **tkwargs)

convp1 = convkxf(erb_conv1, out_ch=p.conv_out_ch, name='convp1', **pkwargs) # Conv
vt1_in = convt2  + convp1
convt1 = convkxf(vt1_in, out_ch=p.conv_out_ch, name='convt1', **tkwargs)

convp0 = convkxf(erb_conv0, out_ch=p.conv_out_ch, name='convp0', **pkwargs) # Conv
vt0_in = convt1 + convp0
mask_out = convkxf(vt0_in, out_ch=1, k=1, fstride=1, act='sigmoid', batch_norm=False, 
                   reshape=True, name='mask_out')

{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convp3'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 1), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convt3'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convp2'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convp1'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convp0'}
{'filters': 1, 'groups': 1, 'kernel_size': (1, 3), 'strides': (1, 1), 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'mask_out'}


In [17]:
GRU_emb2 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU1', add_output=True)
convp = convkxf(df_conv0, 2*p.df_order, k=1, f=1, complex_in=True, batch_norm=True, 
                BN_type = BN_type, name='convp_DfDecoder') 
convp = tf.transpose(convp, (0,1,3,2))

df_alpha = Dense(1, name='convp_alpha', activation='sigmoid')(GRU_emb2)
c = GroupFC(GRU_emb2, p.nb_df*p.df_order*2, p.fc_group, activation='tanh', name='GFC_c')

shape = [tf.shape(c)[k] for k in range(2)]
c = tf.reshape(c,(shape[0], shape[1], p.df_order*2, p.nb_df))

c = tf.reshape(c + convp,(shape[0], shape[1], p.df_order, 2, p.nb_df))

df_coeff = tf.transpose(c, [0, 1, 2, 4, 3], name='output_c')

{'filters': 10, 'groups': 2, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'kernel_initializer': 'he_normal', 'padding': 'valid', 'name': 'convp_DfDecoder'}


# Enhance operation

In [18]:
spec = mask_operations(noisy_spec, mask_out) # mask gain
enhanced = df_operations(spec, df_coeff, df_alpha) # deep filter

# Callbacks

In [19]:
callbacks_nan = TerminateOnNaN()
callbacks_earlystop = EarlyStopping(patience=10, mode='min', restore_best_weights=True)
# Memory clear
Memory_callback=ClearMemory()

# Loss

In [20]:
maskloss = Lambda(lambda x:MaskLoss(*x, factor=p.mask_factor, r=p.mask_gamma), 
                  name='maskloss')([mask_out, clean_spec, noisy_spec])
spectralloss = Lambda(lambda x:SpectralLoss(*x, gamma=p.df_gamma, factor_mag=p.df_factor, factor_img=p.df_factor), 
                      name='spectralloss')([enhanced, clean_spec])
# spectralloss2 = Lambda(lambda x:SpectralLoss(*x, gamma=p.df_gamma*2, factor_mag=p.df_factor, factor_img=p.df_factor), 
#                       name='spectralloss2')([enhanced, clean_spec])

# Model compile

In [21]:
inputs = [erb_inputs, spec_inputs, clean_spec, noisy_spec]
outputs = [maskloss,spectralloss]
model=Model(inputs=inputs, outputs=outputs, name='DfNet')

In [22]:
model.summary()

Model: "DfNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 ERB_input (InputLayer)         [(None, None, 32, 1  0           []                               
                                )]                                                                
                                                                                                  
 tf.compat.v1.pad (TFOpLambda)  (None, None, 32, 1)  0           ['ERB_input[0][0]']              
                                                                                                  
 spec_input (InputLayer)        [(None, None, 96, 2  0           []                               
                                )]                                                                
                                                                                              

                                                                                                  
 conv2_encoderBatchNorm (BatchN  (None, None, 8, 64)  256        ['conv2_encoder_1x1[0][0]']      
 ormalization)                                                                                    
                                                                                                  
 tf.compat.v1.shape (TFOpLambda  (4,)                0           ['tf.nn.relu_5[0][0]']           
 )                                                                                                
                                                                                                  
 tf.compat.v1.shape_1 (TFOpLamb  (4,)                0           ['tf.nn.relu_5[0][0]']           
 da)                                                                                              
                                                                                                  
 tf.__oper

                                                                                                  
 tf.compat.v1.transpose_1 (TFOp  (None, None, 8, 64)  0          ['tf.stack[0][0]']               
 Lambda)                                                                                          
                                                                                                  
 tf.compat.v1.shape_10 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)                                                                                             
                                                                                                  
 tf.compat.v1.shape_11 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)                                                                                             
                                                                                                  
 tf.compat

                                                                                                  
 tf.nn.relu_6 (TFOpLambda)      (None, None, 256)    0           ['GGRU00[0][0]']                 
                                                                                                  
 GGRU010 (GRU)                  (None, None, 256)    394752      ['tf.nn.relu_6[0][0]']           
                                                                                                  
 tf.nn.relu_7 (TFOpLambda)      (None, None, 256)    0           ['GGRU010[0][0]']                
                                                                                                  
 GGRU011 (GRU)                  (None, None, 256)    394752      ['tf.nn.relu_7[0][0]']           
                                                                                                  
 tf.nn.relu_8 (TFOpLambda)      (None, None, 256)    0           ['GGRU011[0][0]']                
          

                                                                  'tf.__operators__.getitem_13[0][
                                                                 0]',                             
                                                                  'tf.math.multiply_3[0][0]']     
                                                                                                  
 tf.compat.v1.shape_16 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)                                                                                             
                                                                                                  
 tf.compat.v1.shape_17 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)                                                                                             
                                                                                                  
 tf.compat

                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
                                ),                                                                
                                 (None, None, 8, 1)                                               
                                , (None, None, 8, 1                                               
          

                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, None, 16, 1)  3          ['tf.split_2[0][2]']             
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, None, 16, 1)  3          ['tf.split_2[0][3]']             
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_4 (Conv2DTran  (None, None, 16, 1)  3          ['tf.split_2[0][4]']             
 spose)                                                                                           
                                                                                                  
 conv2d_tr

 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_30 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][30]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_31 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][31]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_32 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][32]']            
 nspose)                                                                                          
          

 conv2d_transpose_57 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][57]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_58 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][58]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_59 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][59]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_60 (Conv2DTra  (None, None, 16, 1)  3          ['tf.split_2[0][60]']            
 nspose)  

                                                                  'conv2d_transpose_62[0][0]',    
                                                                  'conv2d_transpose_63[0][0]']    
                                                                                                  
 convt2_1x1 (Conv2D)            (None, None, 16, 64  4096        ['tf.concat[0][0]']              
                                )                                                                 
                                                                                                  
 convp1 (Conv2D)                (None, None, 16, 64  64          ['tf.nn.relu_1[0][0]']           
                                )                                                                 
                                                                                                  
 convt2BatchNorm (BatchNormaliz  (None, None, 16, 64  256        ['convt2_1x1[0][0]']             
 ation)   

                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
                                ),                                                                
                                 (None, None, 16, 1                                               
          

 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_69 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][5]']             
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_70 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][6]']             
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_71 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][7]']             
 nspose)                                                                                          
          

 conv2d_transpose_96 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][32]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_97 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][33]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_98 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][34]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_99 (Conv2DTra  (None, None, 32, 1)  3          ['tf.split_3[0][35]']            
 nspose)  

                                                                                                  
 conv2d_transpose_124 (Conv2DTr  (None, None, 32, 1)  3          ['tf.split_3[0][60]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_125 (Conv2DTr  (None, None, 32, 1)  3          ['tf.split_3[0][61]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_126 (Conv2DTr  (None, None, 32, 1)  3          ['tf.split_3[0][62]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_tr

                                                                                                  
 GGRU10 (GRU)                   (None, None, 256)    394752      ['tf.nn.relu_6[0][0]']           
                                                                                                  
 convt1BatchNorm (BatchNormaliz  (None, None, 32, 64  256        ['convt1_1x1[0][0]']             
 ation)                         )                                                                 
                                                                                                  
 convp0BatchNorm (BatchNormaliz  (None, None, 32, 64  256        ['convp0[0][0]']                 
 ation)                         )                                                                 
                                                                                                  
 tf.nn.relu_17 (TFOpLambda)     (None, None, 256)    0           ['GGRU10[0][0]']                 
          

                                                                                                  
 tf.reshape_5 (TFOpLambda)      (None, None, 32)     0           ['mask_out[0][0]',               
                                                                  'tf.__operators__.getitem_20[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_21[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_22[0][
                                                                 0]']                             
                                                                                                  
 tf.compat.v1.transpose_5 (TFOp  (None, None, 8, 120  0          ['tf.stack_2[0][0]']             
 Lambda)  

                                 (None, None, 96, 1                                               
                                )]                                                                
                                                                                                  
 tf.__operators__.getitem_28 (S  ()                  0           ['tf.compat.v1.shape_28[0][0]']  
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_29 (S  ()                  0           ['tf.compat.v1.shape_29[0][0]']  
 licingOpLambda)                                                                                  
                                                                                                  
 tf.nn.relu_19 (TFOpLambda)     (None, None, 96, 10  0           ['convp_DfDecoderBatchNorm[0][0]'
          

                                                                  'tf.__operators__.getitem_34[0][
                                                                 0]']                             
                                                                                                  
 tf.math.multiply_8 (TFOpLambda  (None, None, 5, 96)  0          ['tf.__operators__.getitem_35[0][
 )                                                               0]',                             
                                                                  'tf.__operators__.getitem_36[0][
                                                                 0]']                             
                                                                                                  
 tf.math.multiply_9 (TFOpLambda  (None, None, 5, 96)  0          ['tf.__operators__.getitem_37[0][
 )                                                               0]',                             
          

In [23]:
# warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=p.lr,
#                                         total_steps=p.epochs*steps_per_epoch,
#                                         warmup_steps=(p.epochs//10)*steps_per_epoch,
#                                         warmup_learning_rate = 1e-4,
#                                         mini_lr=1e-7,
#                                         steps_per_epoch=steps_per_epoch,
#                                         global_step_init=initial_epoch*steps_per_epoch)
warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=p.lr,
                                        total_steps=p.epochs*steps_per_epoch,
                                        warmup_steps=0*steps_per_epoch,
                                        warmup_learning_rate = 1e-4,
                                        mini_lr=1e-7,
                                        steps_per_epoch=steps_per_epoch,
                                        global_step_init=initial_epoch*steps_per_epoch)

In [24]:
# plt_lr = []
# for i in range(p.epochs):
#     plt_lr.append(cosine_decay_with_warmup(global_step=i,
#                               learning_rate_base=1e-3,
#                               total_steps=30*1,
#                               warmup_learning_rate=1e-4,
#                               warmup_steps=3*1,
#                               hold_base_rate_steps=0,
#                               steps_per_epoch=1))
# for i in range(3):
#     for j in range(10):
#         print('%.7f'%plt_lr[i*10+j], end=' ')
#     print()
# plt_lr[-1]

In [25]:
# Checkpoint
cp_callback = ModelCheckpoint(model, './weights8/weights_newlossadam.{epoch:02d}-{loss:.2f}.h5')
callbacks_overall = [cp_callback, callbacks_nan]#, warm_up_lr]

In [26]:
model.add_metric(maskloss, name = "maskloss")
model.add_metric(spectralloss, name="spectralloss")
# model.add_metric(spectralloss2, name="spectralloss2")
# model.add_loss(maskloss)
model.add_loss(spectralloss)
# model.add_loss(spectralloss2)

In [27]:
# optimizer = AdamW(learning_rate=p.lr, weight_decay=0.0, clipnorm=1.0)
# optimizer = Adam(learning_rate=p.lr, clipnorm=1.0)
optimizer = SGD(learning_rate=p.lr, clipnorm=1.0)

In [28]:
model.compile(optimizer=optimizer, run_eagerly=True)

# Load weights

In [29]:
path = "./weights8/"
h5_lists = os.listdir(path)
h5_lists.sort(key=lambda fn:os.path.getmtime(path + fn)
                if not os.path.isdir(path + fn) else 0)
new_file = os.path.join(path, h5_lists[-1])
# new_file = "weights.19-58.49.h5"
print(new_file)

./weights8/weights_newlossadam.38-75.06.h5


In [30]:
# store weights before loading pre-trained weights
preloaded_layers = model.layers.copy()
preloaded_weights = []
for pre in preloaded_layers:
    preloaded_weights.append(pre.get_weights())
    
# load pre-trained weights
model.load_weights(new_file, by_name=True)

# compare previews weights vs loaded weights
for layer, pre in zip(model.layers, preloaded_weights):
    weights = layer.get_weights()

    if weights:
        if np.array_equal(weights, pre):
            print('not loaded', layer.name)
        else:
            pass

  a1, a2 = asarray(a1), asarray(a2)


# TRAIN

In [None]:
history = model.fit(dataset,
                    epochs=p.epochs,
                    validation_data=val_dataset,
                    steps_per_epoch=steps_per_epoch,
                    validation_steps=validation_steps,
                    callbacks=callbacks_overall,
                    initial_epoch=initial_epoch)

Epoch 40/50
Saving model to ./weights8/weights_newlossadam.39-75.52.h5
, Validation loss decreased from inf to 75.51563262939453.

Epoch 41/50
Saving model to ./weights8/weights_newlossadam.40-75.38.h5
, Validation loss decreased from 75.51563262939453 to 75.38206481933594.

Epoch 42/50
Saving model to ./weights8/weights_newlossadam.41-75.25.h5
, Validation loss decreased from 75.38206481933594 to 75.24954986572266.

Epoch 43/50
Saving model to ./weights8/weights_newlossadam.42-75.16.h5
, Validation loss decreased from 75.24954986572266 to 75.16417694091797.

Epoch 44/50
Saving model to ./weights8/weights_newlossadam.43-75.08.h5
, Validation loss decreased from 75.16417694091797 to 75.07879638671875.

Epoch 45/50
Saving model to ./weights8/weights_newlossadam.44-75.02.h5
, Validation loss decreased from 75.07879638671875 to 75.0167007446289.

Epoch 46/50