# Fine-tuning a pretrained speaker segmentation model

- [`Demo`](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/intro.ipynb)
- [`What is Speaker diarization?`](https://en.wikipedia.org/wiki/Speaker_diarisation)
- [`Docs`](https://docs.google.com/document/d/1I_fp6gQHGgkVju5kcXX4La-B-Ksu95oTrlTwB-WPW1I/edit?usp=sharing)

# Installation 

Libraries used:

  *   [pyannote](https://github.com/pyannote/pyannote-audio) (require restarting kernel after finish installing)

In [None]:
# install pyannote(currently 2.0)
!pip install -q pyannote.audio
# !pip install -q https://github.com/pyannote/pyannote-audio/archive/develop.zip 

In [None]:
# install pyannote 1.0 (commits before big change of 2.0)
# !pip install https://github.com/pyannote/pyannote-audio/archive/871272c7c9c2ffacc4570cea9e487658ed145e70.zip

In [None]:
# # install pyannote 2.0 on Colab without dependencies
# # for speechbrain
# !pip install -qq torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 torchtext==0.12.0
# !pip install -qq speechbrain==0.5.12

# # pyannote.audio
# !pip install -qq pyannote.audio

# # for visualization purposes
# !pip install -qq moviepy ipython==7.34.0

# Resources Directory structure

```
Finetuning-db-setup
└───database.yml 
│
└───wav
│   └───ami => EN2001a.wav, EN2001b.wav, ...
|   └───cv02 => cv02_cct_00.wav, cv02_ovl_00.wav
│   │ ...
|
└───rttm
│   └───ami => EN2001a.rttm, EN2001b.rttm, ...
│   └───amiF07 => EN2001a.rttm, EN2001b.rttm, ...
|   └───cv02 => cv02_cct_00.rttm, cv02_ovl_00.rttm, ...
|   | ...
|
└───uem
|   └───amiF07 => EN2001a.uem, EN2001b.uem, ...
|   └───cv02 => cv02_cct_00.uem, cv02_ovl_00.uem, ...
|   | ...
|
└───uri
    └───amiF07
    │     | train.amiF07.txt
    │     | dev.amiF07.txt
    │     | test.amiF07.txt
    |
    └───cv02
    |     | train.cv02.txt
    |...  | dev.cv02.txt
          | test.cv02.txt
```

# Setup protocol

- [`What is protocol?`](https://github.com/pyannote/pyannote-database/blob/develop/pyannote/database/protocol/protocol.py#L233)
- [`Format of Database Configuration File (database.yml)`](https://github.com/pyannote/pyannote-database#speaker-diarization)
- [`Meta-protocols`](https://github.com/pyannote/pyannote-database#meta-protocols)
- [`TLDR: example`](https://github.com/pyannote/pyannote-database/blob/develop/tests/data/database.yml)

In [1]:
# tell pyannote.database where to find partition, reference, and wav files
import os
os.environ["PYANNOTE_DATABASE_CONFIG"] = '../input/finetuningdbsetup/database_kg.yml'

In [2]:
# get protocol from PYANNOTE_DATABASE_CONFIG
from pyannote.database import get_protocol
dataset = get_protocol('X.SpeakerDiarization.cv02_amiF07')
# dataset = get_protocol('Base.SpeakerDiarization.Mini') # for test

# Train
- [`Parameters of Segmentation`](https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/tasks/segmentation/segmentation.py#L47)
- [`PL Custom Learning Rate Schedulers`](https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html#bring-your-own-custom-learning-rate-schedulers)
- [`Fix issue`](https://github.com/Lightning-AI/lightning/issues/4929)


In [3]:
## Config protocol and hparams of pretrained segmentation model
from pyannote.audio import Model
from pyannote.audio.tasks import Segmentation
from copy import deepcopy

pretrained = Model.from_pretrained("pyannote/segmentation")
seg_task = Segmentation(
    protocol=dataset, 
    duration=5.0, # pretrained is 2
    max_num_speakers=4,
#     batch_size=128,
)
finetuned = deepcopy(pretrained)
finetuned.task = seg_task

In [None]:
# define learning rate scheduler
from torch.optim.lr_scheduler import LambdaLR
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )
    return LambdaLR(optimizer, lr_lambda, last_epoch, verbose=True)

In [None]:
# config learning rate scheduler
finetuned.hparams["lr"] = 1e-4
last_epoch = -1

def new_optimizers(self):
    import torch
#     optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    optimizer = torch.optim.Adam([{
        "params": self.parameters(), 
        "initial_lr": self.hparams.lr,
        "lr": self.hparams.lr}])
    
    total_steps = self.trainer.estimated_stepping_batches
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps = 0.05 * total_steps, 
        num_training_steps = total_steps,
        last_epoch = last_epoch,
    )
    return {
        'optimizer': optimizer, 
        'lr_scheduler': {
            'scheduler': scheduler, 
            'interval': 'step' # update lr every step (default is every epoch)
        }
    }

# TODO: find better ways than overriding methods
from types import MethodType
finetuned.configure_optimizers = MethodType(new_optimizers, finetuned)

In [None]:
# callbacks: logging lr and save checkpoints every epochs
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/', 
    save_top_k=-1,
#     monitor='Segmentation-XSpeakerDiarizationcv02_amiF07-OptimalDiarizationErrorRate',
#     save_top_k=4,
)

In [None]:
# start training
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=-1, max_epochs=10, callbacks=[lr_monitor, checkpoint_callback])
trainer.fit(finetuned)
# trainer.fit(finetuned, ckpt_path="../input/checkpoints/epoch9-step6780.ckpt")

In [None]:
# checkpoints and tensorboard logging
!zip -r lightning_logs.zip lightning_logs
!zip -r checkpoints.zip checkpoints