In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io.wavfile as wav
from scipy.signal import stft, istft

In [2]:
import tensorflow as tf

In [10]:
import os

sample_rate = 44100


def get_vocals_mask(Zxx, threshold=1):
    binary_mask = (np.abs(Zxx) > threshold).astype(float)
    return binary_mask

def get_ideal_binary_mask(Zxx_vocals, Zxx_mix, threshold=1):
    binary_mask = (np.abs(Zxx_vocals) > threshold * np.abs(Zxx_mix)).astype(float)
    return binary_mask

def get_inverse_mask(binary_mask):
    return (binary_mask == 0).astype(float)

def get_track_list():
    return os.listdir('data/musdb/musdb18hq/train')

def read_track(filename):
    directory = f'data/musdb/musdb18hq/train/{filename}'
    sample_rate, vocals = wav.read(f'{directory}/vocals.wav')
    sample_rate, mix = wav.read(f'{directory}/mixture.wav')

    return sample_rate, mix, vocals

def transform_track(signal, sample_rate, nperseg=2048, noverlap=2048 // 2):
    frequencies, times, Zxx_l = stft(signal[:, 0], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    frequencies, times, Zxx_r = stft(signal[:, 1], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')

    return frequencies, times, Zxx_l, Zxx_r

def plot_spectogram(times, frequencies, Zxx, segment=None, freq_limit=4000, log=False):
    plt.figure(figsize=(12, 6))
    Zxx_for_plot = np.abs(Zxx[:, segment[0]:segment[1]]) if segment else np.abs(Zxx)
    Zxx_for_plot = np.log10(Zxx_for_plot + 1e-6) if log else Zxx_for_plot
    times_for_plot = times[segment[0]:segment[1]] if segment else times
    plt.pcolormesh(times_for_plot, frequencies, Zxx_for_plot, shading='gouraud', cmap='Greys_r')
    plt.title('STFT Magnitude of Music Track')
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [sec]')
    plt.colorbar(label='Magnitude')
    plt.ylim([0, freq_limit])
    plt.show()

def get_window_data(Zxx, window_size=5):
    transposed_stft = Zxx.T
    zeros_for_window = np.zeros((window_size // 2, transposed_stft.shape[1]))
    windows = np.zeros((transposed_stft.shape[0], window_size, transposed_stft.shape[1]), dtype=np.float32)
    data_for_windows = np.concatenate([zeros_for_window, transposed_stft, zeros_for_window])
    
    for i in range(transposed_stft.shape[0]):
        windows[i] = data_for_windows[i:i+window_size]

    return windows


In [164]:
audio_files = [f'data/musdb/musdb18hq/train/{track}/mixture.wav' for track in os.listdir('data/musdb/musdb18hq/train')]
vocals_files = [f'data/musdb/musdb18hq/train/{track}/vocals.wav' for track in os.listdir('data/musdb/musdb18hq/train')]
audio_files.extend([f'data/ccmixter_corpus/{track}/mix.wav' for track in os.listdir('data/ccmixter_corpus')[:30]])
vocals_files.extend([f'data/ccmixter_corpus/{track}/source-02.wav' for track in os.listdir('data/ccmixter_corpus')[:30]])


def read_tf_file(file_path):
    file_path = file_path.numpy().decode("utf-8")
    sr, audio = wav.read(file_path)
    
    return audio.astype(np.float32)

def stft_transform(signal, sample_rate=44100, nperseg=2048, noverlap=2048 // 2):
    frequencies, times, Zxx_l = stft(signal[:, 0], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    frequencies, times, Zxx_r = stft(signal[:, 1], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    return Zxx_l, Zxx_r

def process_audio(audio_file, vocal_file):
    audio = read_tf_file(audio_file)
    Zxx_l, Zxx_r = stft_transform(audio)
    vocals_audio = read_tf_file(vocal_file)
    Zxx_l_v, Zxx_r_v = stft_transform(vocals_audio)
    return tf.abs(Zxx_l), tf.abs(Zxx_r), tf.abs(Zxx_l_v), tf.abs(Zxx_r_v)

def get_ideal_mask_tf(stft, stft_vocals):
    mask = get_ideal_binary_mask(stft_vocals, stft)
    return tf.convert_to_tensor(mask, dtype=tf.float32)

def get_window_data_tf(stft, window_size=9):
    transposed_stft = tf.transpose(stft)
    zeros_for_window = np.zeros((window_size // 2, transposed_stft.shape[1]))
    windows = np.zeros((transposed_stft.shape[0], window_size, transposed_stft.shape[1]), dtype=np.float32)
    data_for_windows = np.concatenate([zeros_for_window, transposed_stft, zeros_for_window])
    
    for i in range(transposed_stft.shape[0]):
        windows[i] = data_for_windows[i:i+window_size]

    return tf.convert_to_tensor(windows, dtype=tf.float32)


def process_with_tf(audio_file, vocal_file):
    stft_left, stft_right, stft_left_vocals, stft_right_vocals = tf.py_function(
        func=process_audio, inp=[audio_file, vocal_file], Tout=(tf.float32, tf.float32, tf.float32, tf.float32)
    )
    mask_left = tf.py_function(
        func=get_ideal_mask_tf, inp=[stft_left, stft_left_vocals], Tout=tf.float32
    )
    mask_right = tf.py_function(
        func=get_ideal_mask_tf, inp=[stft_right, stft_right_vocals], Tout=tf.float32
    )
    left_windows = tf.py_function(
        func=get_window_data_tf, inp=[stft_left], Tout=tf.float32
    )
    right_windows = tf.py_function(
        func=get_window_data_tf, inp=[stft_right], Tout=tf.float32
    )
    
    left_ds = tf.data.Dataset.from_tensor_slices(left_windows)
    right_ds = tf.data.Dataset.from_tensor_slices(right_windows)
    left_ds_vocals = tf.data.Dataset.from_tensor_slices(tf.transpose(mask_left))
    right_ds_vocals = tf.data.Dataset.from_tensor_slices(tf.transpose(mask_right))

    return tf.data.Dataset.zip(left_ds.concatenate(right_ds), left_ds_vocals.concatenate(right_ds_vocals))

In [165]:
dataset = tf.data.Dataset.from_tensor_slices((audio_files, vocals_files))
dataset = dataset.flat_map(process_with_tf)
dataset = dataset.map(lambda x, y: (tf.ensure_shape(x, (9, 1025)), tf.ensure_shape(y, (1025, ))) )
dataset = dataset.batch(32)
dataset = dataset.map(lambda x, y: (tf.ensure_shape(x, (32, 9, 1025)), tf.ensure_shape(y, (32, 1025))))
dataset = dataset.prefetch(tf.data.AUTOTUNE)

In [166]:
for batch in dataset.take(1):
    x_sample, y_sample = batch
    print("Sample input shape (X):", x_sample.shape)  # (batch_size, window_size, freq_bins, 1)
    print("Sample output shape (Y):", y_sample.shape)  # (batch_size, freq_bins, 1)

Sample input shape (X): (32, 9, 1025)
Sample output shape (Y): (32, 1025)


In [167]:
dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(32, 9, 1025), dtype=tf.float32, name=None), TensorSpec(shape=(32, 1025), dtype=tf.float32, name=None))>

## Training model

In [168]:
for batch in dataset.take(1):
    x_sample, y_sample = batch
    print("Sample input shape (X):", x_sample.shape)  # (batch_size, window_size, freq_bins, 1)
    print("Sample output shape (Y):", y_sample.shape)  # (batch_size, freq_bins, 1)

Sample input shape (X): (32, 9, 1025)
Sample output shape (Y): (32, 1025)


In [169]:
import tensorflow as tf
from tensorflow.keras import layers

In [170]:
freq_bins = 1025
window_size = 9

def create_model():
    inputs = tf.keras.Input(shape=(window_size, freq_bins), name="input_audio")
    inputs = tf.keras.layers.Reshape((9, 1025, 1))(inputs)  # Теперь данные имеют форму (9, 1025, 1)
    
    x = tf.keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(256, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    outputs = tf.keras.layers.Dense(freq_bins, activation="linear")(x)
    outputs = tf.keras.layers.Reshape((freq_bins, 1))(outputs)
    
    model = tf.keras.Model(inputs, outputs, name="vocal_separator")
    return model

In [171]:
nn = create_model()

In [172]:
nn.summary()

In [175]:
nn.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss="categorical_crossentropy",
    metrics=["mse"]
)
epochs = 1

In [176]:
history = nn.fit(dataset, epochs=epochs)

    522/Unknown [1m241s[0m 454ms/step - loss: 1281.5637 - mse: 4025449728.0000

KeyboardInterrupt: 