In [None]:
#!/usr/bin/env python3
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
torch.cuda.is_available(), torch.cuda.get_device_name(0)


In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import soundfile as sf
import webrtcvad

from pyannote.audio import Pipeline



In [None]:
participant_id = "ABAN141223"
session_date   = "20250216"

session_dir = Path("/scratch/users/arunps/hindibabynet/audio_raw/ABAN141223/20250216")
wav_files = sorted(session_dir.glob("*.WAV")) + sorted(session_dir.glob("*.wav"))

wav_path = Path(wav_files[0])  # choose which file
info = sf.info(str(wav_path))
print(wav_path.name, info.samplerate, info.channels, info.duration/3600, "hours")
recording_id = wav_path.stem

# diarization bounds
MIN_SPEAKERS = 2
MAX_SPEAKERS = 4

# chunking
CHUNK_SEC   = 15 * 60  # 15 min
OVERLAP_SEC = 10       # overlap between chunks to avoid boundary issues

# VAD params
VAD_AGGR = 2
VAD_FRAME_MS = 30
VAD_MIN_REGION_MS = 300

# intersection / post-filter
MIN_KEEP_SEC = 0.20


In [None]:
rows = []
for p in wav_files:
    info = sf.info(str(p))
    rows.append({
        "participant_id": participant_id,
        "session_date": session_date,
        "recording_id": p.stem,
        "path": str(p),
        "duration_sec": float(info.duration),
        "sample_rate": int(info.samplerate),
        "channels": int(info.channels),
        "size_bytes": p.stat().st_size,
    })

recordings = pd.DataFrame(rows)
recordings


In [None]:
from dotenv import load_dotenv
load_dotenv()

assert os.getenv("HF_TOKEN") is not None, "HF_TOKEN not loaded"
HF_TOKEN = os.environ["HF_TOKEN"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
os.environ["PYANNOTE_DISABLE_NOTEBOOK"] = "1"

scratch_cache = f"/scratch/users/{os.environ['USER']}/.cache/huggingface"
os.environ["HF_HOME"] = scratch_cache
os.environ["HF_HUB_CACHE"] = f"{scratch_cache}/hub"
os.environ["TRANSFORMERS_CACHE"] = f"{scratch_cache}/transformers"


In [None]:
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=HF_TOKEN
)
pipeline.to(device)
print("Diarization pipeline loaded on", device)


In [None]:
def make_chunks(duration_sec: float, chunk_sec: float, overlap_sec: float):
    """Yield (chunk_id, chunk_start, chunk_end) with overlap."""
    step = chunk_sec - overlap_sec
    assert step > 0, "chunk_sec must be > overlap_sec"

    t = 0.0
    chunk_id = 0
    while t < duration_sec:
        s = t
        e = min(t + chunk_sec, duration_sec)
        yield chunk_id, s, e
        if e >= duration_sec:
            break
        t += step
        chunk_id += 1


In [None]:
def webrtc_vad_regions_streaming(
    path: Path,
    aggressiveness: int = 2,
    frame_ms: int = 30,
    min_region_ms: int = 300,
):
    vad = webrtcvad.Vad(aggressiveness)
    info = sf.info(str(path))
    sr = info.samplerate

    if sr not in (8000, 16000, 32000, 48000):
        raise ValueError(f"webrtcvad needs sr in (8k,16k,32k,48k). got: {sr}")

    frame_len = int(sr * frame_ms / 1000)

    speech_flags = []
    with sf.SoundFile(str(path), mode="r") as f:
        while True:
            frame = f.read(frames=frame_len, dtype="int16", always_2d=True)
            if frame.size == 0 or len(frame) < frame_len:
                break
            mono = frame[:, 0]  # channel 0
            speech_flags.append(vad.is_speech(mono.tobytes(), sr))

    # merge consecutive True flags
    regions = []
    in_speech = False
    start_i = 0
    for i, is_speech in enumerate(speech_flags):
        if is_speech and not in_speech:
            in_speech = True
            start_i = i
        elif (not is_speech) and in_speech:
            in_speech = False
            regions.append((start_i, i))
    if in_speech:
        regions.append((start_i, len(speech_flags)))

    out = []
    for s_i, e_i in regions:
        s = (s_i * frame_len) / sr
        e = (e_i * frame_len) / sr
        if (e - s) * 1000 >= min_region_ms:
            out.append((float(s), float(e)))
    return out


