In [None]:
!pip install -U tensorflow-probability==0.15.0 
!pip install -U tensorflow-io
!pip install -U tensorflow_model_optimization

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_io as tfio
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"


device_name = tf.test.gpu_device_name()
print(tf.__version__)
print(device_name)
tf.config.run_functions_eagerly(True)

In [None]:
import pickle
import numpy as np
import os
from scipy import signal
import librosa
import datetime
import random
import math
from tensorflow.keras import layers
import time
import inspect


import matplotlib.pyplot as plt
import librosa.display
from scipy.io.wavfile import write

from tensorflow.keras.layers import Conv2D, Input, LeakyReLU, UpSampling2D, Flatten, Dropout, Dense, Reshape, Conv2DTranspose, BatchNormalization, Activation, MaxPooling2D
from tensorflow.keras import Model, Sequential, initializers # Data Generator


In [None]:
tf.random.set_seed(256)
np.random.seed(256)

# Data Generator

In [None]:
snr_dB = 5
sampling_rate = 4000

def awgn(sinal, noise_signal):
    regsnr=snr_dB
    sigpower=sum([math.pow(abs(sinal[i]),2) for i in range(len(sinal))])
    sigpower=sigpower/len(sinal)
    sig_dB = 10* math.log(sigpower, 10)
    noise_dB = sig_dB-snr_dB
    noisescale= (math.pow(10,noise_dB/20))
    npower=sum([math.pow(abs(noise_signal[i]),2) for i in range(len(noise_signal))])
    npower=math.sqrt(npower/len(noise_signal))
    noise=noisescale*noise_signal/npower
    return noise

def get_stft(x, fs, n_fft, hop_length, only_real=True):
    c_stft = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
    if only_real:
        return np.abs(c_stft)
    else:
        return c_stft

In [None]:
volunteer_id = 1
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
clean_dir = "../Dataset/Interfered/Volunteer"+str(volunteer_id)

checkpoint_dir = './volunteer_'+str(volunteer_id)+'_training_checkpoints'
cycle_checkpoint_dir = './volunteer_'+str(volunteer_id)+'_cycle_training_checkpoints'

output_parent_folder_name = "volunteer_"+str(volunteer_id)

In [None]:
## read the files. 
with (open(os.path.join(clean_dir, "target_train.p"), "rb")) as target_train_file:
    all_train_mic_time, _, all_train_mic_stft, all_train_imu_stft, all_train_labels = pickle.load(target_train_file)
with (open(os.path.join(clean_dir, "target_test.p"), "rb")) as target_test_file:
    all_test_mic_time, _, all_test_mic_stft, all_test_imu_stft, all_test_labels = pickle.load(target_test_file)
with (open(os.path.join(clean_dir, "noise_train.p"), "rb")) as noise_train_file:
    all_train_noise_time, _, all_train_noise_labels = pickle.load(noise_train_file)
with (open(os.path.join(clean_dir, "noise_test.p"), "rb")) as noise_test_file:
    all_test_noise_time, _, all_test_noise_labels = pickle.load(noise_test_file)

In [None]:
print(all_train_mic_time.shape, all_train_mic_stft.shape, all_train_imu_stft.shape)
print(all_test_mic_time.shape, all_test_mic_stft.shape, all_test_imu_stft.shape)
print(all_train_noise_time.shape)
print(all_test_noise_time.shape)

In [None]:
# read combination files
with (open(os.path.join(clean_dir, "combination_index_train.p"), "rb")) as index_train_file:
    train_indexes = pickle.load(index_train_file)
with (open(os.path.join(clean_dir, "combination_index_test.p"), "rb")) as index_test_file:
    test_indexes = pickle.load(index_test_file)
    
print(len(train_indexes), len(test_indexes))

