## 1. Setup and Imports
This section imports required libraries and the pyannote pipeline.

In [2]:
import os
import json
import pandas as pd
import time

from pyannote.audio import Pipeline


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Suppress specific UserWarnings from rich.live about ipywidgets
import warnings
warnings.filterwarnings("ignore", message='install "ipywidgets" for Jupyter support')

## 2. Load Dataset
This section loads the refined dataset CSV file.

In [4]:
df = pd.read_csv("../data/refined_dataset.csv")
df.head(2)


Unnamed: 0,audio,speaker,speaker_count
0,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,"[{""start"":0.025320884681576335,""end"":11.079020...",1.0
1,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,"[{""start"":0.1782457853618421,""end"":116.0241887...",1.0


## 3. Parse and Inspect Segments
This section parses the speaker segments and inspects the first row.

In [5]:
import ast

# Parse the JSON-like string into Python objects
df["segments"] = df["speaker"].apply(lambda x: json.loads(x))

# Quick check on the first row
print("Audio file:", df.loc[0, "audio"])
print("Speaker count (ground truth):", df.loc[0, "speaker_count"])
print("First 2 segments:", df.loc[0, "segments"][:2])


Audio file: ../audios-wav/12-audios-ar-en/6-audios-ar/1_speaker_ar/solo10_ar.wav
Speaker count (ground truth): 1.0
First 2 segments: [{'start': 0.025320884681576335, 'end': 11.079020083885514, 'channel': 0, 'labels': ['Speaker 1']}, {'start': 11.353187839068042, 'end': 28.60528196857287, 'channel': 0, 'labels': ['Speaker 1']}]


## 4. Load and Run PyAnnote Pipeline
This section loads the pyannote pipeline, runs diarization, and collects predictions for a sample audio file.

In [6]:
import os, time, json
import torch
from dotenv import load_dotenv
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print(f"[INFO] Using device: {device}")
os.environ["SPEECHBRAIN_LOCAL_STRATEGY"] = "copy"
# Load token from .env
load_dotenv()
hf_token = os.getenv("HUGGINGFACE_TOKEN")
assert hf_token is not None, "HUGGINGFACE_TOKEN not found in .env"

# Load diarization pipeline
print("[INFO] Loading pyannote pipeline...")
start_time = time.time()
# make progress bar for pyannote with hook()


pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
print(f"[INFO] Pipeline loaded in {time.time() - start_time:.2f} sec")
pipeline.to(device)
# Pick one audio for now
test_audio = df.loc[0, "audio"]
print(f"[INFO] Starting diarization for: {test_audio}")

file_start = time.time()
with ProgressHook() as hook:
    diarization = pipeline(test_audio, hook=hook)
    file_end = time.time()

# Collect predictions
pred_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    pred_segments.append({
        "start": float(turn.start),
        "end": float(turn.end),
        "labels": [speaker]
    })

print(f"[INFO] Finished diarization in {file_end - file_start:.2f} sec")
print(f"[INFO] Total segments detected: {len(pred_segments)}")
print("Preview (first 5):", json.dumps(pred_segments[:5], indent=2))


[INFO] Using device: cpu
[INFO] Loading pyannote pipeline...
[INFO] Pipeline loaded in 2.73 sec
[INFO] Starting diarization for: ../audios-wav/12-audios-ar-en/6-audios-ar/1_speaker_ar/solo10_ar.wav


[INFO] Finished diarization in 243.46 sec
[INFO] Total segments detected: 7
Preview (first 5): [
  {
    "start": 0.03096875,
    "end": 208.97721875000002,
    "labels": [
      "SPEAKER_00"
    ]
  },
  {
    "start": 17.918468750000002,
    "end": 18.812843750000003,
    "labels": [
      "SPEAKER_01"
    ]
  },
  {
    "start": 71.12534375,
    "end": 71.95221875,
    "labels": [
      "SPEAKER_01"
    ]
  },
  {
    "start": 83.14034375,
    "end": 83.32596875,
    "labels": [
      "SPEAKER_01"
    ]
  },
  {
    "start": 115.13534375,
    "end": 115.37159375,
    "labels": [
      "SPEAKER_01"
    ]
  }
]


