In [1]:
import autoreload
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_KERAS"] = "1"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  

In [4]:
import tensorflow as tf
import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input , Dense, Lambda, Normalization, Concatenate
from tensorflow.keras.activations import relu, sigmoid, tanh
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD, schedules
from tensorflow_addons.optimizers import AdamW, extend_with_decoupled_weight_decay

from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN, TensorBoard, \
                                        EarlyStopping, ReduceLROnPlateau, CSVLogger, Callback

import numpy as np
import random
import os

import gc

from tqdm import tqdm 



In [5]:
import import_ipynb

from modules import convkxf, GroupFC#, GroupGRULayer

from loss import MaskLoss, LocalSnrTarget, DfAlphaLoss, SpectralLoss, SISNR_Loss
from utils import mask_operations, df_operations, synthesis_frame, df_operations_wo_alpha
from params import model_params
from dataloader import read_tfrecod_data
from bandERB import ERBBand, ERB_pro_matrix

importing Jupyter notebook from modules.ipynb
importing Jupyter notebook from loss.ipynb
importing Jupyter notebook from bandERB.ipynb
importing Jupyter notebook from params.ipynb
importing Jupyter notebook from utils.ipynb
importing Jupyter notebook from dataloader.ipynb


In [6]:
stateful = False
gru_norm, fc_norm = True, True
if stateful: from modules import GroupGRULayer
else: from modules import GroupGRULayer_lite as GroupGRULayer

In [7]:
# from tensorflow.compat.v1 import ConfigProto, InteractiveSession, enable_eager_execution

# config = ConfigProto(allow_soft_placement=True)
# config.gpu_options.per_process_gpu_memory_fraction = 0.6
# config.gpu_options.allow_growth = True
# InteractiveSession(config=config)
# enable_eager_execution(config=config)

# # tf.compat.v1.disable_eager_execution()

# Read config

In [8]:
p = model_params('config.ini')

In [9]:
def as_complex(x):
    if x.dtype == np.complex64 or x.dtype == np.complex128:
        return x
    else:
        return np.complex(x[...,-2], x[..., -1])

def as_real(x):
    if x.dtype == np.complex64 or x.dtype == np.complex128:
        return np.concatenate([np.expand_dims(np.real(x),axis=-1), 
                               np.expand_dims(np.imag(x),axis=-1)],axis=-1)
    else: return x

# Model construction

In [10]:
erb_inputs = Input(shape=(3, p.nb_erb, 1), batch_size=1, name='ERB_input')
spec_inputs = Input(shape=(3, p.nb_df, 2), batch_size=1, name='spec_input') # complex

state = Input(shape=(256), batch_size=1, name='state0')

state10 = Input(shape=(256), batch_size=1, name='state10')
state11 = Input(shape=(256), batch_size=1, name='state11')
state1 = [state10, state11]

state20 = Input(shape=(256), batch_size=1, name='state20')
state21 = Input(shape=(256), batch_size=1, name='state21')
state2 = [state20,state21]

In [11]:
erb_conv0 = convkxf(erb_inputs, p.conv_out_ch, k=3, f=3, fstride=1, bias=False, batch_norm=True, 
                    training=False, infer=True, name='conv0_encoder')
erb_conv1 = convkxf(erb_conv0, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True, 
                    infer=True, name='conv1_encoder')
erb_conv2 = convkxf(erb_conv1, p.conv_out_ch, k=1, f=3, fstride=2, bias=False, batch_norm=True, 
                    infer=True, name='conv2_encoder')
erb_conv3 = convkxf(erb_conv2, p.conv_out_ch, k=1, f=3, fstride=1, bias=False, batch_norm=True, 
                    infer=True, name='conv3_encoder')

if not p.mask_only: 
    df_conv0 = convkxf(spec_inputs, p.conv_out_ch, k=3, f=3, fstride=1, batch_norm=True, 
                       training=False, infer=True, name='df_conv0_encoder')
    df_conv1 = convkxf(df_conv0, p.conv_out_ch, k=1, f=3, fstride=2, batch_norm=True, 
                       infer=True, name='df_conv1_encoder')
    
    shape = [tf.shape(df_conv1)[l] for l in range(4)]
    df_conv1 = tf.transpose(df_conv1,(0,1,3,2))
    df_conv1 = tf.reshape(df_conv1, [shape[0], shape[1], shape[2]*shape[3]])
    
    cemb = GroupFC(df_conv1, p.fc_hidden, p.fc_group, infer=True, name='GFC_encoder')
    
shape = [tf.shape(erb_conv3)[l] for l in range(4)]
if p.mask_only: 
    emb = tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]])
else:           
    emb = Concatenate()([cemb, tf.reshape(tf.transpose(erb_conv3,(0,1,3,2)), [shape[0], shape[1], shape[2]*shape[3]])])

if stateful: GRU_emb = GroupGRULayer(emb, p.gru_hidden, p.gru_group, num_layer=1, name='GGRU0', 
                               add_output=True, training=False, norm=gru_norm)
else: GRU_emb, lstate_o = GroupGRULayer(emb, p.gru_hidden, p.gru_group, num_layer=1, state=[state], name='GGRU0', 
                               add_output=True, training=False, norm=gru_norm)

