In [1]:
# !/usr/bin/env python3
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  


In [2]:
# Check GPU availability
import torch
torch.cuda.is_available(), torch.cuda.get_device_name(0)


(True, 'NVIDIA GeForce RTX 2080 Ti')

In [3]:
# Imports
import os
from pathlib import Path
import numpy as np
import pandas as pd
import soundfile as sf
import webrtcvad
import torch

from pyannote.audio import Pipeline, Model
from pyannote.audio import Inference
from pyannote.core import Segment

  from .autonotebook import tqdm as notebook_tqdm


In [56]:
# Configuration parameters
participant_id = "ABAN141223"
session_date   = "20250216"

session_dir = Path("/scratch/users/arunps/hindibabynet/audio_raw/ABAN141223/20250216")
wav_files = sorted(session_dir.glob("*.WAV")) + sorted(session_dir.glob("*.wav"))
wav_path = Path(wav_files[1]) 
info = sf.info(str(wav_path))
print(wav_path.name, info.samplerate, info.channels, info.duration/3600, "hours") 
recording_id = wav_path.stem
# diarization bounds 
MIN_SPEAKERS = 2
MAX_SPEAKERS = 4

# chunking
CHUNK_SEC = 15 * 60     # 15 min
OVERLAP_SEC = 10        # small overlap to avoid cutting speech

# VAD params 
VAD_AGGR = 2
VAD_FRAME_MS = 30
VAD_MIN_REGION_MS = 300

# intersection / post-filter
MIN_KEEP_SEC = 0.20     # drop tiny fragments after intersection

# speaker embedding / global clustering
EMB_MODEL_ID = "pyannote/embedding"     #  embedding checkpoint
MIN_EMB_SEG_SEC = 1.0                  # ignore too-short segments for embeddings
MIN_SPK_TOTAL_SEC_FOR_CENTROID = 10.0  # need enough speech to build a stable centroid
COS_SIM_MERGE_THRESHOLD = 0.78         # threshold for merging speakers based on cosine similarity


1739701628.WAV 16000 1 2.0364166666666668 hours


In [44]:
# Create recordings DataFrame
import soundfile as sf
import pandas as pd

rows = []
for p in wav_files:
    info = sf.info(str(p))
    rows.append({
        "participant_id": "ABAN141223",
        "session_date": "20250216",
        "recording_id": p.stem,
        "path": str(p),
        "duration_sec": float(info.duration),
        "sample_rate": int(info.samplerate),
        "channels": int(info.channels),
        "size_bytes": p.stat().st_size,
    })

recordings = pd.DataFrame(rows)
recordings


Unnamed: 0,participant_id,session_date,recording_id,path,duration_sec,sample_rate,channels,size_bytes
0,ABAN141223,20250216,1739683525,/scratch/users/arunps/hindibabynet/audio_raw/A...,17940.02,16000,1,574081152
1,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,7331.1,16000,1,234595712


In [45]:
# Load HF_TOKEN from .env
from dotenv import load_dotenv
import os

load_dotenv()  

assert os.getenv("HF_TOKEN") is not None, "HF_TOKEN not loaded"
print("HF_TOKEN loaded")

HF_TOKEN loaded


In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [47]:
# Disable pyannote notebook mode
import os
os.environ["PYANNOTE_DISABLE_NOTEBOOK"] = "1"

import matplotlib
matplotlib.use("Agg")

In [48]:
# Set Hugging Face cache to scratch space
import os

scratch_cache = f"/scratch/users/{os.environ['USER']}/.cache/huggingface"
os.environ["HF_HOME"] = scratch_cache
os.environ["HF_HUB_CACHE"] = f"{scratch_cache}/hub"
os.environ["TRANSFORMERS_CACHE"] = f"{scratch_cache}/transformers"

In [49]:
# Load diarization pipeline
HF_TOKEN = os.environ["HF_TOKEN"]
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=HF_TOKEN
)
pipeline.to(device)

print("Diarization pipeline loaded on", device)


/itf-fi-ml/home/arunps/Projects/HindiBabyNet/.venv/lib/python3.10/site-packages/lightning_fabric/utilities/cloud_io.py:73: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Diarization pipeline loaded on cuda


In [50]:
# Load embedding model
emb_model = Model.from_pretrained(EMB_MODEL_ID, use_auth_token=HF_TOKEN)
emb_infer = Inference(emb_model, window="whole")  # one embedding per segment
emb_infer.to(device)

