# SEGAN_OM

a GAN based filter method for speech enhancement

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
from io import *
import os.path

In [None]:
from IPython.display import clear_output

In [None]:
WINDOWS_SIZE=2**14
STRIDE = 0.5
sampling_rate=16000
SAMPLING = tf.constant(sampling_rate,dtype=tf.int32,shape=())
KERNEL_SIZE=31
BATCH_SIZE = 10 # used for loading the data 
EPOCHS = 50

### 1) Data pipeline

In [None]:
path = r".\Dataset\clean"
path_noisy = r".\Dataset\noisy"
files_number = len(os.listdir(path))

In [None]:
# From TFRecord Tutorial
# The following functions can be used to convert a value to a type compatible
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def slice_signal(signal, window_size=2**14, stride=0.5):
    """ Return windows of the given signal by sweeping in stride fractions
        of window
    """
    #assert signal.ndim == 1, signal.ndim
    n_samples = signal.shape[0]
    overlap = int(window_size * stride)
    slices = []
    for beg_i in range(0, n_samples, overlap):
        end_i = beg_i + window_size
        slice_ = signal[beg_i:end_i]
        if slice_.shape[0] == window_size:
            slices.append(slice_)
    return np.array(slices, dtype=np.float32)

In [None]:
def make_tf_example(track_1, track_2):
    features={
        'clean': _bytes_feature(track_1),
        'noisy': _bytes_feature(track_2),
    }
    
    return tf.train.Example(features=tf.train.Features(feature=features))

#feature = _bytes_feature(clean_list[0][0].tobytes())

In [None]:
i = 0
record_file = 'new_waves.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
    for file in os.listdir(path):  # iterate over each image
        i+=1
        name, _ = os.path.splitext(file)
        file_path = os.path.join(path,file)
        file = tf.io.read_file(file_path)
        clean, sample_rate = tf.audio.decode_wav(file) # returns 2 objs: tf.Tensor(sample_rate, shape=(), dtype=int32), tf.Tensor([[x]...], shape=(46797, 1), dtype=float32)
        if not tf.math.equal(SAMPLING,sample_rate):
            raise ValueError(f'Sampling rate of clean is expected to be {SAMPLING}! Got {sample_rate}')
        file_path = os.path.join(path_noisy, name + '_CAFE_SNR_0.wav')
        file = tf.io.read_file(file_path)
        noisy, sample_rate = tf.audio.decode_wav(file)
        if not tf.math.equal(SAMPLING,sample_rate):
            raise ValueError(f'Sampling rate of noisy is expected to be {SAMPLING}! Got {sample_rate}')
        seq_clean = slice_signal(clean,WINDOWS_SIZE,STRIDE)
        seq_noisy = slice_signal(noisy,WINDOWS_SIZE,STRIDE)
        for track_clean, track_noisy in zip(seq_clean, seq_noisy):
            track_clean = track_clean.tostring()
            track_noisy = track_noisy.tostring()
            tf_example = make_tf_example(track_clean, track_noisy)
            writer.write(tf_example.SerializeToString())
        clear_output()
        print(f"file {i} from {files_number} written")

### 2) Layers & Model

#### 2.1: Build the Generator

In [None]:
def downsample(filter_width, kernel=31, #size, 
              strides = 2, padding = 'same', init= None):
    """
    creates a 1D-Conv-Block for the Generator with given kernel & filters.
    
    Arguments:
    filter_size -- tf.keras.Conv1D.filters
    kernel -- tf.keras.Conv1D.kernel_size, set to 31 for this application
    strides -- optional, default is '2' for this application
    padding -- optional, default is 'same'
    init -- weights initializer, will be set to He is none is given
    
    Returns:
    block -- tf.Tensor block of a 1D-Conv
    """
    # set the initializer if none is given
    if init is None:
        init = tf.keras.initializers.he_normal()
    
    # make the convolutional block
    block = tf.keras.Sequential()
    block.add(tf.keras.layers.Conv1D(filters = filter_width, kernel_size = kernel, strides=strides,
                                     #(kernel, 1), strides=(strides, 1), #for conv2d
                                     padding=padding, kernel_initializer=init, use_bias=False))
    # add the activation function
    block.add(tf.keras.layers.PReLU())
    
    return block

