In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_io as tfio

import numpy as np
import random
import os
import glob
import multiprocessing
from math import ceil

from tensorflow.keras.layers import Input , Dense, Lambda, Normalization

In [None]:
import import_ipynb
try:
    from bandERB import ERBBand, ERB_pro_matrix
    from loss import as_complex, as_real
except:
    from bandERB import ERBBand, ERB_pro_matrix
    from loss import as_complex, as_real

In [None]:
from params import model_params
p = model_params('config.ini')

In [None]:
def band_mean_norm_erb(x, s, alpha=0.99):
    s = x * (1-alpha) + s * alpha
    x = (x-s) / 40.0
    return x, s

In [None]:
def band_unit_norm(x, s, alpha=0.99):
#     s = tf.linalg.norm(x) * tf.complex(1-alpha,0.0) + s * tf.complex(alpha,0.0)
    s = tf.math.real(tf.linalg.norm(x)) * (1-alpha) + s * alpha
    x = x / tf.complex((tf.sqrt(s)+1e-12),0.0)
    return x, s

In [None]:
def erb_norm(x, mean_init=[-60.0,-90.0]):
    # x : [T,F,C]
    # state: [F,C]
    shape = x.get_shape().as_list()
    state = tf.linspace(mean_init[0],mean_init[1],shape[-2]) # [F,]
#     state *= 0.0
    state = tf.reshape(state, (1, shape[-2]))
    state = tf.tile(state, (1, shape[-1])) # [C,F]
    
    x_i_list = []
    state_list = []
    
    for i in range(shape[-1]):
        x_i = tf.split(x, shape[-1], axis=-1) 
        state_i = tf.split(state, shape[-1], axis=-1)
        
        x_ij_list = []
        state_tmp = state_i[i]
        for j in range(shape[-3]):
            x_ij = tf.split(x_i[i], shape[0], axis=0) 
            x_tmp, state_tmp = band_mean_norm_erb(tf.squeeze(x_ij[j],-1), state_tmp)
            x_ij_list.append(x_tmp)
        
        x_i_list.append(tf.stack(x_ij_list,1))
        state_list.append(state_tmp)
    x = tf.squeeze(tf.stack(x_i_list,-1),0)
    print('done erb')
    return x

def unit_norm(x, unit_init=[0.001, 0.0001]):
    # x : [T,F,C]
    # state: [F,C]
    shape = x.get_shape().as_list()
    state = tf.linspace(unit_init[0],unit_init[1],shape[-2]) # [F,]
#     state *= 0.0
    state = tf.reshape(state, (1, shape[-2]))
    state = tf.tile(state, (1, shape[-1])) # [C,F]
#     state = tf.complex(state, 0.0)
    
    x_i_list = []
    state_list = []
    
    for i in range(shape[-1]):
        x_i = tf.split(x, shape[-1], axis=-1) 
        state_i = tf.split(state, shape[-1], axis=-1)
        
        x_ij_list = []
        state_tmp = state_i[i]
        for j in range(shape[-3]):
            x_ij = tf.split(x_i[i], shape[0], axis=0) 
            x_tmp, state_tmp = band_unit_norm(tf.squeeze(x_ij[j],-1), state_tmp)
            x_ij_list.append(x_tmp)
        
        x_i_list.append(tf.stack(x_ij_list,1))
        state_list.append(state_tmp)
    
    x = tf.squeeze(tf.stack(x_i_list,-1),0)
    
    print('done spec')
    return x

In [None]:
def biquad(x, mem=[0, 0], b=[-2, 1], a=[-1.99599, 0.99600]):
    shape = x.get_shape().as_list()
    y_i_list = []
        
    for i in range(shape[0]):
        x_i = x[i]
        y_i = x[i] + mem[0]
        mem[0] = mem[1] + (b[0]*x_i - a[0]*y_i);
        mem[1] = (b[1]*x_i - a[1]*y_i);
        y_i_list.append(y_i)
    y = tf.squeeze(tf.stack(y_i_list,-1))
    return y