In [None]:
info = sf.info(str(wav_path))
full_duration = float(info.duration)

vad_intervals = webrtc_vad_regions_streaming(
    wav_path,
    aggressiveness=VAD_AGGR,
    frame_ms=VAD_FRAME_MS,
    min_region_ms=VAD_MIN_REGION_MS
)

print("Full duration (hours):", full_duration / 3600)
print("VAD intervals:", len(vad_intervals), "first:", vad_intervals[:3])

vad_df = (
    pd.DataFrame([{
        "participant_id": participant_id,
        "session_date": session_date,
        "recording_id": recording_id,
        "wav_path": str(wav_path),
        "start_sec": s,
        "end_sec": e,
        "duration_sec": e - s,
    } for s, e in vad_intervals])
    .sort_values(["start_sec", "end_sec"])
    .reset_index(drop=True)
)
vad_df.head()


In [None]:
vad_df.head(20)

In [None]:
import soundfile as sf
import numpy as np
from pathlib import Path

def write_wav_chunk(
    wav_path: Path,
    chunk_path: Path,
    start_sec: float,
    end_sec: float
):
    try:
        info = sf.info(str(wav_path))
        sr = info.samplerate
        total_frames = info.frames

        start_frame = max(0, int(start_sec * sr))
        end_frame   = min(total_frames, int(end_sec * sr))
        n_frames    = end_frame - start_frame

        # ZERO-SAMPLE GUARD
        if n_frames <= 0:
            return None

        audio, _ = sf.read(
            str(wav_path),
            start=start_frame,
            frames=n_frames,
            dtype="float32"
        )

        # EMPTY ARRAY GUARD
        if audio is None or audio.size == 0:
            return None

        chunk_path.parent.mkdir(parents=True, exist_ok=True)

        sf.write(
            str(chunk_path),
            audio,
            sr,
            format="WAV",
            subtype="PCM_16"
        )

        return chunk_path

    except Exception as e:
        print(f"[WARN] write_wav_chunk failed ({chunk_path.name}): {e}")
        return None


In [None]:
import pandas as pd
from pathlib import Path

all_turn_rows = []

tmp_chunks_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp_chunks"
tmp_chunks_dir.mkdir(parents=True, exist_ok=True)

for chunk_id, chunk_start, chunk_end in make_chunks(full_duration, CHUNK_SEC, OVERLAP_SEC):

    chunk_wav = tmp_chunks_dir / f"{wav_path.stem}_chunk{chunk_id:04d}_{int(chunk_start)}_{int(chunk_end)}.wav"

    # skip if zero samples / empty / write failed
    out = write_wav_chunk(wav_path, chunk_wav, chunk_start, chunk_end)
    if out is None:
        continue

    try:
        diar_chunk = pipeline(
            {"audio": str(chunk_wav)},
            min_speakers=MIN_SPEAKERS,
            max_speakers=MAX_SPEAKERS
        )

        for seg, _, spk in diar_chunk.itertracks(yield_label=True):
            # convert chunk-local to GLOBAL times
            s = float(seg.start) + float(chunk_start)
            e = float(seg.end) + float(chunk_start)
            if e <= s:
                continue

            all_turn_rows.append({
                "participant_id": participant_id,
                "session_date": session_date,
                "recording_id": recording_id,
                "wav_path": str(wav_path),
                "chunk_id": int(chunk_id),
                "chunk_start_sec": float(chunk_start),
                "chunk_end_sec": float(chunk_end),
                "speaker_id_local": spk,
                "start_sec": s,
                "end_sec": e,
                "duration_sec": float(e - s),
                "chunk_wav_path": str(chunk_wav),  
            })

    except Exception as e:
        print(f"[WARN] Diarization failed for chunk {chunk_id} ({chunk_wav.name}): {e}")

    finally:
        # delete chunk immediately to avoid scratch filling up
        try:
            chunk_wav.unlink(missing_ok=True)
        except Exception:
            pass

print("Turns collected:", len(all_turn_rows))