In [None]:
def upsample(filter_width, kernel=31, #size, 
              strides = 2, padding = 'same', init= None):
    """
    creates a 1D-Deconv-Block for the Generator with given kernel & filters.
    
    Arguments:
    filter_size -- tf.keras.Conv1D.filters
    kernel -- tf.keras.Conv1D.kernel_size, set to 31 for this application
    strides -- optional, default is '2' for this application
    padding -- optional, default is 'same'
    init -- weights initializer, will be set to He is none is given
    
    Returns:
    block -- tf.Tensor block of a 1D-Conv
    """
    # set the initializer if none is given
    if init is None:
        init = tf.keras.initializers.he_normal()
    
    # make the convolutional block
    block = tf.keras.Sequential()
    block.add(tf.keras.layers.Conv2DTranspose(filters = filter_width, kernel_size = (kernel, 1), strides=(strides, 1),
                                     padding=padding, kernel_initializer=init, use_bias=False))
    
    
    # add the activation function
    block.add(tf.keras.layers.LeakyReLU())
    
    return block

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[2**14, 1])
    #inputs = tf.keras.backend.expand_dims(inputs, axis=1)
    
    down_stack = [
        downsample(16, 16384),
        downsample(32, 8192),
        downsample(32, 4096),
        downsample(64, 2048),
        downsample(64, 1024),
        downsample(128, 512),
        downsample(128, 256),
        downsample(256, 128),
        downsample(256,  64),
        downsample(512,  32),
        downsample(1024, 16),
    ]

    up_stack = [
        upsample(512,  32),
        upsample(256,  64),
        upsample(256, 128),
        upsample(128, 256),
        upsample(128, 512),
        upsample(64, 1024),
        upsample(64, 2048),
        upsample(32, 4096),
        upsample(32, 8192),
        upsample(16, 16382),
    ]

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
        
    skips = reversed(skips[:-1])
    
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = tf.keras.backend.expand_dims(x, axis=2)
        x = up(x)
        x = tf.keras.backend.squeeze(x, axis=2)
        x = tf.keras.layers.Concatenate()([x, skip])

        
    x = tf.keras.backend.expand_dims(x, axis=2)
    x = upsample(1, 32768)(x)
    x = tf.keras.backend.squeeze(x, axis=2)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
generator.summary()
#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)


#### 2.2 Build Discriminator

### 3) Load Data

In [None]:
raw_waves_dataset = tf.data.TFRecordDataset('new_waves.tfrecords')
#raw_waves_dataset = tf.data.TFRecordDataset('waves.tfrecords')
raw_waves_dataset

In [None]:
def _parse_wave_function(example_proto):
    feature_description = {
        'clean': tf.io.FixedLenFeature([], tf.string),
        'noisy': tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(example_proto, feature_description)

In [None]:
def _decode_parsed_wave(parsed_wave):
    clean_raw = parsed_wave['clean']
    noisy_raw = parsed_wave['noisy']
    return tf.io.decode_raw(clean_raw,tf.float32), tf.io.decode_raw(noisy_raw,tf.float32)

In [None]:
def _correct_dim_wave(dec_clean, dec_noisy):
    return tf.expand_dims(dec_clean, 1), tf.expand_dims(dec_noisy, 1)

In [None]:
parsed_waves_dataset = raw_waves_dataset.map(_parse_wave_function)
parsed_waves_dataset

In [None]:
decoded_waves_dataset = parsed_waves_dataset.map(_decode_parsed_wave)
decoded_waves_dataset

In [None]:
corrected_waves_dataset = decoded_waves_dataset.map(_correct_dim_wave)
corrected_waves_dataset

In [None]:
dataset = corrected_waves_dataset.shuffle(600)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset

In [None]:
training_set = dataset.take(3)
testing_set  = dataset.skip(3)

In [None]:
training_set

In [None]:
testing_set

### 4 Train the model

In order to understand the functionning of the generator we will train it over 50 epochs using the l1-loss function.

In [None]:
def generator_loss(gen_output, target):
    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    return l1_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 generator=generator)

In [None]:
import datetime

log_dir="logs/"

summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_wave, target, epoch):
    with tf.GradientTape() as gen_tape:
        gen_output = generator(input_wave, training=True)

        gen_l1_loss = generator_loss(gen_output, target)

        generator_gradients = gen_tape.gradient(gen_l1_loss, generator.trainable_variables)
        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        with summary_writer.as_default():
            tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()
        for n, (target,input_wave) in train_ds.enumerate():
            print(f'epoch {epoch}, batch {n}')
            train_step(input_wave, target, epoch)
            clear_output()
        # saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
        checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [None]:
import time

In [None]:
fit(training_set, EPOCHS, testing_set) 