# Speaker diarization.
Here we are running the Diarization pipeline on the test set and evaluating the results using the provided ground truth.
    

### All imports 

In [1]:
# Typical imports
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from dotenv import load_dotenv
import speechbrain as sb
import torchaudio

# Scipy
from scipy.spatial.distance import cdist


# All Pyannote imports
from pyannote.audio import Pipeline
from pyannote.core import Segment, Annotation
from pyannote.audio import Model, Inference
from pyannote.audio import Audio
from pyannote.audio.pipelines.utils.hook import ProgressHook

# Loading the envs
load_dotenv("auths.env")
api_key = os.getenv("API_KEY")

In [2]:
# Setting up the path to the test
test_audio = "../Dataset/Audio/Test"
test_rttm = "../Dataset/RTTMs/Test"
output_file_path = "../Results/Diarization_pipeline"

# Ensure the output directory exists
os.makedirs(output_file_path, exist_ok=True)

# Load the pipeline model
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1", use_auth_token=api_key
)
pipeline.to(torch.device("cuda"))

# Prepare an empty annotation to store hypothesis results
hypothesis_annotation = {}

# Iterate over all the test audio files and get the diarization results. Output them as RTTMs.
for audio_file in tqdm(os.listdir(test_audio)):
    if audio_file.endswith(".wav"):
        # Get the audio file path
        audio_file_path = os.path.join(test_audio, audio_file)
        # Get the rttm file path
        rttm_file_path = os.path.join(test_rttm, audio_file.replace(".wav", ".rttm"))

        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Process the audio file with the diarization pipeline
        diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate})

        # Create the output file path
        file_output_name = os.path.join(
            output_file_path, audio_file.replace(".wav", ".rttm")
        )

        # Save the diarization results as an RTTM file
        with open(file_output_name, "w") as f:
            diarization.write_rttm(f)

  1%|          | 2/232 [00:18<36:54,  9.63s/it]

### Calculate the diarization error rate
- Calculate DER using the provided ground truth and the output of the diarization pipeline.
- Calculate purity, coverage, and diarization error rate for each file in the test set.
- Calculate the average purity, coverage, and diarization error rate for the entire test set.



In [None]:
# Calculate DER using the provided truth and output of the pipeline
der = sb.utils.metrics.DER()

# Iterate over all the test audio files and calculate the DER
for audio_file in os.listdir(test_audio):
    if audio_file.endswith(".wav"):
        # Get the rttm file path
        rttm_file_path = os.path.join(test_rttm, audio_file.replace(".wav", ".rttm"))

        # Load the ground truth annotation
        reference = Annotation.from_rttm(rttm_file_path)

        # Load the hypothesis annotation
        hypothesis = Annotation.from_rttm(
            os.path.join(output_file_path, audio_file.replace(".wav", ".rttm"))
        )

        # Calculate the DER
        der(reference, hypothesis)

# Print the DER
print(f"DER: {der}")