turns_df = (
    pd.DataFrame(all_turn_rows)
      .sort_values(["start_sec", "end_sec"])
      .reset_index(drop=True)
)

turns_df.head()


In [None]:
def intersect_turns_with_vad(turns_df: pd.DataFrame, vad_intervals, min_keep_sec: float = 0.0):
    # ensure sorted
    turns_df = turns_df.sort_values(["start_sec", "end_sec"]).reset_index(drop=True)
    diar_arr = turns_df[["start_sec", "end_sec", "chunk_id", "speaker_id_local"]].to_numpy()
    vad_arr = np.array(sorted(vad_intervals), dtype=float)

    i = 0
    j = 0
    rows = []

    def intersect(a_s, a_e, b_s, b_e):
        s = max(a_s, b_s)
        e = min(a_e, b_e)
        return (s, e) if s < e else None

    while i < len(diar_arr) and j < len(vad_arr):
        ds, de, cid, spk = diar_arr[i]
        vs, ve = vad_arr[j]

        inter = intersect(ds, de, vs, ve)
        if inter is not None:
            s, e = inter
            dur = float(e - s)
            if dur >= min_keep_sec:
                rows.append({
                    "start_sec": float(s),
                    "end_sec": float(e),
                    "duration_sec": dur,
                    "chunk_id": int(cid),
                    "speaker_id_local": str(spk),
                })

        if de <= ve:
            i += 1
        else:
            j += 1

    return pd.DataFrame(rows)


In [None]:
speech_only_df = intersect_turns_with_vad(
    turns_df=turns_df,
    vad_intervals=vad_intervals,
    min_keep_sec=MIN_KEEP_SEC
)

# attach metadata
speech_only_df.insert(0, "wav_path", str(wav_path))
speech_only_df.insert(0, "recording_id", recording_id)
speech_only_df.insert(0, "session_date", session_date)
speech_only_df.insert(0, "participant_id", participant_id)

final_df_full = speech_only_df[[
    "participant_id","session_date","recording_id","wav_path",
    "chunk_id","start_sec","end_sec","duration_sec",
    "speaker_id_local"
]].sort_values(["start_sec","end_sec"]).reset_index(drop=True)

final_df_full.head(), len(final_df_full)


In [None]:
final_df_full.head(20)

In [None]:
final_df_full.groupby(["chunk_id", "speaker_id_local"])["duration_sec"].sum().sort_values(ascending=False).head(20)

In [None]:
import numpy as np
import pandas as pd

dur = final_df_full["duration_sec"].astype(float)

summary = dur.describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99]).to_frame().T
summary


In [None]:
thresholds = [0.2, 0.3, 0.5, 0.8, 1.0, 2.0, 5.0]

counts = []
n = len(dur)
for t in thresholds:
    counts.append({
        "threshold_sec": t,
        "n_segments": int((dur < t).sum()),
        "pct_segments": float((dur < t).mean() * 100),
    })

pd.DataFrame(counts)


In [None]:
bins = [0, 0.2, 0.5, 1, 2, 5, 10, np.inf]
labels = ["<0.2", "0.2-0.5", "0.5-1", "1-2", "2-5", "5-10", "10+"]

tmp = final_df_full.copy()
tmp["dur_bin"] = pd.cut(tmp["duration_sec"], bins=bins, labels=labels, right=False)

by_bin = (
    tmp.groupby("dur_bin")
       .agg(
           n_segments=("duration_sec", "size"),
           total_sec=("duration_sec", "sum"),
           mean_sec=("duration_sec", "mean"),
           median_sec=("duration_sec", "median"),
       )
       .reset_index()
)

by_bin["pct_segments"] = by_bin["n_segments"] / by_bin["n_segments"].sum() * 100
by_bin["pct_time"] = by_bin["total_sec"] / by_bin["total_sec"].sum() * 100

by_bin.sort_values("dur_bin")


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure()
plt.hist(dur, bins=80)
plt.xlabel("duration_sec")
plt.ylabel("count")
plt.title("Diar∩VAD segment duration distribution")
plt.show()