print("Embedding model loaded on", device)


/itf-fi-ml/home/arunps/Projects/HindiBabyNet/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/

Model was trained with pyannote.audio 0.0.1, yours is 3.4.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.5.1+cu124. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.4.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.5.1+cu124. Bad things might happen unless you revert torch to 1.x.
Embedding model loaded on cuda


/itf-fi-ml/home/arunps/Projects/HindiBabyNet/.venv/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:197: Found keys that are not in the model state dict but in the checkpoint: ['loss_func.W']


In [51]:
# Utility functions 
def make_chunks(duration_sec: float, chunk_sec: float, overlap_sec: float):
    """Yield (chunk_id, chunk_start, chunk_end) with overlap."""
    step = chunk_sec - overlap_sec
    assert step > 0, "chunk_sec must be > overlap_sec"

    t = 0.0
    chunk_id = 0
    while t < duration_sec:
        s = t
        e = min(t + chunk_sec, duration_sec)
        yield chunk_id, s, e
        if e >= duration_sec:
            break
        t += step
        chunk_id += 1


In [52]:
# Helper: streaming WebRTC VAD on FULL AUDIO (no full-file load)
# Returns sorted speech intervals [(start_sec, end_sec), ...]
def webrtc_vad_regions_streaming(
    path: Path,
    aggressiveness: int = 2,
    frame_ms: int = 30,
    min_region_ms: int = 300,
):
    vad = webrtcvad.Vad(aggressiveness)
    info = sf.info(str(path))
    sr = info.samplerate
    ch = info.channels

    if sr not in (8000, 16000, 32000, 48000):
        raise ValueError(f"webrtcvad needs sr in (8k,16k,32k,48k). got: {sr}")

    frame_len = int(sr * frame_ms / 1000)

    speech_flags = []
    with sf.SoundFile(str(path), mode="r") as f:
        while True:
            frame = f.read(frames=frame_len, dtype="int16", always_2d=True)
            if frame.size == 0 or len(frame) < frame_len:
                break
            mono = frame[:, 0]  # take ch0
            speech_flags.append(vad.is_speech(mono.tobytes(), sr))

    # merge consecutive true flags to regions in frames
    regions = []
    in_speech = False
    start_i = 0
    for i, is_speech in enumerate(speech_flags):
        if is_speech and not in_speech:
            in_speech = True
            start_i = i
        elif (not is_speech) and in_speech:
            in_speech = False
            regions.append((start_i, i))
    if in_speech:
        regions.append((start_i, len(speech_flags)))

    # convert to seconds and filter
    out = []
    for s_i, e_i in regions:
        s = (s_i * frame_len) / sr
        e = (e_i * frame_len) / sr
        if (e - s) * 1000 >= min_region_ms:
            out.append((float(s), float(e)))

    
    return out


In [53]:
# --------------------------
# Helper: interval intersection (two-pointer, sorted lists)
# Inputs:
#   diar_df: columns [start_sec, end_sec, ...] sorted
#   vad_intervals: list of (start_sec, end_sec) sorted
# Output:
#   speech-only rows with same metadata + speaker columns
# --------------------------
def intersect_diar_with_vad(diar_df: pd.DataFrame, vad_intervals, min_keep_sec: float = 0.0):
    diar_arr = diar_df[["start_sec", "end_sec", "speaker_id_global"]].to_numpy()
    vad_arr = np.array(vad_intervals, dtype=float)

    i = 0
    j = 0
    rows = []

    def intersect(a_s, a_e, b_s, b_e):
        s = max(a_s, b_s)
        e = min(a_e, b_e)
        return (s, e) if s < e else None

    while i < len(diar_arr) and j < len(vad_arr):
        ds, de, spk = diar_arr[i]
        vs, ve = vad_arr[j]

        inter = intersect(ds, de, vs, ve)
        if inter is not None:
            s, e = inter
            dur = float(e - s)
            if dur >= min_keep_sec:
                rows.append({
                    "start_sec": float(s),
                    "end_sec": float(e),
                    "duration_sec": dur,
                    "speaker_id": spk
                })

        # advance the one that ends first
        if de <= ve:
            i += 1
        else:
            j += 1

    out = pd.DataFrame(rows)
    return out


In [54]:
# Helper: compute speaker centroid embedding from segments
import numpy as np

