## Conversation Simulator

In [14]:
#@title Conversation Simulator

import os
import json
import random
import librosa
import soundfile as sf
import numpy as np
import shutil
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple


@dataclass
class SpeakerTurn:
    """Represents one speaker's turn in conversation"""

    speaker_id: str
    audio_path: str
    transcript: str
    start_time: float
    end_time: float
    overlap_with_previous: float = 0.0
    overlap_with_next: float = 0.0


class ConversationSimulator:
    def __init__(
        self,
        librispeech_root: str,
        output_dir: str,
        num_speakers: int = 3,
        turns_per_speaker: int = 3,
        sr: int = 16000,
        min_overlap_sec: float = 0.5,
        max_overlap_sec: float = 2.0,
        silence_between_speakers: float = 0.2,
        clean_output_dir: bool = True,
    ):

        self.librispeech_root = Path(librispeech_root)
        self.output_dir = Path(output_dir)
        self.num_speakers = num_speakers
        self.turns_per_speaker = turns_per_speaker
        self.sr = sr
        self.min_overlap_sec = min_overlap_sec
        self.max_overlap_sec = max_overlap_sec
        self.silence_between_speakers = silence_between_speakers

        # Clean and create output directories
        if clean_output_dir and self.output_dir.exists():
            print(f"Cleaning existing output directory: {self.output_dir}")
            shutil.rmtree(self.output_dir)

        # Create output directories (with parents)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        (self.output_dir / "conversations").mkdir(exist_ok=True)
        (self.output_dir / "reference").mkdir(exist_ok=True)

        # Cache for speaker analysis to speed up multiple conversations
        self._speaker_cache = None

    def get_transcript(self, flac_path):
        """Extract transcript from LibriSpeech structure"""
        flac_path = Path(flac_path)
        speaker_id = flac_path.parent.parent.name
        chapter_id = flac_path.parent.name
        txt_path = flac_path.parent / f"{speaker_id}-{chapter_id}.trans.txt"

        try:
            with open(txt_path, "r", encoding="utf-8") as f:
                search_key = f"{flac_path.stem} "
                for line in f:
                    if line.startswith(search_key):
                        return line[len(search_key) :].strip()
        except Exception as e:
            print(f"Transcript error for {flac_path}: {str(e)}")
        return ""

    def _analyze_speakers_once(self) -> List[Tuple[str, List[Path]]]:
        """Analyze all speakers once and cache results"""
        if self._speaker_cache is not None:
            return self._speaker_cache

        print("Analyzing LibriSpeech speakers (one-time setup)...")
        speaker_dirs = []

        # Check if we have LibriSpeech subdirectories (test-clean, train-clean-100, etc.)
        potential_roots = [self.librispeech_root]

        # Check for common LibriSpeech subdirectories
        for subdir in [
            "test-clean",
            "dev-clean",
            "train-clean-100",
            "train-clean-360",
            "train-other-500",
        ]:
            subdir_path = self.librispeech_root / subdir
            if subdir_path.exists() and subdir_path.is_dir():
                potential_roots.append(subdir_path)
                print(f"Found LibriSpeech subset: {subdir}")

        # If we found subdirectories, use them instead of root
        if len(potential_roots) > 1:
            search_roots = potential_roots[1:]  # Skip the original root
        else:
            search_roots = [self.librispeech_root]

        print(f"Searching in: {[str(root) for root in search_roots]}")

        # Find all speaker directories
        for search_root in search_roots:
            print(f"Scanning directory: {search_root}")

            try:
                speaker_list = list(os.listdir(search_root))
                print(f"Found {len(speaker_list)} entries in {search_root}")
            except:
                print(f"Cannot read directory: {search_root}")
                continue

            for speaker in speaker_list:
                speaker_path = search_root / speaker
                if not speaker_path.is_dir():
                    continue

                # Get all flac files for this speaker
                flac_files = list(speaker_path.rglob("*.flac"))

                # Sample a few files to check durations (faster than checking all)
                sample_size = min(5, len(flac_files))
                sample_files = random.sample(flac_files, sample_size)
                valid_count = 0

                for flac_path in sample_files:
                    try:
                        # Fixed: use 'path' instead of 'filename'
                        duration = librosa.get_duration(path=flac_path)
                        if duration >= 1.0:  # At least 1 second
                            valid_count += 1
                    except Exception as e:
                        print(f"    Error checking {flac_path}: {e}")
                        continue

                # If at least half the sample is valid, include this speaker
                if valid_count >= sample_size // 2:
                    speaker_dirs.append((speaker, flac_files))

        self._speaker_cache = speaker_dirs
        return speaker_dirs

    def select_speakers_and_files(self) -> Dict[str, List[Path]]:
        """Select speakers and their audio files (optimized)"""
        all_speakers = self._analyze_speakers_once()

        if len(all_speakers) < self.num_speakers:
            raise ValueError(
                f"Not enough speakers found. Need {self.num_speakers}, found {len(all_speakers)}"
            )

        # Select required number of speakers
        selected_speakers = random.sample(all_speakers, self.num_speakers)

        # Select and validate audio files for each speaker
        audio_pool = {}
        for speaker_id, all_files in selected_speakers:
            # Shuffle and check files until we have enough valid ones
            shuffled_files = random.sample(
                all_files, min(len(all_files), self.turns_per_speaker * 3)
            )
            valid_files = []

            for flac_path in shuffled_files:
                if len(valid_files) >= self.turns_per_speaker:
                    break
                try:
                    # Only check duration for files we actually plan to use
                    if librosa.get_duration(path=flac_path) >= 1.0:  # Lowered threshold
                        valid_files.append(flac_path)
                except:
                    continue

            if len(valid_files) < self.turns_per_speaker:
                # Fallback: use what we have
                print(
                    f"Warning: Speaker {speaker_id} only has {len(valid_files)} valid files"
                )

            audio_pool[speaker_id] = valid_files[: self.turns_per_speaker]

        return audio_pool

    def create_reference_audios(
        self, audio_pool: Dict[str, List[Path]]
    ) -> Dict[str, Path]:
        """Create reference audio files for each speaker"""
        ref_audios = {}

        for speaker_id, files in audio_pool.items():
            # Choose the longest file as reference
            ref_file = max(files, key=lambda x: librosa.get_duration(filename=x))
            ref_path = self.output_dir / "reference" / f"{speaker_id}.wav"

            # Load and save as reference
            y, _ = librosa.load(ref_file, sr=self.sr)
            sf.write(ref_path, y, self.sr)
            ref_audios[speaker_id] = ref_path

        return ref_audios

    def generate_conversation_sequence(
        self, audio_pool: Dict[str, List[Path]]
    ) -> List[SpeakerTurn]:
        """Generate conversation sequence with speaker turns"""
        turns = []
        speakers = list(audio_pool.keys())

        # Create round-robin conversation pattern
        # Each speaker gets their turns distributed throughout the conversation
        all_turns = []
        for speaker_id in speakers:
            for i, audio_path in enumerate(audio_pool[speaker_id]):
                transcript = self.get_transcript(audio_path)
                all_turns.append(
                    {
                        "speaker_id": speaker_id,
                        "audio_path": audio_path,
                        "transcript": transcript,
                        "speaker_turn_index": i,
                    }
                )

        # Shuffle to create natural conversation flow
        random.shuffle(all_turns)

        # Convert to SpeakerTurn objects with timing
        current_time = 0.0

        for i, turn_data in enumerate(all_turns):
            # Load audio to get duration
            audio_path = turn_data["audio_path"]
            y, _ = librosa.load(audio_path, sr=self.sr)
            duration = len(y) / self.sr

            # Calculate overlap with previous speaker (only if different speaker)
            overlap_with_prev = 0.0
            if i > 0 and turns[i - 1].speaker_id != turn_data["speaker_id"]:
                overlap_with_prev = random.uniform(
                    self.min_overlap_sec,
                    min(
                        self.max_overlap_sec,
                        duration * 0.3,  # Max 30% of current audio
                        turns[i - 1].end_time - turns[i - 1].start_time,
                    ),
                )

            # Adjust start time based on overlap
            if overlap_with_prev > 0:
                start_time = current_time - overlap_with_prev
            else:
                start_time = current_time + self.silence_between_speakers

            end_time = start_time + duration

            turn = SpeakerTurn(
                speaker_id=turn_data["speaker_id"],
                audio_path=audio_path,
                transcript=turn_data["transcript"],
                start_time=start_time,
                end_time=end_time,
                overlap_with_previous=overlap_with_prev,
            )

            turns.append(turn)
            current_time = end_time

        return turns

    def create_conversation_audio(
        self, turns: List[SpeakerTurn], output_path: Path
    ) -> float:
        """Create the final conversation audio file"""
        # Calculate total duration
        total_duration = max(turn.end_time for turn in turns)
        total_samples = int(total_duration * self.sr)

        # Create empty audio array
        conversation_audio = np.zeros(total_samples)

        # Add each speaker's audio
        for turn in turns:
            # Load audio
            y, _ = librosa.load(turn.audio_path, sr=self.sr)

            # Calculate sample positions
            start_sample = int(turn.start_time * self.sr)
            end_sample = min(start_sample + len(y), total_samples)
            audio_length = end_sample - start_sample

            # Add to conversation (with overlap handling)
            if audio_length > 0:
                conversation_audio[start_sample:end_sample] += y[:audio_length]

        # Normalize to prevent clipping
        peak = np.max(np.abs(conversation_audio))
        if peak > 0.95:
            conversation_audio = conversation_audio * 0.95 / peak

        # Save conversation audio
        sf.write(output_path, conversation_audio, self.sr)

        return total_duration

    def generate_conversations(self, num_conversations: int = 5) -> List[Dict]:
        """Generate multiple conversation scenarios"""
        metadata = []

        print(f"Starting generation of {num_conversations} conversations...")
        print(
            f"Configuration: {self.num_speakers} speakers, {self.turns_per_speaker} turns each"
        )

        for conv_id in range(num_conversations):
            try:
                print(f"\n--- Conversation {conv_id + 1}/{num_conversations} ---")

                # Select speakers and files
                print("  Selecting speakers and audio files...")
                audio_pool = self.select_speakers_and_files()
                speaker_names = list(audio_pool.keys())
                print(f"  Selected speakers: {speaker_names}")

                # Create reference audios
                print("  Creating reference audios...")
                ref_audios = self.create_reference_audios(audio_pool)

                # Generate conversation sequence
                print("  Generating conversation sequence...")
                turns = self.generate_conversation_sequence(audio_pool)
                print(f"  Created {len(turns)} speaker turns")

                # Create conversation audio
                print("  Mixing conversation audio...")
                conv_filename = f"conversation_{conv_id:03d}.wav"
                conv_path = self.output_dir / "conversations" / conv_filename
                total_duration = self.create_conversation_audio(turns, conv_path)
                print(f"  Duration: {total_duration:.1f} seconds")

                # Create metadata
                speakers_info = {}
                for speaker_id in audio_pool.keys():
                    speaker_turns = [
                        turn for turn in turns if turn.speaker_id == speaker_id
                    ]
                    speakers_info[speaker_id] = {
                        "reference_audio": ref_audios[speaker_id].as_posix(),
                        "turns": [
                            {
                                "start_time": turn.start_time,
                                "end_time": turn.end_time,
                                "transcript": turn.transcript,
                                "overlap_with_previous": turn.overlap_with_previous,
                                "audio_source": turn.audio_path.as_posix(),
                            }
                            for turn in speaker_turns
                        ],
                    }

                # Overlapping pairs for separation training
                overlapping_pairs = []
                for i in range(len(turns) - 1):
                    current_turn = turns[i]
                    next_turn = turns[i + 1]

                    if (
                        current_turn.speaker_id != next_turn.speaker_id
                        and next_turn.overlap_with_previous > 0
                    ):
                        overlapping_pairs.append(
                            {
                                "speaker_1": current_turn.speaker_id,
                                "speaker_2": next_turn.speaker_id,
                                "overlap_start": next_turn.start_time,
                                "overlap_end": current_turn.end_time,
                                "overlap_duration": next_turn.overlap_with_previous,
                            }
                        )

                conversation_metadata = {
                    "conversation_id": conv_id,
                    "conversation_audio": conv_path.as_posix(),
                    "total_duration": total_duration,
                    "num_speakers": len(audio_pool),
                    "speakers": speakers_info,
                    "overlapping_pairs": overlapping_pairs,
                    "full_transcript": " ".join([turn.transcript for turn in turns]),
                    "speaker_sequence": [turn.speaker_id for turn in turns],
                }

                metadata.append(conversation_metadata)
                print(f"  ✓ Conversation {conv_id + 1} completed!")

            except Exception as e:
                print(f"  ✗ Failed to generate conversation {conv_id}: {str(e)}")
                continue

        # Save metadata
        metadata_path = self.output_dir / "conversations_metadata.json"
        with open(metadata_path, "w", encoding="utf-8") as f:
            json.dump(metadata, f, indent=2, ensure_ascii=False)

        print(f"\n=== Generation Complete ===")
        print(
            f"Successfully generated: {len(metadata)}/{num_conversations} conversations"
        )
        print(f"Metadata saved to: {metadata_path}")

        return metadata


