In [53]:
import os
import sys
import glob
import numpy as np
import pandas as pd
import torch
import torchaudio
import xml.etree.ElementTree as ET
from datetime import datetime
import torch.nn.functional as F

# For clustering speaker embeddings
from sklearn.cluster import AgglomerativeClustering

# For evaluation metrics (pyannote)
from pyannote.metrics.diarization import DiarizationErrorRate
from pydub import AudioSegment
# For ASR using OpenAI Whisper
import whisper

# For speaker embeddings via SpeechBrain
from speechbrain.pretrained import EncoderClassifier

import soundfile as sf

In [54]:
# -------------------------------
# Global Constants and Paths
# -------------------------------
AMI_AUDIO_PATH = "./AMI/ES2008a.wav"  # Mixed audio file for session ES2008a
AMI_ANNOTATIONS_DIR = "./AMI/ES2008a/"  # AMI annotation XML file
OUTPUT_CSV_PATH = "./Outputs/ES2008a_transcript.csv"  # Output CSV file for ASR results
OUTPUT_SPEAKER_EMBEDDINGS_DIR = "./Outputs/ES2008a_embeddings/"  # Directory for speaker embeddings
OUTPUT_SPEAKER_EMBEDDINGS_CSV = "./Outputs/ES2008a_embeddings.csv"  # CSV for speaker embeddings
ENROLLED_TEMPLATES_DIR = "./voxceleb2/"
CUSTOM_AUDIO_PATH = "./custom_hinglish_audio.wav"

In [55]:
TARGET_SR = 16000
DISTANCE_THRESHOLD = 0.015
VAD_THRESHOLD = 0.4
MIN_SPEECH_DURATION = 1.04
RECOGNITION_THRESHOLD = 0.45
GAP_THRESHOLD = 0.15

In [56]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [57]:
print("DEVICE", DEVICE)

DEVICE mps


In [58]:
# -------------------------------
# Function: Load Audio File
# -------------------------------
def load_audio(file_path, target_sr=TARGET_SR):
    """
    Load an audio file using torchaudio and resample to target sampling rate if necessary.

    Args:
        file_path (str): Path to the audio file.
        target_sr (int): Target sampling rate (Hz).

    Returns:
        waveform (Tensor): Audio waveform (1, N) as torch.float32.
        sr (int): Sampling rate.
    """
    try:
        file_path = os.path.expanduser(file_path)
        # Check if file exists
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
        
        if file_path.lower().endswith((".m4a", ".mp4")):
            # try:
            # waveform, sr = torchaudio.load(file_path,format="mp4")
            audio = AudioSegment.from_file(file_path, format="m4a")
            # Convert AudioSegment to numpy array
            samples = np.array(audio.get_array_of_samples())
            waveform = torch.from_numpy(samples.astype(np.float32) / 32768.0).unsqueeze(0)
            sr = audio.frame_rate
            # except Exception as e:
                # print(f"torchaudio.load failed for {file_path} with format='mp4': {e}")
                # Fallback: use pydub to load the file.
        else:
            import soundfile as sf
            # Set the audio backend to sox_io (recommended on macOS)
            data, sr = sf.read(file_path)
            if len(data.shape) > 1:
                data = data[:, 0]  # use first channel if multi-channel
            waveform = torch.from_numpy(data).unsqueeze(0).float()
        
        # Resample if needed
        if sr != target_sr:
            resampler = torchaudio.transforms.Resample(sr, target_sr)
            waveform = resampler(waveform)
            sr = target_sr
        return waveform, sr
    except Exception as e:
        print(f"Error loading audio file {file_path}: {e}")
        sys.exit(1)

In [59]:
# -------------------------------
# Function: Perform Voice Activity Detection (VAD)
# -------------------------------
def perform_vad(waveform, sr, vad_threshold=VAD_THRESHOLD, min_speech_duration=MIN_SPEECH_DURATION):
    """
    Apply Silero VAD to detect and extract speech segments from the audio.

    Args:
        waveform (Tensor): Input audio waveform (1, N).
        sr (int): Sampling rate.
        vad_threshold (float): Probability threshold for speech.
        min_speech_duration (float): Minimum speech segment duration (in seconds).

    Returns:
        speech_segments (list of dict): Each dict has 'start' and 'end' (in seconds).
    """
    try:
        # Necessary for using a forked repo in Colab
        torch.hub._validate_not_a_forked_repo = lambda repo_owner, repo_name, ref: True
        # Load Silero VAD model and utilities
        models, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", trust_repo=True)
        (get_speech_ts, _, _, _, _) = utils

        # Run VAD to get speech timestamps
        speech_timestamps = get_speech_ts(waveform, model=models, sampling_rate=sr, threshold=vad_threshold)

        # Determine if the returned timestamps appear to be in samples rather than seconds.
        audio_duration_sec = waveform.shape[1] / sr  # total duration in seconds

        # Filter out very short segments
        speech_segments = []
        if any(seg["start"] > audio_duration_sec for seg in speech_timestamps):
            # Timestamps seem to be in samples; convert to seconds.
            print('Timestamps seem to be in samples; converting to seconds.')
            for seg in speech_timestamps:
                # print(seg)
                # print("Segment Start",seg["start"])
                # print("Segment End",seg["end"])
                seg["start"] = seg["start"] / sr
                seg["end"]   = seg["end"]   / sr
                duration = seg["end"] - seg["start"]
                # print("Duration",duration)
                if duration >= min_speech_duration:
                    speech_segments.append({"start": seg["start"], "end": seg["end"]})
        else:
            # Timestamps are already in seconds.
            print('Timestamps are already in seconds.')
            for seg in speech_timestamps:
                duration = seg["end"] - seg["start"]
                # print("Duration",duration)
                if duration >= min_speech_duration:
                    speech_segments.append({"start": seg["start"], "end": seg["end"]})
        return speech_segments
    except Exception as e:
        print(f"Error during VAD: {e}")
        sys.exit(1)


