In [8]:
!pip install musdb




In [9]:
import numpy as np
import librosa
import tensorflow as tf
import musdb
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Load MUSDB18 dataset
def load_musdb_data(sample_rate=44100):
    mus = musdb.DB(download=True, subsets='train')

    mixtures = []
    vocals = []
    drums = []
    bass = []
    others = []

    for track in mus:
        audio, sr = track.audio, track.rate
        if sr != sample_rate:
            audio = librosa.resample(audio, sr, sample_rate)

        mixtures.append(audio)
        vocals.append(track.targets['vocals'].audio)
        drums.append(track.targets['drums'].audio)
        bass.append(track.targets['bass'].audio)
        others.append(track.targets['other'].audio)

    mixtures = np.array(mixtures)
    vocals = np.array(vocals)
    drums = np.array(drums)
    bass = np.array(bass)
    others = np.array(others)

    return mixtures, vocals, drums, bass, others


In [10]:
def audio_to_spectrogram(audio, n_fft=1024, hop_length=512):
    # Check if the audio length is less than n_fft
    if len(audio) < n_fft:
        # Pad audio with zeros if it's too short
        audio = np.pad(audio, (0, n_fft - len(audio)), 'constant')

    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(stft), np.angle(stft)
    return magnitude, phase


def spectrogram_to_audio(magnitude, phase, hop_length=512):
    stft_reconstructed = magnitude * np.exp(1j * phase)
    audio_reconstructed = librosa.istft(stft_reconstructed, hop_length=hop_length)
    return audio_reconstructed


In [11]:
def preprocess_spectrograms(audio_list, n_fft=1024, hop_length=512):
    magnitudes = []
    phases = []

    for audio in audio_list:
        # Ensure the audio is at least as long as n_fft
        if len(audio) < n_fft:
            # Pad the audio to make sure it is long enough
            audio = np.pad(audio, (0, n_fft - len(audio)), 'constant')

        try:
            mag, ph = audio_to_spectrogram(audio, n_fft, hop_length)
        except Exception as e:
            print(f'Error processing audio: {e}')
            # Provide a default shape for cases where an exception occurs
            mag, ph = np.zeros((n_fft//2 + 1, (len(audio) // hop_length) + 1)), np.zeros((n_fft//2 + 1, (len(audio) // hop_length) + 1))

        magnitudes.append(mag)
        phases.append(ph)

    magnitudes = np.array(magnitudes)
    phases = np.array(phases)
    return magnitudes, phases


In [12]:
def unet_model(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)

    conv1 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')(inputs)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')(pool1)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')(pool2)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = tf.keras.layers.Conv2D(512, kernel_size=3, activation='relu', padding='same')(pool3)

    up5 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv4)
    up5 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')(up5)
    up5 = tf.keras.layers.Concatenate()([up5, conv3])

    up6 = tf.keras.layers.UpSampling2D(size=(2, 2))(up5)
    up6 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')(up6)
    up6 = tf.keras.layers.Concatenate()([up6, conv2])

    up7 = tf.keras.layers.UpSampling2D(size=(2, 2))(up6)
    up7 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')(up7)
    up7 = tf.keras.layers.Concatenate()([up7, conv1])

    outputs = tf.keras.layers.Conv2D(1, kernel_size=1, activation='sigmoid')(up7)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return model


In [13]:
!pip install keras



In [None]:
import keras
import keras.backend as K
import tensorflow as tf
from tensorflow.keras.utils import plot_model

# Load MUSDB18 dataset
mixtures, vocals, drums, bass, others = load_musdb_data()

# Preprocess audio data to obtain spectrograms
mixture_mags, mixture_phases = preprocess_spectrograms(mixtures)
vocals_mags, _ = preprocess_spectrograms(vocals)

# Split the dataset into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(mixture_mags, vocals_mags, test_size=0.2, random_state=42)

# Define the input shape for the model
input_shape = (mixture_mags.shape[1], mixture_mags.shape[2], 1)

# Create and compile the U-Net model
model = unet_model(input_shape)
model.compile(optimizer='adam', loss='mae')

plot_model(model,
           show_shapes = True,
           show_dtype=False,
           show_layer_names = True,
           rankdir = 'TB',
           expand_nested = False,
           dpi = 70)




In [None]:

# Train the U-Net model
history = model.fit(X_train[..., np.newaxis], y_train[..., np.newaxis],
                    validation_data=(X_val[..., np.newaxis], y_val[..., np.newaxis]),
                    epochs=50, batch_size=8)

In [None]:
def predict_and_save_audio(test_file_path, model, n_fft=1024, hop_length=512):
    test_mixture, sr = librosa.load(test_file_path, sr=None)
    test_mixture = librosa.resample(test_mixture, sr, 44100)

    if len(test_mixture) < n_fft:
        test_mixture = np.pad(test_mixture, (0, n_fft - len(test_mixture)), 'constant')

    test_mixture_mag, test_mixture_phase = audio_to_spectrogram(test_mixture, n_fft, hop_length)
    test_mixture_mag = np.expand_dims(test_mixture_mag, axis=(0, -1))

    predicted_vocal_mag = model.predict(test_mixture_mag)
    predicted_vocal_audio = spectrogram_to_audio(predicted_vocal_mag.squeeze(), test_mixture_phase)

    librosa.output.write_wav('predicted_vocal.wav', predicted_vocal_audio, sr=44100)

predict_and_save_audio('path_to_test_mixture.wav', model)


In [None]:
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