# Usage example
def generate_conversation_dataset(
    librispeech_root: str,
    output_dir: str,
    num_conversations: int = 5,
    num_speakers: int = 3,
    turns_per_speaker: int = 3,
    clean_output_dir: bool = True,
):
    """
    Generate conversation dataset

    Args:
        librispeech_root: Path to LibriSpeech dataset
        output_dir: Output directory for conversations
        num_conversations: Number of conversations to generate
        num_speakers: Number of speakers per conversation
        turns_per_speaker: Number of audio files per speaker
        clean_output_dir: Whether to clean output directory first
    """

    simulator = ConversationSimulator(
        librispeech_root=librispeech_root,
        output_dir=output_dir,
        num_speakers=num_speakers,
        turns_per_speaker=turns_per_speaker,
        min_overlap_sec=0.5,
        max_overlap_sec=2.0,
        silence_between_speakers=0.2,
        clean_output_dir=clean_output_dir,
    )

    metadata = simulator.generate_conversations(num_conversations)

    # Print summary
    print(f"\n=== Conversation Dataset Summary ===")
    print(f"Total conversations: {len(metadata)}")
    print(f"Speakers per conversation: {num_speakers}")
    print(f"Turns per speaker: {turns_per_speaker}")

    if metadata:
        avg_duration = np.mean([conv["total_duration"] for conv in metadata])
        total_overlaps = sum(len(conv["overlapping_pairs"]) for conv in metadata)
        print(f"Average conversation duration: {avg_duration:.1f} seconds")
        print(f"Total overlapping segments: {total_overlaps}")

    return metadata

In [15]:
metadata = generate_conversation_dataset(
    librispeech_root="LibriSpeech",
    output_dir="Pipeline_data/Conv/ver1",
    num_conversations=5,
    num_speakers=4,
    turns_per_speaker=3,
)

Starting generation of 5 conversations...
Configuration: 4 speakers, 3 turns each

--- Conversation 1/5 ---
  Selecting speakers and audio files...
Analyzing LibriSpeech speakers (one-time setup)...
Found LibriSpeech subset: test-clean
Searching in: ['LibriSpeech\\test-clean']
Scanning directory: LibriSpeech\test-clean
Found 40 entries in LibriSpeech\test-clean
  Selected speakers: ['8463', '6829', '7021', '8224']
  Creating reference audios...


	This alias will be removed in version 1.0.
  ref_file = max(files, key=lambda x: librosa.get_duration(filename=x))


  Generating conversation sequence...
  Created 12 speaker turns
  Mixing conversation audio...
  Duration: 84.4 seconds
  ✓ Conversation 1 completed!

--- Conversation 2/5 ---
  Selecting speakers and audio files...
  Selected speakers: ['7021', '7729', '1284', '4992']
  Creating reference audios...
  Generating conversation sequence...
  Created 12 speaker turns
  Mixing conversation audio...
  Duration: 93.8 seconds
  ✓ Conversation 2 completed!

--- Conversation 3/5 ---
  Selecting speakers and audio files...
  Selected speakers: ['8230', '3575', '8224', '5142']
  Creating reference audios...
  Generating conversation sequence...
  Created 12 speaker turns
  Mixing conversation audio...
  Duration: 90.9 seconds
  ✓ Conversation 3 completed!

--- Conversation 4/5 ---
  Selecting speakers and audio files...
  Selected speakers: ['3729', '4077', '1580', '7127']
  Creating reference audios...
  Generating conversation sequence...
  Created 12 speaker turns
  Mixing conversation audio..

## Seperator

In [2]:
import os
import json
import numpy as np
import librosa
import soundfile as sf
from pathlib import Path
import torch
from transformers import pipeline
from tqdm import tqdm
import re

from src.seperator import VoiceSeparator
from utils.audio import Audio
from utils.hparams import HParam
from model.embedder import SpeechEmbedder