In [60]:
# -------------------------------
# Function: Extract Audio Segments from VAD Output
# -------------------------------
def extract_audio_segments(waveform, sr, segments):
    """
    Extract waveform segments corresponding to the speech timestamps.

    Args:
        waveform (Tensor): Full audio waveform (1, N).
        sr (int): Sampling rate.
        segments (list of dict): List with 'start' and 'end' in seconds.

    Returns:
        list: List of Tensors corresponding to each speech segment.
    """
    segment_list = []
    for seg in segments:
        start_sample = int(seg["start"] * sr)
        end_sample = int(seg["end"] * sr)
        segment_list.append(waveform[:, start_sample:end_sample])
    return segment_list


In [61]:
# -------------------------------
# Function: Extract Speaker Embeddings for Each Segment
# -------------------------------
def extract_embeddings(segments_audio, spk_model, sr=TARGET_SR):
    """
    Extract speaker embeddings for each speech segment.
    Pads segments shorter than 1 second.
    """
    embeddings = []
    min_samples = int(1.0 * sr)  # minimum 1 second length
    for seg in segments_audio:
        try:
            # Pad segment if too short
            if seg.shape[1] < min_samples:
                pad_amount = min_samples - seg.shape[1]
                seg = F.pad(seg, (0, pad_amount))
            seg = seg.to(DEVICE)
            with torch.no_grad():
                emb = spk_model.encode_batch(seg)
            embeddings.append(emb.squeeze().cpu().numpy())
        except Exception as e:
            print(f"Error extracting embedding for a segment: {e}")
    if not embeddings:
        print("No embeddings extracted!")
        sys.exit(1)
    return np.vstack(embeddings)


In [62]:
# -------------------------------
# Function: Cluster Speaker Embeddings (Diarization)
# -------------------------------
def cluster_embeddings(embeddings, distance_threshold=DISTANCE_THRESHOLD):
    """
    Cluster speaker embeddings using Agglomerative Clustering with cosine distance.

    Args:
        embeddings (np.ndarray): Array of speaker embeddings (n_segments, embedding_dim).
        distance_threshold (float): Threshold to decide cluster merging.

    Returns:
        np.ndarray: Array of cluster labels for each segment.
    """
    try:
        # Using AgglomerativeClustering with distance_threshold requires n_clusters=None.
        clusterer = AgglomerativeClustering(
    n_clusters=None, metric="cosine", linkage="average", distance_threshold=distance_threshold
)
        labels = clusterer.fit_predict(embeddings)
        return labels
    except Exception as e:
        print(f"Error during clustering: {e}")
        sys.exit(1)


In [63]:
# -------------------------------
# Function: Assign Generic Speaker Labels
# -------------------------------
def assign_speaker_labels(cluster_labels):
    """
    Assign generic speaker names (e.g., Speaker 1, Speaker 2, ...) based on cluster labels.

    Args:
        cluster_labels (np.ndarray): Array of integer cluster labels.

    Returns:
        list: List of speaker label strings corresponding to each segment.
    """
    unique_labels = sorted(set(cluster_labels))
    label_map = {label: f"Speaker {i+1}" for i, label in enumerate(unique_labels)}
    speaker_labels = [label_map[label] for label in cluster_labels]
    return speaker_labels


In [64]:
def collect_voxceleb2_files(root_dir):
    """
    Recursively traverse the VoxCeleb2 AAC directory to collect all .m4a file paths 
    and map them to speaker IDs. This function assumes that the directory structure is:
        <root_dir>/aac/<speaker_id>/<subfolder>/.../<filename>.m4a
    and that the speaker ID is the folder immediately following the "aac" folder.
    
    Args:
        root_dir (str): Root directory of the VoxCeleb2 data.
    
    Returns:
        dict: Mapping from speaker_id (str) to a list of file paths (str).
    """
    templates_paths = {}
    aac_path = os.path.join(root_dir, "aac")
    if not os.path.exists(aac_path):
        print(f"AAC folder not found in VoxCeleb2 root directory: {root_dir}")
        return templates_paths
    # Walk through the 'aac' folder recursively
    for dirpath, _, filenames in os.walk(aac_path):
        for filename in filenames:
            if filename.lower().endswith(".m4a"):
                full_path = os.path.join(dirpath, filename)
                # Split the path and extract the speaker_id as the folder immediately after "aac"
                parts = os.path.normpath(full_path).split(os.sep)
                try:
                    idx = parts.index("aac")
                    # The speaker_id is assumed to be the next folder after "aac"
                    speaker_id = parts[idx + 1]
                except (ValueError, IndexError):
                    speaker_id = "Unknown"
                if speaker_id not in templates_paths:
                    templates_paths[speaker_id] = []
                templates_paths[speaker_id].append(full_path)
    return templates_paths

