In [11]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import random
from glob import glob
import librosa

import soundfile

from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import fashion_mnist

from tensorflow.keras.layers import Conv2D, Input, LeakyReLU, Flatten, Dense, Reshape, Conv2DTranspose, BatchNormalization, Activation
from tensorflow.keras import Model, Sequential


In [12]:
tf.test.is_built_with_cuda()
if len(tf.config.list_physical_devices('GPU'))==0:
    print("Training CPU")
else:
    print("Training GPU")

Training GPU


In [13]:
def load_mel(filepth="data/train.part1/noisy/25/25_88353_25-88353-0017.npy"):
    mel_spec=np.load(filepth).astype(np.float64)
    return mel_spec


def reconstruct_audio_from_mel(mel_spec, out='rec.flac'):
    sr=16000
    hop_length=2561 
    fmin=20
    fmax=8000

    mel_spec = np.exp((mel_spec - 1)*10).T
    y_inv = librosa.feature.inverse.mel_to_audio(M=mel_spec, sr=16000, n_fft=1024, hop_length=256, fmin=20, fmax=8000)
    soundfile.write(out, y_inv, samplerate=sr)


def show_mel_spectra(img_pth="data/train.part1/clean/31/31_121969_31-121969-0000.npy"):
    plt.figure(figsize=(20,6))
    mel_img=np.load(img_pth)
    mel_img = (mel_img-mel_img.mean()) / mel_img.std()
    plt.imshow(mel_img.astype(np.float64).T)
    print(mel_img.mean())

# Training

In [14]:
numFeatures = 80 # размер скользящего окна
numSegments = 8 # кочличество фурье-веторов для авторегрессии 

In [None]:
# TODO: load-ры исправить циклы 

In [15]:
class DenoisingDataGen(tf.keras.utils.Sequence):
    numSegments = numSegments

    def __init__(self, data_folders : list,
                 batch_size,
                 numSegments=numSegments,
                 shuffle=True):

        self.batch_size = batch_size
        self.numSegments = numSegments
        self.shuffle = shuffle

        self.file_paths = []
        for data_folder in data_folders:
            self.file_paths += list(zip(glob(f"{data_folder}/noisy/*/*.npy"), glob(f"{data_folder}/clean/*/*.npy")))
            random.shuffle(self.file_paths)

        self.n = len(self.file_paths)

    def on_epoch_end(self):
        if self.shuffle:
            random.shuffle(self.file_paths)

    @classmethod
    def __load_noisy(cls, path):
        mel_image = np.load(path)
        mel_image_segmented = []
        for i in range(cls.numSegments, len(mel_image) // cls.numSegments):
            segment = np.expand_dims(mel_image[i:i+cls.numSegments].T, axis=-1)
            mel_image_segmented.append(segment)
        return tf.convert_to_tensor(mel_image_segmented)
    
    @classmethod
    def __load_clean(cls, path):        
        mel_image = np.load(path)
        mel_image_segmented = []
        for i in range(cls.numSegments, len(mel_image) // cls.numSegments):
            segment = mel_image[i].T
            segment = np.expand_dims(segment, axis=-1)
            segment = np.expand_dims(segment, axis=-1)
            mel_image_segmented.append(segment)
        return tf.convert_to_tensor(mel_image_segmented)

    def __get_data(self, file_path_batches):
        X_batch = []
        y_batch = []

        for pth in file_path_batches:
            X_batch.extend(self.__load_noisy(pth[0]))
            y_batch.extend(self.__load_clean(pth[1]))

        X_batch = np.array(X_batch)
        y_batch = np.array(y_batch)

        return X_batch, y_batch

    def __getitem__(self, index):
        file_path_batches = self.file_paths[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__get_data(file_path_batches)        
        return X, y

    def __len__(self):
        return self.n // self.batch_size


In [16]:
traingen = DenoisingDataGen(data_folders=["data/train.part1"], batch_size=3)
valgen = DenoisingDataGen(data_folders=["data/val"], batch_size=3)

In [17]:
def build_model(l2_strength):
  inputs = Input(shape=[numFeatures, numSegments, 1])
  x = inputs

  # 1 -----
  x = tf.keras.layers.ZeroPadding2D(((4,4), (0,0)))(x)
  x = Conv2D(filters=18, kernel_size=[9,8], strides=[1, 1], padding='valid', use_bias=False,
             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  skip0 = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(skip0)
  x = BatchNormalization()(x)

  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  # 2 -----
  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  skip1 = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(skip1)
  x = BatchNormalization()(x)

  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  # 3 ----
  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)
  
  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  # 4 ----
  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = x + skip1
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  # 5 ----
  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,
             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = x + skip0
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,
              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)
  x = Activation('relu')(x)
  x = BatchNormalization()(x)

  # 6 ----
  x = tf.keras.layers.SpatialDropout2D(0.2)(x)
  x = Conv2D(filters=1, kernel_size=[129,1], strides=[1, 1], padding='same')(x)

  model = Model(inputs=inputs, outputs=x)

  optimizer = tf.keras.optimizers.Adam(3e-4)
  #optimizer = RAdam(total_steps=10000, warmup_proportion=0.1, min_lr=3e-4)

  model.compile(optimizer=optimizer, loss='mse', 
                metrics=[tf.keras.metrics.RootMeanSquaredError('rmse')])
  return model

In [18]:
model = build_model(l2_strength=0.0)

In [9]:
model.fit(traingen,
          validation_data=valgen,
          epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x1539466c250>

In [10]:
model.save_weights('my_checkpoint')

In [19]:
model.load_weights('my_checkpoint')


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x154a060e070>

# Testing

In [60]:
def preprocess_noisy(mel_image, numSegments=8):
    mel_image_segmented = []
    for i in range(0, len(mel_image)-numSegments):
        segment = np.expand_dims(mel_image[i:i+numSegments].T, axis=-1)
        mel_image_segmented.append(segment)
    return tf.convert_to_tensor(mel_image_segmented)

In [51]:
fpth_noisy = "data/train.part1/noisy/25/25_88353_25-88353-0017.npy"
mel_noisy = load_mel(filepth=fpth_noisy)
reconstruct_audio_from_mel(mel_noisy, "mel_noisy.flac")

fpth_clean = fpth_noisy.replace("noisy", "clean")
mel_clean = load_mel(filepth=fpth_clean)
reconstruct_audio_from_mel(mel_clean, "mel_clean.flac")


In [57]:
len(mel_noisy)

903

In [61]:
preprocessed_noisy = preprocess_noisy(mel_noisy)

In [63]:
preprocessed_noisy.shape

TensorShape([895, 80, 8, 1])

In [64]:
mel_filtered = np.squeeze(model.predict(preprocessed_noisy))



In [65]:
reconstruct_audio_from_mel(mel_filtered, "mel_filtered.flac")