In [1]:
from tensorflow.keras import layers, optimizers, Model, callbacks, initializers, utils
from tensorflow.keras.constraints import max_norm
from tensorflow import keras
import tensorflow.keras.backend as K
import tensorflow as tf
from pydub import AudioSegment
import matplotlib.pyplot as plt
import threading
import soundfile as sf
from scipy.io import wavfile
import numpy as np
from tqdm import tqdm
from pathlib import Path
from scipy import signal
import subprocess
import shutil
import time
import gc
import random
import os

DATA_PATH = r"..\\data"
TRAIN_DATASET_PATH = r"..\\train_dataset"
TEST_DATASET_PATH = r"..\\test_dataset"
VAL_SPLIT = 0.2

MAX_SAMPLE_FREQUENCY = 256_000

FRAME_SIZE = 4096
FRAME_OVERLAP = 2048

MAXIMUM_NUMBER_OF_NOISE_CHANNELS = 8
NOISE_APLITUDE_MAX = 4
NOISE_FREQ_MAX = 48_000

PRETRAIN_EPOCHS = 2
EPOCHS = 40
STEPES_PER_EPOCH = 5_000
BATCH_SIZE = 32

frame_shift = FRAME_SIZE - FRAME_OVERLAP

Loading data

In [2]:
def find_files(files, dirs=[], extensions=[]):
    new_dirs = []
    for d in dirs:
        if not os.path.exists(d): continue
        
        try:
            new_dirs += [ os.path.join(d, f) for f in os.listdir(d) ]
        except OSError:
            if os.path.splitext(d)[1] in extensions:
                files.append(d)

    if new_dirs:
        find_files(files, new_dirs, extensions)
    else:
        return

In [3]:
# Convert all mp3 files to wav
def convert_all_files():
    for file_path in file_paths:
        cleaned_name = file_path[:-4]

        if file_path[-4:] == ".mp3":
            new_name = cleaned_name + ".wav"

            sound = AudioSegment.from_file(file_path, format="mp3")
            sound.export(new_name, "wav")

            os.remove(file_path)
        elif file_path[-4:] == ".mp4":
            subprocess.Popen(f"ffmpeg -i {file_path} -vn -acodec pcm_s16le -ar 44100 -ac 2 {cleaned_name}.wav").wait()
            os.remove(file_path)

Process data

In [4]:
def normalization(samples:np.ndarray):
    smpl = samples.copy()
    max_abs_val = max(abs(smpl))

    # Normalization
    smpl /= max_abs_val
    return smpl

In [5]:
def noise_data(data, fs):
  sdata = data.copy()

  length_in_secs = sdata.size / fs
  time = np.linspace(0, length_in_secs, sdata.size, endpoint=False)

  for _ in range(random.randint(1, MAXIMUM_NUMBER_OF_NOISE_CHANNELS)):
    sdata += (random.random() * NOISE_APLITUDE_MAX) * np.cos(2 * np.pi * (random.random() * NOISE_FREQ_MAX) * time + (random.random() * 2 * np.pi))
  
  return sdata

In [6]:
def create_frames(data:np.ndarray):
  number_of_frames = data.size // frame_shift
  frames = [data[idx * frame_shift : FRAME_SIZE + idx * frame_shift] for idx in range(number_of_frames)]
  frames[len(frames) - 1] = np.pad(frames[len(frames) - 1], (0, FRAME_SIZE - frames[len(frames) - 1].shape[0]), "constant")
  return frames

In [7]:
def drawFFT(signal, frequency, idx=None):
  fft = np.fft.fft(signal)

  data_length = len(signal)
  freq = np.arange(data_length)/(data_length/frequency)

  half_length = data_length//2

  freq_one_side = freq[:half_length]
  fft = fft[:half_length]/half_length

  plt.figure(figsize=(18,8))
  plt.title(f"FFT {idx}" if idx is not None else "Generic FFT")
  plt.plot(freq_one_side, abs(fft))
  plt.xlabel('f[Hz]')
  plt.ylabel('Amplituda[-]')
  plt.tight_layout()
  plt.show()

