This Notebook is the second part of Speech Enhancement project with Tensorflow, Please see https://www.kaggle.com/tariqblecher/speech-enhancement-tensorflow-unet-softmask for more information.

Contact : tariq.blecher@gmail.com

# Imports

In [None]:
pip install pypesq

In [None]:
import pandas as pd
import numpy as np
import os
import sys
import librosa
import librosa.display
import matplotlib.pyplot as plt
from IPython.display import Audio
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import keras
from keras.models import Sequential
import tensorflow_io as tfio
import warnings
import glob
import warnings
import os
import time
import datetime
from pypesq import pesq
import soundfile as sf
warnings.filterwarnings("ignore", category=UserWarning)

# Load Data

In [None]:
sr=8000
speech_length_pix_sec=27e-3
total_length = 3.6
trim_length = 28305
n_fft=255
frame_length=255
frame_step = 110

noisefiles = glob.glob('/kaggle/input/hospital-ambient-noise/Hospita*original/**/*.wav',recursive=True)
files= glob.glob('/kaggle/input/speech-accent-archive/**/*.mp3',recursive=True)

print(len(files),'clean data files')
print(len(noisefiles),'noise data files')


from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
rs = RandomState(MT19937(SeedSequence(123456789)))

def white_noise(data,factor=0.035):
    noise_amp = factor*np.max(data)*rs.normal()
    data = data + noise_amp*rs.normal(size=data.shape)
    return data

def urban_noise(data, sr=sr,factor=0.1):
    noise_max = 0
    while noise_max<=0:
        noisefile = rs.choice(noisefiles)
        noisefile,sr = librosa.load(noisefile,sr=sr)
        noisefile,_ = librosa.effects.trim(noisefile,top_db=38)
        noisefile = np.resize(noisefile,data.shape)
        noise_max = noisefile.max()

    mixed = noisefile * factor * data.max()/noisefile.max() + data
    return mixed

def preprocess(filepath,sr=sr, add_white_noise=False, white_noise_factor=0.035, add_urban_noise=False,
               urban_noise_factor=0.1,return_wav=False,fixed_start=False):
    wav,sr = librosa.load(filepath,sr=sr)
    wav,_ = librosa.effects.trim(wav,top_db=38)
    size= wav.shape[0]
    random_start = rs.randint(0,size-trim_length-1)
    if fixed_start:
        random_start=fixed_start
    wav = wav[random_start:random_start+trim_length]
    if add_white_noise:
        wav = white_noise(wav,factor=white_noise_factor)
    if add_urban_noise:
        wav = urban_noise(wav,sr=sr,factor=urban_noise_factor)        
    if return_wav:
        return wav
    
    spectrogram =  tf.signal.stft(wav, frame_length=frame_length, fft_length=n_fft,
                                      frame_step=frame_step)
    spectrogram = tf.expand_dims(spectrogram, axis=2).numpy()
    return spectrogram

specclean = np.abs(preprocess(files[0]))
print('spectrum shape', specclean.shape)

In [None]:
num_test = 100
files = files[:num_test]
try:
    files.remove('/kaggle/input/speech-accent-archive/recordings/recordings/maltese2.mp3')
except:
    pass
spectest = preprocess(files[0])
spectogram_data_clean = np.zeros((len(files),*spectest.shape),dtype=complex)
spectogram_data_corrupted =np.zeros((len(files),*spectest.shape),dtype=complex)

for file_ind, afile in enumerate(files):
    wav,sr = librosa.load(afile,sr=sr)
    wav,_ = librosa.effects.trim(wav,top_db=38)
    size= wav.shape[0]
    random_start = rs.randint(0,size-trim_length-1)
    spectogram_data_clean[file_ind] = preprocess(afile,fixed_start=random_start)
    spectogram_data_corrupted[file_ind] = preprocess(afile,add_urban_noise=True,urban_noise_factor=0.6,
                                                     white_noise_factor=0.035,add_white_noise=True,fixed_start=random_start)

    