def rand_H():
    a = np.empty(2)
    a[0] = (random.random() - 0.5)*0.75
    a[1] = (random.random() - 0.5)*0.75
    b = np.empty(2)
    b[0] = (random.random() - 0.5)*0.75
    b[1] = (random.random() - 0.5)*0.75
    return a, b

In [None]:
def mic_FR(x, y):
    x = biquad(x)
    y = biquad(y)
    a, b = rand_H()
    x = biquad(x, b=b, a=a)
    y = biquad(y, b=b, a=a)
    return x, y

In [None]:
def mic_FR_mono(x):
    x = biquad(x)
    a, b = rand_H()
    x = biquad(x, b=b, a=a)
    return x

In [None]:
def parser(record, normalize=True):
    p = model_params('config.ini')
    ERBB = ERBBand(N=p.nb_erb, high_lim=p.sr//2, NFFT=p.fft_size)
    
    ERB_Matrix = ERB_pro_matrix(ERBB, NFFT=p.fft_size, mode=0) #  ERB convert matrix
    ERBB_tf = tf.convert_to_tensor(ERB_Matrix, dtype=tf.float32)
    
    features = {'X': tf.io.FixedLenFeature([p.length_sec*p.sr], tf.float32),
                'Y': tf.io.FixedLenFeature([p.length_sec*p.sr], tf.float32),}
    
    f = tf.io.parse_single_example(record, features) # return image and label
    
    win = tf.signal.vorbis_window(window_length = p.fft_size)
    win = tf.reshape(win,(1,p.fft_size))
    print(win)
    
    fft_norm = (p.fft_size ** -0.5)
    
    x = tf.cast(f['X'], dtype=tf.float32)
    y = tf.cast(f['Y'], dtype=tf.float32)
    tf.debugging.check_numerics(x ,message='Error number(x)')
    tf.debugging.check_numerics(y ,message='Error number(y)')
#     x, y = mic_FR(x, y)
    #####################
    ####### NOISY #######
    #####################
    print('signal: ',x)
    X_frame = tf.signal.frame(x, p.fft_size, p.hop_size, pad_end=True)
    X_fft = tf.signal.rfft(input_tensor = tf.multiply(X_frame,win),
                           fft_length = tf.constant([p.fft_size], dtype=tf.int32),
                           name = 'X_fft')

    if normalize: X_fft *= fft_norm
    print('fft: ',X_fft)
    
    noisy_spec_amp = tf.math.real(X_fft)**2+ tf.math.imag(X_fft)**2
    print('Amp: ', noisy_spec_amp)
    
    ERB = tf.reshape(noisy_spec_amp @ ERBB_tf, 
                     shape=(ceil(p.length_sec*p.sr/p.hop_size), p.nb_erb, 1))
#     ERB = tf.sqrt(ERB)
    ERB = 10 * tf.experimental.numpy.log10(ERB + 1e-10)
    ERB = erb_norm(ERB)
#     ERB /= 40.0

    NOISY_SPEC_df = tf.reshape(as_real(unit_norm(tf.expand_dims(X_fft[..., :p.nb_df],-1))), 
                               shape=(ceil(p.length_sec*p.sr/p.hop_size), p.nb_df, 2))
#     NOISY_SPEC_df = tf.reshape(as_real(tf.expand_dims(X_fft[..., :p.nb_df],-1)), 
#                                shape=(ceil(p.length_sec*p.sr/p.hop_size), p.nb_df, 2))
    NOISY_SPEC = tf.reshape(as_real(X_fft), shape=(ceil(p.length_sec*p.sr/p.hop_size), p.fft_size//2+1, 2))
    #####################
    ####### CLEAN #######
    #####################
    Y_frame = tf.signal.frame(y, p.fft_size, p.hop_size, pad_end=True)
    Y_fft = tf.signal.rfft(input_tensor = tf.multiply(Y_frame,win),
                           fft_length = tf.constant([p.fft_size], dtype=tf.int32),
                           name = 'Y_fft')# * (p.fft_size ** -0.5)

    if normalize: Y_fft *= fft_norm
    CLEAN_SPEC = tf.reshape(as_real(Y_fft), shape=(ceil(p.length_sec*p.sr/p.hop_size), p.fft_size//2+1, 2))
    
    print(CLEAN_SPEC)
    print('ERB shape:', ERB.get_shape(), ERB)
    print('NOISY_SPEC_df shape:', NOISY_SPEC_df.get_shape(), NOISY_SPEC_df)
    print('CLEAN_SPEC shape:', CLEAN_SPEC.get_shape(), CLEAN_SPEC)
    print('NOISY_SPEC shape:', NOISY_SPEC.get_shape(), NOISY_SPEC)
    
    tf.debugging.check_numerics(ERB ,message='Error number(ERB)')
    tf.debugging.check_numerics(NOISY_SPEC_df ,message='Error number(NOISY_SPEC_df)')
    tf.debugging.check_numerics(CLEAN_SPEC ,message='Error number(CLEAN_SPEC)')
    tf.debugging.check_numerics(NOISY_SPEC ,message='Error number(NOISY_SPEC)')
    
    time_shape = ERB.get_shape().as_list()[0]
    spec_shape = CLEAN_SPEC.get_shape().as_list()[-2]
    
    if p.mask_only: 
        return (ERB, CLEAN_SPEC, NOISY_SPEC), 
    else: 
        return (ERB, NOISY_SPEC_df, CLEAN_SPEC, NOISY_SPEC), 


In [None]:
def read_tfrecod_data(TFR_ROOT, training=True, batch_size=None, repeat=None):
    p = model_params('config.ini')
#     print(TFR_ROOT, ", file directory exist: {}".format(os.path.exists(TFR_ROOT)))
    
    _path = TFR_ROOT[0]    
    print(TFR_ROOT, ", file directory exist: {}".format(os.path.exists(_path)))    
    TFR_FILE_PATH = os.path.join(_path, '*.tfrecord')

    filenames = glob.glob(TFR_FILE_PATH)
    print(len(filenames))
    for filename in filenames:
        if (filenames.index(filename)%4)==3: print(filename.split('/')[-1])
        else: print(filename.split('/')[-1], end=', ')

    num_of_train = len(filenames)

    threads = multiprocessing.cpu_count()
    dataset1 = tf.data.TFRecordDataset(filenames, num_parallel_reads=threads)
    combined_dataset = dataset1.map(parser, num_parallel_calls=threads)

#     if training: 
#         dataset1 = dataset1.shuffle(buffer_size=5000, reshuffle_each_iteration=True)
#     dataset1 = dataset1.batch(p.batch_size)
#     dataset1 = dataset1.repeat(p.epochs)
#     dataset1 = dataset1.prefetch(tf.data.AUTOTUNE)
    
    if len(TFR_ROOT)>1:
        for i in range(1,len(TFR_ROOT)):
            _path = TFR_ROOT[i]

            print(TFR_ROOT, ", file directory exist: {}".format(os.path.exists(_path)))
            TFR_FILE_PATH = os.path.join(_path, '*.tfrecord')

            filenames = glob.glob(TFR_FILE_PATH)
            print(len(filenames))
            for filename in filenames:
                if (filenames.index(filename)%4)==3: print(filename.split('/')[-1])
                else: print(filename.split('/')[-1], end=', ')

            num_of_train = len(filenames)

            threads = multiprocessing.cpu_count()
            dataset2 = tf.data.TFRecordDataset(filenames, num_parallel_reads=threads)
            dataset2 = dataset2.map(parser, num_parallel_calls=threads)

            combined_dataset = combined_dataset.concatenate(dataset2)
    
    if training: 
        combined_dataset = combined_dataset.shuffle(buffer_size=5000, reshuffle_each_iteration=True)
    if batch_size is None: batch_size=p.batch_size
    if repeat is None: repeat = p.epochs
    combined_dataset = combined_dataset.repeat(repeat)
    
    combined_dataset = combined_dataset.batch(batch_size)
    combined_dataset = combined_dataset.prefetch(tf.data.AUTOTUNE)
    
    return combined_dataset