In [None]:
grp = (
    final_df_full
    .groupby(["chunk_id", "speaker_id_local"])["duration_sec"]
    .agg(["count", "sum", "mean", "median", "max"])
    .sort_values("sum", ascending=False)
)

grp.head(20)


In [None]:
import pandas as pd

df = final_df_full.copy()

overlap_rows = []

for chunk_id, g in df.groupby("chunk_id"):
    g = g.sort_values("start_sec").reset_index(drop=True)

    for i in range(len(g)):
        si, ei, spki = g.loc[i, ["start_sec", "end_sec", "speaker_id_local"]]

        for j in range(i + 1, len(g)):
            sj, ej, spkj = g.loc[j, ["start_sec", "end_sec", "speaker_id_local"]]

            # stop early (sorted by start time)
            if sj >= ei:
                break

            if spki != spkj:
                overlap_start = max(si, sj)
                overlap_end = min(ei, ej)

                if overlap_end > overlap_start:
                    overlap_rows.append({
                        "chunk_id": chunk_id,
                        "speaker_1": spki,
                        "speaker_2": spkj,
                        "seg1_start": si,
                        "seg1_end": ei,
                        "seg2_start": sj,
                        "seg2_end": ej,
                        "overlap_start": overlap_start,
                        "overlap_end": overlap_end,
                        "overlap_dur": overlap_end - overlap_start,
                    })


In [None]:
overlap_df = pd.DataFrame(overlap_rows)

len(overlap_df), overlap_df.head()


In [None]:
overlap_df.head(20)

In [None]:
overlap_df["overlap_dur"].describe(percentiles=[0.5, 0.75, 0.9, 0.95, 0.99])


In [None]:
segments_in_overlap = pd.concat([
    overlap_df[["chunk_id", "speaker_1", "seg1_start", "seg1_end"]]
        .rename(columns={"speaker_1": "speaker_id_local",
                          "seg1_start": "start_sec",
                          "seg1_end": "end_sec"}),

    overlap_df[["chunk_id", "speaker_2", "seg2_start", "seg2_end"]]
        .rename(columns={"speaker_2": "speaker_id_local",
                          "seg2_start": "start_sec",
                          "seg2_end": "end_sec"})
]).drop_duplicates()

len(segments_in_overlap)


In [None]:
len(segments_in_overlap) / len(df) * 100


In [None]:
bins = [0, 0.2, 0.5, 1, 2, 5, 10]
labels = ["<0.2", "0.2-0.5", "0.5-1", "1-2", "2-5", "5+"]

overlap_df["dur_bin"] = pd.cut(
    overlap_df["overlap_dur"],
    bins=bins,
    labels=labels,
    right=False
)

(
    overlap_df
    .groupby("dur_bin")
    .agg(
        n_overlaps=("overlap_dur", "size"),
        total_overlap_sec=("overlap_dur", "sum"),
    )
    .assign(
        pct_overlaps=lambda x: x["n_overlaps"] / x["n_overlaps"].sum() * 100,
        pct_time=lambda x: x["total_overlap_sec"] / x["total_overlap_sec"].sum() * 100,
    )
    .sort_index()
)


In [None]:
overlap_df.sort_values("overlap_dur", ascending=False).head(10)


In [None]:
from pathlib import Path
import soundfile as sf
import pandas as pd
from praatio import textgrid as tgio


def _make_interval_tier(name, entries, xmin, xmax):
    """
    Create IntervalTier across praatio versions.
    """
    # IntervalTier(name, entries, minT, maxT)
    try:
        return tgio.IntervalTier(str(name), entries, xmin, xmax)
    except TypeError:
        pass

    # IntervalTier(name, entries=..., minT=..., maxT=...)
    try:
        return tgio.IntervalTier(name=str(name), entries=entries, minT=xmin, maxT=xmax)
    except TypeError:
        pass

    # IntervalTier(name, entryList=..., minT=..., maxT=...)
    return tgio.IntervalTier(name=str(name), entryList=entries, minT=xmin, maxT=xmax)