In [65]:
def load_voxceleb2_templates(root_dir, spk_model, target_sr=TARGET_SR):
    """
    Load enrolled speaker templates from VoxCeleb2 data.
    For each speaker, it uses the first available audio file (organized under the "aac" folder)
    to compute a reference embedding using the pretrained speaker embedding model.
    
    Args:
        root_dir (str): Root directory of VoxCeleb2 (should contain the "aac" subfolder).
        spk_model: Pretrained SpeechBrain speaker embedding model.
        target_sr (int): Desired sampling rate.
    
    Returns:
        dict: Mapping from speaker_id to enrollment embedding (numpy array).
    """
    enrolled_templates = {}
    templates_paths = collect_voxceleb2_files(root_dir)
    for speaker_id, file_list in templates_paths.items():
        if not file_list:
            continue
        embeddings_list = []
        # For each speaker, use the first audio file in the list for enrollment.
        for template_path in file_list[:10]:
            try:
                waveform, sr = load_audio(template_path, target_sr=target_sr)
                # If the audio is too long, extract a centered two-second chunk.
                if waveform.shape[1] > target_sr * 2:
                    mid = waveform.shape[1] // 2
                    waveform = waveform[:, mid - target_sr: mid + target_sr]
                with torch.no_grad():
                    emb = spk_model.encode_batch(waveform.to(DEVICE)).squeeze().cpu().numpy()
                enrolled_templates[speaker_id] = emb
            except Exception as e:
                print(f"Error processing VoxCeleb2 file {template_path}: {e}")
        if embeddings_list:
            # Compute the average embedding for a more robust enrollment representation.
            enrolled_templates[speaker_id] = np.mean(embeddings_list, axis=0)
    return enrolled_templates

In [66]:
# -------------------------------
# Function: load_enrolled_templates for speaker recognition
# -------------------------------
def load_enrolled_templates(enrolled_dir, spk_model, target_sr=TARGET_SR):
    """
    Load enrolled speaker audio files and extract embeddings.
    
    Args:
        enrolled_dir (str): Directory containing enrolled WAV files.
        spk_model: Pretrained speaker encoder.
        target_sr (int): Sampling rate.
    
    Returns:
        Dictionary mapping file base names to embedding vectors.
    """
    enrolled_templates = {}
    audio_files = glob.glob(os.path.join(enrolled_dir, "*.wav"))
    if not audio_files:
        print("No enrolled templates found. Speaker recognition will use generic labels.")
        return enrolled_templates
    for file in audio_files:
        try:
            waveform, sr = load_audio(file, target_sr=target_sr)
            # Use a central 2-second chunk if audio is too long
            if waveform.shape[1] > sr * 2:
                mid = waveform.shape[1] // 2
                waveform = waveform[:, mid - sr: mid + sr]
            with torch.no_grad():
                emb = spk_model.encode_batch(waveform.to(DEVICE)).squeeze().cpu().numpy()
            key = os.path.splitext(os.path.basename(file))[0]
            enrolled_templates[key] = emb
        except Exception as e:
            print(f"Error processing enrolled template {file}: {e}")
    return enrolled_templates


In [67]:
# -------------------------------
# Function: perform_speaker_recognition
# -------------------------------
def perform_speaker_recognition(embeddings, cluster_labels, enrolled_templates, recognition_threshold=RECOGNITION_THRESHOLD):
    """
    For each cluster, compute the mean embedding and compare to enrolled templates using cosine similarity.
    
    Args:
        embeddings (np.ndarray): All speaker embeddings.
        cluster_labels (np.ndarray): Cluster labels.
        enrolled_templates (dict): Mapping of enrolled identity to embedding.
        recognition_threshold (float): Cosine similarity threshold.
    
    Returns:
        Dictionary mapping each cluster label to a recognized identity.
    """
    recognized_ids = {}
    if not enrolled_templates:
        for label in sorted(set(cluster_labels)):
            recognized_ids[label] = f"Speaker {label+1}"
        return recognized_ids

    unique_labels = sorted(set(cluster_labels))
    for label in unique_labels:
        # Compute the mean embedding for the current cluster
        idx = np.where(cluster_labels == label)[0]
        cluster_embs = embeddings[idx]
        mean_emb = np.mean(cluster_embs, axis=0)

        # Normalize the mean embedding to unit vector
        norm_mean_emb = mean_emb / (np.linalg.norm(mean_emb) + 1e-8)

        best_score = -1.0
        best_id = None
        for identity, temp_emb in enrolled_templates.items():
            norm_temp_emb = temp_emb / (np.linalg.norm(temp_emb) + 1e-8)
            cos_sim = np.dot(norm_mean_emb, norm_temp_emb)
            if cos_sim > best_score:
                best_score = cos_sim
                best_id = identity
        
        if best_score >= recognition_threshold:
            recognized_ids[label] = best_id
        else:
            print(f"Low similarity score ({best_score:.3f}) for cluster {label}, assigning generic label.")
            recognized_ids[label] = f"Speaker {label+1}"
    return recognized_ids

