# SpeechBrain Diarization Pipeline
This notebook performs speaker diarization using the SpeechBrain ECAPA-TDNN model and evaluates the results.

In [20]:
import os, sys, pathlib, pandas as pd

GT_CSV = "../data/refined_dataset.csv"
SB_PRED_DIR = pathlib.Path("../results/speechbrain_predictions")
SB_SUMMARY  = "../results/speechbrain_summary.csv"

SB_PRED_DIR.mkdir(parents=True, exist_ok=True)

print("PY:", sys.version)
print("CWD:", os.getcwd())
print("GT_CSV exists:", os.path.exists(GT_CSV))

df = pd.read_csv(GT_CSV)
print("GT rows:", len(df))
df.head(2)[["audio","speaker_count"]]

PY: 3.9.23 (main, Jun  3 2025, 18:47:52) 
[Clang 16.0.0 (clang-1600.0.26.6)]
CWD: /Users/s.n.h/Voice-AI/Audio-AI/notebooks
GT_CSV exists: True
GT rows: 12


Unnamed: 0,audio,speaker_count
0,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,1.0
1,../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...,1.0


## 2. Embedding and Clustering Functions
This section defines functions for extracting embeddings and clustering for diarization.

In [21]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
import torch

def windowed_embeddings(wav, sr, win_s=1.0, hop_s=0.5):
    W = int(sr*win_s); H = int(sr*hop_s)
    embs, times = [], []
    with torch.no_grad():
        for start in range(0, len(wav)-W+1, H):
            chunk = wav[start:start+W].unsqueeze(0).to(torch.device)
            emb = sb_enc.encode_batch(chunk).squeeze(0).squeeze(0).cpu().numpy() # type: ignore
            embs.append(emb)
            times.append((start/sr, (start+W)/sr))
    return np.array(embs), times

def cluster_auto(embs, k_min=1, k_max=4):
    if len(embs) < 2:
        return np.zeros(len(embs), dtype=int), 1
    best_k, best_score, best_labels = 1, -1.0, np.zeros(len(embs), dtype=int)
    for k in range(k_min, min(k_max, len(embs)) + 1):
        try:
            lab = AgglomerativeClustering(n_clusters=k, linkage="ward").fit_predict(embs)
            score = silhouette_score(embs, lab) if k > 1 else -1.0
            if score > best_score:
                best_k, best_score, best_labels = k, score, lab
        except Exception:
            pass
    return best_labels, best_k

def windows_to_segments(times, labels, min_seg=0.30, gap_merge=0.25):
    if not times: return []
    ordered = sorted(zip(times, labels), key=lambda x: x[0][0])
    out = []
    cs, ce, cl = ordered[0][0][0], ordered[0][0][1], ordered[0][1]
    for (t0, t1), lab in ordered[1:]:
        if lab == cl and t0 - ce <= gap_merge:
            ce = max(ce, t1)
        else:
            if ce - cs >= min_seg:
                out.append({"start": float(cs), "end": float(ce), "labels": [f"Speaker {int(cl)+1}"]})
            cs, ce, cl = t0, t1, lab
    if ce - cs >= min_seg:
        out.append({"start": float(cs), "end": float(ce), "labels": [f"Speaker {int(cl)+1}"]})
    return out

## 3. Audio Processing and VAD Functions
This section defines functions for reading audio and performing voice activity detection (VAD).

In [22]:
import torchaudio, torch, os

def read_wav(path, target_sr=16000):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Error loading audio file: not found {path}")
    wav, sr = torchaudio.load(path)
    if wav.dim() > 1:  # make mono
        wav = wav.mean(dim=0)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
        sr = target_sr
    return wav.squeeze(0), sr