def df_to_textgrid_by_speaker(
    df: pd.DataFrame,
    wav_path: Path,
    out_textgrid_path: Path,
    start_col: str = "start_sec",
    end_col: str = "end_sec",
    speaker_col: str = "speaker_id",
    label_col: str | None = None,  # None -> label = speaker_id
):
    wav_path = Path(wav_path)
    out_textgrid_path = Path(out_textgrid_path)

    info = sf.info(str(wav_path))
    xmin = 0.0
    xmax = float(info.duration)

    df = df.copy()
    df = df[df[end_col] > df[start_col]].sort_values([speaker_col, start_col, end_col])

    tg = tgio.Textgrid()
    tg.minTimestamp = xmin
    tg.maxTimestamp = xmax

    for spk, g in df.groupby(speaker_col):
        entries = []
        for r in g.itertuples(index=False):
            s = float(getattr(r, start_col))
            e = float(getattr(r, end_col))

            # clamp
            s = max(xmin, min(s, xmax))
            e = max(xmin, min(e, xmax))
            if e <= s:
                continue

            lab = str(spk) if label_col is None else str(getattr(r, label_col))
            entries.append((s, e, lab))

        # ensure non-overlap within the speaker tier
        entries.sort(key=lambda x: (x[0], x[1]))
        cleaned = []
        last_end = -1.0
        for s, e, lab in entries:
            if s < last_end:
                s = last_end
            if e > s:
                cleaned.append((s, e, lab))
                last_end = e

        tier = _make_interval_tier(str(spk), cleaned, xmin, xmax)
        tg.addTier(tier)

    tg.save(str(out_textgrid_path), format="short_textgrid", includeBlankSpaces=True)
    return out_textgrid_path


In [None]:
import numpy as np

tmp_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp"
tmp_dir.mkdir(parents=True, exist_ok=True)

tg_path = df_to_textgrid_by_speaker(
    df=final_df_full,                         # diarization ∩ VAD DataFrame
    wav_path=wav_path,                   # full audio file
    out_textgrid_path=tmp_dir / f"{wav_path.stem}.TextGrid",
)

tg_path

In [None]:
# --------------------------
# CLEANUP: remove all chunk WAVs at once
# --------------------------
import shutil
from pathlib import Path

tmp_chunks_dir = Path("/scratch/users") / Path.home().name / "hindibabynet_tmp_chunks"

if tmp_chunks_dir.exists():
    shutil.rmtree(tmp_chunks_dir)
    print(f"Deleted temporary chunk directory: {tmp_chunks_dir}")
else:
    print("No temporary chunk directory found.")

In [None]:
import pandas as pd
import numpy as np

def merge_close_segments(
    df: pd.DataFrame,
    gap_thresh: float = 0.5
) -> pd.DataFrame:

    df = df.copy()

    # keep original row index for traceability
    df["_orig_row"] = df.index

    # ensure deterministic ordering
    sort_cols = [
        "participant_id", "session_date", "recording_id", "wav_path",
        "chunk_id", "speaker_id_local", "start_sec", "end_sec"
    ]
    df = df.sort_values(sort_cols).reset_index(drop=True)

    prev_pid   = df["participant_id"].shift(1)
    prev_date  = df["session_date"].shift(1)
    prev_rec   = df["recording_id"].shift(1)
    prev_wav   = df["wav_path"].shift(1)
    prev_chunk = df["chunk_id"].shift(1)
    prev_spk   = df["speaker_id_local"].shift(1)
    prev_end   = df["end_sec"].shift(1)

    gap = df["start_sec"] - prev_end

    # define when a new merged segment must start
    new_group = (
        (df["participant_id"] != prev_pid) |
        (df["session_date"] != prev_date) |
        (df["recording_id"] != prev_rec) |
        (df["wav_path"] != prev_wav) |
        (df["chunk_id"] != prev_chunk) |
        (df["speaker_id_local"] != prev_spk) |
        (gap.isna()) |
        (gap < 0) |
        (gap > gap_thresh)
    )

    df["_merge_group"] = new_group.cumsum()

    # aggregate merged segments
    out = (
        df.groupby("_merge_group", as_index=False)
          .agg(
              participant_id=("participant_id", "first"),
              session_date=("session_date", "first"),
              recording_id=("recording_id", "first"),
              wav_path=("wav_path", "first"),
              chunk_id=("chunk_id", "first"),
              speaker_id_local=("speaker_id_local", "first"),
              start_sec=("start_sec", "min"),
              end_sec=("end_sec", "max"),
              n_segments=("_orig_row", "count"),
              orig_rows=("_orig_row", lambda x: list(x))
          )
    )

    out["duration_sec"] = out["end_sec"] - out["start_sec"]

    out = (
        out
        .sort_values(
            ["participant_id", "session_date", "recording_id", "chunk_id", "start_sec"]
        )
        .reset_index(drop=True)
    )

    return out


