In [None]:
import sys
sys.path.append('code/')
from pyannote.database import registry
from nb_functions import *
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.audio import Model,Pipeline
from tqdm.notebook import tqdm_notebook as tqdm
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HF_TOKEN = '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'

## Création du dataset de finetuning avec les données annotées

generate_dataset calcule les segments de confiance faible et crée un nouveau de donnée pour fine-tuner
$$ generate\_dataset(x\_train\_file\_path, dataset\_path, filename\_path, all\_uem\_filename , pipeline, mode, method,threshold=0.5, window\_size=5, annotated\_ratio=0.15)$$
- x_train_file_path : le chemin vers le fichier .txt recençant les noms des fichiers composant le jeu d'entrainement
- dataset_path : le chemin vers le dossier racine du dataset
- filename_path : le chemin et nom du fichier .txt qui sera créée pour contenir les noms des fichiers composant le jeu de fine-tuning
- all_uem_filename : le nom du fichier qui contient l'ensembles des timelines et scores de confiance
- mode : 'dataset' ou 'sample' pour spécifier si on veut X% du dataset ou X% des samples
- methode : 'random' ou 'lowest' pour spécifier si on veut sélectionner les segments de confiance faible ou aléatoirement
- pipeline : segmentation.SoftSpeakerSegmentation(segmentation=model_seg, use_auth_token=HF_TOKEN)
- threshold : le seuil de confiance minimal
- window_size : taille de la fenêtre glissante
- annotated_ratio : le pourcentage de données à annoter pour le fine-tuning

In [None]:
database_wildget, widget_generate_new_ds, eval_widget, widget_validate= display_choices()
display(database_wildget, widget_generate_new_ds, eval_widget, widget_validate)

In [None]:
evaluate = eval_widget.value == "Yes"
database = database_wildget.value
if database == "AMI":
    protocol = "AMI.SpeakerDiarization.mini"
    yaml_path = "datasets-pyannote/ami/pyannote/database.yml"
elif database == "Msdwild":
    protocol = "MSDWILD.SpeakerDiarization.CustomFew"
    yaml_path = "datasets-pyannote/msdwild/database.yml"
    
registry.load_database(yaml_path)
dataset = registry.get_protocol(protocol)
print("Checking that the 'annotation' key is present in all train files...")
for file in dataset.train():
   assert "annotation" in file
print("Checking that the 'annotation' key is present in all test files...")
for file in dataset.test():
  assert "annotation" in file

## Evaluation des prédictions avant Finetuning

In [None]:
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN)
pretrained_pipeline.to(torch.device(device))
torch.cuda.empty_cache()
metric_pretrained = DiarizationErrorRate()
# evaluate = True
if evaluate:
    for file in tqdm(dataset.test(), desc="Evaluating the pretrained pipeline"):
        if file["database"] == "AMI":
            path_to_wav = "datasets-pyannote/ami/wav/"
            suffixe = ".Mix-Headset"
        elif file["database"] == "MSDWILD":
            path_to_wav = "datasets-pyannote/msdwild/wav/"
            suffixe = ""
        file["pretrained pipeline"] = pretrained_pipeline(path_to_wav+file["uri"]+suffixe+".wav")
        metric_pretrained(file["annotation"], file["pretrained pipeline"],uem=file["annotated"],detailed=True)
    print(f"\nThe pretrained pipeline reaches a Diarization Error Rate (DER) of {100 * abs(metric_pretrained):.1f}% on test set.")

## Fine Tuning Segmentation3.0

Code récupéré sur https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb

In [None]:
from types import MethodType
from torch.optim import Adam, SGD
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    RichProgressBar,
)

model_seg = "pyannote/segmentation-3.0"
from pyannote.audio.tasks import Segmentation

model = Model.from_pretrained(model_seg, use_auth_token=HF_TOKEN)
model.task = Segmentation(dataset, duration=5.0, max_speakers_per_chunk=3, max_speakers_per_frame=2)
model.setup("fit")

def configure_optimizers(self):
    return Adam(self.parameters(), lr=5e-4)

model.configure_optimizers = MethodType(configure_optimizers, model)

monitor, direction = model.task.val_monitor
checkpoint = ModelCheckpoint(
    monitor=monitor,
    mode=direction,
    save_top_k=1,
    every_n_epochs=1,
    save_last=False,
    save_weights_only=False,
    filename="{epoch}",
    verbose=False,
)
early_stopping = EarlyStopping(
    monitor=monitor,
    mode=direction,
    min_delta=0.0,
    patience=10,
    strict=True,
    verbose=False,
)