{'filters': 64, 'groups': 1, 'kernel_size': (3, 3), 'strides': (1, 1), 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'conv0_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 2), 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'conv1_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 2), 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'conv2_encoder'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 1), 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'conv3_encoder'}
{'filters': 64, 'groups': 2, 

In [12]:
if stateful: GRU_emb1 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU01', 
                                 add_output=True, training=False, norm=gru_norm)
else: GRU_emb1, lstate1_o = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, state=state1, name='GGRU01', 
                                 add_output=True, training=False, norm=gru_norm)
# GRU_emb1 = BatchNormalization()(GRU_emb1)
de_emb = GroupFC(GRU_emb1, p.fc_hidden, p.fc_group, activation='relu', infer=True, name='GFC_decoder', norm=fc_norm)

shape = [tf.shape(erb_conv3)[l] for l in range(4)]
emb_decoder = tf.reshape(de_emb, shape = [shape[0], shape[1], shape[2], shape[3]], name='decoder_reshape')

kwargs = {
    "k": 1,
    "batch_norm": True,
    "infer": True
}
tkwargs = {
    "k": 1,
    "batch_norm": True,
    "infer": True,
    "mode": "transposed"
}
pkwargs = {
    "k": 1,
    "f": 1,
    "batch_norm": True,
    "infer": True
}

convp3 = convkxf(erb_conv3, out_ch=p.conv_out_ch, name='convp3', **pkwargs)  # Conv
vt3_in = convp3 + emb_decoder
convt3 = convkxf(vt3_in, out_ch=p.conv_out_ch, fstride=1, name='convt3', **kwargs) #ConvT

convp2 = convkxf(erb_conv2, out_ch=p.conv_out_ch, name='convp2', **pkwargs) # Conv
vt2_in = convp2 + convt3
convt2 = convkxf(vt2_in, out_ch=p.conv_out_ch, name='convt2', **tkwargs)

convp1 = convkxf(erb_conv1, out_ch=p.conv_out_ch, name='convp1', **pkwargs) # Conv
vt1_in = convp1 + convt2
convt1 = convkxf(vt1_in, out_ch=p.conv_out_ch, name='convt1', **tkwargs)

convp0 = convkxf(erb_conv0, out_ch=p.conv_out_ch, name='convp0', **pkwargs) # Conv
vt0_in = convp0 + convt1
mask_out = convkxf(vt0_in, out_ch=1, k=1, fstride=1, act='sigmoid', batch_norm=False, 
                   reshape=True, name='mask_out')

{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'convp3'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 3), 'strides': (1, 1), 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'convt3'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'convp2'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'convp1'}
{'filters': 64, 'groups': 64, 'kernel_size': (1, 1), 'strides': 1, 'use

In [13]:
from modules import constraint

In [14]:
if stateful: GRU_emb2 = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, name='GGRU1', 
                                 add_output=True, training=False, norm=gru_norm)
else: GRU_emb2, lstate2_o = GroupGRULayer(GRU_emb, p.gru_hidden, p.gru_group, num_layer=2, state=state2, name='GGRU1', 
                                 add_output=True, training=False, norm=gru_norm)
# GRU_emb2 = BatchNormalization()(GRU_emb2)
convp = convkxf(df_conv0, 2*p.df_order, k=1, f=1, complex_in=True, batch_norm=True, name='convp_DfDecoder') 
convp = tf.transpose(convp, (0,1,3,2))

df_alpha = Dense(1, name='convp_alpha', activation='sigmoid')(GRU_emb2)
#                         kernel_constraint= constraint, 
#                         bias_constraint= constraint)(GRU_emb2)
# c = Dense(p.nb_df*p.df_order*2, name='convp_c', activation='tanh')(GRU_emb2)
#                         kernel_constraint= constraint, 
#                         bias_constraint= constraint)(GRU_emb2)
c = GroupFC(GRU_emb2, p.nb_df*p.df_order*2, p.fc_group, activation='tanh', name='GFC_c')

shape = [tf.shape(c)[k] for k in range(2)]
c = tf.reshape(c,(shape[0], shape[1], p.df_order*2, p.nb_df))

c = tf.reshape(c + convp,(shape[0], shape[1], p.df_order, 2, p.nb_df))

df_coeff = tf.transpose(c, [0, 1, 2, 4, 3], name='output_c')

{'filters': 10, 'groups': 2, 'kernel_size': (1, 1), 'strides': 1, 'use_bias': False, 'padding': 'valid', 'kernel_initializer': 'he_normal', 'kernel_regularizer': <keras.regularizers.L2 object at 0x7f399af88b80>, 'name': 'convp_DfDecoder'}


# Model compile

In [15]:
if stateful:
    inputs = [erb_inputs, spec_inputs]
    outputs = [mask_out, df_alpha, df_coeff]
else:
    inputs = [erb_inputs, spec_inputs, state, state1, state2]
    outputs = [mask_out, df_alpha, df_coeff, lstate_o, lstate1_o, lstate2_o]

In [16]:
model=Model(inputs=inputs, outputs=outputs, name='DfNet')

In [17]:
model.summary()