In [8]:
def process_one_file(smpls, fs):
  # Split and normalize
  try:
    if smpls.shape[1] > 0:
      normalized_samples = []

      for i in range(smpls.shape[1]):
        normalized_samples.extend(normalization(smpls[:, i]))

      normalized_samples = np.array(normalized_samples)
    else:
      normalized_samples = normalization(smpls)
  except Exception as e:
    print(f"Exception on spliting data\n{e}")
    normalized_samples = normalization(smpls)

  # Create noised data and normalize them
  noised_samples = noise_data(normalized_samples, fs)
  noised_samples = normalization(noised_samples)

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(normalized_samples.size) / fs, normalized_samples)
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(noised_samples.size) / fs, noised_samples)
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  # Power density
  # f1, Pxx1 = signal.periodogram(normalized_samples, fs, scaling="density")
  # print(f1.shape)
  # print(Pxx1.shape)

  # plt.semilogy(f1, Pxx1)
  # plt.title("original")
  # plt.xlabel('f[Hz]')
  # plt.ylabel('PSD [V**2/Hz]')
  # plt.show()

  # f2, Pxx2 = signal.periodogram(noised_samples, fs, scaling="density")
  # print(f2.shape)
  # print(Pxx2.shape)

  # plt.semilogy(f2, Pxx2)
  # plt.title("noised")
  # plt.xlabel('f[Hz]')
  # plt.ylabel('PSD [V**2/Hz]')
  # plt.show()

  # FFT
  # normal_fft = np.fft.fft(normalized_samples)
  # noise_fft = np.fft.fft(noised_samples)

  # print(normal_fft.shape)
  # print(noise_fft.shape)

  # data_length = len(normal_fft)
  # freq = np.arange(data_length)/(data_length/fs)

  # half_length = data_length//2

  # freq_one_side = freq[:half_length]
  # normal_fft = normal_fft[:half_length]/half_length
  # noise_fft = noise_fft[:half_length]/half_length

  # plt.figure(figsize=(18,8))
  # plt.title(f"Normal")
  # plt.plot(freq_one_side, abs(normal_fft))
  # plt.xlabel('f[Hz]')
  # plt.ylabel('Amplituda[-]')
  # plt.tight_layout()
  # plt.show()

  # plt.figure(figsize=(18,8))
  # plt.title(f"Noise")
  # plt.plot(freq_one_side, abs(noise_fft))
  # plt.xlabel('f[Hz]')
  # plt.ylabel('Amplituda[-]')
  # plt.tight_layout()
  # plt.show()

  # Set signals to be between 0 and 1
  normalized_samples += 1
  noised_samples += 1
  normalized_samples = normalization(normalized_samples)
  noised_samples = normalization(noised_samples)

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(normalized_samples.size) / fs, normalized_samples)
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(noised_samples.size) / fs, noised_samples)
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  # Create frames
  normalized_samples = create_frames(normalized_samples)
  noised_samples = create_frames(noised_samples)

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(normalized_samples[50].size) / fs, normalized_samples[50])
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  # plt.figure(figsize=(18,8))
  # plt.plot(np.arange(noised_samples[50].size) / fs, noised_samples[50])
  # plt.gca().set_xlabel('$t[s]$')
  # plt.gca().set_ylabel('$Amplituda[-]$')
  # plt.show()

  return normalized_samples, noised_samples

Process data

In [9]:
def move_file(src, dest):
  name = Path(src).name
  shutil.move(src, os.path.join(dest, name))

def move_files(file_paths, target_path):
  for file_path in file_paths:
    move_file(file_path, target_path)

file_paths = []
find_files(file_paths, dirs=[DATA_PATH], extensions=[".mp3", ".mp4"])
print("Files to convert")
print(len(file_paths))

