# Add your own model

<a href="https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/add_your_own_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Tutorial setup

### `Google Colab` setup

If you are running this tutorial on `Colab`, execute the following commands in order to setup `Colab` environment. These commands will install `pyannote.audio` and download a mini version of the `AMI` corpus.

In [None]:
!pip install -qq pyannote.audio==3.1.1
!pip install -qq ipython==7.34.0
!git clone https://github.com/pyannote/AMI-diarization-setup.git
%cd ./AMI-diarization-setup/pyannote/
!bash ./download_ami_mini.sh
%cd /content

⚠ Restart the runtime (Runtime > Restart session).

### Non `Google Colab` setup

If you are not using `Colab`, this tutorial assumes that
* `pyannote.audio` has been installed
* the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote)

## Defining a custom model

A collection of models is readily available in `pyannote.audio.models` but you will eventually want to try your own architecture.

This tutorial explains how to define (and then use) your own model.  

In [18]:
from typing import Optional
import torch
import torch.nn as nn
from pyannote.audio import Model
from pyannote.core import SlidingWindow
from pyannote.audio.core.task import Task, Resolution
from torchaudio.transforms import MFCC

# Your custom model must be a subclass of `pyannote.audio.Model`,
# which is a subclass of `pytorch_lightning.LightningModule`,
# which is a subclass of `torch.nn.Module`.
class MyCustomModel(Model):
    """My custom model"""


    def __init__(
        self,
        sample_rate: int = 16000,
        num_channels: int = 1,
        task: Optional[Task] = None,
        param1: int = 32,
        param2: int = 16,
    ):

        # First three parameters (sample_rate, num_channels, and task)
        # must be there and passed to super().__init__()
        super().__init__(sample_rate=sample_rate,
                         num_channels=num_channels,
                         task=task)

        # Mark param1 and param2 as hyper-parameters.
        self.save_hyperparameters("param1", "param2")

        # They will be saved automatically into checkpoints.
        # They are now also available in self.hparams:
        #  - param1 == self.hparams.param1
        #  - param2 == self.hparams.param2

        # Layers that do not depend on the addressed task should be defined in '__init__'.
        self.mfcc = MFCC()
        self.linear1 = nn.Linear(self.mfcc.n_mfcc, self.hparams.param1)
        self.linear2 = nn.Linear(self.hparams.param1, self.hparams.param2)

    def num_frames(self, num_samples: int) -> int:
        # Compute number of output frames for a given number of input samples
        hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
        n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
        center = self.mfcc.MelSpectrogram.spectrogram.center
        return (
            1 + num_samples // hop_length
            if center
            else 1 + (num_samples - n_fft) // hop_length
        )

    def receptive_field_size(self, num_frames: int = 1) -> int:
        # Compute receptive field size
        hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
        n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
        center = self.mfcc.MelSpectrogram.spectrogram.center

        if center:
            return (num_frames - 1) * hop_length
        else:
            return (num_frames - 1) * hop_length + n_fft

    def receptive_field(self) -> SlidingWindow:
        # Compute receptive field

        # duration of the receptive field of each output frame
        duration = (
            self.mfcc.MelSpectrogram.spectrogram.win_length / self.hparams.sample_rate
        )

        # step between the receptive field region of two consecutive output frames
        step = (
            self.mfcc.MelSpectrogram.spectrogram.hop_length / self.hparams.sample_rate
        )

        return SlidingWindow(start=0.0, duration=duration, step=step)

    def build(self):
        # Add layers that depend on the specifications of the task addressed
        # by this model.

        # For instance, this simple model could be used for "speech vs. non-speech"
        # or "speech vs. music vs. other" classification and the only difference
        # would lie in the number of classes (2 or 3) in the final classifier.

        # Since task specifications are not available at the time '__init__' is called,
        # task-dependent layers can only be added a 'build' time (where task specifications
        # are available in 'specifications' attribute)

        num_classes = len(self.specifications.classes)
        self.classifier = nn.Linear(self.hparams.param2, num_classes)

        # 'specifications' has several attributes describing what the task is:
        #  - classes: the list of classes
        #  - problem: the type of machine learning problem (e.g. binary
        #      classification or representation learning)
        #  - duration: the duration of input audio chunks, in seconds
        #  - resolution: the resolution of the output (e.g. frame-wise scores
        #      for voice activity detection or chunk-wise vector for speaker
        #      embedding)
        #  - permutation_invariant : whether classes are permutation-invariant
        #      (e.g. in the case of speaker diarization)

        # Depending on the type of 'problem', 'default_activation' can be used
        # to automatically guess what the final activation should be (e.g. softmax
        # for multi-class classification or sigmoid for multi-label classification).
        self.activation = self.default_activation()

        # You obviously do not _have_ to use 'default_activation' and can choose to
        # use any activation you see fit (or even not use any activation layer). But
        # note that pyannote.audio tasks also define default loss functions that are
        # consistent with `default_activation` (e.g. binary cross entropy with softmax
        # for binary classification tasks)

    def forward(self, waveforms: torch.Tensor) -> torch.Tensor:

        # Models are expected to work on batches of audio chunks provided as tensors
        # with shape (batch_size, num_channels, num_samples) and using the sample rate
        # passed to __init__. Resampling will be done automatically for you so you do
        # not have to bother about that when preparing the data.

        # Extract sequence of MFCCs and passed them through two linear layers
        mfcc = self.mfcc(waveforms).squeeze(dim=1).transpose(1, 2)
        output = self.linear1(mfcc)
        output = self.linear2(output)

        # Apply temporal pooling for tasks which need an output at chunk-level.
        if self.specifications.resolution == Resolution.CHUNK:
            output = torch.mean(output, dim=-1)
        # Keep 'mfcc' frame resolution for frame-level tasks.
        elif self.specifications.resolution == Resolution.FRAME:
            pass

        # Apply final classifier and activation function
        output = self.classifier(output)
        return self.activation(output)

