# DER calculation
This file is used to calculate Diarization Error Rate. It is calculated as follows:
1. For each speaker, we calculate the number of false alarms, missed detections and overlapped speech.
2. We sum up the number of false alarms, missed detections and overlapped speech.
3. We divide the sum by the total number of speaker speech segments.
4. The result is the DER.

The formula is as follows:
$$
DER = \frac{FA + MISS + OVER}{N_{spk}}
$$

Where:
- $FA$ is the number of false alarms
- $MISS$ is the number of missed detections
- $OVER$ is the number of overlapped speech
- $N_{spk}$ is the total number of speaker speech segments

The code below is used to calculate the DER.

In [1]:
import pandas as pd
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate
import os
from tqdm import tqdm
import sys

# Paths to the different RTTM files
RTTM_TRUTH = "../Dataset/RTTMs/Test"

RTTM_ORACLE_VAD = "../Results/Oracle_vad/Test"
RTTM_ORACLE_DECODER = "../Results/Oracle_decoder/Test"
RTTM_PYANNOTE = "../Results/Ptannote/Test"


RESULTS_DIRECTORY = "../Results/DER Results"


# Load the RTTM files
def load_rttm(filename):
    """Load RTTM file and convert it to pyannote.core.Annotation"""
    annotation = Annotation()
    with open(filename, "r") as file:
        for line in file:
            parts = line.strip().split()
            if len(parts) < 9:
                continue
            file_id, channel, start_time, duration, _, _, *speaker_parts = parts
            speaker = " ".join(speaker_parts)
            start_time = float(start_time)
            duration = float(duration)
            end_time = start_time + duration
            segment = Segment(start_time, end_time)
            annotation[segment] = speaker
    return annotation


def calculate_der_metrics(
    ground_truth_path, prediction_path, output_path, output_csv="der_results.csv"
):
    """
    This function calculates DER metrics for diarization and saves them to a CSV file,
    including a row with average metrics.

    Args:
        ground_truth_path: Path to the directory containing ground truth RTTM files.
        prediction_path: Path to the directory containing prediction RTTM files.
        output_path: Path (including filename) for the output CSV file. Defaults to "der_results.csv" if not provided.
    """

    # Initialize the diarization error rate metric
    der_metric = DiarizationErrorRate()

    # Store results
    results = []
    total_der = 0
    total_miss_duration = 0
    total_speaker_error_duration = 0
    total_false_alarm_duration = 0
    total_duration = 0

    # Iterate over files in prediction directory
    for file_name in tqdm(os.listdir(prediction_path)):
        truth_file = os.path.join(ground_truth_path, file_name)
        prediction_file = os.path.join(prediction_path, file_name)

        if os.path.exists(truth_file) and os.path.exists(prediction_file):
            ground_truth = load_rttm(truth_file)
            prediction = load_rttm(prediction_file)

            # Calculate DER and components manually
            detailed = der_metric(ground_truth, prediction, detailed=True)
            #print(detailed)

            der = detailed["diarization error rate"]
            confusion_duration = detailed["confusion"]
            false_alarm = detailed["false alarm"]
            missed_speech = confusion_duration - false_alarm
            total_file_duration = detailed["total"]

            # Accumulate totals for average calculation
            total_der += der
            total_miss_duration += missed_speech
            total_speaker_error_duration += confusion_duration - missed_speech
            total_false_alarm_duration += false_alarm
            total_duration += total_file_duration

            # Append results for this file
            results.append(
                {
                    "File": file_name,
                    "DER": der,
                    "Miss Duration": missed_speech / total_file_duration,
                    "Speaker Error Duration": confusion_duration / total_file_duration
                    - missed_speech / total_file_duration,
                    "False Alarm Duration": false_alarm / total_file_duration,
                }
            )
        else:
            print(f"File {file_name} not found")
            sys.exit(1)  # Exit with an error if a file is missing

    # Calculate average metrics (assuming all files have processed)
    if total_duration > 0:  # Avoid division by zero
        average_der = total_der / len(results)
        average_miss_duration = total_miss_duration / total_duration
        average_speaker_error_duration = total_speaker_error_duration / total_duration
        average_false_alarm_duration = total_false_alarm_duration / total_duration

        # Create a dictionary for average metrics
        average_metrics = {
            "File": "Average",
            "DER": average_der,
            "Miss Duration": average_miss_duration,
            "Speaker Error Duration": average_speaker_error_duration,
            "False Alarm Duration": average_false_alarm_duration,
        }

        # Append average metrics to results
        results.append(average_metrics)

    # Convert results to DataFrame and save to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv(output_path, index=False)

    # Convert results to DataFrame and save to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv(output_path, index=False)

In [2]:
# Calculate DER metrics for Oracle VAD
calculate_der_metrics(
    RTTM_TRUTH, RTTM_ORACLE_VAD, os.path.join(RESULTS_DIRECTORY, "der_oracle_vad.csv")
)

# Calculate DER metrics for Oracle Decoder
calculate_der_metrics(
    RTTM_TRUTH, RTTM_ORACLE_DECODER, os.path.join(RESULTS_DIRECTORY, "der_oracle_decoder.csv")
)

# Calculate DER metrics for 
calculate_der_metrics(
    RTTM_TRUTH, RTTM_PYANNOTE, os.path.join(RESULTS_DIRECTORY, "der_pyannote.csv")
)

    

 24%|██▍       | 56/232 [00:29<00:46,  3.79it/s]