# webrtcvad: frames must be exactly 10, 20, or 30 ms; sr must be 8000/16000/32000
def frames_vad(wav, sr, frame_ms=30, vad_aggr=2):
    assert frame_ms in (10, 20, 30), "webrtcvad requires 10/20/30 ms frames"
    import webrtcvad
    from scipy.signal import medfilt

    vad = webrtcvad.Vad(vad_aggr)
    frame_len = int(sr * frame_ms / 1000)
    hop = frame_len

    speech = []
    for start in range(0, len(wav), hop):
        end = min(start + frame_len, len(wav))
        frm = wav[start:end]
        if len(frm) < frame_len:
            frm = torch.nn.functional.pad(frm, (0, frame_len - len(frm)))
        pcm16 = (frm.clamp(-1, 1) * 32767.0).to(torch.int16).cpu().numpy().tobytes()
        speech.append(1 if vad.is_speech(pcm16, sr) else 0)

    speech = medfilt(torch.tensor(speech, dtype=torch.int32).numpy(), kernel_size=5)

    segs = []
    i = 0
    n = len(speech)
    while i < n:
        if speech[i] == 1:
            j = i + 1
            while j < n and speech[j] == 1:
                j += 1
            segs.append((i, j))
            i = j
        else:
            i += 1

    segs_sec = []
    for s, e in segs:
        start = s * frame_len / sr
        end = e * frame_len / sr
        if end - start >= 0.20:
            segs_sec.append((start, end))
    return segs_sec, frame_len / sr

## 4. SpeechBrain Encoder Setup
This section sets up the SpeechBrain ECAPA-TDNN encoder and device.

In [23]:
# Setup SpeechBrain device and encoder
import torch
from speechbrain.pretrained import EncoderClassifier

sb_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sb_enc = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device": str(sb_device)}
)

## 5. Clustering and Segment Conversion Functions
This section defines clustering and segment conversion utilities.

In [24]:
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score

def windowed_embeddings(wav, sr, win_s=1.0, hop_s=0.5):
    W = int(sr*win_s); H = int(sr*hop_s)
    embs, times = [], []
    with torch.no_grad():
        for start in range(0, len(wav)-W+1, H):
            chunk = wav[start:start+W].unsqueeze(0).to(sb_device)
            emb = sb_enc.encode_batch(chunk).squeeze(0).squeeze(0).cpu().numpy()
            embs.append(emb)
            times.append((start/sr, (start+W)/sr))
    return np.array(embs), times

def cluster_auto(embs, k_min=1, k_max=4):
    if len(embs) < 2:
        return np.zeros(len(embs), dtype=int), 1
    best_k, best_score, best_labels = 1, -1.0, np.zeros(len(embs), dtype=int)
    for k in range(k_min, min(k_max, len(embs)) + 1):
        try:
            lab = AgglomerativeClustering(n_clusters=k, linkage="ward").fit_predict(embs)
            score = silhouette_score(embs, lab) if k > 1 else -1.0
            if score > best_score:
                best_k, best_score, best_labels = k, score, lab
        except Exception:
            pass
    return best_labels, best_k

def windows_to_segments(times, labels, min_seg=0.30, gap_merge=0.25):
    if not times: return []
    ordered = sorted(zip(times, labels), key=lambda x: x[0][0])
    out = []
    cs, ce, cl = ordered[0][0][0], ordered[0][0][1], ordered[0][1]
    for (t0, t1), lab in ordered[1:]:
        if lab == cl and t0 - ce <= gap_merge:
            ce = max(ce, t1)
        else:
            if ce - cs >= min_seg:
                out.append({"start": float(cs), "end": float(ce), "labels": [f"Speaker {int(cl)+1}"]})
            cs, ce, cl = t0, t1, lab
    if ce - cs >= min_seg:
        out.append({"start": float(cs), "end": float(ce), "labels": [f"Speaker {int(cl)+1}"]})
    return out

## 6. Batch Diarization and Result Saving
This section runs diarization on all files and saves the results to JSON and CSV.

In [25]:
import time, json, pathlib, numpy as np, pandas as pd
GT_CSV = "../data/refined_dataset.csv"
if 'df' not in globals():
    df = pd.read_csv(GT_CSV)