## Using your model with `pyannote.audio` API

Your model can now be used like any other builtin model.

In [None]:
# initialize your experimental protocol
from pyannote.database import registry, FileFinder

registry.load_database("./AMI-diarization-setup/pyannote/database.yml")
protocol = registry.get_protocol('AMI.SpeakerDiarization.mini', preprocessors={"audio": FileFinder()})

# initialize the task you want to address
from pyannote.audio.tasks import VoiceActivityDetection
task = VoiceActivityDetection(protocol)

# initialize the model
model = MyCustomModel(task=task)

# train the model
from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=1)
trainer.fit(model)

## Using your model with `pyannote-audio-train` CLI

1. Define your model in a proper Python package:

```
/your/favorite/directory/
  your_package_name/
    __init__.py      # needs to be here but can be empty
    custom_model.py  # contains the above definition of your model
```

2. Add the package to your `PYTHONPATH`:

```bash
$ export PYTHONPATH=/your/favorite/directory
```

3. Check that you can import it from Python:

```python
>>> from your_package_name.custom_model import MyCustomModel
```

4. Tell `Hydra` (on which `pyannote-audio-train` is based) about this new model:

```
/your/favorite/directory/
  custom_config/
    model/
      MyCustomModel.yaml
```

where the content of `MyCustomModel.yaml` is as follows:

```yaml
# @package _group_
_target_: your_package_name.custom_model.MyCustomModel
param1: 32
param2: 16
```

5. Enjoy

```bash
$ pyannote-audio-train --config-dir=/your/favorite/directory/custom_config \
                       protocol=Debug.SpeakerDiarization.Debug \
                       task=VoiceActivityDetection \
                       model=MyCustomModel \
                       model.param2=12
```

## Contributing your model to `pyannote-audio`

1. Add your model in `pyannote.audio.models`.

```
pyannote/
  audio/
    models/
      custom_model.py        
```

2. Check that you can import it from Python:

```python
>>> from pyannote.audio.models.custom_model import MyCustomModel
```

3. Add the corresponding `Hydra` configuration file:

```
pyannote/
  audio/
    cli/
      train_config/
        model/
          MyCustomModel.yaml
```

where the content of `MyCustomModel.yaml` is as follows:

```yaml
# @package _group_
_target_: pyannote.audio.models.custom_model.MyCustomModel
param1: 32
param2: 16
```

4. Enjoy

```bash
$ pyannote-audio-train protocol=Debug.SpeakerDiarization.Debug \
                       task=VoiceActivityDetection \
                       model=MyCustomModel \
                       model.param2=12
```