[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/asteroid-team/asteroid/blob/master/notebooks/01_APIOverview.ipynb)

### Introduction
Asteroid is an open-source, community-based toolkit made to design, train, evaluate, use and share audio source separation models such as Deep clustering ([Hershey et al.](https://arxiv.org/abs/1508.04306)), ConvTasNet ([Luo et al.](https://arxiv.org/abs/1809.07454)) DPRNN ([Luo et al.](https://arxiv.org/abs/1910.06379)) etc..  
Along with the models, Asteroid provides building blocks, losses, metrics and datasets commonly used in source separation. This makes it easy to design new source separation models and benchmark them against others ! 

For training, Asteroid relies of the great [PyTorchLightning](https://github.com/PyTorchLightning/pytorch-lightning), which handles automatic distributed training, logging, experiment resume and much more, be sure to check it out! For the rest, it's native [PyTorch](https://pytorch.org).

Enough talking, let's start !

In [1]:
# First off, install asteroid
!pip install git+https://github.com/asteroid-team/asteroid --quiet


You should consider upgrading via the 'pip install --upgrade pip' command.[0m


### After installing requirements, you need to Restart Runtime (Ctrl + M).

Else it will fail to import asteroid

### Waveform transformations & features
Time-frequency transformations are often performed on waveforms before feeding them to source separation models. Most of them can be formulated as convolutions with specific (learned or not) filterbank. Their inverses, mapping back to time domain, can be formulated as transposed convolution. 
Asteroid proposes a unified view of this transformations, which is implemented with the classes `Filterbank`, `Encoder` and `Decoder`.

The `Filterbank` object is the one holding the actual filters that are used to compute the transforms. `Encoder` and `Decoder` are applied on top to provide method to go back and forth from waveform to time-frequency domain.

A common example is the one of the STFT, that can be defined as follows:


In [2]:
from asteroid_filterbanks import STFTFB, Encoder, Decoder
# First, instantiate the STFT filterbank
fb = STFTFB(n_filters=256, kernel_size=128, stride=64)
# Make an encoder out of it, forward some waveform through it.
encoder = Encoder(fb)
# Same for decoder
decoder_fb = STFTFB(n_filters=256, kernel_size=128, stride=64)
decoder = Decoder(decoder_fb)

# The preceding lines can also be obtained faster with these lines
from asteroid_filterbanks import make_enc_dec
encoder, decoder = make_enc_dec('stft', n_filters=256, 
                                kernel_size=128, stride=64)



From there, the interface of `Encoder` is the same as the one from `torch.nn.Conv1d` and `Decoder` as `torch.nn.ConvTranspose1d`, and a waveform-like object can be transformed like this:

In [3]:
import torch
# Waveform-like
wav = torch.randn(2, 1, 16000)
# Time-frequency representation
tf_rep = encoder(wav)
# Back to time domain
wav_back = decoder(tf_rep)

More info on automatic pseudo-inverse, how to define your own filterbanks etc.. can be found in the
[Filterbank notebook](https://github.com/asteroid-team/asteroid/blob/master/notebooks/02_Filterbank.ipynb).

### Masker network & Separation models 
Asteroid aims at providing most state-of-the-art masker neural network. 
Some of these masking networks and/or separation models share building blocks such as residual LSTMs or D-Conv-based convolutional blocks. 
Asteroid provides these building blocks as well as common masker networks with building blocks already assembled (eg. `TDConvNet` or `DPRNN`).

These blocks are already configured optimally according to the corresponding papers, just import them and run ! 

In [4]:
from asteroid.masknn import TDConvNet
# We only need to specify the number of input channels
# and the number of sources we want to estimate.
masker = TDConvNet(in_chan=128, n_src=2)

# Now, we can use it to estimate some masks!
tf_rep = torch.randn(2, 128, 10)
wav_back = masker(tf_rep)

Let's put the encoder, masker and decoder together in an `nn.Module` to make it all simple.

In [5]:
from asteroid_filterbanks import make_enc_dec

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder and Decode in "one line"
        self.enc, self.dec = make_enc_dec(
            'stft', n_filters=256, kernel_size=128, stride=64
            )
        # # Mask network from ConvTasNet in one line.
        self.masker = TDConvNet(in_chan=self.enc.n_feats_out, 
                                n_src=2)
    
    def forward(self, wav):
        # Simplified forward
        tf_rep = self.enc(wav)
        masks = self.masker(tf_rep)
        wavs_out = self.dec(tf_rep.unsqueeze(1) * masks)
        return wavs_out


# Define and forward 
stft_conv_tasnet = Model()
wav_out = stft_conv_tasnet(torch.randn(1, 1, 16000))

Actually, for models like ConvTasNet, they can directly be imported and used from asteroid like this :


In [6]:
from asteroid import ConvTasNet
model = ConvTasNet(n_src=2)

### Datasets and DataLoader
We support several source separation datasets, you can find more information on them in the docs. 
Note that their is no common API between them, preparing the data in the format expected by the `Dataset` is the role of the recipes.

In order to experiment easily, we added a small part of LibriMix for direct download.

In [7]:
from asteroid.data import LibriMix

train_set, val_set = LibriMix.mini_from_download(task='sep_clean')

Drop 0 utterances from 800 (shorter than 3 seconds)
Drop 0 utterances from 200 (shorter than 3 seconds)


### Loss functions
Asteroid provides several loss functions that are commonly used for source separation or speech enhancement. More importantly, we also provide `PITLossWrapper`, an efficient wrapper that can turn any loss function into a permutation invariant loss. 

For example, defining a permuatation invariant si-sdr loss, run

In [8]:
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr


loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')

You can find more info about this in the [PIT loss tutorial](https://github.com/asteroid-team/asteroid/blob/master/notebooks/03_PITLossWrapper.ipynb).

### Training

For training, Asteroid relies on PyTorchLightning which automatizes almost everything for us. We have a thin wrapper around it to make things even simpler.


#### Putting all ingredients together with `System`
To use PyTorchLightning, we need to define all the ingredients (dataloaders, model, loss functions, optimizers, etc..) into one object, the `LightningModule`. In order to keep things separate and re-usable, and to reduce boilerplate, we define a sub-class, `System`, which expects these ingredients separately. 



Additionally, `LightningModule` needs to expose the `training_step` and `validation_step` functions. It is usual for these functions to be shared or really similar so we grouped them under `common_step`.
```
class System(pl.LightningModule):
    def __init__(self, model, optimizer, loss_func, train_loader,
                 val_loader=None, scheduler=None, config=None):
      ...

    def common_step(self, batch, batch_nb, train=True):
        inputs, targets = batch
        est_targets = self(inputs)
        loss = self.loss_func(est_targets, targets)
        return loss
```

#### Example training script

In [9]:
from torch.optim import Adam
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from asteroid.data import LibriMix
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from asteroid import ConvTasNet

train_set, val_set = LibriMix.mini_from_download(task='sep_clean')
train_loader = DataLoader(train_set, batch_size=4, drop_last=True)
val_loader = DataLoader(val_set, batch_size=4, drop_last=True)

# Define model and optimizer (one repeat to be faster)
model = ConvTasNet(n_src=2, n_repeats=1)
optimizer = Adam(model.parameters(), lr=1e-3)
# Define Loss function.
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
# Define System
system = System(model=model, loss_func=loss_func, optimizer=optimizer,
                train_loader=train_loader, val_loader=val_loader)
# Define lightning trainer, and train
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(system)

INFO:lightning:Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
INFO:lightning:GPU available: False, used: False
INFO:lightning:
   | Name                              | Type           | Params
-----------------------------------------------------------------
0  | model                             | ConvTasNet     | 1 M   
1  | model.encoder                     | Encoder        | 8 K   
2  | model.encoder.filterbank          | FreeFB         | 8 K   
3  | model.masker                      | TDConvNet      | 1 M   
4  | model.masker.bottleneck           | Sequential     | 66 K  
5  | model.masker.bottleneck.0         | GlobLN         | 1 K   
6  | model.masker.bottleneck.1         | Conv1d         | 65 K  
7  | model.masker.TCN                  | ModuleList     | 1 M   
8  | model.masker.TCN.0                | Conv1DBlock    | 201 K 
9  | model.masker.TCN.0.shared_block   | Sequential     | 70 K  
10 | model.masker.TCN.0.shared_block.0 | Conv1d

Drop 0 utterances from 800 (shorter than 3 seconds)
Drop 0 utterances from 200 (shorter than 3 seconds)




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1

#### Extending `System`
If your model or data is a bit different, changing `System` is easy, just overwrite the `common_step` method.

In [10]:
# Example of how simple it is to define a new System with 
# different training dynamic.
class YourSystem(System):
    def common_step(self, batch, batch_nb, train=True):
        # Your DataLoader returns three tensors
        inputs, some_other_input, targets = batch
        # Your model returns two.
        est_targets, some_other_output = self(inputs, some_other_input)
        if train:
            # Your loss takes three argument
            loss = self.loss_func(est_targets, targets, cond=some_other_output)
        else:
            # At validation time, you don't want cond 
            loss = self.loss_func(est_targets, targets)
        return loss

Of course, Asteroid is not limited to using `System` as this is pure PyTorchLightning and more complicated use cases might not benefit from `System`. In this case, writing a `LightningModule` would be the way to go !