spectogram_data_corrupted_abs = np.abs(spectogram_data_corrupted)
spectogram_data_clean_abs = np.abs(spectogram_data_clean)

# Create Model

In [None]:
from copy import deepcopy
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    BatchNormalization,
    Conv2D,
    Conv2DTranspose,
    MaxPooling2D,
    Dropout,
    SpatialDropout2D,
    UpSampling2D,
    Input,
    concatenate,
    multiply,
    add,
    Activation,
)


def upsample_conv(filters, kernel_size, strides, padding):
    return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)


def upsample_simple(filters, kernel_size, strides, padding):
    return UpSampling2D(strides)


def attention_gate(inp_1, inp_2, n_intermediate_filters):
    """Attention gate. Compresses both inputs to n_intermediate_filters filters before processing.
       Implemented as proposed by Oktay et al. in their Attention U-net, see: https://arxiv.org/abs/1804.03999.
    """
    inp_1_conv = Conv2D(
        n_intermediate_filters,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer="he_normal",
    )(inp_1)
    inp_2_conv = Conv2D(
        n_intermediate_filters,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer="he_normal",
    )(inp_2)

    f = Activation("relu")(add([inp_1_conv, inp_2_conv]))
    g = Conv2D(
        filters=1,
        kernel_size=1,
        strides=1,
        padding="same",
        kernel_initializer="he_normal",
    )(f)
    h = Activation("sigmoid")(g)
    return multiply([inp_1, h])


def attention_concat(conv_below, skip_connection):
    """Performs concatenation of upsampled conv_below with attention gated version of skip-connection
    """
    below_filters = conv_below.get_shape().as_list()[-1]
    attention_across = attention_gate(skip_connection, conv_below, below_filters)
    return concatenate([conv_below, attention_across])


def conv2d_block(
    inputs,
    use_batch_norm=True,
    dropout=0.3,
    dropout_type="spatial",
    filters=16,
    kernel_size=(3, 3),
    activation="relu",
    kernel_initializer="he_normal",
    padding="same",
):

    if dropout_type == "spatial":
        DO = SpatialDropout2D
    elif dropout_type == "standard":
        DO = Dropout
    else:
        raise ValueError(
            f"dropout_type must be one of ['spatial', 'standard'], got {dropout_type}"
        )

    c = Conv2D(
        filters,
        kernel_size,
        activation=activation,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=not use_batch_norm,
    )(inputs)
    if use_batch_norm:
        c = BatchNormalization()(c)
    if dropout > 0.0:
        c = DO(dropout)(c)
    c = Conv2D(
        filters,
        kernel_size,
        activation=activation,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=not use_batch_norm,
    )(c)
    if use_batch_norm:
        c = BatchNormalization()(c)
    return c


def custom_unet(
    input_shape,
    num_classes=1,
    activation="relu",
    use_batch_norm=True,
    upsample_mode="deconv",  # 'deconv' or 'simple'
    dropout=0.3,
    dropout_change_per_layer=0.0,
    dropout_type="spatial",
    use_dropout_on_upsampling=False,
    use_attention=False,
    filters=16,
    num_layers=4,
    output_activation="sigmoid",
):  # 'sigmoid' or 'softmax'

    if upsample_mode == "deconv":
        upsample = upsample_conv
    else:
        upsample = upsample_simple

    # Build U-Net model
    inputs = Input(input_shape)
    inputs_copy = tf.identity(inputs)
    x = inputs / tf.reduce_max(inputs)

    down_layers = []
    for l in range(num_layers):
        x = conv2d_block(
            inputs=x,
            filters=filters,
            use_batch_norm=use_batch_norm,
            dropout=dropout,
            dropout_type=dropout_type,
            activation=activation,
        )
        down_layers.append(x)
        x = MaxPooling2D((2, 2))(x)
        dropout += dropout_change_per_layer
        filters = filters * 2  # double the number of filters with each layer

    x = conv2d_block(
        inputs=x,
        filters=filters,
        use_batch_norm=use_batch_norm,
        dropout=dropout,
        dropout_type=dropout_type,
        activation=activation,
    )

    if not use_dropout_on_upsampling:
        dropout = 0.0
        dropout_change_per_layer = 0.0

    for conv in reversed(down_layers):
        filters //= 2  # decreasing number of filters with each layer
        dropout -= dropout_change_per_layer
        x = upsample(filters, (2, 2), strides=(2, 2), padding="same")(x)
        if use_attention:
            x = attention_concat(conv_below=x, skip_connection=conv)
        else:
            x = concatenate([x, conv])
        x = conv2d_block(
            inputs=x,
            filters=filters,
            use_batch_norm=use_batch_norm,
            dropout=dropout,
            dropout_type=dropout_type,
            activation=activation,
        )

    output_mask = Conv2D(num_classes, (1, 1), activation=output_activation)(x)
    outputs = keras.layers.Multiply()([output_mask, inputs_copy])
    model = Model(inputs=[inputs], outputs=[outputs])
    return model



