In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib notebook

from keras.layers import Input, Dense, Dropout,\
Conv2D, MaxPooling2D, Flatten, Activation, BatchNormalization, Conv2DTranspose, concatenate
from keras.models import Model
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

In [None]:
def double_conv_block(x, n_filters):
    x = Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    x = Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    return x

def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = MaxPooling2D(2)(f)
    p = Dropout(0.4)(p)
    return f, p

def upsample_block(x, conv_features, n_filters):
    x = Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    x = concatenate([x, conv_features])
    x = Dropout(0.4)(x)
    x = double_conv_block(x, n_filters)

    return x

def build_unet_model():    
    inputs = Input(shape=(80,48,1))
    init_n_filters = 16

    f1, p1 = downsample_block(inputs, init_n_filters)
    f2, p2 = downsample_block(p1, 2*init_n_filters)
    f3, p3 = downsample_block(p2, 4*init_n_filters)
    f4, p4 = downsample_block(p3, 4*init_n_filters)

    bottleneck = double_conv_block(p4, 8*init_n_filters)

    u6 = upsample_block(bottleneck, f4, 4*init_n_filters)
    u7 = upsample_block(u6, f3, 4*init_n_filters)
    u8 = upsample_block(u7, f2, 2*init_n_filters)
    u9 = upsample_block(u8, f1, init_n_filters)

    outputs = Conv2D(1, 1, padding="same", activation = "sigmoid")(u9)
    unet_model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return unet_model

unet_model = build_unet_model()
unet_model.compile(optimizer='adam',
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=['MSE'])
unet_model.summary()

Для обучения изменить пути до соответствующих частей датасета

In [None]:
path_clean = "C:/Users/a.aspidov/Desktop/attachments/train/clean"
path_noisy = "C:/Users/a.aspidov/Desktop/attachments/train/noisy"
path_val_clean = "C:/Users/a.aspidov/Desktop/attachments/val/val/clean/"
path_val_noisy = "C:/Users/a.aspidov/Desktop/attachments/val/val/noisy/"

In [None]:
def normalize_samples(samples_raw):
    mod = [s - s.min() for s in samples_raw]
    mod = [(s / s.max()).T for s in mod]
    
    min_len = 48
    normalized = []

    for s in mod:
        if s.shape[1] % min_len == 0:
            normalized += [s[:, min_len * i:min_len * (i + 1)] for i in range(s.shape[1] // min_len)]
        else:
            normalized += [s[:, min_len * i:min_len * (i + 1)] for i in range(s.shape[1] // min_len + 1)]
            normalized[-1] = np.pad(normalized[-1], ((0, 0), (0, min_len - s.shape[1] % min_len)), 'constant')
    return np.array(normalized)

def load_dataset_part(path_base):
    path_list = [os.path.join(path_base, speaker, file) for speaker in
                os.listdir(path_base) for file in
                os.listdir(os.path.join(path_base, speaker))]
    mel_list = [np.load(path) for path in path_list]
    return mel_list

x_train = load_dataset_part(path_noisy)
y_train = load_dataset_part(path_clean)
x_train = normalize_samples(x_train)
y_train = normalize_samples(y_train)

x_val = load_dataset_part(path_val_noisy)
y_val = load_dataset_part(path_val_clean)
x_val = normalize_samples(x_val)
y_val = normalize_samples(y_val)

assert x_train.shape == y_train.shape
assert x_val.shape == y_val.shape

x_train = x_train[..., np.newaxis]
y_train = y_train[..., np.newaxis]

x_val = x_val[..., np.newaxis]
y_val = y_val[..., np.newaxis]

In [None]:
n_epochs = 30
history = unet_model.fit(x_train, y_train, validation_data=(x_val, y_val), verbose=2, epochs=n_epochs)

In [None]:
mse = history.history['MSE']
val_mse = history.history['val_MSE']

loss = history.history['loss']
val_loss = history.history['val_loss']

In [None]:
plt.figure()
plt.style.use('ggplot')
plt.plot(range(n_epochs), loss, label='loss')
plt.plot(range(n_epochs), val_loss, label='val_loss')
plt.legend()

In [None]:
plt.figure()
plt.plot(range(n_epochs), mse, label='MSE')
plt.plot(range(n_epochs), val_mse, label='val_MSE')
plt.legend()

In [None]:
unet_model.save('model_22.h5')