# Prepare each speaker's r-vector in each session based on Pyannote for training

In [None]:
import logging
from pathlib import Path
import numpy as np
import soundfile as sf
from pyannote.audio import Inference
from pyannote.audio import Audio
from pyannote.core import Timeline
from pyannote.database.util import load_rttm
from tqdm import tqdm
import random
import torch
from pyannote.core import Segment
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding

YOUR_PATH = 'path/to/your/metadata_dir'  # for example, /{BASE_PATH}/AED-TSVAD/recipes/diar_wavlm/data/Compound/dev
init_type = 'oracle'  # use oracle rttm for initialization in training

def setup_file_logger(log_path: str | Path, level=logging.INFO):
    log_path = Path(log_path)
    log_path.parent.mkdir(parents=True, exist_ok=True)

    root = logging.getLogger()
    root.setLevel(level)

    for h in list(root.handlers):
        root.removeHandler(h)

    fh = logging.FileHandler(log_path, mode="w", encoding="utf-8")
    fmt = "%(asctime)s %(levelname)s %(message)s"
    fh.setFormatter(logging.Formatter(fmt))
    root.addHandler(fh)

setup_file_logger(f"{YOUR_PATH}/prepare_single_speaker.log")

## Merge single speaker parts of each session

In [None]:
class PrepareSingleSpeaker:
    def __init__(
        self,
        wav_scp_path,
        rttm_dir,
        out_dir,
        min_segment_dur: float = 0.3,
        min_silence_dur: float = 0.02,
        max_silence_dur: float = 0.08,
        shuffle_segments: bool = True,
    ):
        self.audio = Audio(sample_rate=16000, mono="downmix")
        self.file_list = self.make_file_list(wav_scp_path, rttm_dir)
        self.out_dir = Path(out_dir)

        self.min_segment_dur = min_segment_dur
        self.min_silence_dur = min_silence_dur
        self.max_silence_dur = max_silence_dur
        self.shuffle_segments = shuffle_segments

    def make_file_list(self, wav_scp_path, rttm_dir):
        wav_scp_path = Path(wav_scp_path)
        rttm_dir = Path(rttm_dir)
        file_list = []

        with wav_scp_path.open("r", encoding="utf-8") as f:
            for line in f:
                rec_id, wav_path = line.strip().split(maxsplit=1)
                item = {
                    "wav_path": wav_path,
                    "rttm_path": str(rttm_dir / f"{rec_id}.rttm"),
                }
                file_list.append(item)

        return file_list

    def extract_clean_single_speaker_timeline(self, annotation, spk):
        spk_tl = annotation.label_timeline(spk)

        others_tl = Timeline()
        for other_spk in annotation.labels():
            if other_spk != spk:
                others_tl = others_tl | annotation.label_timeline(other_spk)

        overlap_with_others = spk_tl.crop(others_tl, mode="intersection")
        clean_tl = spk_tl.extrude(overlap_with_others)

        return clean_tl

    def process_single_file(self, wav_path, rttm_path, out_dir):
        rec_id = wav_path.stem
        out_dir.mkdir(parents=True, exist_ok=True)

        ann_by_file = load_rttm(rttm_path)
        if rec_id not in ann_by_file:
            raise ValueError

        ann = ann_by_file[rec_id]
        all_tl: Timeline = ann.get_timeline().support()

        for spk in ann.labels():
            clean_tl = self.extract_clean_single_speaker_timeline(ann, spk)

            filtered_segments = [segment for segment in clean_tl if segment.duration > self.min_segment_dur]
            if not filtered_segments:
                logging.warning(f"No segments longer than {self.min_segment_dur}s for {rec_id} {spk}")
                continue

            if self.shuffle_segments:
                random.shuffle(filtered_segments)

            wav_chunks = list()
            for i, segment in enumerate(filtered_segments):
                audio_duration = self.audio.get_duration(wav_path)
                if segment.start >= audio_duration:
                    logging.warning(f"Segment [{segment.start:.2f}s, {segment.end:.2f}s] starts after audio file ends ({audio_duration:.2f}s) for {wav_path}. Skipping...")
                    continue

                if segment.end > audio_duration:
                    original_end = segment.end
                    segment = Segment(start=segment.start, end=audio_duration)
                    logging.info(f"Segment end time clipped from {original_end:.2f}s to {audio_duration:.2f}s for {wav_path}")

                if segment.end - segment.start < self.min_segment_dur:
                    logging.warning(f"Segment [{segment.start:.2f}s, {segment.end:.2f}s] too short after clipping for {wav_path}. Skipping...")
                    continue

                chunk, _ = self.audio.crop(wav_path, segment, mode="raise")
                wav_chunks.append(chunk.squeeze(0))

                if i < len(filtered_segments) - 1:
                    silence_duration = random.uniform(self.min_silence_dur, self.max_silence_dur)
                    gap = np.zeros(int(silence_duration * self.audio.sample_rate), dtype=np.float32)
                    wav_chunks.append(gap)

            spk_wav = np.concatenate(wav_chunks)
            out_file = out_dir / f"{rec_id}-{spk}.wav"
            sf.write(out_file, spk_wav, self.audio.sample_rate)
            logging.info(f"Saved {out_file} ({len(spk_wav)/self.audio.sample_rate:.1f}s)")

    def process(self):
        for item in tqdm(self.file_list, ncols=100):
            wav_path = Path(item["wav_path"])
            rttm_path = Path(item["rttm_path"])
            if not wav_path.exists() or not rttm_path.exists():
                logging.warning(f"Missing file: {wav_path} or {rttm_path}")
                continue
            self.process_single_file(wav_path, rttm_path, Path(self.out_dir))

