## Description

##### Implementation of "Speech enhancement by LSTM-based noise suppression followed by CNN-based speech restoration" paper - https://link.springer.com/article/10.1186/s13634-020-00707-1

##### Implementation of individual model layers

### Libraries

In [None]:
import tensorflow as tf
import numpy as np
import typing

In [None]:
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

### Input

##### Let's assume we have an input of [batch_size, sequence_length, nr_features]

In [None]:
input_shape = (1, 1024, 2)
x = tf.random.normal(input_shape)
x = tf.constant(x, dtype=tf.float32)

In [None]:
x

### Layers

##### Setup individual layers and test that they work

#### Dense

In [None]:
class Dense(tf.Module):
    
    def __init__(self, out_features, name=None):
        super().__init__(name=name)
        self.is_built = False # is built flag for dynamic input size inference
        self.out_features = out_features
        
    def __call__(self, x_in):
        if not self.is_built:
            self.w = tf.Variable(
                tf.random.normal([x.shape[-1], self.out_features]), name='w')
            self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
            self.is_built = True
        
        x_hat = tf.matmul(x, self.w) + self.b
        return x_hat

In [None]:
dense_layer = Dense(out_features=425)

In [None]:
dense_layer_output = dense_layer(x)
dense_layer_output.shape

#### LSTM

In [None]:
class LSTM(tf.Module):
    
    def __init__(self,
                 num_units: typing.Union[typing.List[int], int],
                 use_peepholes: bool,
                 name: str = None):
        super(LSTM, self).__init__(name)
        
        # if num_units is given as int, ensure it's set to a list
        if isinstance(num_units, int):
            num_units = [num_units]
        else:
            num_units = num_units      
        
        self.num_units = num_units
        self.use_peepholes = use_peepholes
        
        self.lstm_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=size, 
                                                              use_peepholes=self.use_peepholes) for size in self.num_units]
        self.multi_lstm_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(self.lstm_layers)
                
    def __call__(self, x_in):
        
        return tf.compat.v1.nn.dynamic_rnn(cell=self.multi_lstm_cell, inputs=x_in, dtype=tf.float32)

In [None]:
lstm_layer = LSTM(num_units=425, use_peepholes=True)

In [None]:
lstm_layer_output, lstm_layer_state = lstm_layer(x_in=dense_layer_output)
lstm_layer_output.shape

#### Conv1D

In [None]:
class Conv1D(tf.Module):
    
    def __init__(self,
               nr_filters: int,
               kernel: int,
               stride: int,
               use_bias: bool,
               name = None
              ):
        
        super(Conv1D, self).__init__(name)
        
        self.nr_filters = nr_filters
        self.kernel = kernel
        self.stride = stride
        self.use_bias = use_bias
        
        self.is_built: bool = False
        
        self.W: tf.Tensor = None 
        self.b: tf.Tensor = None
            
    def __call__(self, x_in):
        
        if not self.is_built:
            in_channels = x_in.shape[-1]
            filter_weights_shape = (self.kernel, in_channels, self.nr_filters)
            
            self.W = tf.Variable(tf.random.normal(filter_weights_shape, stddev=0.1),
                                trainable=True,
                                dtype = tf.float32,
                                name = "conv1d_filters")
            if self.use_bias:
                self.b = tf.Variable(tf.random.normal([self.nr_filters]))
            
            self.is_built = True
            
        if self.use_bias:
            return tf.add(
                tf.nn.conv1d(
                input=x_in,
                filters=self.W,
                stride=self.stride,
                padding="SAME"
                ),
                self.b,
                name="conv1d_layer_with_bias"
            )
        else:
            return tf.nn.conv1d(
                input=x_in,
                filters=self.W,
                stride=self.stride,
                padding="SAME"
                )

In [None]:
conv1d_layer = Conv1D(nr_filters=88, kernel=24, stride=1, use_bias=True)

In [None]:
conv1d_layer(x_in=x)

#### MaxPooling1D

In [None]:
max_pooling_1d = tf.keras.layers.MaxPool1D(pool_size=2)

In [None]:
max_pooling_1d(x)

#### Conv1D Transpose 

In [None]:
conv1d_transpose_layer = tf.keras.layers.Conv1DTranspose(filters=2*88, kernel_size=24, strides=1, use_bias=True)

In [None]:
conv1d_transpose_layer

In [None]:
conv1d_transpose_layer(x)

#### Upsampling

In [None]:
upsampling_layer = tf.keras.layers.UpSampling1D(size=2)
upsampling_layer(x)