# Example - Serializing An Asteroid Model for Audacity

[Asteroid](https://github.com/asteroid-team/asteroid) is a source separation library that contains recipes for training state-of-the-art source separation models on a variety of datasets. Their models trace into torchscript with no problem, so all we need to take care of is providing the wrappers for dealing with the I/O. The models are hosted in [HuggingFace](https://huggingface.co/models?filter=asteroid).  

### preliminaries

Install some dependencies. 

In [None]:
!git clone https://github.com/asteroid-team/asteroid
!pip install ./asteroid/

In [None]:
%%capture
import os
import math
import torch
from torch import nn
from asteroid.models import ConvTasNet
import json
from pathlib import Path

# use no grad!
torch.set_grad_enabled(False)

In [None]:
%%capture
# required for huggingface
!sudo apt-get install git-lfs
!git lfs install

### Let's serialize a pretrained asteroid model!

In [None]:
# download pretrained model from Asteroid
model = ConvTasNet.from_pretrained('JorisCos/ConvTasNet_Libri2Mix_sepnoisy_16k')

In [None]:
from torchsummary import summary
print(summary(model, (1, 4800)))

### Model Metadata



We need to create a `metadata.json` file for our model. Certain details about the model, such as its sample rate, tool type (e.g. waveform-to-waveform or waveform-to-labels), list of labels, etc. must be provided in a separate metadata file. See the [contributing documentation](https://github.com/hugofloresgarcia/audacitorch) for the full metadata schema.

In [None]:
# create a dictionary with model metadata
args = model.get_model_args()
metadata = {
    'sample_rate': int(args['sample_rate']), 
    'domain_tags': ['speech'],
    'short_description': 'Use me for speech separation! Works with 2 speakers.',
    'long_description':  'This model was trained by Joris Cosentino using the librimix recipe in Asteroid. It was trained on the sep_noisy task of the Libri2Mix dataset.',
    'tags': ['speech separation', 'speech'],
    'labels': ['speaker-1', 'speaker-2'],
    'effect_type': 'waveform-to-waveform',
    'multichannel': False,
}

### Prepare for Wrapping the Model

Because source separation models return audio waveforms as output, we'll need to use the `WaveformToWaveform` base class.

In [None]:
from audacitorch import WaveformToWaveform
from audacitorch.utils import save_model, test_run, validate_metadata

In [None]:
# look at the docstring for do_forward_pass
WaveformToWaveform.do_forward_pass?

Waveform-to-waveform models for Audacity need to be end-to-end. That is, our model needs to be able to receive a waveform tensor as input (shape `(n_channels, n_samples`), and return a waveform tensor as output (shape `n_src, n_samples`). 

Lucky for us, Asteroid already contains a `separate()` method for performing source separation directly from a waveform tensor to another waveform tensor, so all we need to do is remove the batch dimension!

In [None]:
 class AsteroidWrapper(WaveformToWaveform):

  def do_forward_pass(self, x: torch.Tensor) -> torch.Tensor:
    return self.model.separate(x)[0]

## Serialize!

We now have an `AsteroidWrapper` class that satisfies the input/output constraints required by waveform-to-waveform models in Audacity. It's time to serialize into a `torchscript` model. 

In [None]:
# compile!
wrapper = AsteroidWrapper(model)
example_inputs = wrapper.get_example_inputs()

serialized_model = torch.jit.trace(wrapper, example_inputs[0], 
                                   check_inputs=example_inputs)
serialized_model = torch.jit.script(serialized_model)

print(f'sample input shape: {example_inputs[0].shape}')
print(f'sample output shape: {serialized_model(example_inputs[0]).shape}')

# test run!
test_run(serialized_model)

# make sure our metadata is ok
success, msg = validate_metadata(metadata)
assert success

save_model(serialized_model, metadata, Path('ConvTasNet-DAMP-Vocals'))

## All set!

Your `model.pt` and `metadata.json` files are ready for upload to HuggingFace. Once your model has been uploaded, you will be able to access it in Audacity by . See the [contributing documentation](https://github.com/hugofloresgarcia/audacitorch) for more information on uploading to HuggingFace. 