# Training Lab41's monaural source separation model

This notebook contains a detailed example of how to train Lab41's source separation model.  Filepaths to load training data must be filled in to run this notebook.

In [1]:
# Generic imports
import sys
import time

import numpy as np
import tensorflow as tf

# Plotting imports
import IPython
from IPython.display import Audio
from matplotlib import pyplot as plt
fig_size = [0,0]
fig_size[0] = 8
fig_size[1] = 4
plt.rcParams["figure.figsize"] = fig_size

# Import Lab41's separation model
from magnolia.dnnseparate.L41model import L41Model

# Import utilities for using the model
#from magnolia.utils.clustering_utils import clustering_separate, get_cluster_masks, process_signal
#from magnolia.iterate.supervised_iterator import SupervisedIterator, SupervisedMixer
#from magnolia.iterate.hdf5_iterator import SplitsIterator
from magnolia.iterate.mix_iterator import MixIterator

### Hyperparameters

* **batchsize**      : Number of examples per batch used in training
* **train_mixes**    : List of mix configuration settings used for training (total number of signals and noises must be two in all mixes)
* **validate_mixes** : List of mix configuration settings used for validation (total number of signals and noises must be two in all mixes)
* **model_save_path**: Directory to store the saved models

In [2]:
batchsize = 256
train_mixes = ['/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/settings/mixing_LibriSpeech_UrbanSound8K_train.json']
train_from_disk = True
validate_mixes = ['/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/settings/mixing_LibriSpeech_UrbanSound8K_validate.json']
validate_from_disk = True
model_save_path = '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/model_saves'

### Create an mixer that iterates over examples from the training and validation sets. 

MixIterator takes a list a `mix_data.py` setting JSON file names and constructs an iterator that will loop over all the mixes samples in the specified list.

Here, we'll create two separate iterators, one for the training mixes and another for the validation mixes.

The `MixIterator` class will automatically reset whenever it's reached the end of it's epoch.

In [7]:
training_mixer = MixIterator(mixes_settings_filenames=train_mixes,
                             batch_size=batchsize,
                             from_disk=train_from_disk)

validation_mixer = MixIterator(mixes_settings_filenames=validate_mixes,
                               batch_size=batchsize,
                               from_disk=validate_from_disk)

batch = next(training_mixer)
print(batch.shape)

AttributeError: 'tuple' object has no attribute 'shape'

### Create an instance of Lab41's model

Here an untrained model instance is created, and its variables are initialized

In [None]:
model = L41Model(nonlinearity='tanh', normalize='False', device='/gpu:0')
model.initialize()

### Variables needed to track the training progress of the model

During training, the number of iterations (number of processed batches) is tracked, along with the mean cost on examples from the training data and from the validation data.  The last iteration that the model was saved on can also be tracked.

In [None]:
nbatches = []
costs = []

t_costs = []
v_costs = []

last_saved = 0

### Training loop

Here the model is iteratively trained on batches generated by the mixer.  The model is saved every time the validation cost reaches a new minimum value.  The training can be configured to stop if the model has not been saved after a specified number of iterations have elapsed since the previous save.  Plots of the training cost and the validation set are created as well.

In [None]:
# Number of epochs
num_epochs = 2
# Threshold for stopping if the model hasn't improved for this many consecutive batches
stop_threshold = 10000

# Find the number of batches already elapsed (Useful for resuming training)
start = 0
if len(nbatches) != 0:
    start = nbatches[-1]