def speaker_centroid_embedding_from_path(
    wav_path,
    segments,
    emb_infer,
    min_seg_sec: float = 1.0,
    max_segments: int = 50,
):
    embs = []
    total = 0.0
    logged_error = False

    # longest segments first (more stable embeddings)
    segments = sorted(
        segments, key=lambda s: float(s.end - s.start), reverse=True
    )[:max_segments]

    for seg in segments:
        dur = float(seg.end - seg.start)
        if dur < min_seg_sec:
            continue
        try:
            v = emb_infer.crop(str(wav_path), seg)
            v = np.array(v, dtype=np.float32).reshape(-1)
            if np.linalg.norm(v) < 1e-6:
                continue
            embs.append(v)
            total += dur
        except Exception as e:
            if not logged_error:
                print("Embedding crop error (first one):", repr(e))
                logged_error = True
            continue

    if len(embs) == 0:
        return None, 0.0

    centroid = np.mean(np.stack(embs, axis=0), axis=0)
    centroid = centroid / (np.linalg.norm(centroid) + 1e-12)
    return centroid, total


In [55]:
# --------------------------
# 1) VAD ONCE (FULL AUDIO)
# --------------------------
info = sf.info(str(wav_path))
full_duration = float(info.duration)

vad_intervals = webrtc_vad_regions_streaming(
    wav_path,
    aggressiveness=VAD_AGGR,
    frame_ms=VAD_FRAME_MS,
    min_region_ms=VAD_MIN_REGION_MS
)

print("Full duration (hours):", full_duration / 3600)
print("VAD intervals:", len(vad_intervals), "first:", vad_intervals[:3])


Full duration (hours): 2.0364166666666668
VAD intervals: 1745 first: [(2.43, 2.76), (7.71, 10.02), (11.1, 11.91)]


In [57]:
# Create VAD DataFrame
vad_df = (
    pd.DataFrame([{
        "participant_id": participant_id,
        "session_date": session_date,
        "recording_id": recording_id,
        "wav_path": str(wav_path),
        "start_sec": s,
        "end_sec": e,
        "duration_sec": e - s,
    } for s, e in vad_intervals])
    .sort_values(["start_sec", "end_sec"])
    .reset_index(drop=True)
)


In [58]:
vad_df.head(5)

Unnamed: 0,participant_id,session_date,recording_id,wav_path,start_sec,end_sec,duration_sec
0,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,2.43,2.76,0.33
1,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,7.71,10.02,2.31
2,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,11.1,11.91,0.81
3,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,13.32,14.61,1.29
4,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,15.09,15.6,0.51


In [59]:
# --------------------------
# 2) DIARIZATION PER CHUNK (with bounds) 
#    Store:
#      - turns per chunk (local speaker id)
#      - per-(chunk, local_speaker) centroid embedding
# --------------------------

from pathlib import Path
import soundfile as sf
from pyannote.core import Segment

def write_wav_chunk(wav_path: Path, chunk_path: Path, start_sec: float, end_sec: float):
    info = sf.info(str(wav_path))
    sr = info.samplerate

    start_frame = int(start_sec * sr)
    n_frames = int((end_sec - start_sec) * sr)

    audio, _ = sf.read(str(wav_path), start=start_frame, frames=n_frames)
    sf.write(str(chunk_path), audio, sr)
    return chunk_path


all_turn_rows = []
centroid_rows = []

# full wav used for embedding extraction
file_full = {"audio": str(wav_path)}

tmp_chunks_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp_chunks"
tmp_chunks_dir.mkdir(parents=True, exist_ok=True)