class WorkingConversationPipeline:
    """
    Pipeline using your tested separator and ASR components
    """
    
    def __init__(self, separator_config, asr_config, tuning_params=None):
        self.separator_config = separator_config
        self.asr_config = asr_config
        self.sr = 16000
        self.chunk_size = 3.0  # seconds
        self.overlap = 0.05     # seconds
        
        # TUNABLE PARAMETERS 
        default_params = {
            'similarity_threshold': 0.3,      # D-vector similarity threshold
            'energy_threshold': 0.001,        # Audio energy threshold 
            'min_text_length': 2,             # Minimum text length
            'confidence_threshold': 0.0,      # ASR confidence threshold (if available)
            'time_gap_threshold': 1.0,        # Time gap for merging segments
            'min_audio_length': 0.5,          # Minimum audio length for d-vector extraction
            'debug_similarities': True,        # Show similarity scores for debugging
            'contrast_threshold': 0.1,        # Minimum contrast between speakers
        }
        
      
        self.params = default_params.copy()
        if tuning_params:
            self.params.update(tuning_params)
        
        # print(f"TUNING PARAMETERS:")
        # for key, value in self.params.items():
        #     print(f"   {key}: {value}")
        
        # Initialize your working models
        self._init_separator()
        self._init_asr()
        self._init_embedder()
        
        # Cache for reference d-vectors
        self.reference_dvecs = {}
    
    def _init_separator(self):
        """Initialize your working VoiceSeparator"""
        self.separator = VoiceSeparator(
            config_path=self.separator_config["config_path"],
            embedder_path=self.separator_config["embedder_path"],
            checkpoint_path=self.separator_config["checkpoint_path"],
            return_dvec=self.separator_config.get("return_dvec", False),
        )
    
    def _init_asr(self):
        """Initialize ASR using your working config"""
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.stt_pipe = pipeline(
            "automatic-speech-recognition",
            model=self.asr_config["model"],
            device=device,
            **self.asr_config.get("init_kwargs", {}),
        )
    
    def _init_embedder(self):
        """Initialize your working embedder for d-vector matching"""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hp = HParam(self.separator_config["config_path"])
        self.embedder = SpeechEmbedder(self.hp).to(device)
        self.embedder.load_state_dict(
            torch.load(self.separator_config["embedder_path"], map_location=device)
        )
        self.embedder.eval()
        self.audio_processor = Audio(self.hp)
        self.device = device
    
    def extract_dvec(self, audio_path):
        """Extract d-vector using your working method"""
        wav, _ = librosa.load(audio_path, sr=16000)
        mel = self.audio_processor.get_mel(wav)
        mel = torch.from_numpy(mel).float().to(self.device)
        return self.embedder(mel).unsqueeze(0)  # [1, emb_dim]
    
    def get_reference_dvecs(self, speaker_references):
        """Cache reference d-vectors for all speakers"""
        print(" Extracting reference d-vectors...")
        for speaker_id, ref_path in speaker_references.items():
            if speaker_id not in self.reference_dvecs:
                self.reference_dvecs[speaker_id] = self.extract_dvec(ref_path)
                print(f"  ✓ {speaker_id}: {Path(ref_path).name}")
    
    def chunk_audio(self, audio_data):
        """Split audio into overlapping chunks"""
        chunk_samples = int(self.chunk_size * self.sr)
        hop_samples = int((self.chunk_size - self.overlap) * self.sr)
        
        chunks = []
        for i in range(0, len(audio_data) - chunk_samples + 1, hop_samples):
            chunk = audio_data[i:i + chunk_samples]
            start_time = i / self.sr
            end_time = (i + chunk_samples) / self.sr
            
            chunks.append({
                'data': chunk,
                'start_time': start_time,
                'end_time': end_time,
                'chunk_id': len(chunks)
            })
        
        return chunks
    
    def separate_chunk_for_all_speakers(self, chunk, speaker_references, temp_dir="/tmp"):
        """
        Separate one chunk for all speakers using your working separator
        """
        
        chunk_path = f"{temp_dir}/chunk_{chunk['chunk_id']}.wav"
        sf.write(chunk_path, chunk['data'], self.sr)
        
        separated_audios = {}
        
        for speaker_id, ref_path in speaker_references.items():
            try:
                
                est_audio, dvec = self.separator.separate(
                    reference_file=ref_path,
                    mixed_file=chunk_path,
                    out_dir=None, # don't let class save
                )
                
                separated_audios[speaker_id] = est_audio
                
            except Exception as e:
                print(f"Separation failed for {speaker_id} in chunk {chunk['chunk_id']}: {e}")
                separated_audios[speaker_id] = np.zeros_like(chunk['data'])
        
        # Clean up temp file
        if os.path.exists(chunk_path):
            os.remove(chunk_path)
        
        return separated_audios
    
    def diarize_using_dvector_matching(self, separated_audios, chunk_info, speaker_references):
        """
        Diarize using your d-vector matching approach with tunable parameters
        """
        chunk_start = chunk_info['start_time']
        chunk_end = chunk_info['end_time']
        chunk_id = chunk_info['chunk_id']
        
        # Extract d-vectors from separated audios
        chunk_dvecs = {}
        for speaker_id, separated_audio in separated_audios.items():
            if len(separated_audio) > self.params['min_audio_length'] * self.sr:
                try:
                    mel = self.audio_processor.get_mel(separated_audio)
                    mel_tensor = torch.from_numpy(mel).float().to(self.device)
                    with torch.no_grad():
                        emb = self.embedder(mel_tensor).unsqueeze(0)  # [1, emb_dim]
                    chunk_dvecs[speaker_id] = emb
                except Exception as e:
                    if self.params['debug_similarities']:
                        print(f"     D-vector extraction failed for {speaker_id}: {e}")
                    chunk_dvecs[speaker_id] = None
            else:
                if self.params['debug_similarities']:
                    print(f"  {speaker_id}: Audio too short for d-vector extraction "
                          f"({len(separated_audio)/self.sr:.2f}s < {self.params['min_audio_length']}s)")
                chunk_dvecs[speaker_id] = None
        
        # Compare with reference d-vectors to determine active speakers
        active_speakers = {}
        
        if self.params['debug_similarities']:
            print(f"     SIMILARITY ANALYSIS:")
        
        for speaker_id in speaker_references.keys():
            if chunk_dvecs[speaker_id] is not None:
                ref_dvec = self.reference_dvecs[speaker_id]
                
                # Compute similarity with correct speaker
                sim_correct = torch.nn.functional.cosine_similarity(
                    chunk_dvecs[speaker_id], ref_dvec
                ).item()
                
                # Compute similarity with other speakers (for contrast)
                sim_others = []
                sim_details = {}
                for other_speaker_id in speaker_references.keys():
                    if other_speaker_id != speaker_id:
                        other_ref_dvec = self.reference_dvecs[other_speaker_id]
                        sim_other = torch.nn.functional.cosine_similarity(
                            chunk_dvecs[speaker_id], other_ref_dvec
                        ).item()
                        sim_others.append(sim_other)
                        sim_details[other_speaker_id] = sim_other
                
                max_sim_other = max(sim_others) if sim_others else 0.0
                contrast_score = sim_correct - max_sim_other
                
                # Check energy level
                energy = np.mean(separated_audios[speaker_id] ** 2)
                
                if self.params['debug_similarities']:
                    print(f"      {speaker_id}: sim_self={sim_correct:.3f}, "
                          f"sim_others={max_sim_other:.3f}, contrast={contrast_score:.3f}, "
                          f"energy={energy:.6f}")
                    for other_id, sim_val in sim_details.items():
                        print(f"        vs {other_id}: {sim_val:.3f}")
                
                # Decision logic
                passes_similarity = sim_correct > self.params['similarity_threshold']
                passes_contrast = contrast_score > self.params['contrast_threshold']
                passes_energy = energy > self.params['energy_threshold']
                
                if self.params['debug_similarities']:
                    print(f"        Tests: similarity={passes_similarity} "
                          f"({sim_correct:.3f} > {self.params['similarity_threshold']}), "
                          f"contrast={passes_contrast} "
                          f"({contrast_score:.3f} > {self.params['contrast_threshold']}), "
                          f"energy={passes_energy} "
                          f"({energy:.6f} > {self.params['energy_threshold']})")
                
                # Speaker is active if passes all tests
                if passes_similarity and passes_contrast and passes_energy:
                    active_speakers[speaker_id] = {
                        'separated_audio': separated_audios[speaker_id],
                        'similarity_score': sim_correct,
                        'contrast_score': contrast_score,
                        'energy_score': energy
                    }
                    print(f"   [{chunk_start:5.1f}s] {speaker_id} ACTIVE "
                          f"(sim: {sim_correct:.3f}, contrast: {contrast_score:.3f}, energy: {energy:.6f})")
                else:
                    if self.params['debug_similarities']:
                        print(f"   [{chunk_start:5.1f}s] {speaker_id} INACTIVE "
                              f"(failed: {[t for t, p in [('sim', passes_similarity), ('contrast', passes_contrast), ('energy', passes_energy)] if not p]})")
            else:
                if self.params['debug_similarities']:
                    print(f"      {speaker_id}: No d-vector extracted")
        
        return active_speakers
    
    def transcribe_active_speakers(self, active_speakers, chunk_info):
        """
        Transcribe audio for active speakers using your ASR with tunable filtering
        """
        transcriptions = []
        
        for speaker_id, speaker_data in active_speakers.items():
            try:
                separated_audio = speaker_data['separated_audio']
                
                # Check if audio has enough energy 
                energy = speaker_data['energy_score']
                if energy < self.params['energy_threshold']:
                    if self.params['debug_similarities']:
                        print(f"      {speaker_id}: Audio energy too low for transcription")
                    continue
                
                
                result = self.stt_pipe(separated_audio)
                text = result["text"].strip()
                
                # Get confidence if available
                confidence = result.get("confidence", 1.0)  # Default to 1.0 if not available
                
                if self.params['debug_similarities']:
                    print(f"     {speaker_id} raw transcription: '{text}' (conf: {confidence:.3f})")
                
                # Apply your text filtering with tunable parameters
                if self._is_valid_transcription(text, confidence):
                    transcriptions.append({
                        'speaker_id': speaker_id,
                        'start_time': chunk_info['start_time'],
                        'end_time': chunk_info['end_time'],
                        'text': text,
                        'chunk_id': chunk_info['chunk_id'],
                        'similarity_score': speaker_data['similarity_score'],
                        'contrast_score': speaker_data['contrast_score'],
                        'energy_score': energy,
                        'confidence': confidence
                    })
                    
                    print(f"     {speaker_id}: {text}")
                else:
                    if self.params['debug_similarities']:
                        print(f"     {speaker_id}: Transcription filtered out")
            
            except Exception as e:
                print(f" Transcription failed for {speaker_id}: {e}")
        
        return transcriptions
    
    def _is_valid_transcription(self, text, confidence=1.0):
        """Filter transcriptions using tunable parameters"""
        if not text:
            return False
        
        # Length check
        if len(text) < self.params['min_text_length']:
            if self.params['debug_similarities']:
                print(f"        Filter: Text too short ({len(text)} < {self.params['min_text_length']})")
            return False
        
        # Confidence check
        if confidence < self.params['confidence_threshold']:
            if self.params['debug_similarities']:
                print(f"        Filter: Confidence too low ({confidence:.3f} < {self.params['confidence_threshold']})")
            return False
        
        # Filter out common noise patterns
        text_lower = text.lower().strip()
        noise_patterns = [
            r'^[\s\-\.]+$',                    # Only punctuation/spaces
            r'^(uh|um|ah|er|mm|hmm)[\s\-]*$',  # Filler words only
            r'^\[.*\]$',                       # Bracket annotations
            r'^[\(\)]+$',                      # Only parentheses
        ]
        
        for pattern in noise_patterns:
            if re.match(pattern, text_lower):
                if self.params['debug_similarities']:
                    print(f"        Filter: Matches noise pattern '{pattern}'")
                return False
        
        return True
    
    def _normalize_text(self, text):
        """Your text normalization"""
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text) 
        return " ".join(text.split())
    
    def process_conversation(self, conversation_metadata, output_dir):
        """
        Main processing function using your working components
        """
        conv_id = conversation_metadata['conversation_id']
        audio_path = conversation_metadata['conversation_audio']
        speakers = conversation_metadata['speakers']
        
        print(f"\n{'='*60}")
        print(f" PROCESSING CONVERSATION {conv_id} WITH YOUR WORKING MODELS")
        print(f"{'='*60}")
        print(f" Audio: {Path(audio_path).name}")
        print(f" Speakers: {list(speakers.keys())}")
        
        # Get speaker references
        speaker_refs = {
            speaker_id: info['reference_audio']
            for speaker_id, info in speakers.items()
        }
        
        # Extract reference d-vectors
        self.get_reference_dvecs(speaker_refs)
        
        # Load conversation audio
        audio_data, _ = librosa.load(audio_path, sr=self.sr)
        print(f"  Duration: {len(audio_data) / self.sr:.1f} seconds")
        
        # Chunk audio
        chunks = self.chunk_audio(audio_data)
        print(f" Created {len(chunks)} chunks ({self.chunk_size}s each, {self.overlap}s overlap)")
        
        # Initialize speaker transcription lists
        speaker_transcriptions = {speaker_id: [] for speaker_id in speaker_refs.keys()}
        
        # Process each chunk
        print(f"\n Processing chunks...")
        for chunk in tqdm(chunks, desc="Processing"):
            print(f"\nChunk {chunk['chunk_id']} [{chunk['start_time']:.1f}s - {chunk['end_time']:.1f}s]:")
            
            # Step 1: Separate for all speakers
            separated_audios = self.separate_chunk_for_all_speakers(chunk, speaker_refs)
            
            # Step 2: Diarize using d-vector matching
            active_speakers = self.diarize_using_dvector_matching(
                separated_audios, chunk, speaker_refs
            )
            
            # Step 3: Transcribe active speakers
            if active_speakers:
                transcriptions = self.transcribe_active_speakers(active_speakers, chunk)
                
                # Add to speaker lists
                for trans in transcriptions:
                    speaker_transcriptions[trans['speaker_id']].append(trans)
            else:
                print(f"   No active speakers detected")
        
        # Merge consecutive segments
        print(f"\n Merging consecutive segments...")
        for speaker_id in speaker_transcriptions.keys():
            original_count = len(speaker_transcriptions[speaker_id])
            speaker_transcriptions[speaker_id] = self._merge_consecutive_segments(
                speaker_transcriptions[speaker_id]
            )
            merged_count = len(speaker_transcriptions[speaker_id])
            print(f"  {speaker_id}: {original_count} → {merged_count} segments "
                  f"(gap threshold: {self.params['time_gap_threshold']}s)")
        
        # Save individual speaker JSONs
        print(f"\n Saving speaker JSONs...")
        json_paths = {}
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        for speaker_id, transcriptions in speaker_transcriptions.items():
            json_path = self._save_speaker_json(
                speaker_id, transcriptions, output_dir, conv_id
            )
            json_paths[speaker_id] = json_path
            
            total_duration = sum(
                trans['end_time'] - trans['start_time'] 
                for trans in transcriptions
            )
            print(f"   {speaker_id}: {len(transcriptions)} segments, {total_duration:.1f}s")
        
        # Create summary
        self._create_summary(speaker_transcriptions, output_dir, conv_id, conversation_metadata)
        
        print(f"\n PROCESSING COMPLETE!")
        print(f" Output: {output_dir}")
        
        return speaker_transcriptions, json_paths
    
    def _merge_consecutive_segments(self, transcriptions):
        """Merge consecutive segments from same speaker using tunable threshold"""
        if not transcriptions:
            return []
        
        # Sort by start time
        transcriptions.sort(key=lambda x: x['start_time'])
        
        merged = []
        current = transcriptions[0].copy()
        
        for next_trans in transcriptions[1:]:
            # If segments are close in time, merge them
            time_gap = next_trans['start_time'] - current['end_time']
            
            if time_gap <= self.params['time_gap_threshold']:
                current['end_time'] = next_trans['end_time']
                current['text'] = current['text'] + ' ' + next_trans['text']
                current['similarity_score'] = max(
                    current['similarity_score'], next_trans['similarity_score']
                )
                if self.params['debug_similarities']:
                    print(f"      Merged segments (gap: {time_gap:.2f}s <= {self.params['time_gap_threshold']}s)")
            else:
                merged.append(current)
                current = next_trans.copy()
        
        merged.append(current)
        return merged
    
    def _save_speaker_json(self, speaker_id, transcriptions, output_dir, conversation_id):
        """Save transcriptions for one speaker to JSON"""
       
        segments = []
        for trans in transcriptions:
            segments.append({
                'start_time': trans['start_time'],
                'end_time': trans['end_time'],
                'duration': trans['end_time'] - trans['start_time'],
                'text': trans['text'],
                'chunk_id': trans['chunk_id'],
                'similarity_score': trans.get('similarity_score', 0.0),
                'contrast_score': trans.get('contrast_score', 0.0)
            })
        
        # Create speaker JSON
        speaker_data = {
            'speaker_id': speaker_id,
            'conversation_id': conversation_id,
            'total_segments': len(segments),
            'total_duration': sum(seg['duration'] for seg in segments),
            'segments': segments
        }
        
        # Save JSON
        json_path = output_dir / f"{speaker_id}_transcription.json"
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(speaker_data, f, indent=2, ensure_ascii=False)
        
        return json_path
    
    def _create_summary(self, speaker_transcriptions, output_dir, conv_id, conversation_metadata):
        """Create summary with timeline and ground truth comparison"""
        # Create timeline
        all_segments = []
        for speaker_id, transcriptions in speaker_transcriptions.items():
            for trans in transcriptions:
                all_segments.append({
                    'speaker_id': speaker_id,
                    'start_time': trans['start_time'],
                    'end_time': trans['end_time'],
                    'text': trans['text'],
                    'similarity_score': trans.get('similarity_score', 0.0)
                })
        
        # Sort by time
        all_segments.sort(key=lambda x: x['start_time'])
        
        # Create summary
        summary_data = {
            'conversation_id': conv_id,
            'processing_method': 'working_separator_asr_pipeline',
            'speakers': {
                speaker_id: {
                    'total_segments': len(transcriptions),
                    'total_duration': sum(trans['end_time'] - trans['start_time'] for trans in transcriptions)
                }
                for speaker_id, transcriptions in speaker_transcriptions.items()
            },
            'timeline': all_segments,
            'ground_truth': conversation_metadata['speakers']
        }
        
        # Save summary
        summary_path = output_dir / f"conversation_{conv_id:03d}_summary.json"
        with open(summary_path, 'w', encoding='utf-8') as f:
            json.dump(summary_data, f, indent=2, ensure_ascii=False)
        
        print(f" Summary saved: {summary_path}")