In [None]:
wav_scp_path = f"{YOUR_PATH}/wav.scp"
rttm_path = f"{YOUR_PATH}/init_rttms/{init_type}"
out_dir = f"{YOUR_PATH}/wavs_single_spk/{init_type}"

preparer = PrepareSingleSpeaker(wav_scp_path, rttm_path, out_dir, min_segment_dur=0.3, shuffle_segments=True)
preparer.process()

## Extract embeddings using Pyannote

In [None]:
class ExtractSingleSpeakerEmbedding:
    def __init__(
        self,
        wav_dir,
        out_dir,
        model_path,
        duration=5.0,
        step=2.5,
        batch_size=32,
        device=torch.device("cpu"),
    ):
        self.wav_dir = Path(wav_dir)
        self.out_dir = Path(out_dir)
        self.duration = duration
        self.step = step
        self.batch_size = batch_size

        self.infer = Inference(
            model=model_path,
            window="sliding",
            duration=duration,
            step=step,
            skip_aggregation=True,
            batch_size=batch_size,
            device=device,
        )
        # in case chunks are very short
        self.model = PretrainedSpeakerEmbedding(embedding=model_path, device=device)

        self.file_list = self._make_file_list()

    def _make_file_list(self):
        file_list = []
        for wav_file in sorted(self.wav_dir.glob("*.wav")):
            basename = wav_file.stem
            file_list.append({
                "wav_path": wav_file,
                "base_name": basename
            })
        return file_list

    def crop_to_integer_chunks(self, waveform, sample_rate, duration, step):
        total_len = waveform.shape[1]
        chunk_size = int(duration * sample_rate)
        step_size = int(step * sample_rate)

        if total_len < chunk_size:
            raise ValueError

        num_chunks = 1 + (total_len - chunk_size) // step_size
        final_len = (num_chunks - 1) * step_size + chunk_size

        residual = total_len - final_len
        if residual > step_size // 2:
            return np.concatenate([
                waveform[:, :final_len],
                waveform[:, -chunk_size:]
            ], axis=1)
        else:
            return waveform[:, :final_len]        

    def extract_all(self):
        for item in tqdm(self.file_list, ncols=100, desc="Extracting embeddings"):
            self.extract_one(item["wav_path"], item["base_name"])

    def extract_one(self, wav_path, base_name):
        out_file = self.out_dir / (base_name + ".pt")
        out_file.parent.mkdir(parents=True, exist_ok=True)

        waveform, sr = sf.read(wav_path, dtype='float32')

        if waveform.ndim == 1:
            waveform = waveform[np.newaxis, :]

        if waveform.shape[-1] < self.duration * sr:
            logging.info(f"Waveform length {waveform.shape[-1]} is shorter than chunk size")
            embeddings = self.model(torch.from_numpy(waveform).unsqueeze(dim=0))  # [batch, channels, num_samples]
        else:
            waveform = self.crop_to_integer_chunks(waveform, sr, self.duration, self.step)
            embeddings = self.infer({"waveform": torch.from_numpy(waveform), "sample_rate": sr})
        torch.save(embeddings, out_file)


In [None]:
extractor = ExtractSingleSpeakerEmbedding(
    wav_dir=f"{YOUR_PATH}/wavs_single_spk/{init_type}",
    out_dir=f"{YOUR_PATH}/embs_single_spk/{init_type}",
    model_path="path/to/your/downloads/pyannote/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin",
    duration=5.0,
    step=2.5,
    batch_size=32,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
extractor.extract_all()