In [1]:
import os
from dotenv import load_dotenv

load_dotenv()
hf_token = os.environ.get('HF_TOKEN')

In [2]:
from datasets import load_dataset, DatasetDict

EVAL_DATASET = "voxconverse"

if EVAL_DATASET == "callhome":
    ds = load_dataset("talkbank/callhome", "eng", token=hf_token)
    
    train_testvalid = ds['data'].train_test_split(test_size=0.2, seed=0)
    test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=0)
    
    ds = DatasetDict({
        'train': train_testvalid['train'],
        'validation': test_valid['test'],
        'test': test_valid['train']
    })

elif EVAL_DATASET == "ami":
    ds = load_dataset("diarizers-community/ami", "ihm")
elif EVAL_DATASET == "voxconverse":
    ds = load_dataset("diarizers-community/voxconverse")

ds = ds['test']

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from itertools import permutations

import numpy as np
import pandas as pd
from scipy import stats

from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Segment, Annotation

from pipeline import OnlinePipeline, OnlinePipelineConfig

from tqdm import tqdm

INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
  torchaudio.set_audio_backend("soundfile")


In [4]:
config = OnlinePipelineConfig()
diar_pipeline = OnlinePipeline(config)
diar_pipeline.batch_size = 2

INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch classifier.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch label_encoder.txt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: embedding_model, mean_var_norm_emb, c

In [5]:
def create_reference_annotation(speakers, starts, ends):
    reference = Annotation()
    for speaker, start, end in zip(speakers, starts, ends):
        segment = Segment(start, end)
        reference[segment] = speaker
    return reference

In [6]:
def calculate_optimal_der(reference, hypothesis, collar=0.25, skip_overlap=False):
    """
    Calculate DER with optimal mapping between reference and hypothesis speakers.
    """
    # Initialize DER metric
    der_metric = DiarizationErrorRate(collar=collar, skip_overlap=skip_overlap)
    
    # Get unique speakers in reference and hypothesis
    ref_speakers = list(reference.labels())
    hyp_speakers = list(hypothesis.labels())
    
    # If pyannote's optimal_mapping method is available, use it
    if hasattr(der_metric, 'optimal_mapping'):
        mapping = der_metric.optimal_mapping(reference, hypothesis)
        remapped_hypothesis = hypothesis.rename_labels(mapping=mapping)
        der = der_metric(reference, remapped_hypothesis)
        metrics = der_metric.compute_components(reference, remapped_hypothesis)
        return der, mapping, metrics
        
    # Otherwise, implement custom optimal mapping search
    else:
        best_der = float('inf')
        best_mapping = None
        best_metrics = None
        
        # If we have more hypothesis speakers than reference speakers
        # We'll only map the most active hypothesis speakers
        if len(hyp_speakers) > len(ref_speakers):
            # Get duration of each speaker in hypothesis
            speaker_durations = {}
            for segment, _, label in hypothesis.itertracks(yield_label=True):
                duration = segment.duration
                speaker_durations[label] = speaker_durations.get(label, 0) + duration
                
            # Keep only the most active speakers
            hyp_speakers = sorted(hyp_speakers, 
                                 key=lambda spk: speaker_durations.get(spk, 0), 
                                 reverse=True)[:len(ref_speakers)]
            
        # Try all possible mappings between hypothesis and reference speakers
        for perm in permutations(ref_speakers, len(hyp_speakers)):
            mapping = dict(zip(hyp_speakers, perm))
            
            # Create a remapped hypothesis
            remapped = hypothesis.copy()
            for segment, track, label in hypothesis.itertracks(yield_label=True):
                if label in mapping:
                    remapped[segment, track] = mapping[label]
                    
            # Calculate DER with this mapping
            current_der = der_metric(reference, remapped)
            current_metrics = der_metric.compute_components(reference, remapped)
            
            # Update if this is the best mapping so far
            if current_der < best_der:
                best_der = current_der
                best_mapping = mapping
                best_metrics = current_metrics
                
        return best_der, best_mapping, best_metrics

In [7]:
# Initialize DER metric
der_metric = DiarizationErrorRate(collar=0.25, skip_overlap=False)

sample_count = 0 
pbar = tqdm(ds)

sample_metrics = []

