## Description

##### Implementation of "Speech enhancement by LSTM-based noise suppression followed by CNN-based speech restoration" paper - https://link.springer.com/article/10.1186/s13634-020-00707-1

### Libraries

In [None]:
import tensorflow as tf
import numpy as np
import typing

### Input

##### Let's assume we have an input of [batch_size, sequence_length, nr_features]

In [None]:
input_shape = (1, 1024, 2)
x = tf.random.normal(input_shape)
x = tf.constant(x, dtype=tf.float32)

In [None]:
x

### Layers

##### Setup individual layers and test to experiment and test that they work

#### Dense

In [None]:
class Dense(tf.Module):
    
    def __init__(self, out_features, name=None):
        super().__init__(name=name)
        self.is_built = False # is built flag for dynamic input size inference
        self.out_features = out_features
        
    def __call__(self, x_in):
        if not self.is_built:
            self.w = tf.Variable(
                tf.random.normal([x.shape[-1], self.out_features]), name='w')
            self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
            self.is_built = True
        
        x_hat = tf.matmul(x, self.w) + self.b
        return x_hat

In [None]:
dense_layer = Dense(out_features=425)

In [None]:
dense_layer_output = dense_layer(x)
dense_layer_output.shape

#### LSTM

In [None]:
class LSTM(tf.Module):
    
    def __init__(self,
                 num_units: typing.Union[typing.List[int], int],
                 use_peepholes: bool,
                 name: str = None):
        super(LSTM, self).__init__(name)
        
        # if num_units is given as int, ensure it's set to a list
        if isinstance(num_units, int):
            num_units = [num_units]
        else:
            num_units = num_units      
        
        self.num_units = num_units
        self.use_peepholes = use_peepholes
        
        self.lstm_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=size, 
                                                              use_peepholes=self.use_peepholes) for size in self.num_units]
        self.multi_lstm_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(self.lstm_layers)
                
    def __call__(self, x_in):
        
        return tf.compat.v1.nn.dynamic_rnn(cell=self.multi_lstm_cell, inputs=x_in, dtype=tf.float32)

In [None]:
lstm_layer = LSTM(num_units=425, use_peepholes=True)

In [None]:
lstm_layer_output, lstm_layer_state = lstm_layer(x_in=dense_layer_output)
lstm_layer_output.shape

#### Conv1D

In [None]:
class Conv1D(tf.Module):
    
    def __init__(self,
               nr_filters: int,
               kernel: int,
               stride: int,
               use_bias: bool,
               name = None
              ):
        
        super(Conv1D, self).__init__(name)
        
        self.nr_filters = nr_filters
        self.kernel = kernel
        self.stride = stride
        self.use_bias = use_bias
        
        self.is_built: bool = False
        
        self.W: tf.Tensor = None 
        self.b: tf.Tensor = None
            
    def __call__(self, x_in):
        
        if not self.is_built:
            in_channels = x_in.shape[-1]
            filter_weights_shape = (self.kernel, in_channels, self.nr_filters)
            
            self.W = tf.Variable(tf.random.normal(filter_weights_shape, stddev=0.1),
                                trainable=True,
                                dtype = tf.float32,
                                name = "conv1d_filters")
            if self.use_bias:
                self.b = tf.Variable(tf.random.normal([self.nr_filters]))
            
            self.is_built = True
            
        if self.use_bias:
            return tf.add(
                tf.nn.conv1d(
                input=x_in,
                filters=self.W,
                stride=self.stride,
                padding="SAME"
                ),
                self.b,
                name="conv1d_layer_with_bias"
            )
        else:
            return tf.nn.conv1d(
                input=x_in,
                filters=self.W,
                stride=self.stride,
                padding="SAME"
                )

In [None]:
conv1d_layer = Conv1D(nr_filters=88, kernel=24, stride=1, use_bias=True)

In [None]:
conv1d_layer(x_in=x)

#### MaxPooling1D

In [None]:
max_pooling_1d = tf.keras.layers.MaxPool1D(pool_size=2)

In [None]:
max_pooling_1d(x)

#### Conv1D Transpose 

In [None]:
conv1d_transpose_layer = tf.keras.layers.Conv1DTranspose(filters=2*88, kernel_size=24, strides=1, use_bias=True)

In [None]:
conv1d_transpose_layer

In [None]:
conv1d_transpose_layer(x)

#### Upsampling

In [None]:
upsampling_layer = tf.keras.layers.UpSampling1D(size=2)
upsampling_layer(x)

### Noise Suppression Module

##### The noise suppression module makes use of FF and LSTM layers

In [None]:
class NoiseSuppressor(tf.Module):
    
    def __init__(self,
                output_size: int,
                name: str = None):
        super(NoiseSuppressor, self).__init__(name)
        
        self.dense_layer_1 = Dense(out_features=425)
        self.lstm_layer_1 = LSTM(num_units=425, use_peepholes=True)
        self.lstm_layer_2 = LSTM(num_units=425, use_peepholes=True)
        self.dense_layer_2 = Dense(out_features=425)
        self.dense_layer_3 = Dense(out_features=425)
        self.dense_layer_4 = Dense(out_features=output_size)
        
    def __call__(self, x_in: tf.Tensor):
        
        dense_layer_1 = self.dense_layer_1(x_in=x_in)
        dense_layer_1 = tf.nn.relu(dense_layer_1)
        lstm_layer_1, _ = self.lstm_layer_1(x_in=dense_layer_1)
        lstm_layer_2, _ = self.lstm_layer_2(x_in=lstm_layer_1)
        dense_layer_2 = self.dense_layer_2(x_in=lstm_layer_2)
        dense_layer_2 = tf.nn.relu(dense_layer_2)
        dense_layer_3 = self.dense_layer_3(x_in=dense_layer_2)
        dense_layer_3 = tf.nn.relu(dense_layer_3)
        dense_layer_4 = self.dense_layer_4(x_in=dense_layer_3)
        dense_layer_4 = tf.math.tanh(dense_layer_4)
        
        return dense_layer_4