def process_conversation_with_your_models(conversations_metadata_path, 
                                        conversation_id=0,
                                        separator_config=None,
                                        asr_config=None,
                                        tuning_params=None,
                                        output_dir="working_transcriptions"):
    """
    Process conversation using your tested separator and ASR models with tunable parameters
    
    Args:
        conversations_metadata_path: Path to conversation metadata
        conversation_id: Which conversation to process  
        separator_config: Voice separator configuration
        asr_config: ASR configuration
        tuning_params: Dict of tunable parameters for thresholds
        output_dir: Output directory for JSON files
    """
    
    # Default configs based on your working code
    if separator_config is None:
        raise ValueError("Please provide a valid separator_config dictionary.")
    
    if asr_config is None:
        raise ValueError("Please provide a valid asr_config dictionary.")
    
    # Load conversation metadata
    with open(conversations_metadata_path, 'r') as f:
        metadata = json.load(f)
    
    if conversation_id >= len(metadata):
        print(f" Conversation {conversation_id} not found! Available: 0-{len(metadata)-1}")
        return None, None
    
    conversation = metadata[conversation_id]
    
    conv_output_dir = f"{output_dir}/conversation_{conversation_id:03d}"
    
    pipeline = WorkingConversationPipeline(separator_config, asr_config, tuning_params)
    
    transcriptions, json_paths = pipeline.process_conversation(conversation, conv_output_dir)
    
    return transcriptions, json_paths



 

In [3]:
separator_config = {
    "config_path": "config/inference.yaml",
    "embedder_path": "ckpt/embedder.pt",
    "checkpoint_path": "ckpt/seperator_best_checkpoint.pt",
    "return_dvec": False,
}

asr_config = {
    "model": "ckpt/whisper-small", 
    "init_kwargs": {},
}

# TUNABLE PARAMETERS
tuning_params = {
    "similarity_threshold": 0.2,  # Lower = more permissive (try 0.1-0.5)
    "energy_threshold": 0.0005,  # Lower = more permissive (try 0.0001-0.01)
    "min_text_length": 1,  # Lower = more permissive (try 1-5)
    "confidence_threshold": 0.0,  # ASR confidence threshold (try 0.0-0.5)
    "time_gap_threshold": 1.5,  # Time gap for merging (try 0.5-3.0)
    "min_audio_length": 0.3,  # Minimum audio for d-vector (try 0.1-1.0)
    "debug_similarities": False,  # Show detailed similarity scores
    "contrast_threshold": 0.05,  # Lower = more permissive (try 0.0-0.3)
}

transcriptions, json_paths = process_conversation_with_your_models(
    conversations_metadata_path="Pipeline_data/Conv/ver1/conversations_metadata.json",
    conversation_id=1,
    separator_config=separator_config,
    asr_config=asr_config,
    tuning_params=tuning_params,  # Pass tuning parameters
    output_dir="pipeline_data/Conv/ver1/seperated_transcriptions",
)

if transcriptions:
    print(f"\n SUCCESS! Generated speaker JSONs:")
    for speaker_id, json_path in json_paths.items():
        count = len(transcriptions[speaker_id])
        print(f" {speaker_id}: {json_path} ({count} segments)")
else:
    print(f"\n No transcriptions generated. Try adjusting tuning_params:")

Device set to use cuda



 PROCESSING CONVERSATION 1 WITH YOUR WORKING MODELS
 Audio: conversation_001.wav
 Speakers: ['7021', '7729', '1284', '4992']
 Extracting reference d-vectors...
  ✓ 7021: 7021.wav
  ✓ 7729: 7729.wav
  ✓ 1284: 1284.wav
  ✓ 4992: 4992.wav
  Duration: 93.8 seconds
 Created 31 chunks (3.0s each, 0.05s overlap)

 Processing chunks...


Processing:   0%|          | 0/31 [00:00<?, ?it/s]


Chunk 0 [0.0s - 3.0s]:


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.


   [  0.0s] 4992 ACTIVE (sim: 0.747, contrast: 0.340, energy: 0.000756)


Processing:   3%|▎         | 1/31 [00:02<01:10,  2.34s/it]

     4992: Nancy's Curly Chestnut Crop

Chunk 1 [3.0s - 6.0s]:
   [  3.0s] 4992 ACTIVE (sim: 0.681, contrast: 0.222, energy: 0.001619)


Processing:   6%|▋         | 2/31 [00:03<00:50,  1.74s/it]

     4992: shown in the sun and olives thick black plates looked

Chunk 2 [5.9s - 8.9s]:
   [  5.9s] 7729 ACTIVE (sim: 0.705, contrast: 0.108, energy: 0.003128)


Processing:  10%|▉         | 3/31 [00:04<00:41,  1.47s/it]

     7729: But the affair was magnified.

Chunk 3 [8.8s - 11.8s]:
   [  8.8s] 7729 ACTIVE (sim: 0.882, contrast: 0.404, energy: 0.002855)


Processing:  13%|█▎        | 4/31 [00:06<00:37,  1.37s/it]

     7729: as a crowning proof that the Free State

Chunk 4 [11.8s - 14.8s]:
   [ 11.8s] 7729 ACTIVE (sim: 0.833, contrast: 0.403, energy: 0.001997)


Processing:  16%|█▌        | 5/31 [00:07<00:34,  1.32s/it]

     7729: and were insurrectionist and outlaws.

Chunk 5 [14.8s - 17.8s]:
   [ 14.8s] 7729 ACTIVE (sim: 0.868, contrast: 0.317, energy: 0.003991)


Processing:  19%|█▉        | 6/31 [00:08<00:32,  1.30s/it]

     7729: In a few days an officer came with a

Chunk 6 [17.7s - 20.7s]:
   [ 17.7s] 7729 ACTIVE (sim: 0.892, contrast: 0.485, energy: 0.002513)


Processing:  23%|██▎       | 7/31 [00:09<00:30,  1.27s/it]

     7729: requisition from Governor Shannon and to the

Chunk 7 [20.6s - 23.6s]:
   [ 20.6s] 7729 ACTIVE (sim: 0.857, contrast: 0.365, energy: 0.002182)


Processing:  26%|██▌       | 8/31 [00:10<00:28,  1.25s/it]

     7729: prisoner by land to Westport, and afterward

Chunk 8 [23.6s - 26.6s]:
   [ 23.6s] 7021 ACTIVE (sim: 0.627, contrast: 0.085, energy: 0.002389)
   [ 23.6s] 7729 ACTIVE (sim: 0.795, contrast: 0.325, energy: 0.001431)
     7021: It truly is


Processing:  29%|██▉       | 9/31 [00:12<00:30,  1.39s/it]

     7729: from there to Kansas City and Leavenworth.

Chunk 9 [26.6s - 29.6s]:


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


   [ 26.6s] 1284 ACTIVE (sim: 0.648, contrast: 0.270, energy: 0.000555)


Processing:  32%|███▏      | 10/31 [00:13<00:27,  1.29s/it]

     1284: is asserted the

Chunk 10 [29.5s - 32.5s]:
   [ 29.5s] 1284 ACTIVE (sim: 0.708, contrast: 0.249, energy: 0.002833)
   [ 29.5s] 4992 ACTIVE (sim: 0.696, contrast: 0.231, energy: 0.000769)
     1284: was not so easy as you may suppose.


Processing:  35%|███▌      | 11/31 [00:15<00:28,  1.43s/it]

     4992: And yet you might own her.

Chunk 11 [32.5s - 35.5s]:
   [ 32.5s] 4992 ACTIVE (sim: 0.938, contrast: 0.461, energy: 0.001968)


Processing:  39%|███▊      | 12/31 [00:16<00:26,  1.38s/it]

     4992: her behavior has warranted them? Has it not been

Chunk 12 [35.4s - 38.4s]:
   [ 35.4s] 4992 ACTIVE (sim: 0.961, contrast: 0.459, energy: 0.001369)


Processing:  42%|████▏     | 13/31 [00:18<00:24,  1.37s/it]

     4992: in this particular incoherent and unaccountable.

Chunk 13 [38.4s - 41.4s]:
   [ 38.4s] 1284 ACTIVE (sim: 0.936, contrast: 0.494, energy: 0.003997)


Processing:  45%|████▌     | 14/31 [00:19<00:23,  1.38s/it]

     1284: prescribed by the civil and ecclesiastical powers of the Empire.

Chunk 14 [41.3s - 44.3s]:
   [ 41.3s] 1284 ACTIVE (sim: 0.930, contrast: 0.409, energy: 0.002537)


Processing:  48%|████▊     | 15/31 [00:20<00:21,  1.34s/it]

     1284: The Donatist still maintained in some proper

Chunk 15 [44.2s - 47.2s]:
   [ 44.2s] 1284 ACTIVE (sim: 0.928, contrast: 0.487, energy: 0.003048)


Processing:  52%|█████▏    | 16/31 [00:21<00:19,  1.31s/it]

     1284: particularly in the media, their superior numbers.

Chunk 16 [47.2s - 50.2s]:
   [ 47.2s] 1284 ACTIVE (sim: 0.942, contrast: 0.420, energy: 0.002828)


Processing:  55%|█████▍    | 17/31 [00:23<00:18,  1.29s/it]

     1284: numbers and 400 bishops acknowledged the jurisdiction.

Chunk 17 [50.1s - 53.1s]:
   [ 50.1s] 7021 ACTIVE (sim: 0.740, contrast: 0.391, energy: 0.002037)
   [ 50.1s] 1284 ACTIVE (sim: 0.842, contrast: 0.329, energy: 0.000918)
     7021: But his mother hugged him


Processing:  58%|█████▊    | 18/31 [00:24<00:18,  1.40s/it]

     1284: restriction of their primate.

Chunk 18 [53.1s - 56.1s]:
   [ 53.1s] 7021 ACTIVE (sim: 0.856, contrast: 0.462, energy: 0.001570)


Processing:  61%|██████▏   | 19/31 [00:25<00:15,  1.31s/it]

     7021: close. For instance,

Chunk 19 [56.0s - 59.0s]:
   [ 56.0s] 7021 ACTIVE (sim: 0.928, contrast: 0.477, energy: 0.004128)


Processing:  65%|██████▍   | 20/31 [00:27<00:14,  1.30s/it]

     7021: One day, the children had been playing a

Chunk 20 [59.0s - 62.0s]:
   [ 59.0s] 7021 ACTIVE (sim: 0.934, contrast: 0.547, energy: 0.003052)


Processing:  68%|██████▊   | 21/31 [00:28<00:13,  1.32s/it]

     7021: on the Piazza with blocks and other playthings.

Chunk 21 [62.0s - 65.0s]:
   [ 62.0s] 7021 ACTIVE (sim: 0.923, contrast: 0.565, energy: 0.004263)


Processing:  71%|███████   | 22/31 [00:29<00:11,  1.28s/it]

     7021: and finally had gone into the house.

Chunk 22 [64.9s - 67.9s]:
   [ 64.9s] 7021 ACTIVE (sim: 0.895, contrast: 0.512, energy: 0.006593)


Processing:  74%|███████▍  | 23/31 [00:31<00:10,  1.30s/it]

     7021: leaving all the things on the floor of the piazza.

Chunk 23 [67.8s - 70.8s]:
   [ 67.8s] 7021 ACTIVE (sim: 0.930, contrast: 0.538, energy: 0.004258)