for sample in pbar:
    audio, starts, ends, speakers = sample['audio'], sample['timestamps_start'], sample['timestamps_end'], sample['speakers']
    waveform, sample_rate = audio['array'], audio['sampling_rate']
    waveform = np.expand_dims(waveform, 1)

    diar_pipeline(waveform, sample_rate)
    
    hypothesis_before = diar_pipeline.get_annotation()
    reference = create_reference_annotation(speakers, starts, ends)
    
    der_before, mapping_before, metrics_before = calculate_optimal_der(reference, hypothesis_before)
    
    diar_pipeline.reannotate()
    hypothesis_after = diar_pipeline.get_annotation()
    der_after, mapping_after, metrics_after = calculate_optimal_der(reference, hypothesis_after)
    
    sample_metrics.append({
        'sample_id': sample_count,
        'der_before': der_before,
        'confusion_before': metrics_before['confusion'],
        'false_alarm_before': metrics_before['false alarm'],
        'missed_detection_before': metrics_before['missed detection'],
        'correct_before': metrics_before['correct'],
        'total_before': metrics_before['total'],
        'der_after': der_after,
        'confusion_after': metrics_after['confusion'],
        'false_alarm_after': metrics_after['false alarm'],
        'missed_detection_after': metrics_after['missed detection'],
        'correct_after': metrics_after['correct'],
        'total_after': metrics_after['total']
    })
    
    sample_count += 1
    pbar.set_postfix({'sample_no': sample_count, 'der_before': der_before, 'der_after': der_after})
    
    diar_pipeline.reset()

df_samples = pd.DataFrame(sample_metrics)

global_metrics = {
    'metric': [
        'Global DER',
        'Confusion (s)',
        'False Alarm (s)',
        'Missed Detection (s)',
        'Correct (s)',
        'Total Reference (s)',
        'Median DER',
        'Mean DER',
        'Std Dev DER'
    ],
    'before': [
        (df_samples['confusion_before'].sum() + df_samples['false_alarm_before'].sum() + df_samples['missed_detection_before'].sum()) / df_samples['total_before'].sum(),
        df_samples['confusion_before'].sum(),
        df_samples['false_alarm_before'].sum(),
        df_samples['missed_detection_before'].sum(),
        df_samples['correct_before'].sum(),
        df_samples['total_before'].sum(),
        df_samples['der_before'].median(),
        df_samples['der_before'].mean(),
        df_samples['der_before'].std()
    ],
    'after': [
        (df_samples['confusion_after'].sum() + df_samples['false_alarm_after'].sum() + df_samples['missed_detection_after'].sum()) / df_samples['total_after'].sum(),
        df_samples['confusion_after'].sum(),
        df_samples['false_alarm_after'].sum(),
        df_samples['missed_detection_after'].sum(),
        df_samples['correct_after'].sum(),
        df_samples['total_after'].sum(),
        df_samples['der_after'].median(),
        df_samples['der_after'].mean(),
        df_samples['der_after'].std()
    ]
}

df_summary = pd.DataFrame(global_metrics)

print("Per-sample metrics DataFrame:")
print(df_samples.head())
print("\nSummarized statistics DataFrame:")
print(df_summary)

100%|█████████████████████████████| 232/232 [3:23:45<00:00, 52.70s/it, sample_no=232, der_before=0.136, der_after=0.106]

Per-sample metrics DataFrame:
   sample_id  der_before  confusion_before  false_alarm_before  \
0          0    0.217296        204.667059           20.930000   
1          1    0.266398        227.343333           39.806667   
2          2    0.131683         47.040000            5.693134   
3          3    0.142952         20.993333           27.556667   
4          4    0.230829        185.413333            6.639801   

   missed_detection_before  correct_before  total_before  der_after  \
0                13.002941      880.370000       1098.04   0.214585   
1                19.550000      829.316667       1076.21   0.431914   
2                 1.170000      361.130000        409.34   0.065886   
3                16.990700      420.495967        458.48   0.134060   
4                82.333333      920.953333       1188.70   0.225103   

   confusion_after  false_alarm_after  missed_detection_after  correct_after  \
0       201.623726          20.613333               13.386274     




In [8]:
df_samples.to_csv('sample_metrics.csv', index=False)
df_summary.to_csv('summary_metrics.csv', index=False)