In [68]:
# -------------------------------
# Function: Transcribe Speech Segments Using Whisper ASR
# -------------------------------
def transcribe_segments(waveform, sr, segments, speaker_labels, asr_model:whisper.Whisper, language=None):
    """
    Transcribe each speech segment with Whisper ASR.

    Args:
        waveform (Tensor): Full audio waveform (1, N).
        sr (int): Sampling rate.
        segments (list of dict): List of speech segments with 'start' and 'end' times.
        speaker_labels (list): List of speaker label strings for each segment.
        asr_model: Loaded Whisper model.
        language (str): Optional language parameter for ASR.

    Returns:
        list: List of tuples (start_time, end_time, speaker, transcript).
    """
    transcript_list = []
    for seg, spk in zip(segments, speaker_labels):
        try:
            start_sample = int(seg["start"] * sr)
            end_sample = int(seg["end"] * sr)
            # Extract segment and convert to numpy array in float32, normalized between -1 and 1.
            seg_wave = waveform[:, start_sample:end_sample]
            seg_audio = seg_wave.squeeze().cpu().numpy().astype(np.float32)
            # Whisper expects mono audio at 16 kHz.
            # Transcribe the segment (the model auto-detects language if not provided)
            result = asr_model.transcribe(seg_audio, fp16=torch.cuda.is_available(), language=language)
            transcript = result["text"].strip()
            transcript_list.append((seg["start"], seg["end"], spk, transcript))
        except Exception as e:
            print(f"Error during transcription of segment {seg}: {e}")
            transcript_list.append((seg["start"], seg["end"], spk, ""))
    return transcript_list


In [69]:
# -------------------------------
# Function: merge_segments
# -------------------------------
def merge_segments(segments, gap_threshold=GAP_THRESHOLD):
    """
    Merge adjacent segments with the same speaker if the gap between segments is less than gap_threshold seconds.
    
    Args:
        segments (list of dict): Each dict contains keys "start", "end", and optionally "transcript".
        gap_threshold (float): Maximum gap (in seconds) to consider merging.
    
    Returns:
        List of merged segments.
    """
    if not segments:
        return []
    # Sort segments by start time
    segments = sorted(segments, key=lambda x: x["start"])
    merged = [segments[0].copy()]
    for seg in segments[1:]:
        last_seg = merged[-1]
        # Check if the same speaker and if gap is within threshold
        if seg.get("speaker") == last_seg.get("speaker") and (seg["start"] - last_seg["end"]) <= gap_threshold:
            # Merge segments: update end time and concatenate transcripts if available
            last_seg["end"] = seg["end"]
            if "transcript" in last_seg and last_seg["transcript"] and seg.get("transcript"):
                last_seg["transcript"] = last_seg["transcript"].strip() + " " + seg["transcript"].strip()
            elif "transcript" in seg:
                last_seg["transcript"] = seg["transcript"]
        else:
            merged.append(seg.copy())
    return merged