Processing:  77%|███████▋  | 24/31 [00:32<00:08,  1.28s/it]

     7021: instead of putting them away in their place.

Chunk 24 [70.8s - 73.8s]:
   [ 70.8s] 7021 ACTIVE (sim: 0.902, contrast: 0.531, energy: 0.002554)
   [ 70.8s] 7729 ACTIVE (sim: 0.799, contrast: 0.280, energy: 0.003097)
     7021: as they ought to have done.


Processing:  81%|████████  | 25/31 [00:34<00:08,  1.41s/it]

     7729: The whole proceeding was so childish.

Chunk 25 [73.8s - 76.8s]:
   [ 73.8s] 7729 ACTIVE (sim: 0.937, contrast: 0.597, energy: 0.004132)


Processing:  84%|████████▍ | 26/31 [00:35<00:06,  1.35s/it]

     7729: The miserable plot, so transparent.

Chunk 26 [76.7s - 79.7s]:
   [ 76.7s] 7729 ACTIVE (sim: 0.943, contrast: 0.533, energy: 0.003615)


Processing:  87%|████████▋ | 27/31 [00:36<00:05,  1.31s/it]

     7729: the outrage so gross at spring disgust.

Chunk 27 [79.7s - 82.7s]:
   [ 79.7s] 7729 ACTIVE (sim: 0.928, contrast: 0.482, energy: 0.003854)


Processing:  90%|█████████ | 28/31 [00:37<00:03,  1.32s/it]

     7729: to the better class of border ruffians who were witness

Chunk 28 [82.6s - 85.6s]:
   [ 82.6s] 7729 ACTIVE (sim: 0.771, contrast: 0.162, energy: 0.001145)
   [ 82.6s] 4992 ACTIVE (sim: 0.742, contrast: 0.251, energy: 0.002557)
     7729: and accessories.


Processing:  94%|█████████▎| 29/31 [00:39<00:02,  1.45s/it]

     4992: What can you mean by that, Miss Woodward?

Chunk 29 [85.5s - 88.5s]:
   [ 85.5s] 7021 ACTIVE (sim: 0.658, contrast: 0.297, energy: 0.001460)
   [ 85.5s] 4992 ACTIVE (sim: 0.749, contrast: 0.296, energy: 0.001925)
     7021: You have come!


Processing:  97%|█████████▋| 30/31 [00:41<00:01,  1.47s/it]

     4992: You talk mysteriously.

Chunk 30 [88.5s - 91.5s]:
   [ 88.5s] 7021 ACTIVE (sim: 0.816, contrast: 0.358, energy: 0.001007)


Processing: 100%|██████████| 31/31 [00:42<00:00,  1.37s/it]

     7021: Andella. Andella was the name of Gene's.

 Merging consecutive segments...
  7021: 11 → 3 segments (gap threshold: 1.5s)
  7729: 12 → 2 segments (gap threshold: 1.5s)
  1284: 7 → 2 segments (gap threshold: 1.5s)
  4992: 7 → 3 segments (gap threshold: 1.5s)

 Saving speaker JSONs...
   7021: 3 segments, 32.6s
   7729: 2 segments, 35.5s
   1284: 2 segments, 20.7s
   4992: 3 segments, 20.8s
 Summary saved: pipeline_data\Conv\ver1\seperated_transcriptions\conversation_001\conversation_001_summary.json

 PROCESSING COMPLETE!
 Output: pipeline_data\Conv\ver1\seperated_transcriptions\conversation_001

 SUCCESS! Generated speaker JSONs:
 7021: pipeline_data\Conv\ver1\seperated_transcriptions\conversation_001\7021_transcription.json (3 segments)
 7729: pipeline_data\Conv\ver1\seperated_transcriptions\conversation_001\7729_transcription.json (2 segments)
 1284: pipeline_data\Conv\ver1\seperated_transcriptions\conversation_001\1284_transcription.json (2 segments)
 4992: pipeline_data\C




In [18]:
#!pip install seaborn

## Evaluvator

In [4]:
import os
import json
import re
import numpy as np
from pathlib import Path
from jiwer import wer, mer, wil
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns


class ConversationPipelineEvaluator:
    """
    Evaluator for the WorkingConversationPipeline
    Assesses both transcription quality and speaker diarization accuracy
    """
    
    def __init__(self, config):
        self.config = config
        self.results = []
        self.speaker_results = {}
        self.conversation_results = {}
        
    def _normalize_text(self, text):
        """Normalize text for comparison"""
        if not text:
            return ""
        text = text.lower()
        text = re.sub(r"[^\w\s]", "", text)
        return " ".join(text.split())
    
    def _load_ground_truth(self):
        """Load ground truth conversation metadata"""
        with open(self.config["ground_truth_path"], "r") as f:
            return json.load(f)
    
    def _load_pipeline_output(self, conversation_id, output_dir):
        """Load pipeline output for a specific conversation"""
        output_dir = Path(output_dir)
        
        # Load summary
        summary_path = output_dir / f"conversation_{conversation_id:03d}_summary.json"
        if not summary_path.exists():
            return None, {}
        
        with open(summary_path, "r") as f:
            summary = json.load(f)
        
        # Load individual speaker JSONs
        speaker_data = {}
        for speaker_id in summary.get("speakers", {}).keys():
            speaker_path = output_dir / f"{speaker_id}_transcription.json"
            if speaker_path.exists():
                with open(speaker_path, "r") as f:
                    speaker_data[speaker_id] = json.load(f)
        
        return summary, speaker_data
    
    def _extract_ground_truth_segments(self, gt_conversation):
        """Extract ground truth segments by speaker"""
        gt_segments = {}
        
        # Check if ground truth has detailed segments or turns structure
        if "detailed_segments" in gt_conversation:
            # Use detailed ground truth segments (original format)
            for segment in gt_conversation["detailed_segments"]:
                speaker_id = segment["speaker_id"]
                if speaker_id not in gt_segments:
                    gt_segments[speaker_id] = []
                gt_segments[speaker_id].append({
                    "start_time": segment["start_time"],
                    "end_time": segment["end_time"],
                    "text": segment["text"]
                })
        elif "speakers" in gt_conversation:
            # Use speakers with turns structure (your format)
            for speaker_id, speaker_info in gt_conversation["speakers"].items():
                if "turns" in speaker_info:
                    # Extract segments from turns
                    gt_segments[speaker_id] = []
                    for turn in speaker_info["turns"]:
                        gt_segments[speaker_id].append({
                            "start_time": turn["start_time"],
                            "end_time": turn["end_time"], 
                            "text": turn["transcript"]
                        })
                elif "full_transcript" in speaker_info:
                    # Use overall speaker transcripts (create single segment per speaker)
                    gt_segments[speaker_id] = [{
                        "start_time": 0,
                        "end_time": gt_conversation.get("total_duration", 300),
                        "text": speaker_info["full_transcript"]
                    }]
        
        return gt_segments
    
    def _calculate_speaker_wer(self, predicted_segments, ground_truth_segments):
        """Calculate WER for a specific speaker"""
        if not predicted_segments or not ground_truth_segments:
            return {
                "wer": 1.0,  # 100% error if no segments
                "mer": 1.0,
                "wil": 1.0,
                "predicted_text": "",
                "ground_truth_text": "",
                "segment_count": 0
            }
        
        # Combine all segments into full text
        pred_text = " ".join([seg["text"] for seg in predicted_segments])
        gt_text = " ".join([seg["text"] for seg in ground_truth_segments])
        
        # Normalize texts
        pred_norm = self._normalize_text(pred_text)
        gt_norm = self._normalize_text(gt_text)
        
        # Calculate metrics
        if not gt_norm:
            return {
                "wer": 1.0 if pred_norm else 0.0,
                "mer": 1.0 if pred_norm else 0.0,
                "wil": 1.0 if pred_norm else 0.0,
                "predicted_text": pred_text,
                "ground_truth_text": gt_text,
                "segment_count": len(predicted_segments)
            }
        
        try:
            speaker_wer = wer(gt_norm, pred_norm)
            speaker_mer = mer(gt_norm, pred_norm)
            speaker_wil = wil(gt_norm, pred_norm)
        except:
            # Handle edge cases
            speaker_wer = 1.0
            speaker_mer = 1.0
            speaker_wil = 1.0
        
        return {
            "wer": speaker_wer,
            "mer": speaker_mer,
            "wil": speaker_wil,
            "predicted_text": pred_text,
            "ground_truth_text": gt_text,
            "normalized_predicted": pred_norm,
            "normalized_ground_truth": gt_norm,
            "segment_count": len(predicted_segments)
        }
    
    def _calculate_diarization_metrics(self, predicted_timeline, ground_truth_segments):
        """Calculate diarization accuracy metrics"""

        total_time = 0
        correct_time = 0
        
        sampling_rate = 0.1
        max_time = max([seg["end_time"] for segments in ground_truth_segments.values() 
                       for seg in segments], default=300)
        
        for t in np.arange(0, max_time, sampling_rate):
            # Find ground truth speaker at time t
            gt_speaker = None
            for speaker_id, segments in ground_truth_segments.items():
                for seg in segments:
                    if seg["start_time"] <= t <= seg["end_time"]:
                        gt_speaker = speaker_id
                        break
                if gt_speaker:
                    break
            
            # Find predicted speaker at time t
            pred_speaker = None
            for seg in predicted_timeline:
                if seg["start_time"] <= t <= seg["end_time"]:
                    pred_speaker = seg["speaker_id"]
                    break
            
            if gt_speaker is not None:  # Only count time where there should be speech
                total_time += sampling_rate
                if gt_speaker == pred_speaker:
                    correct_time += sampling_rate
        
        diarization_accuracy = correct_time / total_time if total_time > 0 else 0.0
        
        return {
            "diarization_accuracy": diarization_accuracy,
            "total_evaluated_time": total_time,
            "correctly_assigned_time": correct_time
        }
    
    def _calculate_speaker_coverage(self, predicted_speakers, ground_truth_speakers):
        """Calculate how well we detected each speaker"""
        coverage = {}
        
        for gt_speaker in ground_truth_speakers:
            if gt_speaker in predicted_speakers:
                pred_duration = sum(seg["duration"] for seg in predicted_speakers[gt_speaker])
                # Rough estimate of ground truth duration
                gt_duration = 60  # Default assumption, should be calculated from GT
                coverage[gt_speaker] = {
                    "detected": True,
                    "predicted_duration": pred_duration,
                    "coverage_ratio": min(pred_duration / gt_duration, 1.0)
                }
            else:
                coverage[gt_speaker] = {
                    "detected": False,
                    "predicted_duration": 0,
                    "coverage_ratio": 0.0
                }
        
        # Check for false positive speakers
        for pred_speaker in predicted_speakers:
            if pred_speaker not in ground_truth_speakers:
                coverage[f"{pred_speaker}_FP"] = {
                    "false_positive": True,
                    "predicted_duration": sum(seg["duration"] for seg in predicted_speakers[pred_speaker])
                }
        
        return coverage
    
    def evaluate_conversation(self, conversation_id, output_dir):
        """Evaluate a single conversation"""
        print(f"\n Evaluating Conversation {conversation_id}")
        
        # Load ground truth
        ground_truth = self._load_ground_truth()
        if conversation_id >= len(ground_truth):
            print(f" Conversation {conversation_id} not found in ground truth")
            return None
        
        gt_conversation = ground_truth[conversation_id]
        print(f"   Ground truth loaded: {len(gt_conversation.get('speakers', {}))} speakers")
        
        # Load pipeline output
        summary, speaker_data = self._load_pipeline_output(conversation_id, output_dir)
        if summary is None:
            print(f" Pipeline output not found for conversation {conversation_id}")
            return None
        
        print(f"   Pipeline output loaded: {len(speaker_data)} speaker files")
        
        # Extract ground truth segments
        gt_segments = self._extract_ground_truth_segments(gt_conversation)
        print(f"   Ground truth segments extracted for speakers: {list(gt_segments.keys())}")
        
        # Debug: Show segment counts
        for speaker_id, segments in gt_segments.items():
            total_duration = sum(seg["end_time"] - seg["start_time"] for seg in segments)
            print(f"    {speaker_id}: {len(segments)} segments, {total_duration:.1f}s total")
        
        # Evaluate each speaker
        speaker_metrics = {}
        for speaker_id in gt_segments.keys():
            print(f"   Evaluating {speaker_id}...")
            
            # Get predicted segments for this speaker
            predicted_segments = []
            if speaker_id in speaker_data:
                predicted_segments = speaker_data[speaker_id]["segments"]
                print(f"    Found {len(predicted_segments)} predicted segments")
            else:
                print(f"    No predicted segments found for {speaker_id}")
            
            # Calculate WER for this speaker
            speaker_wer_result = self._calculate_speaker_wer(
                predicted_segments, gt_segments[speaker_id]
            )
            
            speaker_metrics[speaker_id] = speaker_wer_result
            
            print(f"    WER: {speaker_wer_result['wer']:.2%}")
            print(f"    Segments: {speaker_wer_result['segment_count']}")
        
        # Calculate diarization metrics
        predicted_timeline = summary.get("timeline", [])
        print(f"   Timeline has {len(predicted_timeline)} segments")
        diarization_metrics = self._calculate_diarization_metrics(predicted_timeline, gt_segments)
        
        # Calculate speaker coverage
        predicted_speakers = {sid: data["segments"] for sid, data in speaker_data.items()}
        coverage_metrics = self._calculate_speaker_coverage(predicted_speakers, gt_segments.keys())
        
        # Compile conversation result
        conversation_result = {
            "conversation_id": conversation_id,
            "speaker_metrics": speaker_metrics,
            "diarization_metrics": diarization_metrics,
            "coverage_metrics": coverage_metrics,
            "overall_wer": float(np.mean([m["wer"] for m in speaker_metrics.values()]) if speaker_metrics else 1.0),
            "total_predicted_speakers": len(speaker_data),
            "total_ground_truth_speakers": len(gt_segments),
            "predicted_segments_total": sum(len(data["segments"]) for data in speaker_data.values())
        }
        
        self.conversation_results[conversation_id] = conversation_result
        
        print(f"   Overall WER: {conversation_result['overall_wer']:.2%}")
        print(f"   Diarization Accuracy: {diarization_metrics['diarization_accuracy']:.2%}")
        print(f"   Speaker Detection: {len(speaker_data)}/{len(gt_segments)}")
        
        return conversation_result
    
    def evaluate_all_conversations(self, output_base_dir):
        """Evaluate all conversations in the output directory"""
        print(f" Starting comprehensive evaluation...")
        
        # Find all conversation output directories
        output_base_path = Path(output_base_dir)
        conversation_dirs = list(output_base_path.glob("conversation_*"))
        
        if not conversation_dirs:
            print(f" No conversation directories found in {output_base_dir}")
            return
        
        # Extract conversation IDs
        conversation_ids = []
        for dir_path in conversation_dirs:
            try:
                conv_id = int(dir_path.name.split("_")[1])
                conversation_ids.append(conv_id)
            except:
                continue
        
        conversation_ids.sort()
        print(f" Found {len(conversation_ids)} conversations to evaluate")
        
        # Evaluate each conversation
        all_results = []
        for conv_id in tqdm(conversation_ids, desc="Evaluating conversations"):
            conv_dir = output_base_path / f"conversation_{conv_id:03d}"
            result = self.evaluate_conversation(conv_id, conv_dir)
            if result:
                all_results.append(result)
        
        # Calculate overall statistics
        self._calculate_overall_statistics(all_results)
        
        # Save results
        self._save_evaluation_results(all_results)
        
        # Generate report
        self._generate_evaluation_report(all_results)
        
        return all_results
    
    def _calculate_overall_statistics(self, all_results):
        """Calculate overall statistics across all conversations"""
        if not all_results:
            return
        
        # WER statistics
        all_wers = []
        speaker_wers = {}
        
        for result in all_results:
            all_wers.append(result["overall_wer"])
            
            for speaker_id, metrics in result["speaker_metrics"].items():
                if speaker_id not in speaker_wers:
                    speaker_wers[speaker_id] = []
                speaker_wers[speaker_id].append(metrics["wer"])
        
        # Diarization statistics
        diarization_scores = [r["diarization_metrics"]["diarization_accuracy"] for r in all_results]
        
        # Speaker detection statistics
        detection_rates = []
        for result in all_results:
            detected = result["total_predicted_speakers"]
            total = result["total_ground_truth_speakers"]
            detection_rates.append(detected / total if total > 0 else 0)
        
        # Convert numpy types to Python types for JSON serialization
        self.overall_stats = {
            "wer": {
                "mean": float(np.mean(all_wers)),
                "std": float(np.std(all_wers)),
                "min": float(np.min(all_wers)),
                "max": float(np.max(all_wers)),
                "median": float(np.median(all_wers))
            },
            "diarization_accuracy": {
                "mean": float(np.mean(diarization_scores)),
                "std": float(np.std(diarization_scores)), 
                "min": float(np.min(diarization_scores)),
                "max": float(np.max(diarization_scores)),
                "median": float(np.median(diarization_scores))
            },
            "speaker_detection_rate": {
                "mean": float(np.mean(detection_rates)),
                "std": float(np.std(detection_rates)),
                "min": float(np.min(detection_rates)),
                "max": float(np.max(detection_rates)),
                "median": float(np.median(detection_rates))
            },
            "per_speaker_wer": {
                speaker_id: {
                    "mean": float(np.mean(wers)),
                    "count": int(len(wers))
                } for speaker_id, wers in speaker_wers.items()
            },
            "total_conversations": int(len(all_results))
        }
    
    def _convert_numpy_types(self, obj):
        """Recursively convert numpy types to Python native types for JSON serialization"""
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: self._convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_numpy_types(item) for item in obj]
        else:
            return obj
    
    def _save_evaluation_results(self, all_results):
        """Save detailed evaluation results to JSON"""
        output_dir = Path(self.config["output_dir"])
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Convert numpy types to Python types
        safe_results = self._convert_numpy_types({
            "config": self.config,
            "overall_statistics": self.overall_stats,
            "conversation_results": all_results
        })
        
        # Save detailed results
        results_path = output_dir / "detailed_evaluation_results.json"
        with open(results_path, "w") as f:
            json.dump(safe_results, f, indent=2)
        
        print(f" Detailed results saved to: {results_path}")
    
    def _generate_evaluation_report(self, all_results):
        """Generate a human-readable evaluation report"""
        output_dir = Path(self.config["output_dir"])
        report_path = output_dir / "evaluation_report.txt"
        
        with open(report_path, "w") as f:
            f.write("="*80 + "\n")
            f.write("CONVERSATION PIPELINE EVALUATION REPORT\n")
            f.write("="*80 + "\n\n")
            
            # Overall Statistics
            f.write("OVERALL PERFORMANCE METRICS\n")
            f.write("-"*40 + "\n")
            f.write(f"Total Conversations Evaluated: {self.overall_stats['total_conversations']}\n\n")
            
            # WER Results
            wer_stats = self.overall_stats['wer']
            f.write("WORD ERROR RATE (WER)\n")
            f.write(f"  Mean WER: {wer_stats['mean']:.2%}\n")
            f.write(f"  Median WER: {wer_stats['median']:.2%}\n")
            f.write(f"  Best WER: {wer_stats['min']:.2%}\n")
            f.write(f"  Worst WER: {wer_stats['max']:.2%}\n")
            f.write(f"  Std Dev: ±{wer_stats['std']:.2%}\n\n")
            
            # Diarization Results
            diar_stats = self.overall_stats['diarization_accuracy']
            f.write("SPEAKER DIARIZATION ACCURACY\n")
            f.write(f"  Mean Accuracy: {diar_stats['mean']:.2%}\n")
            f.write(f"  Median Accuracy: {diar_stats['median']:.2%}\n")
            f.write(f"  Best Accuracy: {diar_stats['max']:.2%}\n")
            f.write(f"  Worst Accuracy: {diar_stats['min']:.2%}\n")
            f.write(f"  Std Dev: ±{diar_stats['std']:.2%}\n\n")
            
            # Speaker Detection
            detect_stats = self.overall_stats['speaker_detection_rate']
            f.write("SPEAKER DETECTION RATE\n")
            f.write(f"  Mean Detection Rate: {detect_stats['mean']:.2%}\n")
            f.write(f"  Median Detection Rate: {detect_stats['median']:.2%}\n")
            f.write(f"  Perfect Detection Rate: {detect_stats['max']:.2%}\n")
            f.write(f"  Worst Detection Rate: {detect_stats['min']:.2%}\n\n")
            
            # Per-Speaker Analysis
            f.write("PER-SPEAKER WER ANALYSIS\n")
            f.write("-"*40 + "\n")
            for speaker_id, stats in self.overall_stats['per_speaker_wer'].items():
                f.write(f"  {speaker_id}: {stats['mean']:.2%} WER ({stats['count']} conversations)\n")
            f.write("\n")
            
            # Conversation-by-Conversation Results
            f.write("CONVERSATION-BY-CONVERSATION RESULTS\n")
            f.write("-"*40 + "\n")
            for result in all_results:
                conv_id = result['conversation_id']
                overall_wer = result['overall_wer']
                diar_acc = result['diarization_metrics']['diarization_accuracy']
                speakers_detected = result['total_predicted_speakers']
                speakers_total = result['total_ground_truth_speakers']
                
                f.write(f"Conversation {conv_id:03d}: WER={overall_wer:.2%}, "
                       f"Diarization={diar_acc:.2%}, Speakers={speakers_detected}/{speakers_total}\n")
        
        print(f" Evaluation report saved to: {report_path}")
        
        # Print summary to console
        print(f"\n EVALUATION SUMMARY")
        print(f"{'='*50}")
        print(f"Mean WER: {wer_stats['mean']:.2%}")
        print(f"Mean Diarization Accuracy: {diar_stats['mean']:.2%}")
        print(f"Mean Speaker Detection Rate: {detect_stats['mean']:.2%}")
        print(f"Total Conversations: {self.overall_stats['total_conversations']}")


def evaluate_pipeline_results(
    ground_truth_path,
    pipeline_output_dir,
    output_dir="evaluation_results",
    conversation_ids=None
):
    """
    Evaluate pipeline results against ground truth
    
    Args:
        ground_truth_path: Path to conversation metadata JSON (ground truth)
        pipeline_output_dir: Directory containing pipeline outputs
        output_dir: Directory to save evaluation results
        conversation_ids: List of specific conversation IDs to evaluate (None for all)
    """
    
    config = {
        "ground_truth_path": ground_truth_path,
        "pipeline_output_dir": pipeline_output_dir,
        "output_dir": output_dir,
        "conversation_ids": conversation_ids
    }
    
    evaluator = ConversationPipelineEvaluator(config)
    
    if conversation_ids is None:
        # Evaluate all conversations
        results = evaluator.evaluate_all_conversations(pipeline_output_dir)
    else:
        # Evaluate specific conversations
        results = []
        for conv_id in conversation_ids:
            conv_dir = Path(pipeline_output_dir) / f"conversation_{conv_id:03d}"
            result = evaluator.evaluate_conversation(conv_id, conv_dir)
            if result:
                results.append(result)
        
        # Calculate overall statistics and save results
        evaluator._calculate_overall_statistics(results)
        evaluator._save_evaluation_results(results)
        evaluator._generate_evaluation_report(results)
    
    return results



In [8]:

results = evaluate_pipeline_results(
    ground_truth_path="Pipeline_data/Conv/ver1/conversations_metadata.json",
    pipeline_output_dir="pipeline_data/Conv/ver1/seperated_transcriptions",
    output_dir="evaluation_results",
)

 Starting comprehensive evaluation...
 Found 1 conversations to evaluate


Evaluating conversations: 100%|██████████| 1/1 [00:00<00:00, 88.95it/s]


 Evaluating Conversation 1
   Ground truth loaded: 4 speakers
   Pipeline output loaded: 4 speaker files
   Ground truth segments extracted for speakers: ['7021', '7729', '1284', '4992']
    7021: 3 segments, 27.7s total
    7729: 3 segments, 31.4s total
    1284: 3 segments, 20.6s total
    4992: 3 segments, 20.1s total
   Evaluating 7021...
    Found 3 predicted segments
    WER: 18.57%
    Segments: 3
   Evaluating 7729...
    Found 2 predicted segments
    WER: 10.13%
    Segments: 2
   Evaluating 1284...
    Found 2 predicted segments
    WER: 30.00%
    Segments: 2
   Evaluating 4992...
    Found 3 predicted segments
    WER: 18.75%
    Segments: 3
   Timeline has 10 segments
   Overall WER: 19.36%
   Diarization Accuracy: 89.03%
   Speaker Detection: 4/4
 Detailed results saved to: evaluation_results\detailed_evaluation_results.json
 Evaluation report saved to: evaluation_results\evaluation_report.txt

 EVALUATION SUMMARY
Mean WER: 19.36%
Mean Diarization Accuracy: 89.03%
Mean 




## End to End Pipeline

In [9]:
import os
import json
import numpy as np
import librosa
import soundfile as sf
from pathlib import Path
import torch
from transformers import pipeline
from tqdm import tqdm
import re

from src.seperator import VoiceSeparator
from utils.audio import Audio
from utils.hparams import HParam
from model.embedder import SpeechEmbedder


class AudioSeparatorTranscriber:
    """
    Ready-made setup for audio separation and transcription.
    Takes a mixed WAV file and reference audio paths, then processes everything.
    """

    def __init__(self, separator_config, asr_config, processing_params=None):
        """
        Initialize the separator and transcriber.

        Args:
            separator_config: Dict with separator model paths
            asr_config: Dict with ASR model configuration
            processing_params: Dict with processing parameters
        """
        self.separator_config = separator_config
        self.asr_config = asr_config
        self.target_sr = 16000
        self.chunk_size = 3.0  # seconds
        self.overlap = 0.05  # seconds

        # Processing parameters
        default_params = {
            "similarity_threshold": 0.3,  # D-vector similarity threshold
            "energy_threshold": 0.001,  # Audio energy threshold
            "min_text_length": 2,  # Minimum text length
            "confidence_threshold": 0.0,  # ASR confidence threshold
            "time_gap_threshold": 1.0,  # Time gap for merging segments
            "min_audio_length": 0.5,  # Minimum audio length for d-vector
            "contrast_threshold": 0.1,  # Minimum contrast between speakers
            "debug_mode": False,  # Show detailed debug info
        }

        self.params = default_params.copy()
        if processing_params:
            self.params.update(processing_params)

        # Initialize models
        self._init_separator()
        self._init_asr()
        self._init_embedder()

        # Cache for reference d-vectors
        self.reference_dvecs = {}

    def _init_separator(self):
        """Initialize voice separator"""
        self.separator = VoiceSeparator(
            config_path=self.separator_config["config_path"],
            embedder_path=self.separator_config["embedder_path"],
            checkpoint_path=self.separator_config["checkpoint_path"],
            return_dvec=self.separator_config.get("return_dvec", False),
        )

    def _init_asr(self):
        """Initialize ASR pipeline"""
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.stt_pipe = pipeline(
            "automatic-speech-recognition",
            model=self.asr_config["model"],
            device=device,
            **self.asr_config.get("init_kwargs", {}),
        )

    def _init_embedder(self):
        """Initialize speech embedder for d-vector extraction"""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hp = HParam(self.separator_config["config_path"])
        self.embedder = SpeechEmbedder(self.hp).to(device)
        self.embedder.load_state_dict(
            torch.load(self.separator_config["embedder_path"], map_location=device)
        )
        self.embedder.eval()
        self.audio_processor = Audio(self.hp)
        self.device = device

    def preprocess_audio(self, audio_path, output_path=None):
        """
        Preprocess audio: load, resample to 16kHz, normalize

        Args:
            audio_path: Path to input audio file
            output_path: Optional path to save preprocessed audio

        Returns:
            audio_data: Preprocessed audio array
            duration: Audio duration in seconds
        """
        print(f"Preprocessing audio: {Path(audio_path).name}")

        # Load audio
        audio_data, original_sr = librosa.load(audio_path, sr=None)
        print(f"  Original: {original_sr}Hz, {len(audio_data)/original_sr:.2f}s")

        # Resample to 16kHz if needed
        if original_sr != self.target_sr:
            audio_data = librosa.resample(
                audio_data, orig_sr=original_sr, target_sr=self.target_sr
            )
            print(f"  Resampled to {self.target_sr}Hz")

        # Normalize audio
        audio_data = audio_data / np.max(np.abs(audio_data))
        print(f"  Normalized to [-1, 1]")

        duration = len(audio_data) / self.target_sr
        print(f"  Final: {self.target_sr}Hz, {duration:.2f}s")

        # Save preprocessed audio if requested
        if output_path:
            sf.write(output_path, audio_data, self.target_sr)
            print(f"  Saved to: {output_path}")

        return audio_data, duration

    def extract_dvec(self, audio_path):
        """Extract d-vector from audio file"""
        wav, _ = librosa.load(audio_path, sr=self.target_sr)
        mel = self.audio_processor.get_mel(wav)
        mel = torch.from_numpy(mel).float().to(self.device)
        with torch.no_grad():
            return self.embedder(mel).unsqueeze(0)

    def prepare_references(self, reference_paths):
        """
        Prepare reference audio files and extract d-vectors

        Args:
            reference_paths: Dict {speaker_id: audio_path} or List of audio paths

        Returns:
            reference_info: Dict with speaker info and d-vectors
        """
        print("\nPreparing reference audio files...")

        
        if isinstance(reference_paths, list):
            # Convert list to dict with speaker IDs
            reference_info = {
                f"Speaker_{i+1}": path for i, path in enumerate(reference_paths)
            }
        elif isinstance(reference_paths, dict):
            reference_info = reference_paths.copy()
        else:
            raise ValueError("reference_paths must be a list or dict")

        # Extract d-vectors for each reference
        for speaker_id, ref_path in reference_info.items():
            if not os.path.exists(ref_path):
                raise FileNotFoundError(f"Reference audio not found: {ref_path}")

            # Preprocess reference audio
            ref_audio, _ = self.preprocess_audio(ref_path)

            # Save preprocessed version temporarily
            temp_ref_path = f"/tmp/ref_{speaker_id}.wav"
            sf.write(temp_ref_path, ref_audio, self.target_sr)

            # Extract d-vector
            self.reference_dvecs[speaker_id] = self.extract_dvec(temp_ref_path)

            if os.path.exists(temp_ref_path):
                os.remove(temp_ref_path)

            print(f"   {speaker_id}: {Path(ref_path).name}")

        return reference_info

    def chunk_audio(self, audio_data):
        """Split audio into overlapping chunks"""
        chunk_samples = int(self.chunk_size * self.target_sr)
        hop_samples = int((self.chunk_size - self.overlap) * self.target_sr)

        chunks = []
        for i in range(0, len(audio_data) - chunk_samples + 1, hop_samples):
            chunk = audio_data[i : i + chunk_samples]
            start_time = i / self.target_sr
            end_time = (i + chunk_samples) / self.target_sr

            chunks.append(
                {
                    "data": chunk,
                    "start_time": start_time,
                    "end_time": end_time,
                    "chunk_id": len(chunks),
                }
            )

        return chunks

    def separate_chunk(self, chunk, reference_info, temp_dir="/tmp"):
        """Separate audio chunk for all speakers"""
        chunk_path = f"{temp_dir}/chunk_{chunk['chunk_id']}.wav"
        sf.write(chunk_path, chunk["data"], self.target_sr)

        separated_audios = {}

        for speaker_id, ref_path in reference_info.items():
            try:
                temp_ref_path = f"{temp_dir}/temp_ref_{speaker_id}.wav"
                ref_audio, _ = librosa.load(ref_path, sr=self.target_sr)
                ref_audio = ref_audio / np.max(np.abs(ref_audio))  # Normalize
                sf.write(temp_ref_path, ref_audio, self.target_sr)

                # Separate audio
                est_audio, dvec = self.separator.separate(
                    reference_file=temp_ref_path,
                    mixed_file=chunk_path,
                    out_dir=None,
                )

                separated_audios[speaker_id] = est_audio

                if os.path.exists(temp_ref_path):
                    os.remove(temp_ref_path)

            except Exception as e:
                if self.params["debug_mode"]:
                    print(f"Separation failed for {speaker_id}: {e}")
                separated_audios[speaker_id] = np.zeros_like(chunk["data"])

        # Clean up chunk file
        if os.path.exists(chunk_path):
            os.remove(chunk_path)

        return separated_audios

    def identify_active_speakers(self, separated_audios, chunk_info):
        """Identify active speakers using d-vector similarity"""
        active_speakers = {}

        if self.params["debug_mode"]:
            print(f"    Analyzing chunk {chunk_info['chunk_id']}:")

        for speaker_id, separated_audio in separated_audios.items():
            if len(separated_audio) < self.params["min_audio_length"] * self.target_sr:
                continue

            try:
                # Extract d-vector from separated audio
                mel = self.audio_processor.get_mel(separated_audio)
                mel_tensor = torch.from_numpy(mel).float().to(self.device)
                with torch.no_grad():
                    chunk_dvec = self.embedder(mel_tensor).unsqueeze(0)

                # Compare with reference
                ref_dvec = self.reference_dvecs[speaker_id]
                similarity = torch.nn.functional.cosine_similarity(
                    chunk_dvec, ref_dvec
                ).item()

                # Calculate contrast with other speakers
                other_similarities = []
                for other_speaker_id in self.reference_dvecs.keys():
                    if other_speaker_id != speaker_id:
                        other_ref = self.reference_dvecs[other_speaker_id]
                        other_sim = torch.nn.functional.cosine_similarity(
                            chunk_dvec, other_ref
                        ).item()
                        other_similarities.append(other_sim)

                max_other_sim = max(other_similarities) if other_similarities else 0.0
                contrast = similarity - max_other_sim

                # Calculate energy
                energy = np.mean(separated_audio**2)

                # Check thresholds
                if (
                    similarity > self.params["similarity_threshold"]
                    and contrast > self.params["contrast_threshold"]
                    and energy > self.params["energy_threshold"]
                ):

                    active_speakers[speaker_id] = {
                        "audio": separated_audio,
                        "similarity": similarity,
                        "contrast": contrast,
                        "energy": energy,
                    }

                    if self.params["debug_mode"]:
                        print(
                            f"      {speaker_id}: ACTIVE (sim={similarity:.3f}, "
                            f"contrast={contrast:.3f}, energy={energy:.6f})"
                        )
                elif self.params["debug_mode"]:
                    print(
                        f"      {speaker_id}: inactive (sim={similarity:.3f}, "
                        f"contrast={contrast:.3f}, energy={energy:.6f})"
                    )

            except Exception as e:
                if self.params["debug_mode"]:
                    print(f"      {speaker_id}: Error in analysis: {e}")

        return active_speakers

    def transcribe_speakers(self, active_speakers, chunk_info):
        """Transcribe audio for active speakers"""
        transcriptions = []

        for speaker_id, speaker_data in active_speakers.items():
            try:
                result = self.stt_pipe(speaker_data["audio"])
                text = result["text"].strip()
                confidence = result.get("confidence", 1.0)

                if self._is_valid_text(text, confidence):
                    transcriptions.append(
                        {
                            "speaker_id": speaker_id,
                            "start_time": chunk_info["start_time"],
                            "end_time": chunk_info["end_time"],
                            "text": text,
                            "confidence": confidence,
                            "similarity": speaker_data["similarity"],
                            "contrast": speaker_data["contrast"],
                            "energy": speaker_data["energy"],
                        }
                    )

                    if self.params["debug_mode"]:
                        print(f"      {speaker_id}: '{text}' (conf={confidence:.3f})")

            except Exception as e:
                if self.params["debug_mode"]:
                    print(f"      {speaker_id}: Transcription error: {e}")

        return transcriptions

    def _is_valid_text(self, text, confidence):
        """Check if transcribed text is valid"""
        if not text or len(text) < self.params["min_text_length"]:
            return False

        if confidence < self.params["confidence_threshold"]:
            return False

        # Filter noise patterns
        text_lower = text.lower().strip()
        noise_patterns = [
            r"^[\s\-\.]+$",
            r"^(uh|um|ah|er|mm|hmm)[\s\-]*$",
            r"^\[.*\]$",
            r"^[\(\)]+$",
        ]

        for pattern in noise_patterns:
            if re.match(pattern, text_lower):
                return False

        return True

    def merge_segments(self, transcriptions):
        """Merge consecutive segments from same speaker"""
        if not transcriptions:
            return []

        # Group by speaker
        speaker_segments = {}
        for trans in transcriptions:
            speaker_id = trans["speaker_id"]
            if speaker_id not in speaker_segments:
                speaker_segments[speaker_id] = []
            speaker_segments[speaker_id].append(trans)

        # Merge segments for each speaker
        merged_segments = []
        for speaker_id, segments in speaker_segments.items():
            segments.sort(key=lambda x: x["start_time"])

            if not segments:
                continue

            current = segments[0].copy()

            for next_seg in segments[1:]:
                time_gap = next_seg["start_time"] - current["end_time"]

                if time_gap <= self.params["time_gap_threshold"]:
                    # Merge segments
                    current["end_time"] = next_seg["end_time"]
                    current["text"] += " " + next_seg["text"]
                    current["confidence"] = max(
                        current["confidence"], next_seg["confidence"]
                    )
                else:
                    # Start new segment
                    merged_segments.append(current)
                    current = next_seg.copy()

            merged_segments.append(current)

        # Sort all segments by time
        merged_segments.sort(key=lambda x: x["start_time"])
        return merged_segments

    def process(self, mixed_audio_path, reference_paths, output_dir="output"):
        """
        Main processing function

        Args:
            mixed_audio_path: Path to mixed audio file
            reference_paths: Dict {speaker_id: ref_path} or List of ref paths
            output_dir: Output directory for results

        Returns:
            results: Dict with transcription results
        """
        print("=" * 60)
        print("AUDIO SEPARATOR AND TRANSCRIBER")
        print("=" * 60)

        # Create output directory
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        # Step 1: Preprocess mixed audio
        print("\n1. PREPROCESSING MIXED AUDIO")
        mixed_audio_data, duration = self.preprocess_audio(mixed_audio_path)

        # Step 2: Prepare references
        print("\n2. PREPARING REFERENCE AUDIO")
        reference_info = self.prepare_references(reference_paths)

        # Step 3: Process audio in chunks
        print(f"\n3. PROCESSING AUDIO ({duration:.1f}s)")
        chunks = self.chunk_audio(mixed_audio_data)
        print(f"Created {len(chunks)} chunks ({self.chunk_size}s each)")

        all_transcriptions = []

        for chunk in tqdm(chunks, desc="Processing chunks"):
            if self.params["debug_mode"]:
                print(
                    f"\n  Chunk {chunk['chunk_id']} "
                    f"[{chunk['start_time']:.1f}s - {chunk['end_time']:.1f}s]:"
                )

            # Separate audio
            separated_audios = self.separate_chunk(chunk, reference_info)

            # Identify active speakers
            active_speakers = self.identify_active_speakers(separated_audios, chunk)

            # Transcribe active speakers
            if active_speakers:
                transcriptions = self.transcribe_speakers(active_speakers, chunk)
                all_transcriptions.extend(transcriptions)

        # Step 4: Merge consecutive segments
        print(f"\n4. MERGING SEGMENTS")
        merged_transcriptions = self.merge_segments(all_transcriptions)
        print(
            f"Merged {len(all_transcriptions)} → {len(merged_transcriptions)} segments"
        )

        # Step 5: Save results
        print(f"\n5. SAVING RESULTS")
        results = self._save_results(merged_transcriptions, reference_info, output_dir)

        print(f"\nPROCESSING COMPLETE!")
        print(f"Output directory: {output_dir}")

        return results

    def _save_results(self, transcriptions, reference_info, output_dir):
        """Save transcription results"""
        # Create timeline
        timeline = []
        speaker_stats = {}

        for trans in transcriptions:
            speaker_id = trans["speaker_id"]

            timeline.append(
                {
                    "speaker": speaker_id,
                    "start_time": trans["start_time"],
                    "end_time": trans["end_time"],
                    "duration": trans["end_time"] - trans["start_time"],
                    "text": trans["text"],
                    "confidence": trans["confidence"],
                }
            )

            # Update speaker stats
            if speaker_id not in speaker_stats:
                speaker_stats[speaker_id] = {
                    "total_segments": 0,
                    "total_duration": 0.0,
                    "segments": [],
                }

            speaker_stats[speaker_id]["total_segments"] += 1
            speaker_stats[speaker_id]["total_duration"] += (
                trans["end_time"] - trans["start_time"]
            )
            speaker_stats[speaker_id]["segments"].append(
                {
                    "start_time": trans["start_time"],
                    "end_time": trans["end_time"],
                    "text": trans["text"],
                    "confidence": trans["confidence"],
                }
            )

        # Save main results
        results = {
            "processing_info": {
                "total_segments": len(transcriptions),
                "speakers": list(reference_info.keys()),
                "parameters": self.params,
            },
            "speaker_statistics": speaker_stats,
            "timeline": timeline,
        }

        # Save JSON
        results_path = output_dir / "transcription_results.json"
        with open(results_path, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"  Results: {results_path}")

        # Save individual speaker files
        for speaker_id, stats in speaker_stats.items():
            speaker_path = output_dir / f"{speaker_id}_transcription.json"
            with open(speaker_path, "w", encoding="utf-8") as f:
                json.dump(
                    {
                        "speaker_id": speaker_id,
                        "total_segments": stats["total_segments"],
                        "total_duration": stats["total_duration"],
                        "segments": stats["segments"],
                    },
                    f,
                    indent=2,
                    ensure_ascii=False,
                )
            print(f"  {speaker_id}: {speaker_path}")

        # Save readable transcript
        transcript_path = output_dir / "transcript.txt"
        with open(transcript_path, "w", encoding="utf-8") as f:
            f.write("AUDIO TRANSCRIPTION\n")
            f.write("=" * 50 + "\n\n")

            for item in timeline:
                f.write(
                    f"[{item['start_time']:6.1f}s - {item['end_time']:6.1f}s] "
                    f"{item['speaker']}: {item['text']}\n"
                )
        print(f"  Transcript: {transcript_path}")

        return results


