# Example - Serializing An Asteroid Model for Audacity

[Asteroid](https://github.com/asteroid-team/asteroid) is a source separation library that contains recipes for training state-of-the-art source separation models on a variety of datasets. Their models trace into torchscript with no problem, so all we need to take care of is providing the wrappers for dealing with the I/O. The models are hosted in [HuggingFace](https://huggingface.co/models?filter=asteroid).  

### preliminaries

Install some dependencies. 

In [None]:
!git clone https://github.com/asteroid-team/asteroid
!pip install ./asteroid/
!pip install "torch==1.8.1"
!pip install "torchaudio==0.8.0"
!rm -rf torchaudacity
!git clone https://github.com/hugofloresgarcia/torchaudacity
!pip install ./torchaudacity


fatal: destination path 'asteroid' already exists and is not an empty directory.
Processing ./asteroid
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: asteroid
  Building wheel for asteroid (PEP 517) ... [?25l[?25hdone
  Created wheel for asteroid: filename=asteroid-0.5.2.dev0-py3-none-any.whl size=152663 sha256=e518ee8b2abe779eee010c383056bea7e8db183a96412226496d09b95bf7e36d
  Stored in directory: /tmp/pip-ephem-wheel-cache-ww88rnrb/w

Collecting torch==1.8.1


KeyboardInterrupt: ignored

In [29]:
%%capture
import os
import math
import torch
from torch import nn
from asteroid.models import ConvTasNet
import json
from pathlib import Path

# use no grad!
torch.set_grad_enabled(False)

In [30]:
%%capture
# required for huggingface
!sudo apt-get install git-lfs
!git lfs install

### Let's serialize a pretrained asteroid model!

In [31]:
# download pretrained model from Asteroid
model = ConvTasNet.from_pretrained('JorisCos/ConvTasNet_Libri2Mix_sepnoisy_16k')

In [32]:
from torchsummary import summary
print(summary(model, (1, 4800)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Encoder-1             [-1, 512, 299]               0
          Identity-2             [-1, 512, 299]               0
            GlobLN-3             [-1, 512, 299]               0
            Conv1d-4             [-1, 128, 299]          65,664
            Conv1d-5             [-1, 512, 299]          66,048
             PReLU-6             [-1, 512, 299]               1
            GlobLN-7             [-1, 512, 299]               0
            Conv1d-8             [-1, 512, 299]           2,048
             PReLU-9             [-1, 512, 299]               1
           GlobLN-10             [-1, 512, 299]               0
           Conv1d-11             [-1, 128, 299]          65,664
           Conv1d-12             [-1, 128, 299]          65,664
      Conv1DBlock-13  [[-1, 128, 299], [-1, 128, 299]]               0
           Conv1d-14            

### Model Metadata



We need to create a `metadata.json` file for our model. Certain details about the model, such as its sample rate, tool type (e.g. waveform-to-waveform or waveform-to-labels), list of labels, etc. must be provided in a separate metadata file. See the [contributing documentation](interactiveaudiolab.github.io/audacity/contrib) for the full metadata schema.

In [33]:
# create a dictionary with model metadata
args = model.get_model_args()
metadata = {
    'sample_rate': int(args['sample_rate']), 
    'domain_tags': ['speech'],
    'short_description': 'Use me for speech separation! Works with 2 speakers.',
    'long_description':  'This model was trained by Joris Cosentino using the librimix recipe in Asteroid. It was trained on the sep_noisy task of the Libri2Mix dataset.',
    'tags': ['speech separation', 'speech'],
    'labels': ['speaker-1', 'speaker-2'],
    'effect_type': 'waveform-to-waveform',
    'multichannel': False,
}

Waveform-to-waveform models for Audacity need to be end-to-end. That is, our model needs to be able to receive a waveform tensor as input (shape `(n_channels, n_samples`), and return a waveform tensor as output (shape `n_src, n_samples`). 

Lucky for us, Asteroid already contains a `separate()` method for performing source separation directly from a waveform tensor to another waveform tensor, so all we need to do is remove the batch dimension!

In [34]:
from torchaudacity import WaveformToWaveform
from torchaudacity.utils import save_model

ImportError: ignored

In [None]:
# look at the docstring for do_forward_pass
WaveformToWaveform.do_forward_pass?

In [None]:
 class AsteroidWrapper(WaveformToWaveform):

  def do_forward_pass(self, x: torch.Tensor) -> torch.Tensor:
    return self.model.separate(x)[0]

## Serialize!

We now have an `AsteroidWrapper` class that satisfies the input/output constraints required by waveform-to-waveform models in Audacity. It's time to serialize into a `torchscript` model. 

In [None]:
# compile!
wrapper = AsteroidWrapper(model)
example_inputs = wrapper.get_example_inputs()

serialized_model = torch.jit.trace(wrapper, example_inputs[0], 
                                   check_inputs=example_inputs)
serialized_model = torch.jit.script(serialized_model)

print(f'sample input shape: {example_inputs[0].shape}')
print(f'sample output shape: {serialized_model(example_inputs[0]).shape}')

save_model(serialized_model, metadata, Path('ConvTasNet-DAMP-Vocals'))

## All set!

Your `model.pt` and `metadata.json` files are ready for upload to HuggingFace. Once your model has been uploaded, you will be able to access it in Audacity by . See the [contributing documentation](https://interactiveaudiolab.github.io/audacity/contrib) for more information on uploading to HuggingFace. 