for chunk_id, chunk_start, chunk_end in make_chunks(full_duration, CHUNK_SEC, OVERLAP_SEC):

    # write chunk wav
    chunk_wav = tmp_chunks_dir / f"{wav_path.stem}_chunk{chunk_id:04d}_{int(chunk_start)}_{int(chunk_end)}.wav"
    write_wav_chunk(wav_path, chunk_wav, chunk_start, chunk_end)

    # diarize CHUNK wav 
    diar_chunk = pipeline(
        {"audio": str(chunk_wav)},
        min_speakers=MIN_SPEAKERS,
        max_speakers=MAX_SPEAKERS
    )

    # collect turns 
    local_segments_by_spk = {}
    for seg, _, spk in diar_chunk.itertracks(yield_label=True):
        s = float(seg.start) + float(chunk_start)
        e = float(seg.end) + float(chunk_start)
        if e <= s:
            continue

        all_turn_rows.append({
            "participant_id": participant_id,
            "session_date": session_date,
            "recording_id": recording_id,
            "wav_path": str(wav_path),
            "chunk_id": int(chunk_id),
            "chunk_start_sec": float(chunk_start),
            "chunk_end_sec": float(chunk_end),
            "speaker_id_local": spk,
            "start_sec": s,
            "end_sec": e,
            "duration_sec": float(e - s),
            "chunk_wav_path": str(chunk_wav), # for debugging
        })

        # store ORIGINAL timeline segments for embedding extraction
        local_segments_by_spk.setdefault(spk, []).append(Segment(s, e))

    # compute centroid embeddings per local speaker (on FULL wav)
    for spk, segs in local_segments_by_spk.items():
        centroid, total_sec = speaker_centroid_embedding_from_path(
            wav_path=wav_path,
            segments=segs,
            emb_infer=emb_infer,
            min_seg_sec=MIN_EMB_SEG_SEC
        )


        if centroid is None:
            continue

        centroid_rows.append({
            "chunk_id": int(chunk_id),
            "chunk_start_sec": float(chunk_start),
            "chunk_end_sec": float(chunk_end),
            "speaker_id_local": spk,
            "total_sec_used": float(total_sec),
            "centroid": centroid,  # numpy vector
        })

print("Turns collected:", len(all_turn_rows))
print("Centroids collected:", len(centroid_rows))


  std = sequences.std(dim=-1, correction=1)


Turns collected: 1178
Centroids collected: 24


In [60]:
turns_df = (
    pd.DataFrame(all_turn_rows)
      .sort_values(["start_sec", "end_sec"])
      .reset_index(drop=True)
)
turns_df.head()


Unnamed: 0,participant_id,session_date,recording_id,wav_path,chunk_id,chunk_start_sec,chunk_end_sec,speaker_id_local,start_sec,end_sec,duration_sec,chunk_wav_path
0,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,0,0.0,900.0,SPEAKER_01,7.624719,9.885969,2.26125,/scratch/users/arunps/hindibabynet_tmp_chunks/...
1,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,0,0.0,900.0,SPEAKER_01,13.294719,14.273469,0.97875,/scratch/users/arunps/hindibabynet_tmp_chunks/...
2,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,0,0.0,900.0,SPEAKER_01,18.728469,20.095344,1.366875,/scratch/users/arunps/hindibabynet_tmp_chunks/...
3,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,0,0.0,900.0,SPEAKER_01,20.770344,22.002219,1.231875,/scratch/users/arunps/hindibabynet_tmp_chunks/...
4,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,0,0.0,900.0,SPEAKER_01,22.255344,23.807844,1.5525,/scratch/users/arunps/hindibabynet_tmp_chunks/...


In [61]:
centroids_df = pd.DataFrame(centroid_rows)
# keep only stable centroids (enough speech)
centroids_df = centroids_df[centroids_df["total_sec_used"] >= MIN_SPK_TOTAL_SEC_FOR_CENTROID].reset_index(drop=True)

print("Centroids after min speech filter:", len(centroids_df))
centroids_df


Centroids after min speech filter: 13


Unnamed: 0,chunk_id,chunk_start_sec,chunk_end_sec,speaker_id_local,total_sec_used,centroid
0,0,0.0,900.0,SPEAKER_01,27.151875,"[0.04787777, 0.030210864, -0.06240629, -0.0642..."
1,2,1780.0,2680.0,SPEAKER_01,10.243125,"[-0.02877247, 0.02127275, -0.035593122, -0.048..."
2,4,3560.0,4460.0,SPEAKER_00,25.75125,"[0.03520787, -0.0027887593, 0.0023050124, -0.0..."
3,4,3560.0,4460.0,SPEAKER_01,20.199375,"[-0.00941363, -0.025264097, -0.08223153, -0.02..."
4,5,4450.0,5350.0,SPEAKER_03,146.5425,"[0.020359423, 0.0022690473, -0.040677346, -0.0..."
5,5,4450.0,5350.0,SPEAKER_02,11.626875,"[-0.0037898656, 0.010166123, -0.05444234, -0.0..."
6,6,5340.0,6240.0,SPEAKER_01,182.1825,"[0.036714647, -0.0077541065, -0.030207438, -0...."
7,7,6230.0,7130.0,SPEAKER_02,228.301875,"[-0.0034936816, 0.08293282, 0.013763596, -0.01..."
8,7,6230.0,7130.0,SPEAKER_00,29.46375,"[-0.029796671, 0.08992742, -0.010794352, -0.04..."
9,7,6230.0,7130.0,SPEAKER_03,22.798125,"[0.0019078947, 0.070387736, 0.043161493, 0.048..."