In [None]:
merged_df = merge_close_segments(final_df_full, gap_thresh=0.7)
print(merged_df)

In [None]:
merged_df.tail(20)

In [None]:
%pwd

In [None]:
import os
os.chdir("..")

In [None]:
%pwd

In [None]:
import joblib
m = joblib.load("models/xgb_egemaps.pkl")
print(type(m))

In [None]:
import os
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import pandas as pd

import joblib
import opensmile

from tqdm.auto import tqdm


In [None]:
MODEL_PATH = Path("models/xgb_egemaps.pkl")

assert MODEL_PATH.exists(), f"Model not found: {MODEL_PATH}"
xgb_model = joblib.load(MODEL_PATH)

print("Loaded model:", type(xgb_model))


In [None]:
from pathlib import Path
import numpy as np

# ===== Class mapping (match training) =====
LABEL2ID = {
    "adult_male": 0,
    "adult_female": 1,
    "child": 2,
    "background": 3,
}

ID2LABEL = {v: k for k, v in LABEL2ID.items()}

# Must follow training order
CLASS_NAMES = [ID2LABEL[i] for i in range(len(ID2LABEL))]
print("CLASS_NAMES:", CLASS_NAMES)

# ===== Model sanity check =====
egemaps_dim = 88  # EGEMAPS feature dimension (match training)

dummy_X = np.zeros((1, egemaps_dim), dtype=np.float32)
proba = xgb_model.predict_proba(dummy_X)

print("predict_proba shape:", proba.shape)

assert proba.shape[1] == len(CLASS_NAMES), (
    f"Model returns {proba.shape[1]} classes, "
    f"but CLASS_NAMES has {len(CLASS_NAMES)}"
)


In [None]:
CLASS_NAMES = ["adult_male", "adult_female", "child", "background"]
N_CLASSES = len(CLASS_NAMES)


In [None]:
import soundfile as sf
from scipy.signal import resample_poly

def load_audio_mono(path: str | Path) -> Tuple[np.ndarray, int]:
    x, sr = sf.read(str(path), always_2d=False)
    if x.ndim == 2:
        x = x.mean(axis=1)
    x = x.astype(np.float32, copy=False)
    return x, sr

def resample_audio(x: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
    if sr == target_sr:
        return x
    gcd = np.gcd(sr, target_sr)
    up = target_sr // gcd
    down = sr // gcd
    return resample_poly(x, up, down).astype(np.float32, copy=False)

def crop_or_pad(x: np.ndarray, target_len: int) -> np.ndarray:
    n = len(x)
    if n == target_len:
        return x
    if n > target_len:
        return x[:target_len]
    out = np.zeros(target_len, dtype=np.float32)
    out[:n] = x
    return out


In [None]:
def generate_windows(start: float, end: float, win: float = 1.0, hop: float = 0.5) -> List[Tuple[float, float, float]]:
    """
    Returns list of (w_start, w_end, weight_duration).
    Short segments get a single padded window (start -> start+win).
    Long segments use fixed 1.0s windows with 0.5s hop, plus an end-anchored last window if needed.
    """
    dur = end - start

    # too short -> single padded window
    if dur <= 0 or dur < win:
        return [(start, start + win, win)]

    windows = []
    t = start
    while t + win <= end:
        windows.append((t, t + win, win))
        t += hop

    # End-anchored last window if needed
    if not windows or windows[-1][1] < end:
        windows.append((end - win, end, win))

    return windows


In [None]:
@dataclass
class EGemapsExtractor:
    egemaps_dim: int = 88          
    target_sr: int = 16000
    win_sec: float = 1.0

    def __post_init__(self):
        self.smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.eGeMAPSv02,
            feature_level=opensmile.FeatureLevel.Functionals,
        )

    def _fix_dim(self, vec: np.ndarray) -> np.ndarray:
        vec = vec.astype(np.float32).flatten()
        if vec.shape[0] == self.egemaps_dim:
            return vec
        out = np.zeros(self.egemaps_dim, dtype=np.float32)
        m = min(self.egemaps_dim, vec.shape[0])
        out[:m] = vec[:m]
        return out