OUT_DIR = pathlib.Path("../results/speechbrain_predictions"); OUT_DIR.mkdir(parents=True, exist_ok=True)
SUMMARY_CSV = "../results/speechbrain_summary.csv"

results = []
for i, row in df.iterrows():
    audio = row["audio"]; stem = pathlib.Path(audio).stem
    print(f"[SB] ({i+1}/{len(df)}) {stem}")
    t0 = time.time()
    try:
        wav, sr = read_wav(audio, 16000)
        vad_segs, _ = frames_vad(wav, sr, frame_ms=30, vad_aggr=2)

        all_embs, all_times = [], []
        for (s,e) in vad_segs:
            seg = wav[int(s*sr):int(e*sr)]
            embs, times = windowed_embeddings(seg, sr, win_s=1.0, hop_s=0.5)
            times = [(s+a, s+b) for (a,b) in times]
            if len(embs):
                all_embs.append(embs); all_times.extend(times)

        embs = np.vstack(all_embs) if len(all_embs) else np.zeros((0,192))
        if len(embs)==0:
            preds = []
        else:
            labels, k = cluster_auto(embs, k_min=1, k_max=4)
            preds = windows_to_segments(all_times, labels, min_seg=0.30, gap_merge=0.25)

        out_path = OUT_DIR / f"{stem}_speechbrain.json"
        with open(out_path, "w") as f: json.dump(preds, f, indent=2)

        dur = time.time() - t0
        results.append({"audio": audio, "n_segments": len(preds), "runtime_sec": dur, "output_file": str(out_path)})
        print(f"  -> {len(preds)} segs, {dur:.2f}s")
    except Exception as e:
        results.append({"audio": audio, "error": str(e)})
        print(f"  !! ERROR: {e}")

pd.DataFrame(results).to_csv(SUMMARY_CSV, index=False)
print(f"[SB] Done -> {SUMMARY_CSV}")

[SB] (1/12) solo10_ar
  -> 85 segs, 55.03s
[SB] (2/12) solo3_ar
  -> 85 segs, 55.03s
[SB] (2/12) solo3_ar
  -> 60 segs, 23.03s
[SB] (3/12) two_speakers7_ar
  -> 60 segs, 23.03s
[SB] (3/12) two_speakers7_ar
  -> 36 segs, 41.53s
[SB] (4/12) two_speakers10_ar
  -> 36 segs, 41.53s
[SB] (4/12) two_speakers10_ar
  -> 61 segs, 52.15s
[SB] (5/12) three_speakers5_ar
  -> 61 segs, 52.15s
[SB] (5/12) three_speakers5_ar
  -> 65 segs, 59.55s
[SB] (6/12) three_speakers1_ar
  -> 65 segs, 59.55s
[SB] (6/12) three_speakers1_ar
  -> 42 segs, 63.40s
[SB] (7/12) solo3_en
  -> 42 segs, 63.40s
[SB] (7/12) solo3_en
  -> 59 segs, 38.72s
[SB] (8/12) solo2_en
  -> 59 segs, 38.72s
[SB] (8/12) solo2_en
  -> 80 segs, 38.32s
[SB] (9/12) two_speakers8_en
  -> 80 segs, 38.32s
[SB] (9/12) two_speakers8_en
  -> 65 segs, 58.24s
[SB] (10/12) two_speakers7_en
  -> 65 segs, 58.24s
[SB] (10/12) two_speakers7_en
  -> 63 segs, 48.59s
[SB] (11/12) three_speakers2_en
  -> 63 segs, 48.59s
[SB] (11/12) three_speakers2_en
  -> 42 

## 7. Merge Ground Truth and Predictions
This section merges ground truth data with diarization predictions for evaluation.

In [29]:
import pandas as pd

# Load ground truth and summary
df = pd.read_csv("../data/refined_dataset.csv")
summary = pd.read_csv("../results/speechbrain_summary.csv")

