In [1]:
from scripts import *
import torch
from ipywidgets import HBox, Label, VBox

# Training the Wave U Net

## Preparing Training DataLoader and Testing DataLoader

We point to the dataset we just made in `CreateDataset.ipynb` and create a Dataset object, which, when indexed with an integer, returns a sample tuple of the form `(mixture_audio, seperated_stems)`. 

In [2]:
data_folder = "./data"
hdf_dir_train = f"{data_folder}/training_data.h5"
hdf_dir_test = f"{data_folder}/testing_data.h5"

SSDTrain = SourceSeperationDataset(hdf_dir_train)
SSDTest = SourceSeperationDataset(hdf_dir_test)

We then load that dataset object into a pytorch Dataloader

In [3]:
from torch.utils.data import DataLoader

DatasetTrainLoader = DataLoader(SSDTrain, batch_size=16, shuffle=True)
DatasetTestLoader = DataLoader(SSDTest, batch_size=16, shuffle=True)

## Create the WaveUNet

We will try to run this on the GPU:

In [4]:
device = torch.device("mps")

Lets define a WaveUNet with:

- 12 Layers
- 24 additional filters per layer
- 1 input channel (because theres a mono soundfile)
- 4 output channels (because we're seperating into 4 instruments)

In [5]:
model = WaveUNet(L=12,Fc=24,in_channels=1,out_channels=4)
model.to(device);

## Pull model from saved one

In [6]:
model_folder = "./models"
model_file = f"{model_folder}/WaveUNet_Full_Kevin.model"
model.load_state_dict(torch.load(model_file)());

# Testing the model

In [88]:
import librosa
y, sr = librosa.load('./testing-songs/', sr=44100//2)

In [89]:
from IPython.display import Audio, display
from ipywidgets import widgets

AS = AudioSection(torch.tensor(y.reshape((1,*y.shape))), sr)
audio_sections = AS.cut_into_sections_based_on_samples(2**14)
pieces = audio_sections[0:100]
with torch.no_grad():
    new_pieces = [model(p.audio.view((1,*p.audio.shape)).to(device)) for p in pieces]
    new = torch.cat(new_pieces, dim=-1)
    new = new.to("cpu")

In [90]:
def get_audio_out(audio):
    out = widgets.Output()
    with out:
        display(audio)
    return out

def audio_torch(tens):
    return get_audio_out(Audio(tens.cpu().numpy(), rate=44100//2))

with torch.no_grad():
    drums = HBox([Label("Drums: "), get_audio_out(Audio(new[0,0,:].cpu().numpy(), rate=44100//2))])
    bass = HBox([Label("Bass: "), get_audio_out(Audio(new[0,1,:].cpu().numpy(), rate=44100//2))])
    other = HBox([Label("Other: "), get_audio_out(Audio(new[0,2,:].cpu().numpy(), rate=44100//2))])
    vocals = HBox([Label("Vocals: "), get_audio_out(Audio(new[0,3,:].cpu().numpy(), rate=44100//2))])
    original = HBox([Label("Original: "), get_audio_out(Audio(AS.audio[0,0:(100*2**14)].cpu().numpy(), rate=44100//2))])
    display(VBox([original, drums, bass, other, vocals]))

VBox(children=(HBox(children=(Label(value='Original: '), Output())), HBox(children=(Label(value='Drums: '), Ou…