Model: "DfNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 ERB_input (InputLayer)         [(1, 3, 32, 1)]      0           []                               
                                                                                                  
 tf.compat.v1.pad (TFOpLambda)  (1, 3, 34, 1)        0           ['ERB_input[0][0]']              
                                                                                                  
 spec_input (InputLayer)        [(1, 3, 96, 2)]      0           []                               
                                                                                                  
 conv0_encoder (Conv2D)         (1, 1, 32, 64)       576         ['tf.compat.v1.pad[0][0]']       
                                                                                              

                                                                                                  
 tf.compat.v1.transpose (TFOpLa  (1, 1, 64, 48)      0           ['tf.nn.relu_5[0][0]']           
 mbda)                                                                                            
                                                                                                  
 tf.__operators__.getitem (Slic  ()                  0           ['tf.compat.v1.shape[0][0]']     
 ingOpLambda)                                                                                     
                                                                                                  
 tf.__operators__.getitem_1 (Sl  ()                  0           ['tf.compat.v1.shape_1[0][0]']   
 icingOpLambda)                                                                                   
                                                                                                  
 tf.math.m

 da)                                                             ]                                
                                                                                                  
 tf.compat.v1.shape_5 (TFOpLamb  (4,)                0           ['tf.compat.v1.transpose_1[0][0]'
 da)                                                             ]                                
                                                                                                  
 tf.__operators__.getitem_7 (Sl  ()                  0           ['tf.compat.v1.shape_7[0][0]']   
 icingOpLambda)                                                                                   
                                                                                                  
 tf.__operators__.getitem_6 (Sl  ()                  0           ['tf.compat.v1.shape_6[0][0]']   
 icingOpLambda)                                                                                   
          

 GGRU011 (GRU)                  [(1, 1, 256),        394752      ['batch_normalization_1[0][0]',  
                                 (1, 256)]                        'state11[0][0]']                
                                                                                                  
 batch_normalization_2 (BatchNo  (1, 1, 256)         1024        ['GGRU011[0][0]']                
 rmalization)                                                                                     
                                                                                                  
 tf.__operators__.add (TFOpLamb  (1, 1, 256)         0           ['batch_normalization_1[0][0]',  
 da)                                                              'batch_normalization_2[0][0]']  
                                                                                                  
 tf.split_1 (TFOpLambda)        [(1, 1, 32),         0           ['tf.__operators__.add[0][0]']   
          

                                                                                                  
 convp3_1x1 (Conv2D)            (1, 1, 8, 64)        4096        ['convp3[0][0]']                 
                                                                                                  
 GFC_decoder_8BatchNorm (BatchN  (1, 1, 512)         2048        ['tf.reshape_3[0][0]']           
 ormalization)                                                                                    
                                                                                                  
 tf.compat.v1.shape_16 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)                                                                                             
                                                                                                  
 tf.compat.v1.shape_17 (TFOpLam  (4,)                0           ['tf.nn.relu_3[0][0]']           
 bda)     

                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
                                 (1, 1, 8, 1),                                                    
          

 spose)                                                                                           
                                                                                                  
 conv2d_transpose_8 (Conv2DTran  (1, 1, 16, 1)       3           ['tf.split_2[0][8]']             
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_9 (Conv2DTran  (1, 1, 16, 1)       3           ['tf.split_2[0][9]']             
 spose)                                                                                           
                                                                                                  
 conv2d_transpose_10 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][10]']            
 nspose)                                                                                          
          

 conv2d_transpose_35 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][35]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_36 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][36]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_37 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][37]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_38 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][38]']            
 nspose)  

                                                                                                  
 conv2d_transpose_63 (Conv2DTra  (1, 1, 16, 1)       3           ['tf.split_2[0][63]']            
 nspose)                                                                                          
                                                                                                  
 batch_normalization_3 (BatchNo  (1, 1, 256)         1024        ['GGRU10[0][0]']                 
 rmalization)                                                                                     
                                                                                                  
 state21 (InputLayer)           [(1, 256)]           0           []                               
                                                                                                  
 convp1 (Conv2D)                (1, 1, 16, 64)       64          ['tf.nn.relu_1[0][0]']           
          

 batch_normalization_4 (BatchNo  (1, 1, 256)         1024        ['GGRU11[0][0]']                 
 rmalization)                                                                                     
                                                                                                  
 convp1BatchNorm (BatchNormaliz  (1, 1, 16, 64)      256         ['convp1_1x1[0][0]']             
 ation)                                                                                           
                                                                                                  
 convt2BatchNorm (BatchNormaliz  (1, 1, 16, 64)      256         ['convt2_1x1[0][0]']             
 ation)                                                                                           
                                                                                                  
 tf.__operators__.add_5 (TFOpLa  (1, 1, 256)         0           ['batch_normalization_3[0][0]',  
 mbda)    

                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
                                 (1, 1, 16, 1),                                                   
          

 conv2d_transpose_80 (Conv2DTra  (1, 1, 32, 1)       3           ['tf.split_3[0][16]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_81 (Conv2DTra  (1, 1, 32, 1)       3           ['tf.split_3[0][17]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_82 (Conv2DTra  (1, 1, 32, 1)       3           ['tf.split_3[0][18]']            
 nspose)                                                                                          
                                                                                                  
 conv2d_transpose_83 (Conv2DTra  (1, 1, 32, 1)       3           ['tf.split_3[0][19]']            
 nspose)  

                                                                                                  
 conv2d_transpose_108 (Conv2DTr  (1, 1, 32, 1)       3           ['tf.split_3[0][44]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_109 (Conv2DTr  (1, 1, 32, 1)       3           ['tf.split_3[0][45]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_transpose_110 (Conv2DTr  (1, 1, 32, 1)       3           ['tf.split_3[0][46]']            
 anspose)                                                                                         
                                                                                                  
 conv2d_tr

                                                                  'conv2d_transpose_81[0][0]',    
                                                                  'conv2d_transpose_82[0][0]',    
                                                                  'conv2d_transpose_83[0][0]',    
                                                                  'conv2d_transpose_84[0][0]',    
                                                                  'conv2d_transpose_85[0][0]',    
                                                                  'conv2d_transpose_86[0][0]',    
                                                                  'conv2d_transpose_87[0][0]',    
                                                                  'conv2d_transpose_88[0][0]',    
                                                                  'conv2d_transpose_89[0][0]',    
                                                                  'conv2d_transpose_90[0][0]',    
          

 )                                                               0]',                             
                                                                  'tf.__operators__.getitem_26[0][
                                                                 0]']                             
                                                                                                  
 tf.nn.relu_13 (TFOpLambda)     (1, 1, 32, 64)       0           ['convp0BatchNorm[0][0]']        
                                                                                                  
 tf.nn.relu_12 (TFOpLambda)     (1, 1, 32, 64)       0           ['convt1BatchNorm[0][0]']        
                                                                                                  
 tf.reshape_6 (TFOpLambda)      (1, 1, 960)          0           ['tf.compat.v1.transpose_5[0][0]'
                                                                 , 'tf.__operators__.getitem_24[0]
          

                                                                  'tf.__operators__.getitem_22[0][
                                                                 0]']                             
                                                                                                  
 tf.reshape_8 (TFOpLambda)      (1, 1, 5, 2, 96)     0           ['tf.__operators__.add_6[0][0]', 
                                                                  'tf.__operators__.getitem_28[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_29[0][
                                                                 0]']                             
                                                                                                  
 tf.math.sigmoid (TFOpLambda)   (1, 1, 32)           0           ['tf.reshape_5[0][0]']           
          

# Pre-train weight

In [18]:
path = "./weights2_6/"
h5_lists = os.listdir(path)
h5_lists.sort(key=lambda fn:os.path.getmtime(path + fn)
                if not os.path.isdir(path + fn) else 0)
new_file = os.path.join(path, h5_lists[-1])
print(new_file)

./weights2_6/weights.07-56.91.h5


In [19]:
# store weights before loading pre-trained weights
preloaded_layers = model.layers.copy()
preloaded_weights = []
for pre in preloaded_layers:
    preloaded_weights.append(pre.get_weights())

# load pre-trained weights
model.load_weights(new_file, by_name=True)

# compare previews weights vs loaded weights
for layer, pre in zip(model.layers, preloaded_weights):
    weights = layer.get_weights()

    if weights:
        if np.array_equal(weights, pre):
            print('not loaded', layer.name)
        else:
#             print('loaded', layer.name)
            pass

  a1, a2 = asarray(a1), asarray(a2)


In [20]:
# for i in range(len(model.layers)):
#     if model.layers[i].name[:4] == 'GGRU': # or model.layers[i].name[:5] == 'batch':
#         print(model.layers[i].name)
#         print(i)
# #         [gamma, beta, mov_mean, mov_var]=model.layers[i].get_weights()
# #         print('gamma min, max: {}, {}'.format(np.min(gamma),np.max(gamma)))
# #         print('beta min, max: {}, {}'.format(np.min(beta),np.max(beta)))
# #         print('mov_mean min, max: {}, {}'.format(np.min(mov_mean),np.max(mov_mean)))
# #         print('mov_var min, max: {}, {}'.format(np.min(mov_var),np.max(mov_var))) 
#         print()

# Load

In [21]:
# if p.mask_only: 
#     inputs = [erb_inputs, spec_inputs]
#     outputs = [mask_out]
# else: 
#     inputs = [erb_inputs, spec_inputs]
#     outputs = [mask_out, df_alpha, df_coeff]

# model = Model(inputs=inputs, outputs=outputs, name='DfNet')

In [22]:
# from keras_flops import get_flops
# flops = get_flops(model, batch_size=1)
# print(f"FLOPs: {flops / 10 ** 9:.03} G")

In [23]:
# inputs_all = [erb_inputs, spec_inputs, state, state1, state2]
# outputs_all = [mask_out, df_alpha, df_coeff, lstate_o, lstate1_o, lstate2_o]
# model2=Model(inputs=inputs, outputs=outputs, name='DfNet2')

# Norm

In [24]:
def band_unit_norm(x, s, alpha=0.99):
    s = np.linalg.norm(x) * (1-alpha) + s * alpha
    x = x / (np.sqrt(s)+1e-12)
    return x, s
def band_mean_norm_erb(x, s, alpha=0.99):
    s = x * (1-alpha) + s * alpha
    x = (x-s) / 40
    return x, s

def erb_norm(x, mean_init=[-60.0,-90.0]):
    # x : [T,F,C]
    # state: [F,C]
    x = np.reshape(x,(x.shape[0],x.shape[1],1))
    shape = x.shape
    state = np.linspace(mean_init[0],mean_init[1],shape[-2]) # [F,]
    state *= 0
    state = np.reshape(state, (1, shape[-2]))
    state = np.tile(state, (1, shape[-1])) # [C,F]
    
    x_i_list = []
    state_list = []
    
    for i in range(shape[-1]):
        x_i = np.split(x, shape[-1], axis=-1) 
        state_i = np.split(state, shape[-1], axis=-1)
        
        x_ij_list = []
        state_tmp = state_i[i]
        for j in range(shape[-3]):
            x_ij = np.split(x_i[i], shape[0], axis=0) 
            x_tmp, state_tmp = band_mean_norm_erb(np.squeeze(x_ij[j],-1), state_tmp)
            x_ij_list.append(x_tmp)
        
        x_i_list.append(np.stack(x_ij_list,1))
        state_list.append(state_tmp)
    x = np.squeeze(np.stack(x_i_list,-1),0)
    return x

def unit_norm(x, unit_init=[0.001, 0.0001]):
    # x : [T,F,C]
    # state: [F,C]
    x = np.reshape(x,(x.shape[0],x.shape[1],1))
    shape = x.shape
    state = np.linspace(unit_init[0],unit_init[1],shape[-2]) # [F,]
    state *= 0
    state = np.reshape(state, (1, shape[-2]))
    state = np.tile(state, (1, shape[-1])) # [C,F]
    
    x_i_list = []
    state_list = []
    
    for i in range(shape[-1]):
        x_i = np.split(x, shape[-1], axis=-1) 
        state_i = np.split(state, shape[-1], axis=-1)
        
        x_ij_list = []
        state_tmp = state_i[i]
        for j in range(shape[-3]):
            x_ij = np.split(x_i[i], shape[0], axis=0) 
            x_tmp, state_tmp = band_unit_norm(np.squeeze(x_ij[j],-1), state_tmp)
            x_ij_list.append(x_tmp)
        
        x_i_list.append(np.stack(x_ij_list,1))
        state_list.append(state_tmp)
    
    x = np.squeeze(np.stack(x_i_list,-1))
#     print('done spec')
    return x

In [25]:
ERBB = ERBBand(N=p.nb_erb, high_lim=p.sr//2, NFFT=p.fft_size)
ERB_Matrix = ERB_pro_matrix(ERBB, NFFT=p.fft_size, mode=0)
iERB_Matrix = ERB_pro_matrix(ERBB, NFFT=p.fft_size, mode=1)

In [26]:
import librosa
from soundfile import write
import matplotlib.pyplot as plt

In [27]:
# import tensorflow as tf
def vorbis_window(FRAME_SIZE, transpose=True):
    FRAME_SIZE = FRAME_SIZE//2
    win = np.zeros((FRAME_SIZE,))
    for i in range(FRAME_SIZE):
        win[i] = np.sin(.5*np.pi*np.sin(.5*np.pi*(i+.5)/FRAME_SIZE) * np.sin(.5*np.pi*(i+.5)/FRAME_SIZE))

    win = np.concatenate((win,np.flip(win)),0)
    if transpose: win = win.T
    return win

def analysis_frame(x, nfft=p.fft_size, hop=p.hop_size, normalize=False):
    length = len(x)
    n_frames = length // hop
    out = np.empty((n_frames, p.fft_size//2+1),dtype=complex)
    if not length % hop == 0:
        x = np.pad(x,(0, nfft - length%hop))
    for frame_idx in range(0, n_frames * hop, hop):
        frame = x[frame_idx : frame_idx + nfft]
        if len(frame)<nfft: frame = np.pad(frame,(0,nfft-len(frame)))
        win = vorbis_window(nfft)
        frame = frame.reshape(win.shape)
        frame_win = np.multiply(frame, win)
        x_fft = np.fft.rfft(frame_win, n=p.fft_size) 
        if normalize: x_fft *= (p.fft_size ** -0.5)
        out[frame_idx//hop,:] = x_fft
    return out 

def synthesis_frame(x, nfft=p.fft_size, hop=p.hop_size, windowing=False, normalize=False):
    n_frames, _ = x.shape
    length = (n_frames+1)*hop
    out = np.zeros((length,))

    win = vorbis_window(nfft, transpose=False)
    if normalize: x /= (p.fft_size ** -0.5)
    for frame_idx in range(0, n_frames * hop, hop):
        frame = np.fft.irfft(x[frame_idx//hop] , n=p.fft_size)
        
        if windowing:
            frame = frame.reshape(win.shape)
            out[frame_idx : frame_idx + nfft] += np.multiply(frame, win)
        else:
            out[frame_idx : frame_idx + nfft] += frame
    return out

In [28]:
process_path = '/home/myhsueh/DeepFilterNet/wav/proc/'
audiofile = [ 
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_001.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_002.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_003.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_005.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_006.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_007.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_009.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/clean/p232_010.wav',
        
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample3_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample1_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample2_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample6_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample7_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample9_01.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_sample11_01.wav',
              
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_usbmicrophone_APPND34J4EDS3_typing_near_fileid_5.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_speaker_A33BAVCUPSMTWJ_Typing_near_fileid_11.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_openspeaker_ASEW6NZHLI41K_typing_far_fileid_9.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_openspeaker_ASEW6NZHLI41K_openingchipspacket_near_fileid_1.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_openspeaker_A338ZXK723N9UH_dogbarking_near_fileid_si2192.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_openspeaker_A338ZXK723N9UH_babycrying_far_fileid_si1687.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_openspeaker_A338ZXK723N9UH_ArConditioner_near_fileid_si2313.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_english_laptopmic_A9AP3ED0K8LS8_heavybreathing_near_fileid_6.wav',
              '/home/myhsueh/DeepFilterNet/wav/orig/DNS_testset/ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_clatternoise_near_crying_fileid_8.wav',
    
#               '/home/myhsueh/DeepFilterNet/wav/orig/conference_01.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/02-V4K-220511_record.wav',
#               '/home/myhsueh/DeepFilterNet/wav/orig/enc_01.wav',
            ]

In [29]:
def process(audiofile, process_path='/home/myhsueh/DeepFilterNet/wav/proc/dfn2/'):
    y, sr = librosa.load(audiofile, sr=None)
    
    new_filename = os.path.basename(audiofile)
    new_filename = new_filename.replace('.wav','dfn'+str(buffer_frame)+'.wav')
    
    Y = analysis_frame(y, normalize=True)
    
    Y_erb = (np.real(Y)**2 + np.imag(Y)**2) @ ERB_Matrix
    Y_erb_norm = erb_norm(np.sqrt(Y_erb))
    Y_spec_norm = as_real(unit_norm(Y[:,:p.nb_df]))
                            
    buffer = np.zeros((buffer_frame,p.fft_size//2+1), dtype=complex)
    buffer_erb = np.zeros((buffer_frame+2,32,1))
    buffer_spec = np.zeros((buffer_frame+2,96,2))
    Z = np.zeros_like(Y, dtype=complex)
    
    state0, state10, state11, state20, state21 = np.zeros((1,256)), np.zeros((1,256)), np.zeros((1,256)), \
                                                 np.zeros((1,256)), np.zeros((1,256))
    rstate0, rstate10, rstate11, rstate20, rstate21 = [],[],[],[],[]
    
    for i in range(Y.shape[0]-1):
        buffer = np.roll(buffer, -1, axis=0)
        buffer[-1] = Y[i]
        if buffer_frame==1:
            buffer[0] = Y[i]
        
        buffer_erb = np.roll(buffer_erb, -1, axis=0)
        buffer_erb[-1] = Y_erb_norm[i]
        
        buffer_spec = np.roll(buffer_spec, -1, axis=0)
        buffer_spec[-1] = Y_spec_norm[i]
        
        if stateful:
            buffer_feature = [np.expand_dims(buffer_erb,0), np.expand_dims(buffer_spec,0)]
            mask_out, df_alpha, df_coeff = model.predict(buffer_feature, verbose=0)
        else:
            buffer_feature = [np.expand_dims(buffer_erb,0), np.expand_dims(buffer_spec,0),
                         state0, [state10, state11], [state20, state21]]
            
            rstate0.append(state0)
            rstate10.append(state10)
            rstate11.append(state11)
            rstate20.append(state20)
            rstate21.append(state21)

            mask_out, df_alpha, df_coeff, \
            state0, [state10, state11], [state20, state21] = model.predict(buffer_feature, verbose=0)
            
            state0 = np.reshape(state0, (1,256))
            state10 = np.reshape(state10, (1,256))
            state11 = np.reshape(state11, (1,256))
            state20 = np.reshape(state20, (1,256))
            state21 = np.reshape(state21, (1,256))
        
        buffer_real = np.concatenate((np.expand_dims(np.real(buffer),-1),np.expand_dims(np.imag(buffer),-1)),axis=-1)
        buffer_spec_proc = mask_operations(buffer_real, mask_out, training=False) # mask gain
        enhanced = df_operations(buffer_spec_proc, df_coeff, df_alpha) # deep filter
        
        enhanced = np.squeeze(enhanced)
        enhanced_complex = np.squeeze(enhanced[..., 0]+1j*enhanced[..., 1])
        
        if buffer_frame==1:
            Z[i,:] = enhanced_complex
        else:
            Z[i,:] = enhanced_complex[-1]
        
    z = synthesis_frame(Z, windowing=True, normalize=True)
    if not os.path.exists(process_path): os.makedirs(process_path)
    y = np.reshape(y,(y.shape[0],1))
    z = np.reshape(z,(z.shape[0],1))
    if y.shape[0]>z.shape[0]:
        z = np.concatenate((np.zeros((y.shape[0]-z.shape[0],1)),z),axis=0)
    if y.shape[0]<z.shape[0]:
        y = np.concatenate((np.zeros((z.shape[0]-y.shape[0],1)),y),axis=0)
    write(os.path.join(process_path,new_filename), np.concatenate((y,z),axis=1), 48000)
    
    plt.figure(figsize=(10,5))
    plt.plot(y, color = 'b')
    plt.plot(z, color = 'r')
    plt.legend(['unproc','proc'])
    plt.show()
    
    return Y.shape[0], \
           Y_erb_norm, \
           Y_spec_norm, \
           np.concatenate(rstate0, 0),\
           np.concatenate(rstate10, 0),\
           np.concatenate(rstate11, 0),\
           np.concatenate(rstate20, 0),\
           np.concatenate(rstate21, 0)

In [30]:
# import time
# model.reset_states()  
# buffer_frame=1
    
# r_erb_list, r_spec_list = [],[]
# r_state_list, r_state10_list, r_state11_list, r_state20_list, r_state21_list = [],[],[],[],[]
# buffer_frame = 1
# for file in audiofile:
#     start = time.time()
#     num_frame, r_erb, r_spec, r_state, r_state10, r_state11, r_state20, r_state21 = process(file)
#     end = time.time()
#     print((end-start)/num_frame)
    
#     r_erb_list.append(r_erb)
#     r_spec_list.append(r_spec)
#     r_state_list.append(r_state)
#     r_state10_list.append(r_state10)
#     r_state11_list.append(r_state11)
#     r_state20_list.append(r_state20)
#     r_state21_list.append(r_state21)

# np.save('./quantize_npy/whole_erb.npy', np.concatenate(r_erb_list,0))
# np.save('./quantize_npy/whole_spec.npy',np.concatenate(r_spec_list,0))
# np.save('./quantize_npy/whole_state.npy',np.concatenate(r_state_list,0))
# np.save('./quantize_npy/whole_state10.npy',np.concatenate(r_state10_list,0))
# np.save('./quantize_npy/whole_state11.npy',np.concatenate(r_state11_list,0))
# np.save('./quantize_npy/whole_state20.npy',np.concatenate(r_state20_list,0))
# np.save('./quantize_npy/whole_state21.npy',np.concatenate(r_state21_list,0))

In [31]:
# def process_record(audiofile):
#     y, sr = librosa.load(audiofile, sr=None)
    
#     Y = analysis_frame(y, normalize=True)
    
#     Y_erb = (np.real(Y)**2 + np.imag(Y)**2) @ ERB_Matrix
#     Y_erb_norm = erb_norm(np.sqrt(Y_erb))
#     Y_spec_norm = as_real(unit_norm(Y[:,:p.nb_df]))
                            
#     buffer = np.zeros((buffer_frame,p.fft_size//2+1), dtype=complex)
#     buffer_erb = np.zeros((buffer_frame+2,32,1))
#     buffer_spec = np.zeros((buffer_frame+2,96,2))
    
#     state0, state10, state11, state20, state21 = np.zeros((1,256)), np.zeros((1,256)), np.zeros((1,256)), \
#                                                  np.zeros((1,256)), np.zeros((1,256))
#     rstate0, rstate10, rstate11, rstate20, rstate21 = [],[],[],[],[]
    
# #     inputs_all = [erb_inputs, spec_inputs, state, state1, state2]
# #     outputs_all = [mask_out, df_alpha, df_coeff, lstate_o, lstate1_o, lstate2_o]

#     for i in range(Y.shape[0]-1):
#         buffer[0] = Y[i]
        
#         buffer_erb = np.roll(buffer_erb, -1, axis=0)
#         buffer_erb[-1] = Y_erb_norm[i]
        
#         buffer_spec = np.roll(buffer_spec, -1, axis=0)
#         buffer_spec[-1] = Y_spec_norm[i]
        
#         buffer_feature = [np.expand_dims(buffer_erb,0), np.expand_dims(buffer_spec,0),
#                          state0, [state10, state11], [state20, state21]]
        
#         mask_out, df_alpha, df_coeff, \
#         state0, [state10, state11], [state20, state21] = model.predict(buffer_feature, verbose=0)
        
#         state0 = np.reshape(state0, (1,256))
#         state10 = np.reshape(state10, (1,256))
#         state11 = np.reshape(state11, (1,256))
#         state20 = np.reshape(state20, (1,256))
#         state21 = np.reshape(state21, (1,256))
        
#         rstate0.append(state0)
#         rstate10.append(state10)
#         rstate11.append(state11)
#         rstate20.append(state20)
#         rstate21.append(state21)
    
#     return Y_erb_norm, \
#            Y_spec_norm, \
#            np.concatenate(rstate0, 0),\
#            np.concatenate(rstate10, 0),\
#            np.concatenate(rstate11, 0),\
#            np.concatenate(rstate20, 0),\
#            np.concatenate(rstate21, 0)

In [32]:
# r_erb_list, r_spec_list = [],[]
# r_state_list, r_state10_list, r_state11_list, r_state20_list, r_state21_list = [],[],[],[],[]
# buffer_frame = 1
# for file in audiofile:
#     r_erb, r_spec, r_state, r_state10, r_state11, r_state20, r_state21 = process_record(file)
    
    
#     r_erb_list.append(r_erb)
#     r_spec_list.append(r_spec)
#     r_state_list.append(r_state)
#     r_state10_list.append(r_state10)
#     r_state11_list.append(r_state11)
#     r_state20_list.append(r_state20)
#     r_state21_list.append(r_state21)

# np.save('./quantize_npy/whole_erb.npy', np.concatenate(r_erb_list,0))
# np.save('./quantize_npy/whole_spec.npy',np.concatenate(r_spec_list,0))
# np.save('./quantize_npy/whole_state.npy',np.concatenate(r_state_list,0))
# np.save('./quantize_npy/whole_state10.npy',np.concatenate(r_state10_list,0))
# np.save('./quantize_npy/whole_state11.npy',np.concatenate(r_state11_list,0))
# np.save('./quantize_npy/whole_state20.npy',np.concatenate(r_state20_list,0))
# np.save('./quantize_npy/whole_state21.npy',np.concatenate(r_state21_list,0))

# FP16

In [33]:
# converter = tf.lite.TFLiteConverter.from_keras_model(model)
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
# converter.target_spec.supported_types = [tf.float16]
# # Set the input and output tensors to uint8 (APIs added in r2.3)
# converter.inference_input_type = tf.float32
# converter.inference_output_type = tf.float32

# tflite_model_quant = converter.convert()

In [34]:
# interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
# input_type = interpreter.get_input_details()[0]['dtype']
# print('input: ', input_type)
# output_type = interpreter.get_output_details()[0]['dtype']
# print('output: ', output_type)

In [35]:
# with open('quantized_FP16_wholemodel.tflite', 'wb') as f:
#     f.write(tflite_model_quant)

In [36]:
# for i in range(len(interpreter.get_input_details())):
#     print(interpreter.get_input_details()[i])

# INT8

In [37]:
quant_erb = np.load('./quantize_npy/whole_erb.npy')
quant_spec = np.load('./quantize_npy/whole_spec.npy')
quant_state0 = np.load('./quantize_npy/whole_state.npy')
quant_state10 = np.load('./quantize_npy/whole_state10.npy')
quant_state11 = np.load('./quantize_npy/whole_state11.npy')
quant_state20 = np.load('./quantize_npy/whole_state20.npy')
quant_state21 = np.load('./quantize_npy/whole_state21.npy')

In [38]:
quant_erb = np.concatenate((np.zeros((2,32,1)),quant_erb),axis=0)
quant_spec = np.concatenate((np.zeros((2,96,2)),quant_spec),axis=0)

In [39]:
range_list = [i for i in range(quant_state21.shape[0])]
random.shuffle(range_list)

In [40]:
def representative_data_gen():
    for j in range_list:
        yield [
               quant_spec[j:j+3].reshape(1,3,96,2).astype(np.float32), 
               quant_state11[j].reshape(1,256).astype(np.float32), 
               quant_state10[j].reshape(1,256).astype(np.float32), 
               quant_state21[j].reshape(1,256).astype(np.float32), 
               quant_state20[j].reshape(1,256).astype(np.float32),
               quant_erb[j:j+3].reshape(1,3,32,1).astype(np.float32), 
               quant_state0[j].reshape(1,256).astype(np.float32)
        ]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8] 
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32

tflite_model_quant = converter.convert()



INFO:tensorflow:Assets written to: /tmp/tmpg3k0no71/assets


INFO:tensorflow:Assets written to: /tmp/tmpg3k0no71/assets


In [41]:
with open('quantized_INT8_wholemodel.tflite', 'wb') as f:
    f.write(tflite_model_quant)

In [42]:
if time.localtime(time.time()).tm_hour>=18:
    %run Lite_inference.ipynb

NameError: name 'time' is not defined

In [None]:
#debugger
debugger = tf.lite.experimental.QuantizationDebugger(
    converter=converter,
    debug_dataset=representative_data_gen)
debugger.run()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
root_path = './quantize_npy'
# dump debug results as csv file
RESULTS_FILE = os.path.join(root_path, 'debugger_result_1000_int8_whole.csv')
with open(RESULTS_FILE, 'w') as f:
    debugger.layer_statistics_dump(f)

# compute rmse error
layer_stats = pd.read_csv(RESULTS_FILE)
scale_range = 255.0 * layer_stats['scale']
rmse_scale = layer_stats.apply(
    lambda row: np.sqrt(row['mean_squared_error']) / row['scale'], axis=1)
#display(HTML(layer_stats.to_html()))

# insert to the dataframe and re-dump back to file
layer_stats.insert(loc=8, column='range', value=scale_range)
layer_stats.insert(loc=10, column='rmse/scale', value=rmse_scale)

# dump again
layer_stats.to_csv(RESULTS_FILE, index=False)

# plot rmse error
plt.figure(figsize=(15, 10))
plt.style.use('default')
ax1 = plt.subplot(211)
ax1.bar(np.arange(len(layer_stats)), scale_range)
ax1.set_ylabel('range')
plt.grid()
ax2 = plt.subplot(212)
ax2.bar(np.arange(len(layer_stats)), rmse_scale)
ax2.plot(np.arange(len(layer_stats)+2), [0.289]*(len(layer_stats)+2),
         color='r', linewidth=5)
t = ax2.text(0.05, 0.73, '0.289', transform=ax2.transAxes, fontsize=20, color='red')
t.set_bbox(dict(facecolor='white', alpha=0.01))
ax2.set_ylabel('rmse/scale')
plt.grid()

fig_save_path = os.path.join(root_path, 'rmse_1000_ep35_wh.png')
plt.savefig(fig_save_path, dpi=100,
            bbox_inches='tight')
plt.show()


In [None]:
import h5py
keys = []
with h5py.File(new_file,'r') as f: # open file
    f.visit(keys.append) # append all keys to list
    for key in keys:
        if ':' in key: # contains data if ':' in key
            print(f[key].name)
    
f = h5py.File(new_file,'r')
group = f[key]

In [None]:
def get_param(folding_layer, weight_name):
    layer = '/' + folding_layer + '/' + folding_layer + '/'
    if folding_layer == 'GGRU00': weight_name = 'gru_cell/'+ weight_name
    if folding_layer == 'GGRU010': weight_name = 'gru_cell_1/'+ weight_name
    if folding_layer == 'GGRU011': weight_name = 'gru_cell_2/'+ weight_name
    if folding_layer == 'GGRU10': weight_name = 'gru_cell_3/'+ weight_name
    if folding_layer == 'GGRU11': weight_name = 'gru_cell_4/'+ weight_name
    return group[layer+weight_name+':0'][:]

In [None]:
for layer in model.layers:
    if layer.name[:11] == 'GFC_decoder':
        try:
            kernel_weight = get_param(layer.name, 'kernel')
            bias = get_param(layer.name, 'bias')
            plt.subplot(1,2,1)
            plt.hist(kernel_weight.flatten())
            plt.title(layer.name)
            plt.subplot(1,2,2)
            plt.hist(bias.flatten())
            plt.title(layer.name + 'bias')
            plt.show()
        except:
            pass