# Use 'speaker' as 'segments' for evaluation
df["segments"] = df["speaker"]

# Merge ground truth with predictions
eval_df = summary.merge(df[["audio", "segments", "speaker_count"]], on="audio", how="left")

## 8. Evaluation Functions and Metrics
This section defines evaluation functions and computes diarization metrics.

In [30]:
import numpy as np
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate
from scipy.optimize import linear_sum_assignment
import json, pathlib

def segments_to_annotation(segments):
    ann = Annotation()
    for seg in segments:
        start = float(seg["start"]); end = float(seg["end"])
        if end > start:
            ann[Segment(start, end)] = str(seg["labels"][0])
    return ann

def load_pred_segments(path):
    with open(path) as f:
        return json.load(f)

def extract_boundaries(segments):
    return sorted(set([float(s["start"]) for s in segments] + [float(s["end"]) for s in segments]))

def match_boundaries(pred_b, ref_b, tol=0.5):
    ref_used = [False]*len(ref_b); matches=[]; TP=0
    for p in pred_b:
        best=None; best_abs=None; best_idx=None
        for i,r in enumerate(ref_b):
            if ref_used[i]: continue
            d=p-r
            if abs(d)<=tol and (best_abs is None or abs(d)<best_abs):
                best=(p,r,d); best_abs=abs(d); best_idx=i
        if best is not None:
            matches.append(best); ref_used[best_idx]=True; TP+=1
    FP=len(pred_b)-TP; FN=len(ref_b)-TP
    return matches,TP,FP,FN

def speaker_assignment_accuracy(ref_ann, hyp_ann):
    total_ref = ref_ann.get_timeline().duration()
    if total_ref <= 1e-9:
        return np.nan
    ref_labels = sorted(set(ref_ann.labels()))
    hyp_labels = sorted(set(hyp_ann.labels()))
    if not ref_labels or not hyp_labels:
        return 0.0
    M = np.zeros((len(ref_labels), len(hyp_labels)), dtype=float)
    for ref_seg, _, r_lab in ref_ann.itertracks(yield_label=True):
        for hyp_seg, _, h_lab in hyp_ann.itertracks(yield_label=True):
            inter = min(ref_seg.end, hyp_seg.end) - max(ref_seg.start, hyp_seg.start)
            if inter > 1e-9:
                i = ref_labels.index(r_lab)
                j = hyp_labels.index(h_lab)
                M[i, j] += inter
    if M.size == 0:
        return 0.0
    r_ind, h_ind = linear_sum_assignment(-M)
    matched_overlap = M[r_ind, h_ind].sum()
    return float(matched_overlap / total_ref)

# --- Evaluation ---
der_metric = DiarizationErrorRate(collar=0.5, skip_overlap=False)

records = []
for _, row in eval_df.iterrows():
    audio = row["audio"]
    out_file = row.get("output_file")
    if not isinstance(out_file, str) or not pathlib.Path(out_file).exists():
        records.append({"audio": audio, "ok": False, "error": "missing pred"}); continue
    hyp_segments = load_pred_segments(out_file)
    ref_segments = json.loads(row["segments"]) if isinstance(row["segments"], str) else row["segments"]
    ref_ann = segments_to_annotation(ref_segments)
    hyp_ann = segments_to_annotation(hyp_segments)
    # DER
    der = der_metric(ref_ann, hyp_ann)
    # Boundary metrics
    ref_b = extract_boundaries(ref_segments)
    hyp_b = extract_boundaries(hyp_segments)
    matches, TP, FP, FN = match_boundaries(hyp_b, ref_b, tol=0.5)
    prec = TP/(TP+FP) if (TP+FP) else 0.0
    rec  = TP/(TP+FN) if (TP+FN) else 0.0
    f1   = 2*prec*rec/(prec+rec) if (prec+rec) else 0.0
    offsets = [m[2] for m in matches]
    mad = float(np.median(np.abs(offsets))) if offsets else np.nan
    # Speaker assignment accuracy
    assign_acc = speaker_assignment_accuracy(ref_ann, hyp_ann)
    records.append({
        "audio": audio,
        "ok": True,
        "DER": float(der),
        "Boundary_precision": float(prec),
        "Boundary_recall": float(rec),
        "Boundary_F1": float(f1),
        "Boundary_median_abs_offset": mad,
        "Speaker_assign_acc": assign_acc,
        "Runtime_sec": float(row.get("runtime_sec", np.nan)),
        "N_ref": int(len(ref_segments)),
        "N_pred": int(len(hyp_segments)),
    })

