In [1]:
from scripts import *

# Training the Wave U Net

## Preparing Training DataLoader and Testing DataLoader

We point to the dataset we just made in `CreateDataset.ipynb` and create a Dataset object, which, when indexed with an integer, returns a sample tuple of the form `(mixture_audio, seperated_stems)`. 

In [2]:
data_folder = "./data"
hdf_dir_train = f"{data_folder}/training_data.h5"
hdf_dir_test = f"{data_folder}/testing_data.h5"

SSDTrain = SourceSeperationDataset(hdf_dir_train)
SSDTest = SourceSeperationDataset(hdf_dir_test)

We then load that dataset object into a pytorch Dataloader

In [3]:
from torch.utils.data import DataLoader

DatasetTrainLoader = DataLoader(SSDTrain, batch_size=16, shuffle=True)
DatasetTestLoader = DataLoader(SSDTest, batch_size=16, shuffle=True)

we can now iterate through the dataloaders, which will return for us minibatches of tensors. For example:

In [4]:
(lambda x: (x[0].shape, x[1].shape))(next(iter(DatasetTrainLoader)))

(torch.Size([16, 1, 1, 16384]), torch.Size([16, 4, 1, 16384]))

Note that the shape of the input tensors and output tensors are:

`(batch_size)x(instruments)x(audio_channels)x(audio_samples)`

For the input we have:  `(16)x(1)x(1)x(16384)`

For the output we have: `(16)x(4)x(1)x(16384)`

## Create the WaveUNet

We will try to run this on the GPU:

In [5]:
device = torch.device("mps")

Lets define a WaveUNet with:

- 12 Layers
- 24 additional filters per layer
- 1 input channel (because theres a mono soundfile)
- 4 output channels (because we're seperating into 4 instruments)

In [6]:
WN_kevin = WaveUNet(L=12,Fc=24,in_channels=1,out_channels=4)
WN_kevin.to(device);

In [None]:
learning_rate_table = [1e-4, 1e-4, 1e-4] + [1e-4]*20
#optimizer = torch.optim.SGD(WUN.parameters(), lr=learning_rate)


criterion = torch.nn.MSELoss()
model = WN_kevin  # Set the net you want to train to this model
model.to(device)

n = 0
for epoch in tqdm(range(len(learning_rate_table)), desc=" outer"):
    learning_rate = learning_rate_table[epoch]
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for batch, (train_combined, train_seperated) in tqdm(enumerate(DatasetLoader), total=len(DatasetLoader)):
        (batch_size, instruments_in, audio_channels, audio_samples) = train_combined.shape
        X = train_combined.view((batch_size, instruments_in*audio_channels, audio_samples));

        (batch_size, instruments_out, audio_channels, audio_samples) = train_seperated.shape
        Y = train_seperated.view((batch_size, instruments_out*audio_channels, audio_samples));

        X = X.to(device);
        Y = Y.to(device)

        #print()
        #for j in range(20):
        Y_pred = model(X)
        loss = criterion(Y_pred, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        n += 1

        #if batch % 10 == 0:
        writer.add_scalar("Loss/train", loss, n)
            #loss, current = loss.item(), (batch + 1) * len(X)
            #writer.add_scalar("Loss/train", loss, n)
            #print(f"loss: {loss:>7f}")
        
    ### Produce a piece of Audio on test set
    with torch.no_grad():
        test_combined, test_seperated = next(iter(DatasetTestLoader))
        (batch_size, instruments_in, audio_channels, audio_samples) = test_combined.shape
        X = test_combined.view((batch_size, instruments_in*audio_channels, audio_samples));

        (batch_size, instruments_out, audio_channels, audio_samples) = test_seperated.shape
        Y = test_seperated.view((batch_size, instruments_out*audio_channels, audio_samples));

        X = X.to(device);
        Y = Y.to(device)
        Y_pred = model(X)
        writer.add_audio("Sample Input",   X[0, 0, :] ,sample_rate=44100//2)
        writer.add_audio("Sample Vocals",  Y_pred[0, -1, :] ,sample_rate=44100//2)
        writer.add_audio("Actual Vocals",  Y[0, -1, :] ,sample_rate=44100//2)
        writer.add_audio("Sample Drums",  Y_pred[0, 0, :] ,sample_rate=44100//2)
        writer.add_audio("Actual Drums",  Y[0, 0, :] ,sample_rate=44100//2)
        writer.add_audio("Sample Bass",  Y_pred[0, 1, :] ,sample_rate=44100//2)
        writer.add_audio("Actual Bass",  Y[0, 1, :] ,sample_rate=44100//2)
        writer.add_audio("Sample Other",  Y_pred[0, -2, :] ,sample_rate=44100//2)
        writer.add_audio("Actual Other",  Y[0, -2, :] ,sample_rate=44100//2)
        if (epoch % 2) == 0:
            datetime_now = f"{datetime.datetime.now()}".replace(":", "-").split(".")[0]
            torch.save(model.state_dict, f"./the_wn1_updated.model_checkpoint_epoch{epoch}")
            
writer.close()