## 5. Batch Diarization and Save Results
This section runs diarization on all files and saves the results to JSON and CSV.

In [7]:
import pathlib

results = []
output_dir = "../results/pyannote_predictions"
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

for idx, row in df.iterrows():
    audio_path = row["audio"]
    audio_name = pathlib.Path(audio_path).stem
    print(f"\n[INFO] ({idx+1}/{len(df)}) Processing {audio_name} ...")

    start_time = time.time()
    try:
        with ProgressHook() as hook:
            diarization = pipeline(audio_path, hook=hook)

        pred_segments = []
        # Map unique speakers to Speaker 1, Speaker 2, ...
        speaker_map = {s: f"Speaker {i+1}" for i, s in enumerate(sorted(set(label for _, _, label in diarization.itertracks(yield_label=True))))}
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            pred_segments.append({
                "start": float(turn.start),
                "end": float(turn.end),
                "labels": [speaker_map[speaker]]
            })

        duration = time.time() - start_time
        print(f"[INFO] Finished {audio_name} in {duration:.2f} sec ({len(pred_segments)} segments)")

        # Save predictions to JSON
        out_file = f"{output_dir}/{audio_name}_pyannote.json"
        with open(out_file, "w") as f:
            json.dump(pred_segments, f, indent=2)

        # Append summary to results
        results.append({
            "audio": audio_path,
            "n_segments": len(pred_segments),
            "runtime_sec": duration,
            "output_file": out_file
        })

    except Exception as e:
        print(f"[ERROR] Failed on {audio_name}: {e}")
        results.append({
            "audio": audio_path,
            "error": str(e)
        })

# Save overall results as CSV
results_df = pd.DataFrame(results)
results_df.to_csv("../results/pyannote_summary.csv", index=False)

print("\n[INFO] All files processed. Summary saved to ../results/pyannote_summary.csv")
results_df.head()



[INFO] (1/12) Processing solo10_ar ...


[INFO] Finished solo10_ar in 242.92 sec (7 segments)

[INFO] (2/12) Processing solo3_ar ...


[INFO] Finished solo3_ar in 219.94 sec (62 segments)

[INFO] (3/12) Processing two_speakers7_ar ...


[INFO] Finished two_speakers7_ar in 215.25 sec (37 segments)

[INFO] (4/12) Processing two_speakers10_ar ...


[INFO] Finished two_speakers10_ar in 282.31 sec (80 segments)

[INFO] (5/12) Processing three_speakers5_ar ...


[INFO] Finished three_speakers5_ar in 362.47 sec (107 segments)

[INFO] (6/12) Processing three_speakers1_ar ...


[INFO] Finished three_speakers1_ar in 296.23 sec (42 segments)

[INFO] (7/12) Processing solo3_en ...


[INFO] Finished solo3_en in 205.73 sec (11 segments)

[INFO] (8/12) Processing solo2_en ...


[INFO] Finished solo2_en in 202.76 sec (43 segments)

[INFO] (9/12) Processing two_speakers8_en ...


[INFO] Finished two_speakers8_en in 301.70 sec (85 segments)

[INFO] (10/12) Processing two_speakers7_en ...


[INFO] Finished two_speakers7_en in 303.20 sec (35 segments)

[INFO] (11/12) Processing three_speakers2_en ...


[INFO] Finished three_speakers2_en in 280.54 sec (111 segments)

[INFO] (12/12) Processing three_speakers8_en ...


[INFO] Finished three_speakers8_en in 395.84 sec (68 segments)

[INFO] All files processed. Summary saved to ../results/pyannote_summary.csv