sb_eval = pd.DataFrame(records)
sb_eval.to_csv("../results/speechbrain_eval.csv", index=False)
print("[INFO] Saved to ../results/speechbrain_eval.csv")

print("\nPer-file (first 8):")
print(sb_eval[["audio","DER","Boundary_F1","Speaker_assign_acc","Runtime_sec","N_ref","N_pred"]].head(8))

print("\nAggregate means:")
print(sb_eval[sb_eval["ok"]==True][["DER","Boundary_F1","Speaker_assign_acc","Runtime_sec"]].mean())





[INFO] Saved to ../results/speechbrain_eval.csv

Per-file (first 8):
                                               audio       DER  Boundary_F1  \
0  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...  0.538149     0.070588   
1  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...  0.543405     0.303448   
2  ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...  0.127357     0.326531   
3  ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...  0.190055     0.382979   
4  ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...  0.238083     0.206061   
5  ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...  0.109520     0.477612   
6  ../audios-wav/12-audios-ar-en/6-audios-en/1_sp...  0.375217     0.428571   
7  ../audios-wav/12-audios-ar-en/6-audios-en/1_sp...  0.514689     0.481132   

   Speaker_assign_acc  Runtime_sec  N_ref  N_pred  
0            0.659828    55.034382      4      85  
1            0.508158    23.028747     14      60  
2            0.871193    41.525807     13      36  
3          

## 9. Run Evaluation and Save Results
This section runs the evaluation, computes metrics, and saves the results to CSV.

In [31]:
import pandas as pd
import numpy as np

sb_eval = pd.read_csv("../results/speechbrain_eval.csv")

# Show rows with any NaN values (potential errors)
print("[NaN rows]\n", sb_eval[sb_eval.isna().any(axis=1)][["audio", "DER", "Boundary_F1", "N_ref", "N_pred"]])

# Show top over-segmented files
overseg = sb_eval.dropna().assign(overseg_ratio=lambda d: d["N_pred"]/d["N_ref"].replace(0, np.nan))
print("\n[Most over-segmented]\n", overseg.sort_values("overseg_ratio", ascending=False).head(5)[["audio", "N_ref", "N_pred", "overseg_ratio"]])

# Aggregate statistics for all successful files
print("\n[Aggregate means:]")
print(sb_eval[sb_eval["ok"]==True][["DER","Boundary_F1","Speaker_assign_acc","Runtime_sec"]].mean())

[NaN rows]
 Empty DataFrame
Columns: [audio, DER, Boundary_F1, N_ref, N_pred]
Index: []

[Most over-segmented]
                                                audio  N_ref  N_pred  \
0  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...      4      85   
1  ../audios-wav/12-audios-ar-en/6-audios-ar/1_sp...     14      60   
4  ../audios-wav/12-audios-ar-en/6-audios-ar/3_sp...     18      65   
6  ../audios-wav/12-audios-ar-en/6-audios-en/1_sp...     19      59   
2  ../audios-wav/12-audios-ar-en/6-audios-ar/2_sp...     13      36   

   overseg_ratio  
0      21.250000  
1       4.285714  
4       3.611111  
6       3.105263  
2       2.769231  

[Aggregate means:]
DER                    0.289385
Boundary_F1            0.345201
Speaker_assign_acc     0.750636
Runtime_sec           52.955408
dtype: float64
