# SEGAN_OM

a GAN based filter method for speech enhancement

### 1) Data pipeline

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import scipy.io.wavfile as wavfile
from io import *
from pathlib import Path
import os.path

In [None]:
# define the path to the dataset
p = Path('.\Dataset\clean')
for file in p.glob('*.wav'):
    print(file)
    break

In [None]:
# check the data --> to delete
for file in p.glob('*.wav'):
    fm, data = wavfile.read(file)
    #tf.audio.decode_wav(file,desired_channels=-1,desired_samples=16000)
    print(f"fm: {fm}, wav_data: {data}, len_data: ", len(data))
    break

In [None]:
# read each file and 
def read_and_slice(filename, wav_canvas_size, stride=0.5):
    fm, wav_data = wavfile.read(filename)
    if fm != 16000:
        raise ValueError('Sampling rate is expected to be 16kHz!')
    signals = slice_signal(wav_data, wav_canvas_size, stride)
    return signals

In [2]:
def slice_signal(signal, window_size, stride=0.5):
    """ Return windows of the given signal by sweeping in stride fractions
        of window
    """
    assert signal.ndim == 1, signal.ndim
    n_samples = signal.shape[0]
    # print("n_samples: ", n_samples)
    offset = int(window_size * stride)
    slices = []
    for beg_i in range(0, n_samples, offset):
        end_i = beg_i + offset
        slice_ = signal[beg_i:end_i]
        # pad the rest of the slice with 0
        if end_i - beg_i > n_samples - beg_i:
            slice_ = signal[beg_i:n_samples]
            test_l = np.zeros((offset-(n_samples-beg_i)))
            slice_ = np.concatenate((slice_,test_l),axis=None)

        if slice_.shape[0] == offset:
            slices.append(slice_)
    return np.array(slices, dtype=np.int32)

In [None]:
signals_clean = []
for file in p.glob('*.wav'):
    audio_serial = read_and_slice(file, 2 ** 14, stride=1)
    print(audio_serial.shape)
    signals_clean.append(audio_serial)

In [3]:
# From TFRecord Tutorial
# The following functions can be used to convert a value to a type compatible

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
print(len(signals_clean))

In [None]:
out_filepath = Path('.\Dataset\records')
out_file = tf.io.TFRecordWriter(out_filepath)

In [None]:
for wav in signals_clean:
    wav_raw = wav.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'wav_raw': _bytes_feature(wav_raw)}))
    out_file.write(example.SerializeToString())

In [None]:
def make_Record(path, file):
    # create the path if it does not exist
    if not os.path.exists(path):
        os.makedirs(path)
    name, _ = os.path.splitext(file) # extract the name of the file
    name =+ '.tfrecords'
    output_file = os.path.join(path, name) # add it to the working path
    if os.path.exists(out_filepath): # check if it already exists
        raise ValueError(f'ERROR: {output_file} already exists')
    out_file = tf.io.TFRecordWriter(out_filepath) # instantiate the recorder
    fm, wav_data = wavfile.read(file) # read the wav file
    if fm != 16000: # check it sampling
        raise ValueError('Sampling rate is expected to be 16kHz!')
    audio_serial = slice_signal(wav_data, 2 ** 14, stride=1) # transform it into a np.array
    example = tf.train.Example(features=tf.train.Features(feature={
            'wav_raw': _bytes_feature(audio_serial),

In [None]:
path = "./Dataset/clean"
#os.path.exists(path)
signals_clean = []
for file in  os.listdir(path):
    #print(file)
    name, _ = os.path.splitext(file)
    audio_serial = read_and_slice(file, 2 ** 14, stride=1)
    print(audio_serial.shape)
    signals_clean.append(audio_serial)

In [None]:
signals_clean = []
for file in p.glob('*.wav'):
    audio_serial = read_and_slice(file, 2 ** 14, stride=1)
    print(audio_serial.shape)
    signals_clean.append(audio_serial)