In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io.wavfile as wav
from scipy.signal import stft, istft

In [2]:
import tensorflow as tf

In [9]:
import os

sample_rate = 44100


def get_vocals_mask(Zxx, threshold=1):
    binary_mask = (np.abs(Zxx) > threshold).astype(float)
    return binary_mask

def get_ideal_binary_mask(Zxx_vocals, Zxx_mix, threshold=1):
    binary_mask = (np.abs(Zxx_vocals) > threshold * np.abs(Zxx_mix)).astype(float)
    return binary_mask

def get_inverse_mask(binary_mask):
    return (binary_mask == 0).astype(float)

def get_track_list():
    return os.listdir('data/musdb/musdb18hq/train')

def read_track(filename):
    directory = f'data/musdb/musdb18hq/train/{filename}'
    sample_rate, vocals = wav.read(f'{directory}/vocals.wav')
    sample_rate, mix = wav.read(f'{directory}/mixture.wav')

    return sample_rate, mix, vocals

def transform_track(signal, sample_rate, nperseg=2048, noverlap=2048 // 2):
    frequencies, times, Zxx_l = stft(signal[:, 0], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    frequencies, times, Zxx_r = stft(signal[:, 1], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')

    return frequencies, times, Zxx_l, Zxx_r

def plot_spectogram(times, frequencies, Zxx, segment=None, freq_limit=4000, log=False):
    plt.figure(figsize=(12, 6))
    Zxx_for_plot = np.abs(Zxx[:, segment[0]:segment[1]]) if segment else np.abs(Zxx)
    Zxx_for_plot = np.log10(Zxx_for_plot + 1e-6) if log else Zxx_for_plot
    times_for_plot = times[segment[0]:segment[1]] if segment else times
    plt.pcolormesh(times_for_plot, frequencies, Zxx_for_plot, shading='gouraud', cmap='Greys_r')
    plt.title('STFT Magnitude of Music Track')
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [sec]')
    plt.colorbar(label='Magnitude')
    plt.ylim([0, freq_limit])
    plt.show()

def get_window_data(Zxx, window_size=5):
    transposed_stft = Zxx.T
    zeros_for_window = np.zeros((window_size // 2, transposed_stft.shape[1]))
    windows = np.zeros((transposed_stft.shape[0], window_size, transposed_stft.shape[1]), dtype=np.float32)
    data_for_windows = np.concatenate([zeros_for_window, transposed_stft, zeros_for_window])
    
    for i in range(transposed_stft.shape[0]):
        windows[i] = data_for_windows[i:i+window_size]

    return windows


In [114]:
audio_files = [f'data/musdb/musdb18hq/train/{track}/mixture.wav' for track in os.listdir('data/musdb/musdb18hq/train')]
vocals_files = [f'data/musdb/musdb18hq/train/{track}/vocals.wav' for track in os.listdir('data/musdb/musdb18hq/train')]
audio_files.extend([f'data/ccmixter_corpus/{track}/mix.wav' for track in os.listdir('data/ccmixter_corpus')[:30]])
vocals_files.extend([f'data/ccmixter_corpus/{track}/source-02.wav' for track in os.listdir('data/ccmixter_corpus')[:30]])


def read_tf_file(file_path):
    file_path = file_path.numpy().decode("utf-8")
    sr, audio = wav.read(file_path)
    
    return audio.astype(np.float32)

def stft_transform(signal, sample_rate=44100, nperseg=2048, noverlap=2048 // 2):
    frequencies, times, Zxx_l = stft(signal[:, 0], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    frequencies, times, Zxx_r = stft(signal[:, 1], fs=sample_rate, nperseg=nperseg, noverlap=noverlap, window='hamming')
    return Zxx_l, Zxx_r

def process_audio(audio_file, vocal_file):
    audio = read_tf_file(audio_file)
    Zxx_l, Zxx_r = stft_transform(audio)
    vocals_audio = read_tf_file(vocal_file)
    Zxx_l_v, Zxx_r_v = stft_transform(vocals_audio)
    return tf.abs(Zxx_l), tf.abs(Zxx_r), tf.abs(Zxx_l_v), tf.abs(Zxx_r_v)

def get_ideal_mask_tf(stft, stft_vocals):
    mask = get_ideal_binary_mask(stft_vocals, stft)
    return tf.convert_to_tensor(mask, dtype=tf.float32)

def get_window_data_tf(stft, window_size=9):
    transposed_stft = tf.transpose(stft)
    zeros_for_window = np.zeros((window_size // 2, transposed_stft.shape[1]))
    windows = np.zeros((transposed_stft.shape[0], window_size, transposed_stft.shape[1]), dtype=np.float32)
    data_for_windows = np.concatenate([zeros_for_window, transposed_stft, zeros_for_window])
    
    for i in range(transposed_stft.shape[0]):
        windows[i] = data_for_windows[i:i+window_size]

    return tf.convert_to_tensor(windows, dtype=tf.float32)


def process_with_tf(audio_file, vocal_file):
    stft_left, stft_right, stft_left_vocals, stft_right_vocals = tf.py_function(
        func=process_audio, inp=[audio_file, vocal_file], Tout=(tf.float32, tf.float32, tf.float32, tf.float32)
    )
    mask_left = tf.py_function(
        func=get_ideal_mask_tf, inp=[stft_left, stft_left_vocals], Tout=tf.float32
    )
    mask_right = tf.py_function(
        func=get_ideal_mask_tf, inp=[stft_right, stft_right_vocals], Tout=tf.float32
    )
    left_windows = tf.py_function(
        func=get_window_data_tf, inp=[stft_left], Tout=tf.float32
    )
    right_windows = tf.py_function(
        func=get_window_data_tf, inp=[stft_right], Tout=tf.float32
    )
    
    left_ds = tf.data.Dataset.from_tensor_slices(left_windows)
    right_ds = tf.data.Dataset.from_tensor_slices(right_windows)
    left_ds_vocals = tf.data.Dataset.from_tensor_slices(tf.transpose(mask_left))
    right_ds_vocals = tf.data.Dataset.from_tensor_slices(tf.transpose(mask_right))

    return tf.data.Dataset.zip(left_ds.concatenate(right_ds), left_ds_vocals.concatenate(right_ds_vocals))

In [123]:
dataset = tf.data.Dataset.from_tensor_slices((audio_files, vocals_files))
dataset = dataset.flat_map(process_with_tf)
dataset = dataset.map(lambda x, y: (tf.ensure_shape(x, (9, 1025)), tf.ensure_shape(y, (1025, ))) )
dataset = dataset.batch(32)
dataset = dataset.map(lambda x, y: (tf.ensure_shape(x, (32, 9, 1025)), tf.ensure_shape(y, (32, 1025))))
dataset = dataset.prefetch(tf.data.AUTOTUNE)

In [116]:
for batch in dataset.take(1):
    x_sample, y_sample = batch
    print("Sample input shape (X):", x_sample.shape)
    print("Sample output shape (Y):", y_sample.shape)

Sample input shape (X): (9, 1025)
Sample output shape (Y): (1025,)


In [105]:
dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None))>

## Training model

In [36]:
for batch in dataset.take(1):
    x_sample, y_sample = batch
    print("Sample input shape (X):", x_sample.shape)
    print("Sample output shape (Y):", y_sample.shape)

Sample input shape (X): (32, 9, 1025)
Sample output shape (Y): (32, 1025)


In [10]:
from tensorflow.python.client import device_lib

def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]

print(get_available_devices())

['/device:CPU:0']


In [26]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, 
    Conv2D, 
    MaxPooling2D, 
    UpSampling2D, 
    concatenate, 
    GlobalAveragePooling2D, 
    Dense, 
    Flatten, 
    Reshape, 
    ZeroPadding2D
)

In [74]:
freq_bins = 1025
window_size = 9

def create_model():
    # Input
    inputs = Input(shape=(window_size, freq_bins))
    inputs = Reshape((window_size, freq_bins, 1))(inputs)

    # Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottleneck
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)

    # Decoder
    up5 = UpSampling2D(size=(2, 2))(conv4)
    up5 = concatenate([up5, conv3])
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(up5)
    conv5 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)

    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = concatenate([up6, conv2])
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = ZeroPadding2D(padding=((0, 1), (0, 1)))(up7)
    up7 = concatenate([up7, conv1])
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv7)

    # Output
    final_conv = Conv2D(1, (3, 3), activation='relu', padding='same')(conv7)
    pooled = GlobalAveragePooling2D()(final_conv)
    outputs = Dense(1025, activation='sigmoid')(pooled)

    model = Model(inputs=inputs, outputs=outputs, name="vocal_separator")
    return model

In [64]:
freq_bins = 1025
window_size = 9

def create_model2():
    inputs = tf.keras.Input(shape=(window_size, freq_bins), name="input_audio")
    inputs = tf.keras.layers.Reshape((9, 1025, 1))(inputs)
    
    x = tf.keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(256, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    outputs = tf.keras.layers.Dense(freq_bins, activation="linear")(x)
    
    model = tf.keras.Model(inputs, outputs, name="vocal_separator")
    return model

In [117]:
freq_bins = 1025
window_size = 9

def create_model3():
    inputs = tf.keras.Input(shape=(window_size, freq_bins), name="input_audio")
    
    x = tf.keras.layers.Flatten()(inputs)
    x = tf.keras.layers.Dense(256, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    
    outputs = tf.keras.layers.Dense(freq_bins, activation="sigmoid")(x)
    # outputs = tf.keras.layers.Reshape((freq_bins,))(outputs)
    
    model = tf.keras.Model(inputs, outputs, name="vocal_separator")
    return model

In [125]:
nn = create_model()

In [126]:
nn.summary()

In [119]:
nn.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss="categorical_crossentropy",
    metrics=['accuracy', "mse"]
)
epochs = 1

In [121]:
history = nn.fit(dataset, epochs=epochs)

ValueError: Creating variables on a non-first call to a function decorated with tf.function.

In [127]:
for i in dataset.unbatch().take(5).batch(1):
    print(nn.predict(i[0]))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 738ms/step
[[0.49917254 0.49966943 0.50107074 ... 0.50093186 0.50136745 0.50158906]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 235ms/step
[[0.49873325 0.49949393 0.5016391  ... 0.5014266  0.50209343 0.5024327 ]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 252ms/step
[[0.49821535 0.499287   0.5023093  ... 0.50200987 0.50294936 0.5034272 ]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 251ms/step
[[0.49761778 0.49904826 0.5030825  ... 0.5026828  0.5039369  0.5045748 ]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 255ms/step
[[0.49702236 0.49881035 0.503853   ... 0.50335336 0.50492084 0.5057181 ]]