# {wav_path: (audio_16k, sr_16k)}
_AUDIO_CACHE: Dict[str, Tuple[np.ndarray, int]] = {}

def load_audio_16k_cached(wav_path: str | Path, target_sr: int = 16000) -> Tuple[np.ndarray, int]:
    key = str(wav_path)
    if key in _AUDIO_CACHE:
        return _AUDIO_CACHE[key]

    x, sr = load_audio_mono(key)
    x = resample_audio(x, sr, target_sr)
    _AUDIO_CACHE[key] = (x, target_sr)
    return x, target_sr

def extract_egemaps_for_window(
    extractor: EGemapsExtractor,
    wav_path: str | Path,
    start_sec: float,
    end_sec: float
) -> np.ndarray:
    """
    Extract eGeMAPS vector from [start_sec, end_sec] but always pad/crop to win_sec.
    """
    x16, sr = load_audio_16k_cached(wav_path, extractor.target_sr)

    s = int(round(start_sec * sr))
    e = int(round(end_sec * sr))
    s = max(0, s)
    e = min(len(x16), e)

    seg = x16[s:e]
    target_len = int(sr * extractor.win_sec)
    seg = crop_or_pad(seg, target_len)

    try:
        feats = extractor.smile.process_signal(seg, sr)
        vec = feats.values.flatten()
        return extractor._fix_dim(vec)
    except Exception:
        return np.zeros(extractor.egemaps_dim, dtype=np.float32)


In [None]:
def weighted_mean_probs(P: np.ndarray, weights: List[float]) -> np.ndarray:
    """
    P: (nwin, C) probabilities
    weights: list length nwin
    """
    W = np.array(weights, dtype=np.float32).reshape(-1, 1)
    return (P * W).sum(axis=0) / (W.sum() + 1e-12)

def predict_segment_probs(
    row: pd.Series,
    model,
    extractor: EGemapsExtractor,
    win: float = 1.0,
    hop: float = 0.5
) -> Dict[str, Any]:
    """
    Returns dict with:
      - probs (C,)
      - n_windows
      - window_durations
    """
    windows = generate_windows(float(row.start_sec), float(row.end_sec), win=win, hop=hop)

    X_list = []
    weights = []
    for ws, we, wdur in windows:
        vec = extract_egemaps_for_window(extractor, row.wav_path, ws, we)
        X_list.append(vec)
        weights.append(wdur)

    Xw = np.stack(X_list, axis=0).astype(np.float32)       # (nwin, D)
    Pw = model.predict_proba(Xw).astype(np.float32)         # (nwin, C)

    p_final = weighted_mean_probs(Pw, weights)              # (C,)

    return {
        "probs": p_final,
        "n_windows": len(windows),
        "window_durations": weights,
    }


In [None]:
df_in = merged_df.copy()

# sanity check required columns
required_cols = ["wav_path", "start_sec", "end_sec", "duration_sec"]
missing = [c for c in required_cols if c not in df_in.columns]
assert not missing, f"Missing columns in df: {missing}"

extractor = EGemapsExtractor(egemaps_dim=88, target_sr=16000, win_sec=1.0)

probs_out = []
nwin_out = []
wdur_out = []

for _, r in tqdm(df_in.iterrows(), total=len(df_in), desc="Classifying segments"):
    out = predict_segment_probs(r, xgb_model, extractor, win=1.0, hop=0.5)
    probs_out.append(out["probs"])
    nwin_out.append(out["n_windows"])
    wdur_out.append(out["window_durations"])

P = np.vstack(probs_out)  # (N, C)

df_out = df_in.copy()
df_out["n_windows"] = nwin_out
df_out["window_durations"] = wdur_out

for i, cname in enumerate(CLASS_NAMES):
    df_out[f"probs_{cname}"] = P[:, i].astype(float)