callbacks = [RichProgressBar(), checkpoint, early_stopping]

from pytorch_lightning import Trainer
trainer = Trainer(accelerator="gpu",
                  callbacks=callbacks,
                  max_epochs=25,
                  gradient_clip_val=0.5)
                  
trainer.fit(model)

In [None]:
torch.cuda.empty_cache()
finetuned_model = checkpoint.best_model_path
# with open("hyperparameters.pickle", 'rb') as handle:
#     hparameters = pickle.load(handle)

from pyannote.audio.pipelines import SpeakerDiarization
finetuned_pipeline = SpeakerDiarization(
    segmentation=finetuned_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

finetuned_pipeline.to(device)

finetuned_pipeline.instantiate({
    "segmentation": {
        # "threshold": 0.4442333667381752,
        "min_duration_off": 0.0,
    },
    "clustering": {
        "method": "centroid",
        "min_cluster_size": 15 ,
        "threshold": 0.6285824248662424 if database == "ami" else 0.8285487153337224,
    },
})
metric_finetuned = DiarizationErrorRate()

for file in tqdm(dataset.test(), desc="Evaluating the finetuned pipeline"):
    if file["database"] == "AMI":
        path_to_wav = "datasets-pyannote/ami/wav/"
        suffixe = ".Mix-Headset"
    elif file["database"] == "MSDWILD":
        path_to_wav = "datasets-pyannote/msdwild/wav/"
        suffixe = ""
    file["finetuned pipeline"]  = finetuned_pipeline(path_to_wav+file["uri"]+suffixe+".wav")
    metric_finetuned(file["annotation"], file["finetuned pipeline"] ,uem=file["annotated"],detailed=True)
print(f"The finetuned pipeline reaches a Diarization Error Rate (DER) of {100 * abs(metric_finetuned):.1f}% on {database} test set.")

In [None]:
#duree totale annotee
import os
from pyannote.database.util import load_rttm, load_uem
uem_folder = 'datasets-pyannote/ami/manual_uems'
duration = 0
uem_files = [f for f in os.listdir(uem_folder) if f.endswith('.uem')]
for file in uem_files:
    annotated = load_uem('datasets-pyannote/ami/manual_uems'+"/"+file)
    _, annotated = annotated.popitem()
    duration += annotated.duration()
print(f"The total duration of the manually annotated data is {duration/60:.1f} minutes.")
    

# Test du modèle fine tuné sur un dataset différent

In [None]:
# from pyannote.database import registry
# #Ne pas oublier de changer le fichier train dans database.yml pour qu'il pointe vers le bon fichier de fine tuning
# torch.cuda.empty_cache()
# database = "msdwild"

# if database == "ami":
#     protocol = "AMI.SpeakerDiarization.mini"
#     yaml_path = "datasets-pyannote/ami/pyannote/database.yml"
# elif database == "msdwild":
#     protocol = "MSDWILD.SpeakerDiarization.CustomFew"
#     yaml_path = "datasets-pyannote/msdwild/database.yml"

# registry.load_database(yaml_path)
# dataset = registry.get_protocol(protocol)
# print("Checking that the 'annotation' key is present in all train files...")
# for file in dataset.train():
#    assert "annotation" in file
# print("Checking that the 'annotation' key is present in all test files...")
# for file in dataset.test():
#   assert "annotation" in file

In [None]:

# from pyannote.audio.pipelines.utils.hook import ProgressHook
# metric_pretrained = DiarizationErrorRate()

# for file in tqdm(dataset.test()):
#     if file["database"] == "AMI":
#         path_to_wav = "datasets-pyannote/ami/wav/"
#         suffixe = ".Mix-Headset"
#     elif file["database"] == "MSDWILD":
#         path_to_wav = "datasets-pyannote/msdwild/wav/"
#         suffixe = ""

#     file["finetuned pipeline"]  = finetuned_pipeline(path_to_wav+file["uri"]+suffixe+".wav")
#     metric_finetuned(file["annotation"], file["finetuned pipeline"] ,uem=file["annotated"],detailed=True)
# print(f"The finetuned pipeline  reaches a Diarization Error Rate (DER) of {100 * abs(metric_finetuned):.1f}% on {database} test set.")