In [None]:
# requirements
!pip install awscli;
!pip install einops;
!pip install museval;
!pip install fast_bss_eval;
!pip install torch_audiomentations;

Notes:

Please upload all the .py files from the project before proceeding.

In order to test the entire dataset, run the commands to install musdb18hq.zip

Load the variant(s) of the model you want to test using load_model and the checkpoints.

All the necessary code is provided below.

In order to test the file shown in the paper, download the mixture.wav and vocals.wav from the project repository.

In [None]:
# import needed modules and functions
import torch
import torchaudio
from IPython.display import Audio
import tqdm.notebook as tq
from model import Model, ChromaModel
from trainer import load_model
from trainer import hparams_def, hparams_chroma
from eval import eval_dir, one_song_from_filepath
import numpy as np
from google.colab import drive, files

In [None]:
!unzip musdb18hq.zip

In [None]:
!aws configure set aws_access_key_id
!aws configure set aws_secret_access_key
!aws configure set default.region us-east-2
!aws s3 cp s3://mymusicdatasets/chroma_attention_checkpoint_latest.pt ./
!aws s3 cp s3://mymusicdatasets/chroma_fc_group_checkpoint_latest.pt ./
!aws s3 cp s3://mymusicdatasets/checkpt_latest.pt ./


In [None]:
# to install the dataset !!!!! 27GB !!!!! it might not show progress as it downloads.
!aws s3 cp s3://mymusicdatasets/musdb18hq.zip ./
!unzip musdb18hq.zip

In [None]:
# load bsrnn model - load in separate notebook to load bsrnn vs chroma models since the file names are the same and Colab does weird things with imports
model = Model(hparams_def).eval()
load_model('checkpt_latest.pt', model)

In [None]:
# load chroma attention bsrnn
chroma_attention = ChromaModel(hparams_chroma, 'attention').eval()
load_model('chroma_attention_checkpoint_latest.pt', chroma_attention)

In [None]:
chroma_fc = ChromaModel(hparams_chroma, 'group_fc').eval()
load_model('chroma_fc_group_checkpoint_latest.pt', chroma_fc)

In [None]:
start_in_seconds = 45
length_in_seconds = 45
filepath = '/content/test/Mu - Too Bright/mixture.wav'
source_path = '/content/test/Mu - Too Bright/vocals.wav'

In [None]:
import matplotlib.pyplot as plt
import librosa
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or 'Spectrogram (db)')
    axs.set_ylabel(ylabel)
    axs.set_xlabel('frame')
    im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
    if xmax:
      axs.set_xlim((0, xmax))
    fig.colorbar(im, ax=axs)
    plt.show(block=False)

In [None]:
import torch
import torchaudio
import numpy as np
import os
import fast_bss_eval
import museval
import tqdm.notebook as tq
def zero_pad(signal, segment_samples = 44100 * 6):
    # assumption: even number of samples in a segment
    hop_length = segment_samples // 2
    if signal.shape[1] % hop_length != 0:
        num_zeros = hop_length - (signal.shape[1] % hop_length)
        zero_pad = torch.zeros(signal.shape[0], num_zeros, device = signal.device)
        signal = torch.cat([signal, zero_pad], dim = 1)
    return signal

def split_to_segments(signal, segment_samples = 44100 * 6, overlap = 0.5):
    # input shape: (#channel, samples)
    # output shape: (#channels, #segments, segment_samples)
    start, end = 0, segment_samples
    segments = []
    hop_length = int(segment_samples * (1 - overlap))
    while end <= signal.shape[1]:
        segment = signal[:, start:end]
        start = start + hop_length
        end = end + hop_length
        segments.append(segment)
    return torch.stack(segments, dim = 1)

def predict_and_overlap(model, full_signal, segment_samples = 6 * 44100, overlap = 0.5, n_fft = 2048, hop_length = 512, win_length = 2048, show_progress = True):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.eval().to(device)
    full_signal = full_signal.to(device)
    length = full_signal.shape[1]
    full_signal = zero_pad(full_signal)
    segments = split_to_segments(full_signal)
    num_channels, num_segments, segment_samples = segments.shape
    output = torch.zeros_like(full_signal)
    hop_samples = int(segment_samples * (1 - overlap))
    start, end = 0, segment_samples
    with torch.no_grad():
        with tq.trange(num_segments, desc="Segment ") as segments_tq:
            for num_segment in segments_tq:
                input_segment = segments[:,num_segment,:].unsqueeze(0)
                mask_prediction, input_spectrogram = model(input_segment)
                stft_prediction = (mask_prediction * input_spectrogram)
                stft_prediction = torch.complex(stft_prediction[0,:,:,:,0], stft_prediction[0,:,:,:,1])
                signal_prediction = torch.istft(stft_prediction, n_fft = n_fft, hop_length = hop_length, win_length = win_length, length = segment_samples)
                output[:,start:end] += signal_prediction
                start += hop_samples
                end += hop_samples
    return output, full_signal

