# Using a trained model to separate audio files

In order to use a trained model to run separation on an audio file, we need to:
- load the model from checkpoint
- be able to perform the same audio processing as the one used to transform the audio to features during training

The class AudioSeparator implements the loading of model from checkpoint and the instantiation of the validation set (of the dataset that was used to train the model). The validation set implements the audio processing performed during training, so we can thus use it for our purpose.  
*Note*: the instantiation of the AudioSeparator class will create the validation set, which can be slow if the validation set creation requires a lot of work (eg. load a lot of files into RAM). While this is acceptable for most datasets, for practical applications it should be avoided.

In [None]:
# Some imports that we will need
import librosa  # for audio saving
import torch
import numpy as np

from separator import AudioSeparator

The AudioSeparator class needs 2 parameters: the checkpoint of the model to load, and the path to a folder to store the separated audio. We won't need to use the folder, so we can pass any string for this argument.

In [None]:
separated_audio_folder = ""  # anything will do
# Path to the trained model checkpoint
model_ckpt = 'path_to_mode_checkpoint.ckpt'

In [None]:
# Instantiate the AudioSeparator
separator = AudioSeparator.from_checkpoint({"checkpoint_path": model_ckpt, "separated_audio_folder": separated_audio_folder})

Now to load the audio that we want to perform source separation upon:

In [None]:
# Load aAudioSeparatorratorSeparatorio /home/similar way as during training
audio = separator.data_set.load_audio("path_to_wav_to_separate.wav")

Compute audio features similarly as in training:

In [None]:
# Compute short-time Fourier transform
magnitude, phase = separator.data_set.separated_stft(audio)

# Go from magnitude spectrogram to actual features used during training
features = separator.data_set.stft_magnitude_to_features(magnitude=magnitude)
features = torch.tensor(features).unsqueeze(0)  # convert to torch tensor and add channel dimension
# Scale the features as done during the training
if separator.data_set.config['scaling_type'].lower() != "none":
    features = separator.data_set.shift_and_scale_features(features,
                                                           separator.data_set.config['shift'],
                                                           separator.data_set.config['scaling'])

(Most) models can only process input features of a fixed shape, so the features need to be split in chunks of the right shape. The frequency shape and channel shape are decided by the processing, so we just need to split along time dimension.

In [None]:
features_shape = separator.data_set.features_shape()  # (channel, frequency, time)
# Make chunks along time dimension, and stack them in a newly created batch dimension
# Note: the last chunk which would have a smaller size than required is discarded (equivalent to truncate input audio)
# shape of batch: [n_chunks, channel, frequency, time]
batch = torch.stack([features[..., i*features_shape[-1]:(i+1)*features_shape[-1]] 
                     for i in range(features.shape[-1]//features_shape[-1])], 0)

In [None]:
_, masks = separator.model(batch)  # Labels have no utility for separation

Shape of masks: (n_chunks, n_classes, frequency, time).  
In order to separate for a specific class: we need to know which mask to select. The classes used in training are in separator.data_set.classes:

In [None]:
print('\n'.join("%s: %s" % (class_name,idx) 
                for (idx, class_name) in {idx: class_name for idx, class_name in enumerate(separator.data_set.classes)}.items()))

In [None]:
class_idx = 9  # Fill here the class you are interested in !

In [None]:
# Example to plot the masks and spectrograms

# import matplotlib.pyplot as plt

# chunk_idx = 4

# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# h0 = axs[0].imshow(masks[chunk_idx][class_idx].detach(), aspect='auto', origin='lower')
# h1 = axs[1].imshow(batch[chunk_idx].detach().squeeze(), aspect='auto', origin='lower')
# plt.tight_layout()
# plt.show()

Get the separated spectrograms for all the sources in the data set:

In [None]:
spectrograms = [separator.separate_spectrogram_in_lin_scale(masks[i].detach(),
                                                            features_shape,
                                                            magnitude[..., i*features_shape[-1]:(i+1)*features_shape[-1]])
                for i in range(batch.shape[0])]

Now select the class we are interested in:

In [None]:
# Select the class we are interested in: 
class_spectrograms = [spec[class_idx].squeeze() for spec in spectrograms]

Put the spectrograms together to have a single spectrogram for the entire recording

In [None]:
# concatenate along time dimension to produce a single spectrogram for the entire recording:
source_spectrogram = np.concatenate(class_spectrograms, axis=-1)

Synthetize the separated audio from the separated spectrogram and the mixture phase:

In [None]:
# We need to truncate the phase too.
separated_audio = separator.spectrogram_to_audio(source_spectrogram, phase[..., :source_spectrogram.shape[-1]])

To save the audio to file:

In [None]:
librosa.output.write_wav("path_to_output.wav", separated_audio, sr=separator.data_set.config['sampling_rate'])

Remark: To play audio in jupyter notebook, one can use  
        - IPython.display.Audio  (not installed in the environement by default)  
However, the backend playing the audio automatically normalizes the played audio (so that the max value is 1) so the amplitude of the played sound with this method is not relevant !