In [70]:
# -------------------------------
# Function: Save Transcript to CSV
# -------------------------------
def save_transcript_csv(transcript_list, output_csv_path):
    """
    Save the final transcript to a CSV file.
    
    Args:
        transcript_list (list): List of tuples (start_time, end_time, speaker, transcript).
        output_csv_path (str): Output file path for the CSV.
    """
    # Ensure the output directory exists
    output_dir = os.path.dirname(output_csv_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    df = pd.DataFrame(transcript_list, columns=["start_time", "end_time", "speaker", "transcript"])
    df.to_csv(output_csv_path, index=False)
    print(f"Transcript saved to {output_csv_path}")


In [71]:
# -------------------------------
# Function: Parse a Single AMI Annotation XML File for Ground Truth
# -------------------------------
def parse_ami_annotation(xml_path):
    """
    Parse a single AMI annotation XML file to extract ground truth speech segments.

    Args:
        xml_path (str): Path to the annotation XML file.

    Returns:
        list: List of dict objects, each with 'start' and 'end' keys (floats in seconds)
              and optionally 'speaker' if available.
    """
    gt_segments = []
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for segment in root.iter('segment'):
            # print(segment.attrib)
            start = float(segment.attrib.get("transcriber_start", "0"))
            end = float(segment.attrib.get("transcriber_end", "0"))
            speaker = segment.attrib.get("{http://nite.sourceforge.net/}id", "Unknown")
            gt_segments.append({"start": start, "end": end, "speaker": speaker})
        return gt_segments
    except Exception as e:
        print(f"Error parsing AMI XML annotation {xml_path}: {e}")
        return []


In [72]:
# -------------------------------
# Function: Parse All AMI Annotation XML Files from a Directory
# -------------------------------
def parse_all_ami_annotations(annotations_dir):
    """
    Parse all AMI annotation XML files in a directory and aggregate ground truth segments.

    Args:
        annotations_dir (str): Directory path containing AMI annotation XML files.

    Returns:
        list: Combined list of ground truth segments from all files.
    """
    all_segments = []
    xml_files = glob.glob(os.path.join(annotations_dir, "*.segments.xml"))
    if not xml_files:
        print(f"No XML annotation files found in {annotations_dir}")
        return all_segments
    for xml_file in sorted(xml_files):
        segments = parse_ami_annotation(xml_file)
        all_segments.extend(segments)
    return all_segments


In [73]:
# -------------------------------
# Function: Compute VAD F1-Score Using pyannote.metrics
# -------------------------------
def compute_vad_f1(detected_segments, reference_segments, total_duration):
    """
    Compute VAD F1 score by comparing detected speech segments with reference ground truth.

    Args:
        detected_segments (list of dict): Speech segments from VAD with 'start' and 'end'.
        reference_segments (list of dict): Ground truth speech segments.
        total_duration (float): Total duration of the audio in seconds.
        tolerance (float): Forgiveness collar in seconds.

    Returns:
        float: F1 score.
    """
    # Create binary time series for reference and hypothesis
    # Discretize time with resolution e.g., 0.1 sec
    resolution = 0.1
    times = np.arange(0, total_duration, resolution)
    ref = np.zeros_like(times)
    hyp = np.zeros_like(times)

    for seg in reference_segments:
        start_idx = int(seg["start"] / resolution)
        end_idx = int(seg["end"] / resolution)
        ref[start_idx:end_idx] = 1
    for seg in detected_segments:
        start_idx = int(seg["start"] / resolution)
        end_idx = int(seg["end"] / resolution)
        hyp[start_idx:end_idx] = 1

    # Compute true positives, false positives, false negatives
    tp = np.sum((ref == 1) & (hyp == 1))
    fp = np.sum((ref == 0) & (hyp == 1))
    fn = np.sum((ref == 1) & (hyp == 0))

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return f1


In [74]:
# -------------------------------
# Function: Compute Diarization Error Rate (DER) Using pyannote.metrics
# -------------------------------
def compute_der(hyp_segments, ref_segments, collar=0.25):
    """
    Compute Diarization Error Rate (DER) using pyannote.metrics.

    Args:
        hyp_segments (list of dict): Hypothesis segments (each with 'start', 'end', and 'speaker').
        ref_segments (list of dict): Reference segments (each with 'start', 'end', and 'speaker').
        collar (float): Forgiveness collar in seconds.

    Returns:
        float: DER percentage.
    """
    # Prepare segments in the format required by pyannote (list of tuples: (start, end, speaker))
    hyp = [(seg["start"], seg["end"], seg["speaker"]) for seg in hyp_segments]
    ref = [(seg["start"], seg["end"], seg["speaker"]) for seg in ref_segments]

    # Convert to pyannote annotation format
    from pyannote.core import Annotation, Segment
    ann_ref = Annotation()
    for start, end, spk in ref:
        ann_ref[Segment(start, end)] = spk
    ann_hyp = Annotation()
    for start, end, spk in hyp:
        ann_hyp[Segment(start, end)] = spk

    der_metric = DiarizationErrorRate(collar=collar, skip_overlap=False, ignore_overlap=False)
    der = der_metric(ann_ref, ann_hyp)
    return der * 100  # as percentage

In [75]:
def build_ami_speaker_templates(annotations_dir, audio_file, spk_model, target_sr=TARGET_SR):
    """
    Build speaker enrollment templates from AMI ground truth segments.

    This function parses the AMI annotation XML files to obtain ground truth segments,
    groups segments by speaker, extracts the corresponding portions of the main AMI audio,
    computes speaker embeddings using the pre-trained SpeechBrain model, and averages them
    to produce robust templates.

    Args:
        annotations_dir (str): Directory containing AMI annotation XML files.
        audio_file (str): Path to the main AMI audio file.
        spk_model: Pre-trained speaker embedding model.
        target_sr (int): Target sampling rate.

    Returns:
        dict: Mapping from AMI speaker ID (e.g., 'ES2008a.sync.3') to an averaged embedding (numpy array).
    """
    # Parse all AMI annotations
    ami_segments = parse_all_ami_annotations(annotations_dir)
    if not ami_segments:
        print("No AMI ground truth segments found.")
        return {}

    # Load the main AMI audio file
    waveform, sr = load_audio(audio_file, target_sr=target_sr)

    # Group segments by speaker
    speaker_segments = {}
    for seg in ami_segments:
        speaker = seg.get("speaker", "Unknown")
        speaker_segments.setdefault(speaker, []).append(seg)

    ami_templates = {}
    for speaker, segments in speaker_segments.items():
        embeddings_list = []
        for seg in segments:
            start_sample = int(seg["start"] * sr)
            end_sample = int(seg["end"] * sr)
            seg_waveform = waveform[:, start_sample:end_sample]
            # Ensure a minimum segment length by padding if necessary
            min_samples = int(1.0 * target_sr)
            if seg_waveform.shape[1] < min_samples:
                pad_amount = min_samples - seg_waveform.shape[1]
                seg_waveform = F.pad(seg_waveform, (0, pad_amount))
            try:
                with torch.no_grad():
                    emb = spk_model.encode_batch(seg_waveform.to(DEVICE)).squeeze().cpu().numpy()
                embeddings_list.append(emb)
            except Exception as e:
                print(f"Error extracting embedding for AMI segment {seg}: {e}")
        if embeddings_list:
            # Average embeddings for a robust template
            ami_templates[speaker] = np.mean(embeddings_list, axis=0)
    return ami_templates

In [76]:
# -------------------------------
# Main Processing Pipeline
# -------------------------------

In [77]:
start_time = datetime.now()
print(f"Pipeline started at {start_time}")

Pipeline started at 2025-04-14 16:00:22.641536


In [78]:
# Load audio file
print("Loading audio file...")
waveform, sr = load_audio(AMI_AUDIO_PATH, target_sr=TARGET_SR)
total_audio_duration = waveform.shape[1] / sr
print(f"Audio loaded, duration: {total_audio_duration:.2f} seconds, sampling rate: {sr} Hz")

Loading audio file...
Audio loaded, duration: 1043.36 seconds, sampling rate: 16000 Hz


In [79]:
# Perform Voice Activity Detection (VAD)
print("Running VAD...")
detected_speech_segments = perform_vad(waveform, sr, vad_threshold=VAD_THRESHOLD, min_speech_duration=MIN_SPEECH_DURATION)
print(f"Detected {len(detected_speech_segments)} speech segments via VAD.")

Running VAD...


Using cache found in /Users/aryan-kumar/.cache/torch/hub/snakers4_silero-vad_master


Timestamps seem to be in samples; converting to seconds.
Detected 192 speech segments via VAD.


In [80]:
# Extract audio segments corresponding to VAD output
segments_audio = extract_audio_segments(waveform, sr, detected_speech_segments)

In [81]:
# Load pretrained speaker embedding model (SpeechBrain ECAPA-TDNN)
print("Loading speaker embedding model (SpeechBrain ECAPA)...")
try:
    spk_model = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": DEVICE})
    print("MODEL 'SpeechBrain ECAPA' LOADED SUCCESSFULLY")
except Exception as e:
    print(f"Error loading speaker embedding model: {e}")
    sys.exit(1)

Loading speaker embedding model (SpeechBrain ECAPA)...
MODEL 'SpeechBrain ECAPA' LOADED SUCCESSFULLY


In [82]:
# Extract speaker embeddings for each speech segment
print("Extracting speaker embeddings...")
embeddings = extract_embeddings(segments_audio, spk_model, sr=sr)
print(f"Extracted embeddings for {embeddings.shape[0]} segments.")

Extracting speaker embeddings...
Extracted embeddings for 192 segments.


In [83]:
# Cluster embeddings to perform speaker diarization
print("Clustering embeddings for diarization...")
cluster_labels = cluster_embeddings(embeddings, distance_threshold=DISTANCE_THRESHOLD)

Clustering embeddings for diarization...


In [84]:
generic_labels = assign_speaker_labels(cluster_labels)
print(f"Assigned speaker labels: {set(generic_labels)}")

Assigned speaker labels: {'Speaker 61', 'Speaker 45', 'Speaker 189', 'Speaker 3', 'Speaker 125', 'Speaker 40', 'Speaker 128', 'Speaker 105', 'Speaker 31', 'Speaker 141', 'Speaker 64', 'Speaker 132', 'Speaker 183', 'Speaker 37', 'Speaker 30', 'Speaker 67', 'Speaker 88', 'Speaker 187', 'Speaker 69', 'Speaker 131', 'Speaker 170', 'Speaker 68', 'Speaker 53', 'Speaker 168', 'Speaker 136', 'Speaker 16', 'Speaker 74', 'Speaker 71', 'Speaker 164', 'Speaker 171', 'Speaker 138', 'Speaker 39', 'Speaker 6', 'Speaker 66', 'Speaker 50', 'Speaker 78', 'Speaker 188', 'Speaker 43', 'Speaker 140', 'Speaker 111', 'Speaker 13', 'Speaker 112', 'Speaker 77', 'Speaker 79', 'Speaker 29', 'Speaker 104', 'Speaker 147', 'Speaker 12', 'Speaker 107', 'Speaker 133', 'Speaker 99', 'Speaker 87', 'Speaker 19', 'Speaker 26', 'Speaker 33', 'Speaker 192', 'Speaker 143', 'Speaker 23', 'Speaker 95', 'Speaker 173', 'Speaker 5', 'Speaker 25', 'Speaker 91', 'Speaker 113', 'Speaker 49', 'Speaker 182', 'Speaker 96', 'Speaker 18

In [85]:
# Load enrolled speaker templates (optional)
# print("Loading enrolled speaker templates...")
# enrolled_templates = load_voxceleb2_templates(ENROLLED_TEMPLATES_DIR, spk_model, target_sr=TARGET_SR)

In [86]:
# Load AMI enrolled speaker templates to reduce domain mismatch for speaker recognition.
print("Loading AMI enrolled speaker templates...")
enrolled_templates = build_ami_speaker_templates(AMI_ANNOTATIONS_DIR, AMI_AUDIO_PATH, spk_model, target_sr=TARGET_SR)

Loading AMI enrolled speaker templates...


In [87]:
print(f"Loaded {len(enrolled_templates)} enrolled templates.")

Loaded 194 enrolled templates.


In [88]:
# Perform speaker recognition, if templates are available; otherwise use generic labels
if enrolled_templates:
    recognized_ids = perform_speaker_recognition(embeddings, cluster_labels, enrolled_templates, recognition_threshold=RECOGNITION_THRESHOLD)
    final_speaker_labels = [recognized_ids[label] for label in cluster_labels]
    print("Speaker recognition completed. Recognized IDs:", set(recognized_ids.values()))
else:
    final_speaker_labels = generic_labels
    print("No enrolled templates found; using generic speaker labels.")

Low similarity score (0.394) for cluster 11, assigning generic label.
Speaker recognition completed. Recognized IDs: {'ES2008a.sync.79', 'ES2008a.sync.75', 'ES2008a.sync.248', 'ES2008a.sync.19', 'ES2008a.sync.167', 'ES2008a.sync.125', 'Speaker 12', 'ES2008a.sync.143', 'ES2008a.sync.77', 'ES2008a.sync.165', 'ES2008a.sync.278', 'ES2008a.sync.359', 'ES2008a.sync.163', 'ES2008a.sync.270', 'ES2008a.sync.290', 'ES2008a.sync.53', 'ES2008a.sync.91', 'ES2008a.sync.355', 'ES2008a.sync.123', 'ES2008a.sync.351', 'ES2008a.sync.181', 'ES2008a.sync.365', 'ES2008a.sync.129', 'ES2008a.sync.137', 'ES2008a.sync.246', 'ES2008a.sync.345', 'ES2008a.sync.234', 'ES2008a.sync.13', 'ES2008a.sync.383', 'ES2008a.sync.27', 'ES2008a.sync.67', 'ES2008a.sync.131', 'ES2008a.sync.347', 'ES2008a.sync.361', 'ES2008a.sync.101', 'ES2008a.sync.268', 'ES2008a.sync.329', 'ES2008a.sync.11', 'ES2008a.sync.35', 'ES2008a.sync.9', 'ES2008a.sync.57', 'ES2008a.sync.103', 'ES2008a.sync.256', 'ES2008a.sync.221', 'ES2008a.sync.385', 'E

In [89]:
# Load Whisper ASR model (adjust model size as needed)
ASR_DEVICE = torch.device("cpu")
print("Loading Whisper ASR model on CPU (MPS not supported for Whisper)...")
try:
    asr_model = whisper.load_model("medium", device=ASR_DEVICE)
except Exception as e:
    print(f"Error loading Whisper ASR model: {e}")
    sys.exit(1)

Loading Whisper ASR model on CPU (MPS not supported for Whisper)...


In [90]:
# Transcribe each speech segment with assigned speaker labels
print("Running ASR on each speech segment...")
transcript_entries = transcribe_segments(waveform, sr, detected_speech_segments, final_speaker_labels, asr_model, language="en")

Running ASR on each speech segment...


In [91]:
# Convert transcript_entries (tuples) to dicts for merging and evaluation
hyp_segments = []
for start, end, speaker, transcript in transcript_entries:
    hyp_segments.append({"start": start, "end": end, "speaker": speaker, "transcript": transcript})

In [92]:
# Merge adjacent hypothesis segments with the same speaker if gap is below threshold
print("Merging adjacent segments with the same speaker...")
merged_hyp_segments = merge_segments(hyp_segments, gap_threshold=GAP_THRESHOLD)

Merging adjacent segments with the same speaker...


In [93]:
final_transcript = [(seg["start"], seg["end"], seg["speaker"], seg.get("transcript", "")) for seg in merged_hyp_segments]

In [94]:
final_transcript[:2]  # Display first 2 entries for verification

[(32.706,
  36.926,
  'ES2008a.sync.5',
  "Okay, good morning everybody. I'm glad you could all come. I'm really excited to start this team."),
 (37.538,
  41.374,
  'ES2008a.sync.5',
  "I'm just gonna have a little PowerPoint presentation for us for our kickoff meeting.")]

In [95]:
# Save the transcript to a CSV file
print("Saving transcript to CSV...")
save_transcript_csv(final_transcript, OUTPUT_CSV_PATH)

Saving transcript to CSV...
Transcript saved to ./Outputs/ES2008a_transcript.csv


In [96]:
# Evaluation: Aggregate ground truth segments from all AMI annotation XML files
print("Aggregating ground truth annotations from all XML files...")
reference_segments = parse_all_ami_annotations(AMI_ANNOTATIONS_DIR)

Aggregating ground truth annotations from all XML files...


In [97]:
reference_segments[:5]

[{'start': 27.863, 'end': 28.533, 'speaker': 'ES2008a.sync.3'},
 {'start': 32.265, 'end': 44.976, 'speaker': 'ES2008a.sync.5'},
 {'start': 46.112, 'end': 61.813, 'speaker': 'ES2008a.sync.7'},
 {'start': 63.84, 'end': 99.762, 'speaker': 'ES2008a.sync.9'},
 {'start': 101.136, 'end': 122.88, 'speaker': 'ES2008a.sync.11'}]

In [98]:
if reference_segments:
    vad_f1 = compute_vad_f1(detected_speech_segments, reference_segments, total_audio_duration)
    print(f"VAD F1-Score: {vad_f1:.3f}")
else:
    print("No ground truth VAD annotations available for evaluation.")

VAD F1-Score: 0.812


In [99]:
hyp_segments_for_der = [{"start": entry[0], "end": entry[1], "speaker": entry[2]} for entry in transcript_entries]

In [100]:
hyp_segments_for_der[:5]

[{'start': 32.706, 'end': 36.926, 'speaker': 'ES2008a.sync.5'},
 {'start': 37.538, 'end': 41.374, 'speaker': 'ES2008a.sync.5'},
 {'start': 41.986, 'end': 43.934, 'speaker': 'ES2008a.sync.7'},
 {'start': 46.146, 'end': 54.046, 'speaker': 'ES2008a.sync.7'},
 {'start': 54.37, 'end': 56.35, 'speaker': 'ES2008a.sync.7'}]

In [101]:
# Diarization Error Rate (DER) Evaluation
# Prepare hypothesis segments from transcript (using start, end, and speaker)
if reference_segments:
    der_value = compute_der(hyp_segments_for_der, reference_segments, collar=0.25)
    print(f"Diarization Error Rate (DER): {der_value:.2f}%")
else:
    print("No reference segments available for DER evaluation.")

Diarization Error Rate (DER): 38.75%




In [102]:
end_time = datetime.now()
print(f"Pipeline completed at {end_time}, elapsed time: {(end_time - start_time)}")

Pipeline completed at 2025-04-14 16:18:04.474946, elapsed time: 0:17:41.833410


In [103]:
def demo_inference_custom(audio_file, spk_model, asr_model, enrolled_templates=None):
    """
    Run the entire pipeline on a custom audio file and return a merged transcript.

    Args:
        audio_file (str): Path to the custom audio file.
        spk_model: Pretrained speaker embedding model.
        asr_model: Pretrained Whisper ASR model.
        enrolled_templates (dict, optional): Enrolled speaker templates for recognition.

    Returns:
        list: Merged transcript segments with 'start', 'end', 'speaker', and 'transcript'.
    """
    # Load custom audio file
    waveform, sr = load_audio(audio_file, target_sr=TARGET_SR)
    print(f"Loaded custom audio: duration {waveform.shape[1]/sr:.2f} sec, sampling rate: {sr}")

    # Perform Voice Activity Detection (VAD)
    speech_segments = perform_vad(waveform, sr, vad_threshold=VAD_THRESHOLD, min_speech_duration=MIN_SPEECH_DURATION)
    print(f"Detected {len(speech_segments)} speech segments.")

    # Extract audio segments from the detected speech regions
    segments_audio = extract_audio_segments(waveform, sr, speech_segments)

    # Extract speaker embeddings for each speech segment
    embeddings = extract_embeddings(segments_audio, spk_model, sr=sr)

    # Cluster the embeddings for diarization
    cluster_labels = cluster_embeddings(embeddings, distance_threshold=DISTANCE_THRESHOLD)

    # Assign speaker labels (generic or using enrolled templates if available)
    speaker_labels = assign_speaker_labels(cluster_labels)
    if enrolled_templates:
        recognized_ids = perform_speaker_recognition(embeddings, cluster_labels, enrolled_templates, recognition_threshold=RECOGNITION_THRESHOLD)
        speaker_labels = [recognized_ids[label] for label in cluster_labels]
    print("Speaker labels assigned.")

    # Transcribe the segments using Whisper ASR
    transcripts = transcribe_segments(waveform, sr, speech_segments, speaker_labels, asr_model, language='en')

    # Merge adjacent segments if they belong to the same speaker
    merged_segments = merge_segments([{"start": s[0], "end": s[1], "speaker": s[2], "transcript": s[3]} for s in transcripts], gap_threshold=GAP_THRESHOLD)
    return merged_segments

In [104]:
# Run demo inference with the custom audio file
custom_output = demo_inference_custom(CUSTOM_AUDIO_PATH, spk_model, asr_model, enrolled_templates)
print("Demo Transcript for Custom Hinglish Speech:")
for seg in custom_output:
    print(f"[{seg['start']:.2f}-{seg['end']:.2f}] {seg['speaker']}: {seg['transcript']}")

Loaded custom audio: duration 24.00 sec, sampling rate: 16000


Using cache found in /Users/aryan-kumar/.cache/torch/hub/snakers4_silero-vad_master


Timestamps seem to be in samples; converting to seconds.
Detected 7 speech segments.
Low similarity score (0.271) for cluster 0, assigning generic label.
Low similarity score (0.234) for cluster 1, assigning generic label.
Low similarity score (0.275) for cluster 2, assigning generic label.
Low similarity score (0.258) for cluster 3, assigning generic label.
Low similarity score (0.162) for cluster 4, assigning generic label.
Low similarity score (0.176) for cluster 5, assigning generic label.
Low similarity score (0.149) for cluster 6, assigning generic label.
Speaker labels assigned.
Demo Transcript for Custom Hinglish Speech:
[2.18-3.49] Speaker 5: This is Aryan Kumar.
[4.90-6.53] Speaker 6: I am speaking from India.
[8.26-9.50] Speaker 7: This is Robby Sharma.
[10.40-12.09] Speaker 4: I am also speaking from India
[13.12-14.24] Speaker 2: I'm in Bangalore.
[15.39-19.20] Speaker 3: This is Humana Rorah and I'm also talking from Bangalore.
[19.30-23.01] Speaker 1: Nice to meet you al