pred_idx = np.argmax(P, axis=1)
df_out["predicted_class"] = [CLASS_NAMES[i] for i in pred_idx]
df_out["predicted_confidence"] = P[np.arange(len(df_out)), pred_idx].astype(float)

df_out.head()


In [None]:
df_out.tail(20)

In [None]:
import numpy as np
import pandas as pd
import soundfile as sf
from scipy.signal import resample_poly
from IPython.display import Audio, display

def _load_mono(path):
    x, sr = sf.read(str(path), always_2d=False)
    if x.ndim == 2:
        x = x.mean(axis=1)
    return x.astype(np.float32, copy=False), sr

def _resample(x, sr, target_sr):
    if sr == target_sr:
        return x
    g = np.gcd(sr, target_sr)
    up = target_sr // g
    down = sr // g
    return resample_poly(x, up, down).astype(np.float32, copy=False)

def _slice_audio(x, sr, start_sec, end_sec):
    s = max(0, int(round(start_sec * sr)))
    e = min(len(x), int(round(end_sec * sr)))
    return x[s:e]


In [None]:
def build_combined_audio(
    df: pd.DataFrame,
    target_sr: int = 16000,
    gap_sec: float = 0.20,
    max_total_sec: float = None,     # prevent huge memory usage; set None for no limit
    min_conf: float = 0.0,
    sort_by_time: bool = True,
):
    """
    df must have: wav_path, start_sec, end_sec
    Returns: (audio_array, sr, used_rows_df)
    """
    use_df = df.copy()

    if min_conf is not None and "predicted_confidence" in use_df.columns:
        use_df = use_df[use_df["predicted_confidence"] >= float(min_conf)]

    if sort_by_time:
        use_df = use_df.sort_values(["wav_path", "start_sec"])

    if use_df.empty:
        raise ValueError("No rows selected to build combined audio.")

    gap = np.zeros(int(target_sr * gap_sec), dtype=np.float32) if gap_sec and gap_sec > 0 else None

    pieces = []
    total = 0.0
    used_rows = []

    for wav_path, gdf in use_df.groupby("wav_path"):
        x, sr = _load_mono(wav_path)
        x = _resample(x, sr, target_sr)

        for _, r in gdf.iterrows():
            seg = _slice_audio(x, target_sr, float(r.start_sec), float(r.end_sec))
            if len(seg) == 0:
                continue

            seg_dur = len(seg) / target_sr
            if max_total_sec is not None and (total + seg_dur) > float(max_total_sec):
                break

            pieces.append(seg)
            if gap is not None:
                pieces.append(gap)

            total += seg_dur + (gap_sec if gap is not None else 0.0)
            used_rows.append(r)

        if max_total_sec is not None and total >= float(max_total_sec):
            break

    if not pieces:
        raise ValueError("No audio collected (all segments empty or max_total_sec too small).")

    y = np.concatenate(pieces).astype(np.float32, copy=False)
    used_rows_df = pd.DataFrame(used_rows)

    print(f"Combined audio duration ≈ {len(y)/target_sr:.1f}s  | segments used = {len(used_rows_df)}")
    return y, target_sr, used_rows_df


In [None]:
sel = df_out[df_out["predicted_class"] == "adult_female"].copy()
y, sr, used = build_combined_audio(sel, gap_sec=0.15, max_total_sec=None, min_conf=0.0)
display(Audio(y, rate=sr))

In [None]:
sel = df_out[df_out["predicted_class"] == "child"].copy()
y, sr, used = build_combined_audio(sel, gap_sec=0.15, max_total_sec=None, min_conf=0.0)
display(Audio(y, rate=sr))

In [None]:
sel = df_out[df_out["predicted_class"] == "background"].copy()
y, sr, used = build_combined_audio(sel, gap_sec=0.15, max_total_sec=None, min_conf=0.0)
display(Audio(y, rate=sr))

In [None]:
sel = df_out[df_out["predicted_class"] == "adult_male"].copy()
y, sr, used = build_combined_audio(sel, gap_sec=0.15, max_total_sec=None, min_conf=0.0)
display(Audio(y, rate=sr))

In [None]:
from IPython.display import clear_output
clear_output(wait=False)