In [62]:
# --------------------------
# GLOBAL SPEAKER MERGING (across chunks)
#    Cosine-threshold graph clustering on centroids
# --------------------------
def cosine_sim_matrix(X):
    # X should already be L2-normalized
    return X @ X.T

# build matrix
X = np.stack(centroids_df["centroid"].to_numpy(), axis=0).astype(np.float32)
S = cosine_sim_matrix(X)

# graph edges where sim >= threshold
thr = COS_SIM_MERGE_THRESHOLD
n = S.shape[0]
visited = np.zeros(n, dtype=bool)
global_labels = -np.ones(n, dtype=int)

gid = 0
for i in range(n):
    if visited[i]:
        continue
    # BFS/DFS
    stack = [i]
    visited[i] = True
    global_labels[i] = gid
    while stack:
        u = stack.pop()
        # neighbors above threshold
        nbrs = np.where((S[u] >= thr) & (~visited))[0]
        for v in nbrs:
            visited[v] = True
            global_labels[v] = gid
            stack.append(v)
    gid += 1

centroids_df["speaker_id_global"] = [f"GSPK_{k:02d}" for k in global_labels]
print("Global speakers found:", centroids_df["speaker_id_global"].nunique())
centroids_df[["chunk_id","speaker_id_local","total_sec_used","speaker_id_global"]].head(10)


Global speakers found: 10


Unnamed: 0,chunk_id,speaker_id_local,total_sec_used,speaker_id_global
0,0,SPEAKER_01,27.151875,GSPK_00
1,2,SPEAKER_01,10.243125,GSPK_01
2,4,SPEAKER_00,25.75125,GSPK_02
3,4,SPEAKER_01,20.199375,GSPK_03
4,5,SPEAKER_03,146.5425,GSPK_03
5,5,SPEAKER_02,11.626875,GSPK_04
6,6,SPEAKER_01,182.1825,GSPK_03
7,7,SPEAKER_02,228.301875,GSPK_05
8,7,SPEAKER_00,29.46375,GSPK_06
9,7,SPEAKER_03,22.798125,GSPK_07


In [63]:
# --------------------------
# Create mapping: (chunk_id, speaker_id_local) -> speaker_id_global
#    For local speakers that didn't get a stable centroid (too little speech),
#    assign them to the nearest global centroid available (fallback).
# --------------------------
# Build lookup for stable mappings
map_df = centroids_df[["chunk_id", "speaker_id_local", "speaker_id_global"]].drop_duplicates()

# precompute global centroids (one vector per global speaker)
global_centroids = {}
for gspk, g in centroids_df.groupby("speaker_id_global"):
    G = np.stack(g["centroid"].to_numpy(), axis=0)
    c = np.mean(G, axis=0)
    c = c / (np.linalg.norm(c) + 1e-12)
    global_centroids[gspk] = c

gspk_list = sorted(global_centroids.keys())
Gmat = np.stack([global_centroids[g] for g in gspk_list], axis=0)  # (G, dim)

# Find chunk-local speakers missing in mapping
all_pairs = turns_df[["chunk_id", "speaker_id_local"]].drop_duplicates()
mapped_pairs = map_df[["chunk_id", "speaker_id_local"]].drop_duplicates()
missing_pairs = all_pairs.merge(mapped_pairs, on=["chunk_id","speaker_id_local"], how="left", indicator=True)
missing_pairs = missing_pairs[missing_pairs["_merge"] == "left_only"][["chunk_id","speaker_id_local"]]