def process_audio(
    mixed_audio_path,
    reference_paths,
    separator_config,
    asr_config,
    processing_params=None,
    output_dir="output",
):
    """
    Convenience function to process audio with default settings

    Args:
        mixed_audio_path: Path to mixed audio file
        reference_paths: Dict {speaker_id: ref_path} or List of ref paths
        separator_config: Separator model configuration
        asr_config: ASR model configuration
        processing_params: Optional processing parameters
        output_dir: Output directory

    Returns:
        results: Transcription results
    """
    processor = AudioSeparatorTranscriber(
        separator_config=separator_config,
        asr_config=asr_config,
        processing_params=processing_params,
    )

    return processor.process(mixed_audio_path, reference_paths, output_dir)


In [10]:
separator_config = {
    "config_path": "config/inference.yaml",
    "embedder_path": "ckpt/embedder.pt",
    "checkpoint_path": "ckpt/seperator_best_checkpoint.pt",
    "return_dvec": False,
}

asr_config = {
    "model": "ckpt/whisper-small",
    "init_kwargs": {},
}

processing_params = {
    "similarity_threshold": 0.2,
    "energy_threshold": 0.0005,
    "debug_mode": False,
}


results = process_audio(
    mixed_audio_path="pipeline_data/Conv/ver1/conversations/conversation_000.wav",
    reference_paths={
        "Speaker_A": "Pipeline_data/Conv/ver1/reference/8463.wav",
        "Speaker_B": "Pipeline_data/Conv/ver1/reference/6829.wav",
        "Speaker_C": "Pipeline_data/Conv/ver1/reference/7021.wav",
        "Speaker_D": "Pipeline_data/Conv/ver1/reference/8224.wav",
    },
    separator_config=separator_config,
    asr_config=asr_config,
    processing_params=processing_params,
    output_dir="Pipeline_data/Conv/ver1/end_to_end_results",
)

print("Processing completed!")
print(f"Found {len(results['timeline'])} segments")
for speaker, stats in results["speaker_statistics"].items():
    print(
        f"{speaker}: {stats['total_segments']} segments, "
        f"{stats['total_duration']:.1f}s total"
    )

Device set to use cuda


AUDIO SEPARATOR AND TRANSCRIBER

1. PREPROCESSING MIXED AUDIO
Preprocessing audio: conversation_000.wav
  Original: 16000Hz, 84.44s
  Normalized to [-1, 1]
  Final: 16000Hz, 84.44s

2. PREPARING REFERENCE AUDIO

Preparing reference audio files...
Preprocessing audio: 8463.wav
  Original: 16000Hz, 5.98s
  Normalized to [-1, 1]
  Final: 16000Hz, 5.98s
   Speaker_A: 8463.wav
Preprocessing audio: 6829.wav
  Original: 16000Hz, 5.32s
  Normalized to [-1, 1]
  Final: 16000Hz, 5.32s
   Speaker_B: 6829.wav
Preprocessing audio: 7021.wav
  Original: 16000Hz, 18.41s
  Normalized to [-1, 1]
  Final: 16000Hz, 18.41s
   Speaker_C: 7021.wav
Preprocessing audio: 8224.wav
  Original: 16000Hz, 26.16s
  Normalized to [-1, 1]
  Final: 16000Hz, 26.16s
   Speaker_D: 8224.wav

3. PROCESSING AUDIO (84.4s)
Created 28 chunks (3.0s each)


Processing chunks: 100%|██████████| 28/28 [00:38<00:00,  1.39s/it]


4. MERGING SEGMENTS
Merged 35 → 10 segments

5. SAVING RESULTS
  Results: Pipeline_data\Conv\ver1\end_to_end_results\transcription_results.json
  Speaker_D: Pipeline_data\Conv\ver1\end_to_end_results\Speaker_D_transcription.json
  Speaker_C: Pipeline_data\Conv\ver1\end_to_end_results\Speaker_C_transcription.json
  Speaker_B: Pipeline_data\Conv\ver1\end_to_end_results\Speaker_B_transcription.json
  Speaker_A: Pipeline_data\Conv\ver1\end_to_end_results\Speaker_A_transcription.json
  Transcript: Pipeline_data\Conv\ver1\end_to_end_results\transcript.txt

PROCESSING COMPLETE!
Output directory: Pipeline_data\Conv\ver1\end_to_end_results
Processing completed!
Found 10 segments
Speaker_D: 2 segments, 41.4s total
Speaker_C: 3 segments, 26.7s total
Speaker_B: 3 segments, 20.8s total
Speaker_A: 2 segments, 14.8s total



