<a href="https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/add_your_own_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Defining a custom task

## Tutorial setup

### `Google Colab` setup

If you are running this tutorial on `Colab`, execute the following commands in order to setup `Colab` environment. These commands will install `pyannote.audio` and download a mini version of the `AMI` corpus.

In [None]:
!pip install -qq pyannote.audio==3.1.1
!pip install -qq ipython==7.34.0
!git clone https://github.com/pyannote/AMI-diarization-setup.git
%cd ./AMI-diarization-setup/pyannote/
!bash ./download_ami_mini.sh
%cd /content

⚠ Restart the runtime (Runtime > Restart session).

###  Non `Google Colab` setup

If you are not using `Colab`, this tutorial assumes that
* `pyannote.audio` has been installed
* the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote)

## Task in `pyannote.audio`


In `pyannote.audio`, a *task* is a combination of a **_problem_** that needs to be addressed and an **experimental protocol**.

For example, one can address **_voice activity detection_** following the **AMI only_words** experimental protocol, by instantiating the following *task*:


In [None]:
# this assumes that the AMI corpus has been setup for diarization
# according to https://github.com/pyannote/AMI-diarization-setup

from pyannote.database import registry, FileFinder
registry.load_database("AMI-diarization-setup/pyannote/database.yml")
ami = registry.get_protocol('AMI.SpeakerDiarization.mini',
                   preprocessors={'audio': FileFinder()})

# address voice activity detection
from pyannote.audio.tasks import VoiceActivityDetection
task = VoiceActivityDetection(ami)

A growing collection of tasks is readily available in `pyannote.audio.tasks`...

In [2]:
from pyannote.audio.tasks import __all__ as TASKS; print('\n'.join(TASKS))

SpeakerDiarization
VoiceActivityDetection
OverlappedSpeechDetection
MultiLabelSegmentation
SpeakerEmbedding
Segmentation


... but you will eventually want to use `pyannote.audio` to address a different task.  
In this example, we will add a new task addressing the **sound event detection** problem.



## Problem specification

A problem is expected to be solved by a model $f$ that takes an audio chunk  $X$ as input and returns its predicted solution $\hat{y} = f(X)$.

### Resolution

Depending on the addressed problem, you might expect the model to output just one prediction for the whole audio chunk (`Resolution.CHUNK`) or a temporal sequence of predictions (`Resolution.FRAME`).

In our particular case, we would like the model to provide one decision for the whole chunk:

In [3]:
from pyannote.audio.core.task import Resolution
resolution = Resolution.CHUNK

### Type of problem

Similarly, the type of your problem may fall into one of these generic machine learning categories:
* `Problem.BINARY_CLASSIFICATION` for binary classification
* `Problem.MONO_LABEL_CLASSIFICATION` for multi-class classification
* `Problem.MULTI_LABEL_CLASSIFICATION` for multi-label classification
* `Problem.REGRESSION` for regression
* `Problem.REPRESENTATION` for representation learning

In our particular case, we would like the model to do multi-label classification because one audio chunk may contain multiple sound events:

In [4]:
from pyannote.audio.core.task import Problem
problem = Problem.MULTI_LABEL_CLASSIFICATION

In [5]:
from pyannote.audio.core.task import Specifications
specifications = Specifications(
    problem=problem,
    resolution=resolution,
    duration=5.0,
    classes=["Speech", "Dog", "Cat", "Alarm_bell_ringing", "Dishes",
             "Frying", "Blender", "Running_water", "Vacuum_cleaner",
             "Electric_shaver_toothbrush"],
)

A task is expected to be solved by a model $f$ that (usually) takes an audio chunk  $X$ as input and returns its predicted solution $\hat{y} = f(X)$.

To help training the model $f$, the task $\mathcal{T}$ is in charge of
- generating $(X, y)$ training samples using the **dataset**
- defining the loss function $\mathcal{L}(y, \hat{y})$


In [None]:
from math import ceil
from typing import Dict, Optional,Tuple, Union
import numpy as np
from pyannote.core import Segment, SlidingWindow
from pyannote.audio.utils.random import create_rng_for_worker
from pyannote.audio.core.task import Task, Resolution
from pyannote.database import Protocol
from torchmetrics.classification import MultilabelAUROC