convert_all_files()
    
already_used_filenames = []
find_files(already_used_filenames, dirs=[TEST_DATASET_PATH, TRAIN_DATASET_PATH], extensions=[".npy"])

if not os.path.exists("tmp_dataset"):
  os.mkdir("tmp_dataset")

  file_paths = []
  find_files(file_paths, dirs=[DATA_PATH], extensions=[".wav"])

  print("Numbe of source files")
  print(len(file_paths))

  for file_path in tqdm(file_paths):
    cleaned_name = Path(file_path).name[:-4]

    if (any([cleaned_name in file_name for file_name in already_used_filenames]) or
        any([cleaned_name in file_name for file_name in os.listdir("tmp_dataset")])):
      print(f"Skipping {cleaned_name}")
      continue

    smpls, f = sf.read(file_path)
    gc.collect()

    print(f"Processing {cleaned_name}")
        
    random.seed()
    norm_s, nois_s = process_one_file(smpls, f)

    for idx, (a, b) in enumerate(zip(norm_s, nois_s)):
      if not os.path.exists(f"tmp_dataset/{idx}_{cleaned_name}_{f}"):
        np.save(f"tmp_dataset/{idx}_{cleaned_name}_{f}", np.array([a, b, np.fft.fft(b)]))
    gc.collect()

file_paths = []
find_files(file_paths, dirs=["tmp_dataset"], extensions=[".npy"])
number_of_files = len(file_paths)

print("Files to sort")
print(number_of_files)

if number_of_files > 0:
  random.shuffle(file_paths)

  if not os.path.exists(TEST_DATASET_PATH):
    os.mkdir(TEST_DATASET_PATH)

  if not os.path.exists(TRAIN_DATASET_PATH):
    os.mkdir(TRAIN_DATASET_PATH)

  print("Moving files")
  valid_file_count = int(number_of_files * VAL_SPLIT)
  move_files(file_paths[:valid_file_count], TEST_DATASET_PATH)
  move_files(file_paths[valid_file_count:], TRAIN_DATASET_PATH)

shutil.rmtree("tmp_dataset")

Files to convert
0
Files to sort
73745
Moving files


Data generator

In [12]:
def convert_imag_to_parts(array:np.ndarray):
    tmp = np.empty((array.shape[0], 2))

    for idx, a in enumerate(array):
        tmp[idx][0] = a.real
        tmp[idx][1] = a.imag

    return np.nan_to_num(tmp)

class DataGenerator(keras.utils.Sequence, threading.Thread):
    def __init__(self, path, dim, batch_size=32, shuffle=True):
        super(DataGenerator, self).__init__()

        self.dim = dim
        self.files = []
        find_files(self.files, dirs=[path], extensions=[".npy"])

        self.batch_size = batch_size
        self.shuffle = shuffle

        self.queue = []

        self.length = int(np.floor(len(self.files) / self.batch_size))
        self.index = 0
        self.daemon = True
        self.__terminate = False;
        self.shuffle_data()
        self.start()

    def __del__(self):
        self.__terminate = True

    def __len__(self):
        return self.length

    def run(self) -> None:
        while True:
            if self.__terminate: break

            if len(self.queue) < 20:
                if (self.index + 1) >= self.length:
                    self.shuffle_data()
                    self.index = 0

                files = self.files[self.index*self.batch_size:(self.index+1)*self.batch_size]
                self.index += 1

                # Generate data
                data = self.__data_generation(files)
                if data is not None:
                    self.queue.append(data)
            else: time.sleep(0.1)

    def __getitem__(self, _):
        while len(self.queue) == 0: time.sleep(0.01)
        return self.queue.pop()

    def shuffle_data(self):
        if self.shuffle == True:
            np.random.shuffle(self.files)

    def on_epoch_end(self):
        pass

    def __data_generation(self, files):
        # Initialization
        X = np.empty((self.batch_size, *self.dim))
        fft = np.empty((self.batch_size, self.dim[0], 2))
        f = np.empty((self.batch_size), dtype=float)
        y = np.empty((self.batch_size, *self.dim))

        # Generate data
        for idx, file in enumerate(files):
            freq = float(file[:-4].split("_")[-1])
            loaded_data = np.load(file)

            try:
                X[idx,] = loaded_data[1].real
                fft[idx,] = convert_imag_to_parts(loaded_data[2])
                f[idx] = freq / MAX_SAMPLE_FREQUENCY
                y[idx,] = loaded_data[0].real
            except Exception:
                return None

        return [X, fft, f], y