Unnamed: 0,audio,n_segments,runtime_sec,output_file
0,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,7,242.921457,../results/pyannote_predictions/solo10_ar_pyan...
1,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,62,219.935136,../results/pyannote_predictions/solo3_ar_pyann...
2,../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...,37,215.254184,../results/pyannote_predictions/two_speakers7_...
3,../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...,80,282.313628,../results/pyannote_predictions/two_speakers10...
4,../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...,107,362.471811,../results/pyannote_predictions/three_speakers...


## 6. Evaluation and Metrics
This section evaluates diarization results using DER, boundary metrics, and speaker assignment accuracy.

In [8]:
import os, json, math, pathlib, numpy as np, pandas as pd
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate
from scipy.optimize import linear_sum_assignment

# === Load GT ===
df = pd.read_csv("../data/refined_dataset.csv")
df["segments"] = df["speaker"].apply(lambda x: json.loads(x))

# === Load summary of predictions ===
summary = pd.read_csv("../results/pyannote_summary.csv")

# Join GT with predictions
eval_df = summary.merge(df[["audio", "segments", "speaker_count"]], on="audio", how="left")

# === Helpers ===
def segments_to_annotation(segments):
    ann = Annotation()
    for seg in segments:
        start = float(seg["start"]); end = float(seg["end"])
        if end > start:
            ann[Segment(start, end)] = str(seg["labels"][0])
    return ann

def load_pred_segments(path):
    with open(path) as f:
        return json.load(f)

def extract_boundaries(segments):
    return sorted(set([float(s["start"]) for s in segments] + [float(s["end"]) for s in segments]))

def match_boundaries(pred_b, ref_b, tol=0.5):
    ref_used = [False]*len(ref_b); matches=[]; TP=0
    for p in pred_b:
        best=None; best_abs=None; best_idx=None
        for i,r in enumerate(ref_b):
            if ref_used[i]: continue
            d=p-r
            if abs(d)<=tol and (best_abs is None or abs(d)<best_abs):
                best=(p,r,d); best_abs=abs(d); best_idx=i
        if best is not None:
            matches.append(best); ref_used[best_idx]=True; TP+=1
    FP=len(pred_b)-TP; FN=len(ref_b)-TP
    return matches,TP,FP,FN

def speaker_assignment_accuracy(ref_ann, hyp_ann):
    """Permutation-invariant assignment via Hungarian on overlap duration."""
    total_ref = ref_ann.get_timeline().duration()
    if total_ref <= 1e-9:
        return np.nan

    ref_labels = sorted(set(ref_ann.labels()))
    hyp_labels = sorted(set(hyp_ann.labels()))
    if not ref_labels or not hyp_labels:
        return 0.0

    M = np.zeros((len(ref_labels), len(hyp_labels)), dtype=float)

    # NOTE: itertracks(yield_label=True) -> (segment, track, label)
    for ref_seg, _, r_lab in ref_ann.itertracks(yield_label=True):
        for hyp_seg, _, h_lab in hyp_ann.itertracks(yield_label=True):
            inter = min(ref_seg.end, hyp_seg.end) - max(ref_seg.start, hyp_seg.start)
            if inter > 1e-9:
                i = ref_labels.index(r_lab)
                j = hyp_labels.index(h_lab)
                M[i, j] += inter

    if M.size == 0:
        return 0.0

    r_ind, h_ind = linear_sum_assignment(-M)  # maximize overlap
    matched_overlap = M[r_ind, h_ind].sum()
    return float(matched_overlap / total_ref)

# === Evaluation ===
der_metric = DiarizationErrorRate(collar=0.5, skip_overlap=False)

