# Training (or fine-tuning) a model

The objective of this tutorial is to learn how to train a model from scratch, or fine-tune a pretrained model.  
**Warning:** follow the [data preparation]() tutorial first and make sure the `PYANNOTE_DATABASE_CONFIG` environment variable is set accordingly.

In [None]:
import os
os.environ['PYANNOTE_DATABASE_CONFIG'] = '/people/bredin/dev/pyannote/pyannote-db/AMI-diarization-setup/pyannote/database.yml'

We start by defining which `task` the `model` will address.  
Here, we want the `model` to address voice activity detection using the AMI dataset.

In [None]:
from pyannote.database import get_protocol
ami = get_protocol('AMI.SpeakerDiarization.only_words')

from pyannote.audio.tasks import VoiceActivityDetection
vad = VoiceActivityDetection(ami)

For the purpose of this tutorial, we define a `compute_model_fscore` function that runs a model on the first file of AMI test set and returns the voice activity detection F-score. It also displays the output of the model on the second minute of this file. 

In [None]:
from pyannote.audio.core.inference import Inference
from pyannote.audio.utils.signal import Binarize
from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
from pyannote.core import Segment
from pyannote.audio.utils.preview import preview
import torch

def compute_model_fscore(model):

    file = next(ami.test())    
    
    # use model to extract speech probability
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    inference = Inference(model, progress_hook='Processing...', device=device)
    speech_prob = inference(file)
    
    # take hard decision using a 0.5 threshold
    binarize = Binarize()
    speech = binarize(speech_prob)
    
    # compute detection f-score
    fscore = DetectionPrecisionRecallFMeasure()
    fscore = fscore(
        file['annotation'],     # this is the reference annotation
        speech,                 # this is the hypothesized annotation
        uem=file['annotated'])  # this is the part of the file that should be evaluated
    
    print(f'F-score = {100 * fscore:.1f}%')
    
    # preview results
    # (comment if you don't care about visualization as it takes a relatively long amount of time to generate
    second_minute = Segment(60, 120)
    return preview(file, 
                   segment=second_minute, 
                   video_fps=5., 
                   reference=file['annotation'], 
                   probability=speech_prob, 
                   speech=speech.get_timeline())

## Using a pretrained model

To serve as our baseline, we load a voice activity detection model pretrained on DIHARD III dataset.

In [None]:
from pyannote.audio import Model
pretrained = Model.from_pretrained('hbredin/VoiceActivityDetection-PyanNet-DIHARD')

This `pretrained` model relies on the `PyanNet` architecture available in `pyannote.audio`, that combines (trainable) SincNet feature extraction, a few LSTM layers, a few linear layers and a final classification layer.

In [None]:
_ = pretrained.summarize()

In [None]:
compute_model_fscore(pretrained)

## Training a model from scratch

We will now train a voice activity detection model from scratch, using the AMI training set.

To make sure we use the exact same architecture, we rely on `pretrained.hparams` that conveniently keeps track of the hyper-parameters used to instantiate the architecture of `pretrained` model.

In [None]:
pretrained.hparams

In [None]:
from pyannote.audio.models.segmentation import PyanNet
from_scratch = PyanNet(task=vad, **pretrained.hparams)

👀  Notice how we passed `vad` as the `task` argument of our `from_scratch` model.  
This allows `pyannote.audio` to automagically register the right `classifier` and `activation` layers into the `PyanNet` model.

> Look ma, no hands!

This magic trick is possible because every task in `pyannote.audio` exposes its specifications.

In [None]:
vad.specifications

Voice activity detection is a *binary classification* problem that is trained on *2s* audio chunks.

In [None]:
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(from_scratch)

In [None]:
compute_model_fscore(from_scratch)

Can we do better (or at least faster) by fine-tuning the pretrained DIHARD model? 

In [None]:
fine_tuned = Model.from_pretrained('hbredin/VoiceActivityDetection-PyanNet-DIHARD')
fine_tuned.task = vad

In [None]:
# this callback will freeze all layers, except 
# * 'classifier' which is always trained;
# * 'linear' which the callback is asked to unfreeze as soon as epoch 0
from pyannote.audio.core.callback import GraduallyUnfreeze
callback = GraduallyUnfreeze({"linear": 0})

trainer = pl.Trainer(gpus=1, max_epochs=2, 
                     callbacks=[callback])
trainer.fit(fine_tuned)

In [None]:
compute_model_fscore(fine_tuned)