In [None]:
model = custom_unet(
    input_shape=(256, 128, 1),
    use_batch_norm=True,
    num_classes=1,
    filters=16, 
    num_layers=4,
    dropout=0.2,
    output_activation='sigmoid')


def signal_enhancement_loss(y_true, y_pred):
    mae = tf.abs(y_true - y_pred)
    speech_loss =  2 * tf.abs(y_true**2 - y_pred*y_true)
    return tf.reduce_mean(mae, axis=-1) + tf.reduce_mean(speech_loss, axis=-1) # Note the `axis=-1`

model.compile(optimizer='adam', loss=signal_enhancement_loss)
model.load_weights('/kaggle/input/speech-mask-model/model_weights_custom_loss2.h5')


In [None]:
keras.utils.vis_utils.plot_model(model,show_shapes=True)

# Inspect Results

In [None]:
history = model.evaluate(x=spectogram_data_corrupted_abs,y=spectogram_data_clean_abs)

In [None]:
files_to_test = files
pesq_with_noise = np.zeros(len(files_to_test))
pesq_denoised = np.zeros(len(files_to_test))

wav_clean_array =  np.zeros((len(files_to_test),trim_length))
wav_corrupt_array =  np.zeros((len(files_to_test),trim_length))
wav_correct_array =  np.zeros((len(files_to_test),trim_length))
spec_clean_array=  np.zeros((len(files_to_test), 256, 128))
spec_corrupt_array=  np.zeros((len(files_to_test), 256, 128))
spec_correct_array=  np.zeros((len(files_to_test), 256, 128))
loss_with_noise = np.zeros(len(files_to_test))
loss_denoised = np.zeros(len(files_to_test))

for ind in range(len(files_to_test)):
    corr = spectogram_data_corrupted[ind]
    clean = spectogram_data_clean[ind]
    corr_wav = tf.signal.inverse_stft(corr[:,:,0], frame_length=frame_length, fft_length=n_fft, frame_step=frame_step)
    clean_wav = tf.signal.inverse_stft(clean[:,:,0], frame_length=frame_length, fft_length=n_fft, frame_step=frame_step)
    corr_amp = np.abs(corr)
    corrected_amp = model.predict(np.expand_dims(corr_amp,0))
    corrected_spec = corrected_amp * np.exp(1j*np.angle(np.expand_dims(corr,0)))
    corrected_wav = tf.signal.inverse_stft(corrected_spec[0,:,:,0], frame_length=frame_length, fft_length=n_fft, frame_step=frame_step) 
     
    pesq_with_noise[ind] = pesq(clean_wav,corr_wav,sr)
    pesq_denoised[ind] = pesq(clean_wav,corrected_wav,sr)
    if np.isnan(pesq_denoised[ind]):
        print('nan detected')
        break
    wav_clean_array[ind] = clean_wav
    wav_corrupt_array[ind] = corr_wav
    wav_correct_array[ind] = corrected_wav
    spec_clean_array[ind] = np.abs(clean[:,:,0])
    spec_corrupt_array[ind] = np.abs(corr[:,:,0])
    spec_correct_array[ind] = corrected_amp[0,:,:,0]
    loss_with_noise[ind] = tf.reduce_mean(signal_enhancement_loss(np.abs(clean), corr_amp)).numpy()
    loss_denoised[ind] = tf.reduce_mean(signal_enhancement_loss(np.abs(clean[:,:,0]), corrected_amp[0,:,:,0])).numpy()