records=[]
for _, row in eval_df.iterrows():
    audio = row["audio"]
    out_file = row.get("output_file")

    if not isinstance(out_file, str) or not pathlib.Path(out_file).exists():
        records.append({"audio": audio, "ok": False, "error": "missing pred"}); continue

    hyp_segments = load_pred_segments(out_file)
    ref_segments = row["segments"]

    ref_ann = segments_to_annotation(ref_segments)
    hyp_ann = segments_to_annotation(hyp_segments)

    # DER
    der = der_metric(ref_ann, hyp_ann)

    # Boundary metrics
    ref_b = extract_boundaries(ref_segments)
    hyp_b = extract_boundaries(hyp_segments)
    matches, TP, FP, FN = match_boundaries(hyp_b, ref_b, tol=0.5)
    prec = TP/(TP+FP) if (TP+FP) else 0.0
    rec  = TP/(TP+FN) if (TP+FN) else 0.0
    f1   = 2*prec*rec/(prec+rec) if (prec+rec) else 0.0
    offsets = [m[2] for m in matches]
    mad = float(np.median(np.abs(offsets))) if offsets else np.nan

    # Speaker assignment accuracy (fixed unpack)
    assign_acc = speaker_assignment_accuracy(ref_ann, hyp_ann)

    records.append({
        "audio": audio,
        "ok": True,
        "DER": float(der),
        "Boundary_precision": float(prec),
        "Boundary_recall": float(rec),
        "Boundary_F1": float(f1),
        "Boundary_median_abs_offset": mad,
        "Speaker_assign_acc": assign_acc,
        "Runtime_sec": float(row.get("runtime_sec", np.nan)),
        "N_ref": int(len(ref_segments)),
        "N_pred": int(len(hyp_segments)),
    })

py_eval = pd.DataFrame(records)
py_eval.to_csv("../results/pyannote_eval.csv", index=False)
print("[INFO] Saved to ../results/pyannote_eval.csv")

print("\nPer-file (first 8):")
print(py_eval[["audio","DER","Boundary_F1","Speaker_assign_acc","Runtime_sec","N_ref","N_pred"]].head(8))

print("\nAggregate means:")
print(py_eval[py_eval["ok"]==True][["DER","Boundary_F1","Speaker_assign_acc","Runtime_sec"]].mean())



[INFO] Saved to ../results/pyannote_eval.csv

Per-file (first 8):
                                               audio       DER  Boundary_F1  \
0  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...  0.011285     0.181818   
1  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...  0.254558     0.368421   
2  ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...  0.074083     0.375000   
3  ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...  0.145053     0.490909   
4  ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...  0.222181     0.202532   
5  ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...  0.094761     0.348485   
6  ../audios-wav/12-audios-ar-en/6-audios-en/1_sp...  0.026526     0.533333   
7  ../audios-wav/12-audios-ar-en/6-audios-en/1_sp...  0.049459     0.712329   

   Speaker_assign_acc  Runtime_sec  N_ref  N_pred  
0            0.999973   242.921457      4       7  
1            0.753860   219.935136     14      62  
2            0.920906   215.254184     13      37  
3            0

## 7. Error Analysis and Aggregate Statistics
This section analyzes errors, shows per-file and aggregate statistics, and saves evaluation results.

In [9]:
import pandas as pd, numpy as np, json, pathlib

py_eval = pd.read_csv("../results/pyannote_eval.csv")
print("[NaN rows]\n", py_eval[py_eval.isna().any(axis=1)][["audio","DER","Boundary_F1","N_ref","N_pred"]])

# Show top over-segmented files
overseg = py_eval.dropna().assign(overseg_ratio=lambda d: d["N_pred"]/d["N_ref"].replace(0,np.nan))
print("\n[Most over-segmented]\n", overseg.sort_values("overseg_ratio", ascending=False).head(5)[["audio","N_ref","N_pred","overseg_ratio"]])

[NaN rows]
 Empty DataFrame
Columns: [audio, DER, Boundary_F1, N_ref, N_pred]
Index: []

[Most over-segmented]
                                                 audio  N_ref  N_pred  \
4   ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...     18     107   
1   ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...     14      62   
2   ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...     13      37   
3   ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...     33      80   
11  ../audios-wav/12-audios-ar-en/6-audios-en/3_sp...     30      68   

    overseg_ratio  
4        5.944444  
1        4.428571  
2        2.846154  
3        2.424242  
11       2.266667  