# For each missing pair, build a centroid with relaxed constraints and assign nearest global
extra_rows = []
for r in missing_pairs.itertuples(index=False):
    cid = int(r.chunk_id)
    spk = r.speaker_id_local

    segs = turns_df[(turns_df["chunk_id"] == cid) & (turns_df["speaker_id_local"] == spk)]
    seg_objs = [Segment(float(a), float(b)) for a, b in segs[["start_sec","end_sec"]].to_numpy()]

    centroid, total_sec = speaker_centroid_embedding_from_path(
            wav_path=wav_path,
            segments=seg_objs,
            emb_infer=emb_infer,
            min_seg_sec=MIN_EMB_SEG_SEC
        )

    if centroid is None or len(gspk_list) == 0:
        continue

    # nearest by cosine (vectors normalized)
    sims = Gmat @ centroid
    best_idx = int(np.argmax(sims))
    best_g = gspk_list[best_idx]

    extra_rows.append({"chunk_id": cid, "speaker_id_local": spk, "speaker_id_global": best_g})

if extra_rows:
    map_df = pd.concat([map_df, pd.DataFrame(extra_rows)], ignore_index=True)

# final mapping dict
map_dict = {(int(r.chunk_id), r.speaker_id_local): r.speaker_id_global for r in map_df.itertuples(index=False)}

print("Total mapped (chunk, local):", len(map_dict))


Total mapped (chunk, local): 24


In [64]:
# --------------------------
# Relabel turns_df with global speakers
# --------------------------
def map_global(row):
    return map_dict.get((int(row["chunk_id"]), row["speaker_id_local"]), None)

turns_df["speaker_id_global"] = turns_df.apply(map_global, axis=1)

# drop any segments we couldn't map (rare)
turns_df = turns_df.dropna(subset=["speaker_id_global"]).reset_index(drop=True)

# Keep only columns needed downstream + global speaker id
diar_global_df = turns_df[[
    "participant_id","session_date","recording_id","wav_path",
    "start_sec","end_sec","duration_sec",
    "speaker_id_global"
]].sort_values(["start_sec","end_sec"]).reset_index(drop=True)

diar_global_df.head(10)


Unnamed: 0,participant_id,session_date,recording_id,wav_path,start_sec,end_sec,duration_sec,speaker_id_global
0,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,7.624719,9.885969,2.26125,GSPK_00
1,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,13.294719,14.273469,0.97875,GSPK_00
2,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,18.728469,20.095344,1.366875,GSPK_00
3,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,20.770344,22.002219,1.231875,GSPK_00
4,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,22.255344,23.807844,1.5525,GSPK_00
5,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,24.533469,25.090344,0.556875,GSPK_00
6,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,27.975969,30.675969,2.7,GSPK_00
7,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,33.021594,35.839719,2.818125,GSPK_00
8,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,34.624719,34.759719,0.135,GSPK_03
9,ABAN141223,20250216,1739701628,/scratch/users/arunps/hindibabynet/audio_raw/A...,35.215344,35.822844,0.6075,GSPK_03


In [65]:
# --------------------------
# INTERSECT diarization (global speakers) with VAD speech intervals
# --------------------------
speech_only_df = intersect_diar_with_vad(
    diar_df=diar_global_df,
    vad_intervals=vad_intervals,
    min_keep_sec=MIN_KEEP_SEC
)

# attach required metadata columns (exact ones you asked for) + speaker_id
speech_only_df.insert(0, "wav_path", str(wav_path))
speech_only_df.insert(0, "recording_id", recording_id)
speech_only_df.insert(0, "session_date", session_date)
speech_only_df.insert(0, "participant_id", participant_id)

# reorder columns
final_df_full = speech_only_df[[
    "participant_id","session_date","recording_id","wav_path",
    "start_sec","end_sec","duration_sec",
    "speaker_id"
]].sort_values(["start_sec","end_sec"]).reset_index(drop=True)

final_df_full.head(), len(final_df_full)


(  participant_id session_date recording_id  \
 0     ABAN141223     20250216   1739701628   
 1     ABAN141223     20250216   1739701628   
 2     ABAN141223     20250216   1739701628   
 3     ABAN141223     20250216   1739701628   
 4     ABAN141223     20250216   1739701628   
 
                                             wav_path  start_sec    end_sec  \
 0  /scratch/users/arunps/hindibabynet/audio_raw/A...   7.710000   9.885969   
 1  /scratch/users/arunps/hindibabynet/audio_raw/A...  13.320000  14.273469   
 2  /scratch/users/arunps/hindibabynet/audio_raw/A...  18.728469  20.095344   
 3  /scratch/users/arunps/hindibabynet/audio_raw/A...  20.790000  22.002219   
 4  /scratch/users/arunps/hindibabynet/audio_raw/A...  22.320000  23.807844   
 
    duration_sec speaker_id  
 0      2.175969    GSPK_00  
 1      0.953469    GSPK_00  
 2      1.366875    GSPK_00  
 3      1.212219    GSPK_00  
 4      1.487844    GSPK_00  ,
 983)