In [None]:
noise_suppressor = NoiseSuppressor(output_size=500)

In [None]:
test = noise_suppressor(x_in=x)

In [None]:
test

### Speech Restoration Network

##### The speech restoration module makes use of con1d, conv1d transpose, max pooling and upsampling layers

###### Below is the class for the first block in figure 3 of the paper

In [None]:
class SpeechRestorationNetworkBlock1(tf.Module):
    
    def __init__(self,
                nr_conv_filters: int,
                conv_filter_size: int,
                name: str = None
                ):
        super(SpeechRestorationNetworkBlock1, self).__init__(name)
        
        self.conv1D_1 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_2 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        
        self.max_pooling_1 = tf.keras.layers.MaxPool1D(pool_size=2)
        self.conv1D_3 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_4 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        
        self.max_pooling_2 = tf.keras.layers.MaxPool1D(pool_size=2)
        
        self.conv1D_5 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.upsampling_1 = tf.keras.layers.UpSampling1D(size=2)
        
        self.conv1D_6 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_7 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        
        self.upsampling_2 = tf.keras.layers.UpSampling1D(size=2)
        self.conv1D_8 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_9 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)

        self.conv1D_last = Conv1D(nr_filters=2, kernel=conv_filter_size, stride=1, use_bias=True)
        
    
    def __call__(self, x_in: tf.Tensor):
        
        x = self.conv1D_1(x_in)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_2(x) # this will be addeed at a later stage in model
        temp1 = tf.nn.leaky_relu(x)
        x = self.max_pooling_1(temp1)
        x = self.conv1D_3(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_4(x) # this will be added at a later stage in model
        temp2 = tf.nn.leaky_relu(x)
        x = self.max_pooling_2(temp2)
        x = self.conv1D_5(x)
        x = tf.nn.leaky_relu(x)
        x = self.upsampling_1(x)
        x = self.conv1D_6(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_7(x)
        x = tf.nn.leaky_relu(x)
        add_1 = tf.add(x, temp2)
        x = self.upsampling_2(add_1)
        x = self.conv1D_8(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_9(x)
        x = tf.nn.leaky_relu(x)
        add_2 = tf.add(x, temp1)
        x = self.conv1D_last(add_2)
        return tf.nn.leaky_relu(x)

In [None]:
speech_restoration_network_block1 = SpeechRestorationNetworkBlock1(nr_conv_filters=88, conv_filter_size=24)

In [None]:
block1_output = speech_restoration_network_block1(x_in=test)

###### Below is the class for the first block in figure 4 of the paper

In [None]:
class SpeechRestorationNetworkBlock2(tf.Module):
    
    def __init__(self,
                nr_conv_filters: int,
                conv_filter_size: int,
                name: str = None
                ):
        super(SpeechRestorationNetworkBlock2, self).__init__(name)
        
        self.conv1D_1 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_2 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=2, use_bias=True)
        
        self.conv1D_3 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_4 = Conv1D(nr_filters=2*nr_conv_filters, kernel=conv_filter_size, stride=2, use_bias=True)
        
        self.conv1D_5 = Conv1D(nr_filters=nr_conv_filters, kernel=conv_filter_size, stride=1, use_bias=True)
        self.conv1D_transpose_1 = tf.keras.layers.Conv1DTranspose(filters=2*nr_conv_filters, kernel_size=conv_filter_size, strides=2, padding="same") # padding must be set to same to ensure output shape = input shape
        
        self.conv1D_transpose_2 = tf.keras.layers.Conv1DTranspose(filters=2*nr_conv_filters, kernel_size=conv_filter_size, strides=1, padding="same")
        self.conv1D_transpose_3 = tf.keras.layers.Conv1DTranspose(filters=nr_conv_filters, kernel_size=conv_filter_size, strides=2, padding="same")
        
        self.conv1D_transpose_4 = tf.keras.layers.Conv1DTranspose(filters=nr_conv_filters, kernel_size=conv_filter_size, strides=1, padding="same")
        
        self.conv1D_last = Conv1D(nr_filters=2, kernel=conv_filter_size, stride=1, use_bias=True)
        
    def __call__(self, x_in: tf.Tensor):
        x = self.conv1D_1(x_in)
        temp1 = tf.nn.leaky_relu(x)
        x = self.conv1D_2(temp1)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_3(x)
        temp2 = tf.nn.leaky_relu(x)
        x = self.conv1D_4(temp2)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_5(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_transpose_1(x)
        x = tf.nn.leaky_relu(x)
        add_1 = tf.add(x, temp2)
        x = self.conv1D_transpose_2(add_1)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_transpose_3(x)
        x = tf.nn.leaky_relu(x)
        add_2 = tf.add(x, temp1)
        x = self.conv1D_transpose_4(add_2)
        x = tf.nn.leaky_relu(x)
        x = self.conv1D_last(x)
        return tf.nn.leaky_relu(x)

In [None]:
speech_restoration_network_block2 = SpeechRestorationNetworkBlock2(nr_conv_filters=88, conv_filter_size=24)

In [None]:
speech_restoration_network_block2(x_in=block1_output) # input to block 2 is output from block 1