Model

In [10]:
max_norm_value = 2.0

def create_model():
  inp1 = layers.Input(shape=(FRAME_SIZE,1), name="frame_input")
  inp2 = layers.Input(shape=(FRAME_SIZE,2), name="fft_input")
  inp3 = layers.Input(shape=(1,), name="frequency_input")

  y = layers.Dense(FRAME_SIZE)(inp3)
  y = layers.LeakyReLU(0.2)(y)
  y = layers.Dropout(0.1)(y)
  y = layers.Reshape((FRAME_SIZE, 1))(y)

  x = layers.concatenate([inp1, inp2, y])

  # Encoder
  x = layers.Conv1D(256, kernel_size=3, strides=2, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x1 = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(256, kernel_size=3, strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(128, kernel_size=3, strides=2, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x1)
  x2 = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(128, kernel_size=3, strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(64, kernel_size=3, strides=2, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x2)
  x3 = layers.LeakyReLU(0.2)(x)
  
  # Encoded value

  # Edecoder
  x = layers.Conv1D(64, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x3)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Add()([x, x3])

  x = layers.Conv1D(64, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.UpSampling1D(2)(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(128, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Add()([x, x2])

  x = layers.Conv1D(128, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.UpSampling1D(2)(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Conv1D(256, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.LeakyReLU(0.2)(x)

  x = layers.Add()([x, x1])

  x = layers.Conv1D(256, kernel_size=3, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer=initializers.RandomNormal(stddev=0.02))(x)
  x = layers.UpSampling1D(2)(x)
  x = layers.LeakyReLU(0.2)(x)

  # Output
  x = layers.Conv1D(1, kernel_size=1, kernel_constraint=max_norm(max_norm_value), activation='sigmoid', padding='same')(x)
  return Model([inp1, inp2, inp3], x)

In [13]:
train_data_generator = DataGenerator(TRAIN_DATASET_PATH, (FRAME_SIZE,), BATCH_SIZE)
test_data_generator = DataGenerator(TEST_DATASET_PATH, (FRAME_SIZE,), BATCH_SIZE)

print(len(train_data_generator))
print(len(test_data_generator))

33389
8347


Training

In [14]:
def noiseToSignalLoss(y_true, y_pred):
    losses = tf.math.divide(tf.math.reduce_sum(tf.math.pow(tf.math.abs(tf.math.subtract(y_true,y_pred)),2)),
                            tf.math.reduce_sum(tf.math.pow(tf.math.abs(y_true),2)))
    return tf.reduce_mean(losses)

In [15]:
def SNR(y_true, y_pred):
  return -10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0)

In [16]:
model = create_model()
model.summary()
utils.plot_model(model, show_shapes=True, expand_nested=True, dpi=312)
print(model.input_shape)

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 frequency_input (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 dense (Dense)                  (None, 4096)         8192        ['frequency_input[0][0]']        
                                                                                                  
 leaky_re_lu (LeakyReLU)        (None, 4096)         0           ['dense[0][0]']                  
                                                                                                  
 dropout (Dropout)              (None, 4096)         0           ['leaky_re_lu[0][0]']            
                                                                                              

In [17]:
model.compile(optimizer=optimizers.Adam(1e-4, 0.5), loss=noiseToSignalLoss, metrics=[SNR])
# model.compile(optimizer=optimizers.Adam(learning_rate=1e-4), loss="mse", metrics=[SNR])

In [18]:
number_of_batches = len(train_data_generator) // STEPES_PER_EPOCH
checkpoint = callbacks.ModelCheckpoint("best_weights", monitor="val_SNR", verbose=1, save_best_only=True, mode="max")

try:
  if os.path.exists("checkpoint.h5"):
    print("Loading checkpoint")
    model.load_weights("checkpoint.h5")
  else:
    print("Starting pretrain")
    model.fit(train_data_generator, steps_per_epoch=STEPES_PER_EPOCH, epochs=PRETRAIN_EPOCHS)
    model.save_weights("checkpoint.h5")

  print("Starting training")
  model.fit(train_data_generator, steps_per_epoch=STEPES_PER_EPOCH, epochs=EPOCHS * number_of_batches,
            callbacks=[callbacks.ReduceLROnPlateau(monitor='val_loss',patience=5,factor=0.7,verbose=1), callbacks.TensorBoard("logs"), checkpoint],
            validation_data=test_data_generator, validation_steps=(STEPES_PER_EPOCH // 4) if STEPES_PER_EPOCH is not None else None)
except KeyboardInterrupt:
  model.save_weights("checkpoint.h5")

Loading checkpoint
Starting training


  layer_config = serialize_layer_fn(layer)


Epoch 1/240
Epoch 00001: val_SNR improved from -inf to 29.65762, saving model to best_weights
INFO:tensorflow:Assets written to: best_weights\assets
Epoch 2/240

  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)



Epoch 00002: val_SNR improved from 29.65762 to 30.34129, saving model to best_weights
INFO:tensorflow:Assets written to: best_weights\assets
Epoch 3/240


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)




In [None]:
if os.path.exists("best_weights"): model.load_weights("best_weights")

Predictions

In [None]:
def denoise(sample, fs):
  number_of_frames = sample.size // FRAME_SIZE
  frames = [sample[idx * FRAME_SIZE : FRAME_SIZE + idx * FRAME_SIZE] for idx in range(number_of_frames)]

  frames = np.array(frames)

  fft = np.array([convert_imag_to_parts(np.fft.fft(frame)) for frame in frames])

  output_frames = model.predict([frames, fft, np.ones((frames.shape[0], 1)) * fs])

  final_frame = None
  
  for frame in output_frames:
    frame = np.reshape(frame, (FRAME_SIZE,))

    if final_frame is None:
      final_frame = frame
    else:
      final_frame = np.concatenate([final_frame, frame], axis=0)

  return final_frame

In [None]:
samples_orig, sample_freq = sf.read("../audio/xdousa00.wav")

samples_normal = normalization(samples_orig) + 1
samples_normal = normalization(samples_normal)

plt.figure(figsize=(18,8))
plt.title("Normalizovaný vstupní signál")
plt.plot(np.arange(samples_normal.size) / sample_freq, samples_normal)
plt.gca().set_xlabel('$t[s]$')
plt.gca().set_ylabel('$Amplituda[-]$')
plt.show()

cleared_signal = denoise(samples_normal, sample_freq)
norm_cleared_signal = normalization(cleared_signal)

print(norm_cleared_signal.shape)

plt.figure(figsize=(18,8))
plt.title("Vyčištěný signál")
plt.plot(np.arange(norm_cleared_signal.size) / sample_freq, norm_cleared_signal)
plt.gca().set_xlabel('$t[s]$')
plt.gca().set_ylabel('$Amplituda[-]$')
plt.show()

wavfile.write("clean_test.wav", sample_freq, norm_cleared_signal)

Cleanup

In [None]:
if input("Are you sure you want to delete datasets?\n") == "y":
  shutil.rmtree(TRAIN_DATASET_PATH)
  shutil.rmtree(TEST_DATASET_PATH)