In [66]:
# --------------------------
# Quick sanity checks
# --------------------------
final_df_full.groupby("speaker_id")["duration_sec"].sum().sort_values(ascending=False)


speaker_id
GSPK_03    659.785594
GSPK_05    546.152406
GSPK_06     34.040625
GSPK_08     32.440063
GSPK_09     31.674375
GSPK_02     30.663094
GSPK_01     28.447500
GSPK_00     28.093437
GSPK_07     27.453094
GSPK_04     17.114437
Name: duration_sec, dtype: float64

In [67]:
from pathlib import Path
import soundfile as sf
import pandas as pd
from praatio import textgrid as tgio


def _make_interval_tier(name, entries, xmin, xmax):
    """
    Create IntervalTier across praatio versions.
    """
    # IntervalTier(name, entries, minT, maxT)
    try:
        return tgio.IntervalTier(str(name), entries, xmin, xmax)
    except TypeError:
        pass

    # IntervalTier(name, entries=..., minT=..., maxT=...)
    try:
        return tgio.IntervalTier(name=str(name), entries=entries, minT=xmin, maxT=xmax)
    except TypeError:
        pass

    # IntervalTier(name, entryList=..., minT=..., maxT=...)
    return tgio.IntervalTier(name=str(name), entryList=entries, minT=xmin, maxT=xmax)


def df_to_textgrid_by_speaker(
    df: pd.DataFrame,
    wav_path: Path,
    out_textgrid_path: Path,
    start_col: str = "start_sec",
    end_col: str = "end_sec",
    speaker_col: str = "speaker_id",
    label_col: str | None = None,  # None -> label = speaker_id
):
    wav_path = Path(wav_path)
    out_textgrid_path = Path(out_textgrid_path)

    info = sf.info(str(wav_path))
    xmin = 0.0
    xmax = float(info.duration)

    df = df.copy()
    df = df[df[end_col] > df[start_col]].sort_values([speaker_col, start_col, end_col])

    tg = tgio.Textgrid()
    tg.minTimestamp = xmin
    tg.maxTimestamp = xmax

    for spk, g in df.groupby(speaker_col):
        entries = []
        for r in g.itertuples(index=False):
            s = float(getattr(r, start_col))
            e = float(getattr(r, end_col))

            # clamp
            s = max(xmin, min(s, xmax))
            e = max(xmin, min(e, xmax))
            if e <= s:
                continue

            lab = str(spk) if label_col is None else str(getattr(r, label_col))
            entries.append((s, e, lab))

        # ensure non-overlap within the speaker tier
        entries.sort(key=lambda x: (x[0], x[1]))
        cleaned = []
        last_end = -1.0
        for s, e, lab in entries:
            if s < last_end:
                s = last_end
            if e > s:
                cleaned.append((s, e, lab))
                last_end = e

        tier = _make_interval_tier(str(spk), cleaned, xmin, xmax)
        tg.addTier(tier)

    tg.save(str(out_textgrid_path), format="short_textgrid", includeBlankSpaces=True)
    return out_textgrid_path


In [68]:
import numpy as np

tmp_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp"
tmp_dir.mkdir(parents=True, exist_ok=True)

tg_path = df_to_textgrid_by_speaker(
    df=final_df_full,                         # diarization âˆ© VAD DataFrame
    wav_path=wav_path,                   # full audio file
    out_textgrid_path=tmp_dir / f"{wav_path.stem}.TextGrid",
)

tg_path

PosixPath('/scratch/users/arunps/hindibabynet_tmp/1739701628.TextGrid')

In [69]:
# --------------------------
# CLEANUP: remove all chunk WAVs at once
# --------------------------
import shutil
from pathlib import Path

tmp_chunks_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp_chunks"

if tmp_chunks_dir.exists():
    shutil.rmtree(tmp_chunks_dir)
    print(f"Deleted temporary chunk directory: {tmp_chunks_dir}")
else:
    print("No temporary chunk directory found.")

Deleted temporary chunk directory: /scratch/users/arunps/hindibabynet_tmp_chunks