# Your custom task must be a subclass of `pyannote.audio.core.task.Task`
class SoundEventDetection(Task):
    """Sound event detection"""

    def __init__(
        self,
        protocol: Protocol,
        duration: float = 5.0,
        min_duration: float = 5.0,
        warm_up: Union[float, Tuple[float, float]] = 0.0,
        batch_size: int = 32,
        num_workers: int = None,
        pin_memory: bool = False,
        augmentation = None,
        cache: Optional[Union[str, None]] = None,
        **other_params,
    ):

        super().__init__(
            protocol,
            duration=duration,
            min_duration=min_duration,
            warm_up=warm_up,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            augmentation=augmentation,
            cache=cache,
        )
        
    def prepare_data(self):
        # this method is called to prepare data from the specified protocol. 
        # For most tasks, calling Task.prepare_data() is sufficient. If you 
        # need to prepare task-specific data, define a post_prepare_data method for your task.
        super().prepare_data()

    def post_prepare_data(self, prepared_data: Dict):
        # this method is called at the end of Task.prepare_data() 
        # to complete data preparation with task-specific data, here 
        # the list of classes and some training metadata

        # load metadata for training subset
        prepared_data["train_metadata"] = list()
        for training_file in self.protocol.train():
            prepared_data["train_metadata"].append({
                # path to audio file (str)
                "audio": training_file["audio"],
                # duration of audio file (float)
                "duration": training_file["torchaudio.info"].num_frames / training_file["torchaudio.info"].sample_rate,
                # reference annotation (pyannote.core.Annotation)
                "annotation": training_file["annotation"],
            })

        # gather the list of classes
        classes = set()
        for training_file in prepared_data["train_metadata"]:
            classes.update(training_file["annotation"].labels())
        prepared_data["classes"] = sorted(classes)

        # `has_validation` is True if protocol defines a development set
        if not self.has_validation:
            return
    
    def prepare_validation(self, prepared_data : Dict):
        # this method is called at the end of Task.prepare_data(), to complete data preparation
        # with task validation elements
        
        # load metadata for validation subset
        prepared_data["validation"] = list()
        for validation_file in self.protocol.development():
            prepared_data["validation"].append({
                "audio": validation_file["audio"],
                "num_samples": validation_file["torchaudio.info"].num_frames,
                "annotation": validation_file["annotation"],
            })
     
        
    def setup(self, stage: Optional[Union[str, None]] = None):
        # this method assigns prepared data from task.prepare_data() to the task
        # and declares the task specifications

        super().setup(stage)
        
        # specify the addressed problem
        self.specifications = Specifications(
            # it is a multi-label classification problem
            problem=Problem.MULTI_LABEL_CLASSIFICATION,
            # we expect the model to output one prediction 
            # for the whole chunk
            resolution=Resolution.CHUNK,
            # the model will ingest chunks with that duration (in seconds)
            duration=self.duration,
            # human-readable names of classes
            classes=self.prepared_data["classes"])
    
    def default_metric(self):
        # this method defines the default metrics used to evaluate the model during
        # a training
        num_classes = len(self.specifications.classes)
        return MultilabelAUROC(num_classes, average="macro", compute_on_cpu=True)

    def train__iter__(self):
        # this method generates training samples, one at a time, "ad infinitum". each worker 
        # of the dataloader will run it, independently from other workers. pyannote.audio and
        # pytorch-lightning will take care of making batches out of it.

        # create worker-specific random number generator (RNG) to avoid this common bug:
        # tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
        rng = create_rng_for_worker(self.model)

        # load list and number of classes
        classes = self.specifications.classes
        num_classes = len(classes)

        # yield training samples "ad infinitum"
        while True:

            # select training file at random
            random_training_file, *_ = rng.choices(self.prepared_data["train_metadata"], k=1)

            # select one chunk at random 
            random_start_time = rng.uniform(0, random_training_file["duration"] - self.duration)
            random_chunk = Segment(random_start_time, random_start_time + self.duration)

            # load audio excerpt corresponding to random chunk
            X = self.model.audio.crop(random_training_file["audio"], 
                                      random_chunk, 
                                      fixed=self.duration)
            
            # load labels corresponding to random chunk as {0|1} numpy array
            # y[k] = 1 means that kth class is active
            y = np.zeros((num_classes,))
            active_classes = random_training_file["annotation"].crop(random_chunk).labels()
            for active_class in active_classes:
                y[classes.index(active_class)] = 1
        
            # yield training samples as a dict (use 'X' for input and 'y' for target)
            yield {'X': X, 'y': y}

    def train__len__(self):
        # since train__iter__ runs "ad infinitum", we need a way to define what an epoch is.
        # this is the purpose of this method. it outputs the number of training samples that
        # make an epoch.

        # we compute this number as the total duration of the training set divided by 
        # duration of training chunks. we make sure that an epoch is at least one batch long,
        # or pytorch-lightning will complain
        train_duration = sum(training_file["duration"] for training_file in self.prepared_data["train_metadata"])
        return max(self.batch_size, ceil(train_duration / self.duration))

    def val__getitem__(self, sample_idx):

        # load list and number of classes
        classes = self.specifications.classes
        num_classes = len(classes)


        # find which part of the validation set corresponds to sample_idx
        num_samples = np.cumsum([
            validation_file["num_samples"] for validation_file in self.prepared_data["validation"]])
        file_idx = np.where(num_samples < sample_idx)[0][0]
        validation_file = self.prepared_data["validation"][file_idx]
        idx = sample_idx - (num_samples[file_idx] - validation_file["num_samples"]) 
        chunk = SlidingWindow(start=0., duration=self.duration, step=self.duration)[idx]

        # load audio excerpt corresponding to current chunk
        X = self.model.audio.crop(validation_file["audio"], chunk, fixed=self.duration)

        # load labels corresponding to random chunk as {0|1} numpy array
        # y[k] = 1 means that kth class is active
        y = np.zeros((num_classes,))
        active_classes = validation_file["annotation"].crop(chunk).labels()
        for active_class in active_classes:
            y[classes.index(active_class)] = 1

        return {'X': X, 'y': y}

    def val__len__(self):
        return sum(validation_file["num_samples"] 
                   for validation_file in self.prepared_data["validation"])

    # `pyannote.audio.core.task.Task` base class provides a `LightningModule.training_step` and 
    # `LightningModule.validation_step` methods that rely on self.specifications to guess which 
    # loss and metrics should be used. you can obviously choose to customize them. 
    # More details can be found in pytorch-lightning documentation and in 
    # pyannote.audio.core.task.Task source code. 

    # def training_step(self, batch, batch_idx: int):
    #    return loss

    # def validation_step(self, batch, batch_idx: int):
    #    return metric

    # pyannote.audio.tasks.segmentation.mixin also provides a convenient mixin
    # for "segmentation" tasks (ie. with Resolution.FRAME) that already defines
    # a bunch of useful methods. You can use it by inheriting your task from the 
    # pyannote.audio.tasks.segmentation.mixinSegmentationTask