In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging 
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)

In [None]:
# Set cwd to project root directory, so that waveUNet imports can be resolved.
import os
project_root = os.path.dirname(os.path.dirname(os.path.realpath("__file__")))
os.chdir(project_root)
os.getcwd()

In [None]:
import musdb_loader
import viz
import dataset
import torch

In [None]:
data = musdb_loader.get_musdb_folds('data/musdb', version=None)

In [None]:
train = dataset.SeparationDataset(data['train'])
test = dataset.SeparationDataset(data['test'])

In [None]:
train_audio, target_audios = (train[9])

In [None]:
train_audio, target_audios

In [None]:
viz.play_audio(train_audio, sample_rate=22050)

In [None]:
viz.play_audio(target_audios['vocals'], sample_rate=22050)

In [None]:
viz.plot_specgram(train_audio, sample_rate=22050)

In [None]:
viz.plot_specgram(target_audios['vocals'], sample_rate=22050)

In [None]:
viz.plot_specgram(target_audios['bass'], sample_rate=22050)

In [None]:
viz.plot_waveform(train_audio, sample_rate=22050)

In [None]:
viz.plot_sweep(train_audio, 22050, 'sweep')

In [None]:
viz.plot_sweep(target_audios['bass'], 22050, 'sweep')

### Adding a transformation used in the paper

In [None]:
import numpy as np
def random_amplify(mix, targets, min, max):
    '''
    Data augmentation by randomly amplifying sources before adding them to form a new mixture
    :param mix: Original mixture
    :param targets: Source targets
    :param shapes: Shape dict from model
    :param min: Minimum possible amplification
    :param max: Maximum possible amplification
    :return: New data point as tuple (mix, targets)
    '''
    residual = mix  # start with original mix
    for key in targets.keys():
        residual -= targets[key]  # subtract all instruments (output is zero if all instruments add to mix)
    mix = residual * np.random.uniform(min, max)  # also apply gain data augmentation to residual
    for key in targets.keys():
        if key != "mix":
            targets[key] = targets[key] * np.random.uniform(min, max)
            mix += targets[key]  # add instrument with gain data augmentation to mix
#     mix = torch.clip(mix, -1.0, 1.0) #uncomment later
    return mix, targets


In [None]:
from functools import partial
amplify_function = partial(random_amplify, min=0.7, max=10.0)
transforms = [amplify_function]
transformed_train = dataset.SeparationDataset(data['train'], transforms=transforms)
trasnformed_test = dataset.SeparationDataset(data['test'], transforms=transforms)

In [None]:
transformed_train_audio, transformed_target_audios = transformed_train[9]

In [None]:
viz.plot_specgram(train_audio, sample_rate=22050)

In [None]:
viz.plot_specgram(transformed_train_audio, sample_rate=22050)

In [None]:
viz.play_audio(train_audio, sample_rate=22050)

In [None]:
viz.play_audio(transformed_train_audio, sample_rate=22050)