batch_count = 0
# Total training epoch loop
for epoch_num in range(num_epochs):
    
    # Training epoch loop
    for batch in iter(training_mixer):
        # dimensions of (batch size, time frame, frequency)
        spectral_sum_batch = batch[0].transpose(0, 2, 1)
        # dimensions of (batch size, time frame, frequency, source)
        spectral_masks_batch = batch[1].transpose(0, 3, 2, 1)
        # dimensions of (batch size, source)
        uids_batch = batch[3]
        
        # scale spectral inputs
        spectral_sum_batch = np.sqrt(np.abs(spectral_sum_batch))
        spectral_sum_batch = (spectral_sum_batch - spectral_sum_batch.min())/(spectral_sum_batch.max() - spectral_sum_batch.min())
        # convert and scale {-1.0, 1.0} spectral masks
        spectral_masks_batch = 2.0*spectral_masks_batch.astype(float) - 1.0
        
        # Train the model on one batch and get the cost
        c = model.train_on_batch(spectral_sum_batch, spectral_masks_batch, uids_batch)
        
        # Store the training cost
        costs.append(c)
        
        # Store the current batch_count number
        
        # Every 10 batches, evaluate the model on the validation data and plot the cost curves
        if (batch_count + 1) % 10 == 0:
            IPython.display.clear_output(wait=True)
            
            # Store the training cost
            t_costs.append(np.mean(costs))
            # Reset the cost over the last 10 batches
            costs = []
            
            # Compute average validation score
            all_c_v = []
            for vbatch in iter(validation_mixer):
                # dimensions of (batch size, time frame, frequency)
                spectral_sum_batch = vbatch[0].transpose(0, 2, 1)
                # dimensions of (batch size, time frame, frequency, source)
                spectral_masks_batch = vbatch[1].transpose(0, 3, 2, 1)
                # dimensions of (batch size, source)
                uids_batch = vbatch[3]

                # scale spectral inputs
                spectral_sum_batch = np.sqrt(np.abs(spectral_sum_batch))
                spectral_sum_batch = (spectral_sum_batch - spectral_sum_batch.min())/(spectral_sum_batch.max() - spectral_sum_batch.min())
                # convert and scale {-1.0, 1.0} spectral masks
                spectral_masks_batch = 2.0*spectral_masks_batch.astype(float) - 1.0

                # Get the cost on the validation batch
                c_v = model.get_cost(spectral_sum_batch, spectral_masks_batch, uids_batch)
                all_c_v.append(c_v)
            
            ave_c_v = np.mean(all_c_v)
            
            # Check if the validation cost is below the minimum validation cost, and if so, save it.
            if len(v_costs) > 0 and ave_c_v < min(v_costs) and len(nbatches) > 0:
                print("Saving the model because validation score is ", min(v_costs) - ave_c_v, " below the old minimum.")
                
                # Save the model to the specified path
                model.save(model_save_path)
                
                # Record the batch that the model was last saved on
                last_saved = nbatches[-1]
            
            # Store the validation cost
            v_costs.append(ave_c_v)
            
            # Store the current batch number
            nbatches.append(batch_count + 1 + start)
            
            # Compute scale quantities for plotting
            length = len(nbatches)
            cutoff = int(0.5*length)
            lowline = [min(v_costs)]*length
        
            # Generate the plots and show them
            f, (ax1, ax2) = plt.subplots(2,1)
        
            ax1.plot(nbatches, t_costs)
            ax1.plot(nbatches, v_costs)
            ax1.plot(nbatches, lowline)
        
            y_u = max(max(t_costs[cutoff:]), max(v_costs[cutoff:]))
            y_l = min(min(t_costs[cutoff:]), min(v_costs[cutoff:]))
        
            ax2.set_ylim(y_l, y_u)
        
            ax2.plot(nbatches[cutoff:], t_costs[cutoff:])
            ax2.plot(nbatches[cutoff:], v_costs[cutoff:])
            ax2.plot(nbatches[cutoff:], lowline[cutoff:])
            plt.show()
        
            print("Cost on batch ", nbatches[-1], " is ", ave_c_v, ".")
            print("Last saved ",nbatches[-1] - last_saved," batches ago.")
        
            # Stop training if the number of iterations since the last save point exceeds the threshold
            if nbatches[-1] - last_saved > stop_threshold:
                print("Done!")
                break
        
        batch_count += 1