In [None]:
def dataset_gen(batch_size, ftype = "train", real_output = True, onlyAudio = False, batch_ind = 0):
    if ftype=="train":
        all_mic_time = all_train_mic_time
        all_mic_stft = all_train_mic_stft
        all_imu_stft = all_train_imu_stft
        all_noise_time = all_train_noise_time
        all_target_labels = all_train_labels
        all_noise_labels = all_train_noise_labels
        all_indexes = train_indexes
        t_noise_max = all_train_noise_time.shape[0]-1
    else:
        all_mic_time = all_test_mic_time
        all_mic_stft = all_test_mic_stft
        all_imu_stft = all_test_imu_stft
        all_noise_time = all_test_noise_time
        all_target_labels = all_test_labels
        all_noise_labels = all_test_noise_labels
        all_indexes = test_indexes
        t_noise_max = all_test_noise_time.shape[0]-1

    indxs = all_indexes[batch_size*batch_ind: batch_size*batch_ind + batch_size]
    indxs = np.asarray(indxs)

    noisy_arr_stft = []
    target_arr_mic_stft = []
    target_arr_imu_stft = []
    label_arr = []
    noise_arr = []
    target_mic_arr = []
    noisy_mic_arr = []

    for ind in indxs:
        temp_noise_time = all_noise_time[ind[1][0]][ind[1][1]]
        temp_noise_labels = all_noise_labels[ind[1][0]][ind[1][1]]
        temp_mic_time = all_mic_time[ind[0][0]][ind[0][1]][ind[0][2]]
        temp_mic_stft = all_mic_stft[ind[0][0]][ind[0][1]][ind[0][2]]
        temp_imu_stft = all_imu_stft[ind[0][0]][ind[0][1]][ind[0][2]]
        temp_target_labels = all_target_labels[ind[0][0]][ind[0][1]][ind[0][2]]

        scaled_noise = awgn(temp_mic_time, temp_noise_time)
        roll_idx = random.randint(scaled_noise.shape[0]//2, scaled_noise.shape[0])
        scaled_noise_r = np.roll(scaled_noise, roll_idx)
        noisy_time = np.add(temp_mic_time, scaled_noise_r)
        
        temp = get_stft(noisy_time, fs=noisy_time.shape[0], n_fft=400, hop_length=200, only_real=False)
        temp = temp[1:201, :]
        
        temp_mic_stft = temp_mic_stft[1:201, :]
        temp_imu_stft = temp_imu_stft[1:21, :]
        
        noisy_mic_arr.append(noisy_time)
        noisy_arr_stft.append(temp)
        target_mic_arr.append(temp_mic_time)
        noise_arr.append(scaled_noise_r)
        target_arr_mic_stft.append(temp_mic_stft)
        target_arr_imu_stft.append(temp_imu_stft)
        label_arr.append([temp_target_labels, temp_noise_labels])

    noisy_arr_stft = np.asarray(noisy_arr_stft)
    target_arr_mic_stft = np.asarray(target_arr_mic_stft)
    target_arr_imu_stft = np.asarray(target_arr_imu_stft)
    label_arr = np.asarray(label_arr)
    target_mic_arr = np.asarray(target_mic_arr)
    noise_arr = np.asarray(noise_arr)
    noisy_mic_arr = np.asarray(noisy_mic_arr)


    noisy_arr_stft = noisy_arr_stft.reshape(noisy_arr_stft.shape[0], noisy_arr_stft.shape[1], noisy_arr_stft.shape[2], 1)

    target_arr_mic_stft = target_arr_mic_stft.reshape(target_arr_mic_stft.shape[0], target_arr_mic_stft.shape[1], target_arr_mic_stft.shape[2], 1)
    target_arr_imu_stft = target_arr_imu_stft.reshape(target_arr_imu_stft.shape[0], target_arr_imu_stft.shape[1], target_arr_imu_stft.shape[2], 1)

    output = np.abs(target_arr_mic_stft)


    return noisy_arr_stft, np.abs(target_arr_imu_stft), output, label_arr, target_mic_arr, noise_arr, noisy_mic_arr

In [None]:
test_batch_size = 100
noisy_audio, input_imu, clean_audio, labels, target_mic_arr, noise_arr, noisy_mic_arr = dataset_gen(batch_size=test_batch_size, ftype="test")
# print(labels)

In [None]:
# Label dictionary for KWS10
label_dict = {'bird': 1,
 'happy': 1,
 'cat': 1,
 'dog': 1,
 'follow': 1,
 'house': 1,
 'forward': 1,
 'bed': 1,
 'backward': 1,
 'sheila': 1,
 'tree': 1,
 'two': 1,
 'down': 5,
 'four': 1,
 'eight': 1,
 'visual': 1,
 'five': 1,
 'marvin': 1,
 'go': 11,
 'learn': 1,
 'wow': 1,
 'left': 6,
 'one': 1,
 'seven': 1,
 'off': 9,
 'nine': 1,
 'right': 7,
 'up': 4,
 'stop': 10,
 'zero': 1,
 'three': 1,
 'on': 8,
 'yes': 2,
 'six': 1,
 'no': 3,
 '_silence_': 0}

In [None]:
print(label_dict[labels[0][0]],  labels[0][0])

## Create the models

Both the generator and discriminator are defined using the [Keras Sequential API](https://www.tensorflow.org/guide/keras#sequential_model).

In [None]:
def make_denoiser(l2_strength):
    
    
    input_imu = Input(shape=[20,21, 1])
    input_audio = Input(shape=[200,21,1])

    x = input_audio
    y = input_imu
    

    # ----- Concatenate Modalities
    x = tf.concat([x, y], axis=1)
    
    

    # -----
    x = Conv2D(filters=32, kernel_size=[7,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=64, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    print(x.shape)


    # -----
    x = Conv2D(filters=32, kernel_size=[7,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)

    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    skip0 = Conv2D(filters=64, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)


    x = Activation('relu')(skip0)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[7,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)


    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    # -----
    x = Conv2D(filters=32, kernel_size=[7,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    skip1 = Conv2D(filters=64, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(skip1)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[7,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    # ----
    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=64, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    # ----
    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=64, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = x + skip1
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    # ----
    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=64, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = x + skip0
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(filters=32, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)

    # ----
    x = tf.keras.layers.SpatialDropout2D(0.2)(x)
    x = Conv2D(filters=1, kernel_size=[21,1], strides=[1, 1], padding='valid')(x)

    denoiser = Model(inputs=[input_audio, input_imu], outputs=x)    
    
    return denoiser

In [None]:
denoiser = make_denoiser(l2_strength=0.002)
denoiser.summary()

In [None]:
denoised_audio = denoiser([np.abs(noisy_audio), input_imu])
print(denoised_audio[0].shape)

In [None]:
def make_translator():
    input_imu = Input(shape=[20,21, 1])
    
    # Encoder
    x = Conv2D(filters=128, kernel_size=3,padding='same', name="enc_conv1")(input_imu)
    x = Activation('relu')(x)
    
    x = Conv2D(filters=32, kernel_size=3, strides=1, padding='same', name="enc_conv2")(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((5,1), name="enc_maxpool1")(x)
    
    x = Conv2D(filters=16, kernel_size=3, strides=1, padding='same', name="enc_conv3")(x)
    x = Activation('relu')(x)
    
    
    # Decoder
    x = Conv2D(filters=16, kernel_size=3, strides=1, padding='same', name="dec_conv1")(x)
    x = Activation('relu')(x)
    x = Dropout(0.4)(x)
    x = UpSampling2D((2,1), name="dec_up1")(x)
    
    x = Conv2D(filters=32, kernel_size=3, strides=1, padding='same', name="dec_conv2")(x)
    x = Activation('relu')(x)
    x = Dropout(0.4)(x)
    
    x = Conv2D(filters=1, kernel_size=3,  padding='same', name="dec_conv3")(x)
    x = UpSampling2D((5,1), name="dec_up2")(x)
    up2 = Conv2D(filters=1, kernel_size=3,  padding='same', name="up2")(x)
    x = Dropout(0.4)(x)

    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same', name="dec_conv4")(x)
    x = Activation('relu')(x)
    x = UpSampling2D((5,1), name="dec_up3")(x)
    x = Dropout(0.4)(x)

    x = Conv2D(filters=128, kernel_size=3, strides=1, padding='same', name="dec_conv5")(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters=1, kernel_size=3, padding='same', name="dec_conv6")(x)

    translator = Model(inputs=input_imu, outputs=[up2, x]) 
    return translator


In [None]:
translator = make_translator()
translator.summary()

In [None]:
translated_audio = translator(input_imu)

print(translated_audio[0].shape)

# Loss Functions

In [None]:
def mae(targets, outputs):
    mse = tf.keras.losses.MeanAbsoluteError()
    return mse(targets, outputs)
def mse(targets, outputs):
    mse = tf.keras.losses.MeanSquaredError()
    return mse(targets, outputs)

In [None]:
def stft2time(target_stft, noisy_audio):
    theta_n = np.angle(noisy_audio.reshape(noisy_audio.shape[0], noisy_audio.shape[1]))
    x = target_stft.reshape([target_stft.shape[0], target_stft.shape[1]])
    complex_noisy = x * np.exp(1j*theta_n)
    
    temp_z = np.zeros([1,21])
    temp = np.hstack((temp_z.T, complex_noisy.T)).T
    reconstructed_noisy_mic = librosa.istft(temp, hop_length=200, length=4000)
    return reconstructed_noisy_mic

In [None]:
def translator_loss(translator_output_arr, denoiser_output_arr, noisy_audio_arr, cycle):

    final_out = []
    down1_out = []
    for i in range(0, translator_output_arr[0].shape[0]):
        noisy_audio = noisy_audio_arr[i]
        denoiser_output = denoiser_output_arr[i]
        reconstructed_noisy_mic = stft2time(denoiser_output, noisy_audio)
    
        down_noisy_time1 = signal.resample(reconstructed_noisy_mic, 800)
        down1 = get_stft(down_noisy_time1, fs=down_noisy_time1.shape[0], n_fft=80, hop_length=40, only_real=True)
        down1 = down1[1:41, :]
        down1 = down1.reshape([down1.shape[0], down1.shape[1], 1])
        down1_out.append(down1)
        
        final_out.append(denoiser_output)
    final_out = np.asarray(final_out, dtype=np.float32)
    down1_out = np.asarray(down1_out, dtype=np.float32)

    t1_loss = mse(translator_output_arr[0][i], down1_out)
    t_loss = mse(translator_output_arr[1][i], final_out)
    t_loss += t1_loss
    
    return t_loss

In [None]:
def generate_mask(translator_output, noisy_or_denoiser_output, percentile1, percentile2):
    d_prev_out = noisy_or_denoiser_output.reshape([200, 21])
    t_out = translator_output.reshape([200,21])
    
    d_prev_db = librosa.amplitude_to_db(d_prev_out,ref=80)
    t_db =librosa.amplitude_to_db(t_out,ref=80)

    summing = np.percentile(d_prev_db, percentile1, axis = 0)
    
    ref = (np.max(summing)+np.min(summing))/2
    ref_id = np.where(summing>ref)

    arr = np.arange(0, 21)
    c=np.abs([arr-x for x in ref_id[0]])
    c = np.min(c,axis=0)

    idx = np.where(c>margin)

    truncated_t = t_db.reshape([200,21])[:, idx]
    
    for j in range(0, 4):
        target_data = truncated_t[j*50: j*50+50]
        thr = np.percentile(target_data, percentile2)
        t_sig = tf.math.sigmoid(t_db[j*50: j*50+50]-thr)
        if j == 0:
            stiched_data = t_sig
        else:
            stiched_data = np.concatenate([stiched_data, t_sig], axis=0)


    mask = noisy_or_denoiser_output.reshape([200,21]) * stiched_data
    
    
    return mask

In [None]:
margin = 0
percentile1 = 80
percentile2 = 90

def denoiser_loss(translator_output, denoiser_output, denoised_static, noisy_data, epoch, batch_size, cycle): 
    global percentile1, percentile2
    
    percentile_temp = percentile2 + cycle
    d_loss = 0
    for i in range(batch_size):
        mask = generate_mask(translator_output[1][i].numpy(), denoised_static[i], percentile1, percentile_temp)
        mask = mask.reshape([200,21, 1])
        temp_loss = mae(denoiser_output[i], mask) 
        d_loss += temp_loss 

    d_loss = d_loss/batch_size
    return d_loss

In [None]:
translator_optimizer = tf.keras.optimizers.Adam(1e-4)
denoiser_optimizer = tf.keras.optimizers.Adam(1e-3)

### Save checkpoints
This notebook also demonstrates how to save and restore models, which can be helpful in case a long running training task is interrupted.

In [None]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(translator_optimizer=translator_optimizer,
                                 denoiser_optimizer=denoiser_optimizer,
                                 translator=translator,
                                 denoiser=denoiser)

In [None]:
cycle_checkpoint_prefix = os.path.join(cycle_checkpoint_dir, "ckpt")
cycle_checkpoint = tf.train.Checkpoint(denoiser_optimizer=denoiser_optimizer,
                                 denoiser=denoiser)

### setup the KWS_10 model

In [None]:
import os
from kws_streaming.models import model_params
from kws_streaming.models import model_flags
from kws_streaming.models import models

In [None]:
#import kws_streaming.data.input_data as input_data
#import tensorflow.compat.v1 as tf
def kws_set_flag():
    
    current_dir = os.getcwd()
    MODEL_NAME = 'ds_tc_resnet'
    # MODEL_NAME = 'svdf'
    MODELS_PATH = os.path.join(current_dir, "models")
    MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + "_40k/")
    MODEL_PATH
    FLAGS = model_params.HOTWORD_MODEL_PARAMS[MODEL_NAME]

    # set speech feature extractor properties

    FLAGS.window_size_ms = 30.0
    FLAGS.window_stride_ms = 10.0
    FLAGS.mel_num_bins = 80
    FLAGS.dct_num_features = 40
    FLAGS.feature_type = 'mfcc_tf'
    FLAGS.preprocess = 'raw'

    # for numerical correctness of streaming and non streaming models set it to 1
    # but for real use case streaming set it to 0
    FLAGS.causal_data_frame_padding = 0

    FLAGS.use_tf_fft = True
    FLAGS.mel_non_zero_only = not FLAGS.use_tf_fft

    # set training settings
    FLAGS.train = 1
    # reduced number of training steps for test only
    # so model accuracy will be low,
    # to improve accuracy set how_many_training_steps = '40000,40000,20000,20000'
    FLAGS.how_many_training_steps = '40000,40000,20000,20000'
    FLAGS.learning_rate = '0.001,0.0005,0.0001,0.00002'
    FLAGS.lr_schedule = 'linear'

    # data augmentation parameters
    FLAGS.resample = 0.15
    FLAGS.time_shift_ms = 100
    FLAGS.use_spec_augment = 1
    FLAGS.time_masks_number = 2
    FLAGS.time_mask_max_size = 25
    FLAGS.frequency_masks_number = 2
    FLAGS.frequency_mask_max_size = 7
    FLAGS.pick_deterministically = 1
    
    FLAGS.train_dir = MODEL_PATH
    FLAGS.sample_rate = 4000
    FLAGS.mel_upper_edge_hertz = 2000
    FLAGS.model_name = MODEL_NAME
    if MODEL_NAME == 'ds_tc_resnet':
      # it is an example of model streaming with strided convolution, strided pooling and dilated convolution
      FLAGS.activation = 'relu'
      FLAGS.dropout = 0.0
      FLAGS.ds_filters = '128, 64, 64, 64, 128, 128'
      FLAGS.ds_filter_separable = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_repeat = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_residual = '0, 1, 1, 1, 0, 0' # residual can not be applied with stride
    #   FLAGS.ds_kernel_size = '11, 5, 15, 7, 29, 1'
      FLAGS.ds_kernel_size = '11, 5, 15, 17, 15, 1'
      FLAGS.ds_dilation = '1, 1, 1, 1, 2, 1'
      FLAGS.ds_stride = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_pool = '1, 2, 1, 1, 1, 1'
      FLAGS.ds_padding = "'causal', 'causal', 'causal', 'causal', 'causal', 'causal'"
    FLAGS.clip_duration_ms = 1000  # standard audio file in this data set has 1 sec length
    FLAGS.batch_size = 100
    flags = model_flags.update_flags(FLAGS)
    return flags

In [None]:
flags = kws_set_flag()
flags.batch_size=1

kws_model = models.MODELS[flags.model_name](flags)
weights_name='best_weights'
kws_model.load_weights(os.path.join(flags.train_dir,weights_name)).expect_partial()
kws_model.compile(run_eagerly = True)

In [None]:
temp_1 = target_mic_arr[0].reshape([1,4000])
out = kws_model.predict(temp_1)
print(np.argmax(out, axis=1)[0])

## Define the training loop


In [None]:
EPOCHS = 75
SPLIT_EPOCH = 25

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(audio, imu, normalized_train_imu, latest, epoch, batch_size, cycle):
    global SPLIT_EPOCH
    
    
    with tf.GradientTape() as t_tape, tf.GradientTape() as d_tape:
        d_loss = tf.zeros([1])
        t_loss = tf.zeros([1])
        abs_audio = np.abs(audio)
        if cycle > 4:
            translated_audio = translator(imu, training=False)
            
            denoised_audio = denoiser([abs_audio,normalized_train_imu], training=True)
            d_loss = denoiser_loss(translated_audio, denoised_audio, np.abs(audio), np.abs(audio), epoch, batch_size, cycle)
            gradients_of_denoiser = d_tape.gradient(d_loss, denoiser.trainable_variables)
            denoiser_optimizer.apply_gradients(zip(gradients_of_denoiser, denoiser.trainable_variables))
        else:
            if epoch < SPLIT_EPOCH:
                denoised_audio = denoiser([abs_audio,normalized_train_imu], training=False)

                translated_audio = translator(imu, training=True)
                if cycle == 0:
                    t_loss = translator_loss(translated_audio, np.abs(audio), audio, cycle)
                else:
                    cycle_checkpoint.restore(latest)
                    prev_denoiser_output = cycle_checkpoint.denoiser([np.abs(audio), normalized_train_imu], training=False).numpy()
                    t_loss = translator_loss(translated_audio, prev_denoiser_output, audio, cycle)
                gradients_of_translator = t_tape.gradient(t_loss, translator.trainable_variables)
                translator_optimizer.apply_gradients(zip(gradients_of_translator, translator.trainable_variables))
            else:
                translated_audio = translator(imu, training=False)

                denoised_audio = denoiser([abs_audio,normalized_train_imu], training=True)
                d_loss = denoiser_loss(translated_audio, denoised_audio, np.abs(audio), np.abs(audio), epoch, batch_size, cycle)
                gradients_of_denoiser = d_tape.gradient(d_loss, denoiser.trainable_variables)
                denoiser_optimizer.apply_gradients(zip(gradients_of_denoiser, denoiser.trainable_variables))
            
        return d_loss, t_loss
            
        

In [None]:
def train(epochs):
    total_batch=100
    train_batch_size = 128
    best_acc = -1
    
    for cycle in range(0, 2):
        best_acc_cycle = -1
        if cycle> 0:
            latest = tf.train.latest_checkpoint(cycle_checkpoint_dir)

        for epoch in range(0, epochs):
            start = time.time()
            

            for i in range(0, total_batch):
                train_audio, train_imu, clean_audio, labels, _, _, _= dataset_gen(batch_size=train_batch_size, ftype="train", batch_ind=i*epoch)
                normalized_train_imu = (train_imu/np.max(train_imu.reshape([train_batch_size, 20, 21])))*np.max(np.max(train_audio.reshape([train_batch_size, 200, 21])))
                if cycle > 0:
                    denoiser_loss, translator_loss = train_step(train_audio, train_imu, normalized_train_imu, latest, epoch, train_batch_size, cycle)
                
                else:
                    denoiser_loss, translator_loss = train_step(train_audio, train_imu, normalized_train_imu, train_audio, epoch, train_batch_size, cycle)
                 
            print("Epoch: "+str(epoch+1)+"\t d_loss: "+str(denoiser_loss.numpy())+"\t t_loss: "+str(translator_loss.numpy()))

            d_acc = generate_and_save_images(denoiser, translator, epoch+1, cycle)  
            if d_acc > best_acc:
                checkpoint.save(file_prefix = checkpoint_prefix)
                best_acc = d_acc
        
            if d_acc > best_acc_cycle:
                cycle_checkpoint.save(file_prefix = cycle_checkpoint_prefix)
                best_acc_cycle = d_acc
    print("KWS10 Accuracy: " + str(best_acc))

**Generate and save images**


In [None]:
isExist = os.path.exists(output_parent_folder_name)
if not isExist:
    os.makedirs(output_parent_folder_name)
    
def generate_and_save_images(d_model, t_model, epoch, cycle):
    plot_freq = 20

  # Notice `training` is set to False.
    global output_parent_folder_name, percentile1, percentile2
    
    denoised_correct_classification = 0
    
    normalized_test_imu = (input_imu/np.max(input_imu.reshape([test_batch_size, 20, 21])))*np.max(np.max(np.abs(noisy_audio).reshape([test_batch_size, 200, 21])))
    
    predictions_d = d_model([np.abs(noisy_audio), normalized_test_imu], training=False)
    predictions_t = t_model(input_imu, training=False)
    
    if cycle > 0:
        latest = tf.train.latest_checkpoint(cycle_checkpoint_dir)
        cycle_checkpoint.restore(latest)
        denoised_static = cycle_checkpoint.denoiser([np.abs(noisy_audio), normalized_test_imu], training=False).numpy()
    else:
        denoised_static = np.abs(noisy_audio)
    
    
    for i in range(0, test_batch_size):
        if epoch%plot_freq ==0 and i%10==0:
            folder_name = os.path.join(output_parent_folder_name, "sample_"+str(i))
            isExist = os.path.exists(folder_name)
            if not isExist:
                os.makedirs(folder_name)
                
            fig, ax = plt.subplots(nrows=5, ncols=1, sharex=True, dpi=80, figsize=(7, 10))
            temp = np.asarray(clean_audio[i]).reshape([200,21])
            temp_z = np.zeros([1,21])
            temp = np.hstack((temp_z.T, temp.T)).T
            img1 = librosa.display.specshow(librosa.amplitude_to_db(temp,ref=80),sr=4000, x_axis='time', y_axis='hz', ax=ax[0])
            ax[0].set(title='Non-Interfered Audio')

            temp = np.asarray(np.abs(noisy_audio[i])).reshape([200,21])
            temp_z = np.zeros([1,21])
            temp = np.hstack((temp_z.T, temp.T)).T
            img2 = librosa.display.specshow(librosa.amplitude_to_db(temp,ref=80),sr=4000, x_axis='time', y_axis='hz', ax=ax[1])
            ax[1].set(title='Interfered Audio')

            temp = np.asarray(predictions_d[i]).reshape([200,21])
            temp_z = np.zeros([1,21])
            temp = np.hstack((temp_z.T, temp.T)).T
            t=librosa.amplitude_to_db(temp,ref=80)
            img3 = librosa.display.specshow(t,sr=4000, x_axis='time', y_axis='hz', ax=ax[2])
            ax[2].set(title='Denoised Audio')

            temp = np.asarray(predictions_t[1][i]).reshape([200,21])
            temp_z = np.zeros([1,21])
            temp = np.hstack((temp_z.T, temp.T)).T
            t=librosa.amplitude_to_db(temp,ref=80)
            img4 = librosa.display.specshow(t,sr=4000, x_axis='time', y_axis='hz', ax=ax[3])
            ax[3].set(title='Translated Audio')


        percentile_temp = percentile2 + cycle
        mask = generate_mask(predictions_t[1][i].numpy(), np.abs(noisy_audio)[i], percentile1, percentile_temp)
        m_max = max(mask.flatten())
        p_max = max(predictions_d[i].numpy().flatten())
        temp = mask.reshape([200,21])

        if epoch%plot_freq == 0 and i%10==0:
            t=librosa.amplitude_to_db(temp,ref=80)
            img5 = librosa.display.specshow(t,sr=4000, x_axis='time', y_axis='hz', ax=ax[4])
            ax[4].set(title='Masked Audio as Denoiser Ground Truth')
            plt.tight_layout()
            plt.savefig(os.path.join(folder_name, 'cycle_{:04d}_epoch_{:04d}_sample_{:04d}.png'.format(cycle, epoch, i)))
    
        reconstructed_denoised_mic = stft2time(predictions_d[i].numpy(), noisy_audio[i])
        
        predictions = kws_model.predict(reconstructed_denoised_mic.reshape([1,4000]))
        predicted_labels = np.argmax(predictions, axis=1)[0]

        if predicted_labels == label_dict[labels[i][0]]:
            denoised_correct_classification += 1
        
    p_denoised_correct = round(denoised_correct_classification/test_batch_size, 2)
    
    return p_denoised_correct

## Train the model

# KWS35 Evaluation

In [None]:
label_dict_35 ={'follow': 35,
 'learn': 4,
 'backward': 5,
 'visual': 2,
 'dog': 6,
 'cat': 34,
 'house': 27,
 'bird': 21,
 'bed': 13,
 'tree': 17,
 'eight': 23,
 'marvin': 28,
 'five': 30,
 'go': 11,
 'no': 24,
 'forward': 26,
 'down': 33,
 'four': 20,
 'seven': 18,
 'wow': 3,
 'on': 19,
 'zero': 16,
 'up': 12,
 'three': 32,
 'six': 25,
 'left': 8,
 'happy': 9,
 'sheila': 29,
 'right': 22,
 'nine': 10,
 'one': 15,
 'off': 31,
 'two': 7,
 'yes': 36,
 'stop': 14,
 '_silence_': 0}

In [None]:
def kws_set_flag_35():
    
    current_dir = os.getcwd()
    MODEL_NAME = 'ds_tc_resnet'
    # MODEL_NAME = 'svdf'
    MODELS_PATH = os.path.join(current_dir, "models")
    MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + "_40k_fs_4k_35L/")
    MODEL_PATH
    FLAGS = model_params.HOTWORD_MODEL_PARAMS[MODEL_NAME]

    # set speech feature extractor properties

    FLAGS.window_size_ms = 30.0
    FLAGS.window_stride_ms = 10.0
    FLAGS.mel_num_bins = 80
    FLAGS.dct_num_features = 40
    FLAGS.feature_type = 'mfcc_tf'
    FLAGS.preprocess = 'raw'

    # for numerical correctness of streaming and non streaming models set it to 1
    # but for real use case streaming set it to 0
    FLAGS.causal_data_frame_padding = 0

    FLAGS.use_tf_fft = True
    FLAGS.mel_non_zero_only = not FLAGS.use_tf_fft

    # set training settings
    FLAGS.train = 1
    # reduced number of training steps for test only
    # so model accuracy will be low,
    # to improve accuracy set how_many_training_steps = '40000,40000,20000,20000'
    FLAGS.how_many_training_steps = '40000,40000,20000,20000'
    FLAGS.learning_rate = '0.001,0.0005,0.0001,0.00002'
    FLAGS.lr_schedule = 'linear'

    # data augmentation parameters
    FLAGS.resample = 0.15
    FLAGS.time_shift_ms = 100
    FLAGS.use_spec_augment = 1
    FLAGS.time_masks_number = 2
    FLAGS.time_mask_max_size = 25
    FLAGS.frequency_masks_number = 2
    FLAGS.frequency_mask_max_size = 7
    FLAGS.pick_deterministically = 1
    
    FLAGS.train_dir = MODEL_PATH
    FLAGS.sample_rate = 4000
    FLAGS.mel_upper_edge_hertz = 2000
    FLAGS.model_name = MODEL_NAME
    if MODEL_NAME == 'ds_tc_resnet':
      # it is an example of model streaming with strided convolution, strided pooling and dilated convolution
      FLAGS.activation = 'relu'
      FLAGS.dropout = 0.0
      FLAGS.ds_filters = '128, 64, 64, 64, 128, 128'
      FLAGS.ds_filter_separable = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_repeat = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_residual = '0, 1, 1, 1, 0, 0' # residual can not be applied with stride
    #   FLAGS.ds_kernel_size = '11, 5, 15, 7, 29, 1'
      FLAGS.ds_kernel_size = '11, 5, 15, 17, 15, 1'
      FLAGS.ds_dilation = '1, 1, 1, 1, 2, 1'
      FLAGS.ds_stride = '1, 1, 1, 1, 1, 1'
      FLAGS.ds_pool = '1, 2, 1, 1, 1, 1'
      FLAGS.ds_padding = "'causal', 'causal', 'causal', 'causal', 'causal', 'causal'"
    FLAGS.clip_duration_ms = 1000  # standard audio file in this data set has 1 sec length
    FLAGS.batch_size = 100
    FLAGS.wanted_words = 'visual,wow,learn,backward,dog,two,left,happy,nine,go,up,bed,stop,one,zero,tree,seven,on,four,bird,right,eight,no,six,forward,house,marvin,sheila,five,off,three,down,cat,follow,yes'
    flags = model_flags.update_flags(FLAGS)
    return flags

In [None]:
flags_35 = kws_set_flag_35()
flags_35.batch_size=1

kws_model = models.MODELS[flags_35.model_name](flags_35)
weights_name='best_weights'
kws_model.load_weights(os.path.join(flags_35.train_dir,weights_name)).expect_partial()
kws_model.compile(run_eagerly = True)

In [None]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest)

normalized_test_imu = (input_imu/np.max(input_imu.reshape([test_batch_size, 20, 21])))*np.max(np.max(np.abs(noisy_audio).reshape([test_batch_size, 200, 21])))
denoised_static = checkpoint.denoiser([np.abs(noisy_audio), normalized_test_imu], training=False).numpy()

denoised_correct_classification = 0
for i in range(0, test_batch_size):
    reconstructed_denoised_mic = stft2time(denoised_static[i], noisy_audio[i])
    predictions = kws_model.predict(reconstructed_denoised_mic.reshape([1,4000]))
    predicted_labels = np.argmax(predictions, axis=1)[0]

    if predicted_labels == label_dict_35[labels[i][0]]:
        denoised_correct_classification += 1
print("KWS35 Accuracy: ", str(round(denoised_correct_classification/test_batch_size, 2)))