def eval_dir(model, dir_path = 'test/', out_path = 'outputs/', full_test_mode = False):
    filenames = os.listdir(dir_path)
    cSDRs = []
    for filename in filenames:
        print("processing " + filename)
        mixture, sr1 = torchaudio.load(dir_path + filename + '/mixture.wav')
        source, sr2 = torchaudio.load(dir_path + filename + '/vocals.wav')
        assert(sr1 == sr2 and sr1 == 44100)
        estimate = predict_and_overlap(model, mixture)
        torchaudio.save(out_path + filename + '_vocal.wav', estimate.cpu().detach(), 44100, channels_first = True,)
        if full_test_mode:
            sdr, isr, sir, sar  = museval.evaluate(source.unsqueeze(0).permute(0, 2, 1).cpu().numpy(), estimate.unsqueeze(0).permute(0, 2, 1).cpu().numpy())
        cSDR = np.mean(np.nanmedian(sdr, axis = 1))
        print("cSDR: ", cSDR)
        cSDRs.append(cSDR)
    return cSDRs


def one_song_from_filepath(filepath, model, source_path = "", offset = 0.0, length = -1, sampling_rate = 44100, force_mono = True, eval = False, full_eval_mode = False, plot_spectrograms = False):
    signal, sr = torchaudio.load(filepath, frame_offset = int(offset * sampling_rate), num_frames = int(length * sampling_rate) if length != -1 else -1)
    scores = []
    if force_mono:
        signal = torch.mean(signal, dim = 0).unsqueeze(0)
    estimate, signal = predict_and_overlap(model, signal)
    torchaudio.save(filepath[:-3] + '_vocals_pred.wav', estimate.cpu().detach(), 44100, channels_first = True,)
    if eval and source_path != "":
        source, sr = torchaudio.load(source_path, frame_offset = int(offset * sampling_rate), num_frames = int(length * sampling_rate) if length != -1 else -1)
        if force_mono:
            source = torch.mean(source, dim = 0).unsqueeze(0)
        source = zero_pad(source).to(signal.device)
        ref_chunks = torch.stack(torch.chunk(source, source.shape[1] // sampling_rate, dim = 1), dim = 0)
        est_chunks = torch.stack(torch.chunk(estimate, estimate.shape[1] // sampling_rate, dim = 1), dim = 0)
        csdr = fast_bss_eval.sdr(ref_chunks, est_chunks, use_cg_iter = 20, clamp_db = 30, load_diag = 1e-5)
        csdr = torch.nanmedian(csdr)
        scores.append(csdr)
        if full_eval_mode:
            sdr, isr, sir, sar  = museval.evaluate(source.unsqueeze(0).permute(0, 2, 1).cpu().numpy(), estimate.unsqueeze(0).permute(0, 2, 1).cpu().numpy())
            cSDR = np.mean(np.nanmedian(sdr, axis = 1))
            scores.append(cSDR)
        if plot_spectrograms:
            mixture_mag_stft = torch.abs(torch.stft(signal.squeeze(0), n_fft = 2048, hop_length = 512, return_complex = True))
            source_mag_stft = torch.abs(torch.stft(source.squeeze(0), n_fft = 2048, hop_length = 512, return_complex = True))
            estimate_mag_stft = torch.abs(torch.stft(estimate.squeeze(0), n_fft = 2048, hop_length = 512, return_complex = True))
            plot_spectrogram(mixture_mag_stft.cpu(), title = "Mixture spectrogram (dB)")
            plot_spectrogram(source_mag_stft.cpu(), title = "Source spectrogram (dB)")
            plot_spectrogram(estimate_mag_stft.cpu(), title = "Predicted source spectrogram (dB)")
        return scores


In [None]:
scores = one_song_from_filepath(filepath, model, offset = start_in_seconds, source_path = source_path, length = length_in_seconds, eval = True, full_eval_mode=True, plot_spectrograms = True)



In [None]:
scores = one_song_from_filepath(filepath, chroma_attention, offset = start_in_seconds, source_path = source_path, length = length_in_seconds, eval = True, full_eval_mode=True, plot_spectrograms = True)


In [None]:
scores = one_song_from_filepath(filepath, chroma_fc, offset = start_in_seconds, source_path = source_path, length = length_in_seconds, eval = True, full_eval_mode=True, plot_spectrograms = True)