pesq_diff = pesq_denoised-pesq_with_noise
print(np.mean(pesq_with_noise), np.mean(pesq_denoised))
f'{np.mean(pesq_with_noise):.2f}, {np.mean(pesq_denoised):.2f}'

In [None]:
results_dir =f'results_unseen_test'
os.system(f'rm -r {results_dir}')
os.mkdir(results_dir)

In [None]:
fig = plt.figure()
plt.title('PESQ improvement')
plt.hist(pesq_diff);
plt.xlabel('PESQ corrected - PESQ corrupted')
plt.ylabel('Number')
fig.savefig(results_dir+'/pesq_hist', bbox_inches='tight')

In [None]:
ind=np.where(pesq_diff==pesq_diff.max())[0][0]
sf.write(results_dir +'/'+'clean_best.wav',wav_clean_array[ind],sr)
sf.write(results_dir +'/'+'corrupt_best.wav',wav_corrupt_array[ind],sr)
sf.write(results_dir +'/'+'correct_best.wav',wav_correct_array[ind],sr)

In [None]:
Audio(wav_clean_array[ind],rate=sr)

In [None]:
Audio(wav_corrupt_array[ind],rate=sr)

In [None]:
Audio(wav_correct_array[ind],rate=sr)

In [None]:
ind=np.where(pesq_diff==pesq_diff.max())[0][0]
fig,axes = plt.subplots(ncols=3,figsize=(20,10))
vmax=spec_clean_array[ind].max()/3
vmin=0
plt.subplot(1,3,1)
plt.title('Ground Truth')
plt.imshow(spec_clean_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.subplot(1,3,2)
plt.title('Ground Truth + Noise')
plt.imshow(spec_corrupt_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.subplot(1,3,3)
plt.title('Corrected')
plt.imshow(spec_correct_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.colorbar()
fig.savefig(results_dir+'/best_spec.png', bbox_inches='tight')

In [None]:
ind=np.where(pesq_diff==pesq_diff.min())[0][0]
sf.write(results_dir +'/'+'clean_worst.wav',wav_clean_array[ind],sr)
sf.write(results_dir +'/'+'corrupt_worst.wav',wav_corrupt_array[ind],sr)
sf.write(results_dir +'/'+'correct_worst.wav',wav_correct_array[ind],sr)

In [None]:
Audio(wav_clean_array[ind],rate=sr)

In [None]:
Audio(wav_corrupt_array[ind],rate=sr)

In [None]:
Audio(wav_correct_array[ind],rate=sr)

In [None]:
ind=np.where(pesq_diff==pesq_diff.min())[0][0]
fig,axes = plt.subplots(ncols=3,figsize=(20,10))
vmax=spec_clean_array[ind].max()/3
vmin=0
plt.subplot(1,3,1)
plt.title('Ground Truth')
plt.imshow(spec_clean_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.subplot(1,3,2)
plt.title('Ground Truth + Noise')
plt.imshow(spec_corrupt_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.subplot(1,3,3)
plt.title('Corrected')
plt.imshow(spec_correct_array[ind], origin='lower',vmax=vmax,vmin=vmin)
plt.colorbar()
fig.savefig(results_dir+'/worst_spec.png', bbox_inches='tight')

In [None]:
os.system('rm unseentest_results.tar.gz')

In [None]:
os.system(f'tar -cvzf unseentest_results.tar.gz {results_dir}')