## Base process

In [12]:
import json

from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.models.msdd_models import ClusteringDiarizer

from omegaconf import OmegaConf


# Load yaml file
with open("./config/nemo/diar_infer_telephonic.yaml") as f:
    cfg = OmegaConf.load(f)

meta = {
    "audio_filepath":  "mono_file.wav",
    "offset": 0,
    "duration": None,
    "label": "infer",
    "text": "-",
    "rttm_filepath": None,
    "uem_filepath": None,
}

manifest_path = "infer_manifest.json"
with open("infer_manifest.json", "w") as fp:
    json.dump(meta, fp)
    fp.write("\n")

cfg.diarizer.manifest_filepath = str(manifest_path)
cfg.diarizer.out_dir = "infer_out_dir"

speaker_model = EncDecSpeakerLabelModel.from_pretrained(
    model_name="titanet_large", map_location=None
)
speaker_params = {
    "window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5],
    "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25],
    "multiscale_weights": [1, 1, 1, 1, 1],
    "save_embeddings": True,
}
cluster_params = {
    "oracle_num_speakers": False,
    "max_num_speakers": 8,
    "enhanced_count_thres": 80,
    "max_rp_threshold": 0.25,
    "sparse_search_volume": 30,
    "maj_vote_spk_count": False,
}

clus_diar_model = ClusteringDiarizer(cfg=cfg, speaker_model=speaker_model)

[NeMo I 2023-07-25 09:06:34 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.
[NeMo I 2023-07-25 09:06:34 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo
[NeMo I 2023-07-25 09:06:34 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-07-25 09:06:35 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json
    sample_rate: 16000
    labels: null
    batch_size: 64
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: scatter
    augmentor:
      noise:
        manifest_path: /manifests/noise/rir_noise_manifest.json
        prob: 0.5
        min_snr_db: 0
        max_snr_db: 15
      speed:
        prob: 0.5
        sr: 16000
        resample_type: kaiser_fast
        min_speed_rate: 0.95
        max_speed_rate: 1.05
    num_workers: 15
    pin_memory: true
    
[NeMo W 2023-07-25 09:06:35 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method 

[NeMo I 2023-07-25 09:06:35 features:291] PADDING: 16
[NeMo I 2023-07-25 09:06:35 save_restore_connector:249] Model EncDecSpeakerLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.
[NeMo I 2023-07-25 09:06:35 clustering_diarizer:127] Loading pretrained vad_multilingual_marblenet model from NGC
[NeMo I 2023-07-25 09:06:35 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2023-07-25 09:06:35 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo
[NeMo I 2023-07-25 09:06:35 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-07-25 09:06:35 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json
    sample_rate: 16000
    labels:
    - background
    - speech
    batch_size: 256
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: sca

[NeMo I 2023-07-25 09:06:35 features:291] PADDING: 16
[NeMo I 2023-07-25 09:06:35 save_restore_connector:249] Model EncDecClassificationModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.


In [16]:
import librosa
import soundfile as sf

filepath = "./mono_file.mp3"
waveform, sample_rate = librosa.load(filepath, sr=None)
sf.write("./mono_file.wav", waveform, sample_rate, "PCM_16")

In [17]:
clus_diar_model.diarize()

[NeMo W 2023-07-25 09:37:44 clustering_diarizer:411] Deleting previous clustering diarizer outputs.


[NeMo I 2023-07-25 09:37:44 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2023-07-25 09:37:44 clustering_diarizer:309] Split long audio file to avoid CUDA memory issue


splitting manifest: 100%|██████████| 1/1 [00:00<00:00, 11.75it/s]

[NeMo I 2023-07-25 09:37:44 vad_utils:101] The prepared manifest file exists. Overwriting!
[NeMo I 2023-07-25 09:37:44 classification_models:268] Perform streaming frame-level VAD
[NeMo I 2023-07-25 09:37:44 collections:298] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2023-07-25 09:37:44 collections:299] Dataset loaded with 3 items, total duration of  0.04 hours.
[NeMo I 2023-07-25 09:37:44 collections:301] # 3 files loaded accounting to # 1 labels





[NeMo I 2023-07-25 09:37:46 clustering_diarizer:250] Generating predictions with overlapping input segments


                                                               

[NeMo I 2023-07-25 09:37:47 clustering_diarizer:262] Converting frame level prediction to speech/no-speech segment in start and end times format.


creating speech segments: 100%|██████████| 1/1 [00:00<00:00,  3.86it/s]

[NeMo I 2023-07-25 09:37:47 clustering_diarizer:287] Subsegmentation for embedding extraction: scale0, infer_out_dir/speaker_outputs/subsegments_scale0.json
[NeMo I 2023-07-25 09:37:47 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-07-25 09:37:47 collections:298] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2023-07-25 09:37:47 collections:299] Dataset loaded with 104 items, total duration of  0.04 hours.
[NeMo I 2023-07-25 09:37:47 collections:301] # 104 files loaded accounting to # 1 labels





[NeMo I 2023-07-25 09:37:48 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings
[NeMo I 2023-07-25 09:37:48 clustering_diarizer:287] Subsegmentation for embedding extraction: scale1, infer_out_dir/speaker_outputs/subsegments_scale1.json
[NeMo I 2023-07-25 09:37:48 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-07-25 09:37:48 collections:298] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2023-07-25 09:37:48 collections:299] Dataset loaded with 132 items, total duration of  0.04 hours.
[NeMo I 2023-07-25 09:37:48 collections:301] # 132 files loaded accounting to # 1 labels
[NeMo I 2023-07-25 09:37:48 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings
[NeMo I 2023-07-25 09:37:48 clustering_diarizer:287] Subsegmentation for embedding extraction: scale2, infer_out_dir/speaker_outputs/subsegments_scale2.json
[NeMo I 2023-07-25 09:37:48 clustering_diarizer:343] Extrac

## VAD

In [1]:
from typing import List, Optional, Tuple, Union

import torch
import torchaudio
from faster_whisper.vad import VadOptions, get_speech_timestamps


class VadService:
    """VAD Service for audio files."""

    def __init__(self) -> None:
        """Initialize the VAD Service."""
        self.sample_rate = 16000
        self.options = VadOptions(
            threshold=0.5,
            min_speech_duration_ms=250,
            max_speech_duration_s=30,
            min_silence_duration_ms=100,
            window_size_samples=512,
            speech_pad_ms=30,
        )

    def __call__(
        self, waveform: torch.Tensor, group_timestamps: Optional[bool] = True
    ) -> Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]:
        """
        Use the VAD model to get the speech timestamps. Dual channel pipeline.

        Args:
            waveform (torch.Tensor): Audio tensor.
            group_timestamps (Optional[bool], optional): Group timestamps. Defaults to True.

        Returns:
            Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]: Speech timestamps and audio tensor.
        """
        if waveform.size(0) == 1:
            waveform = waveform.squeeze(0)

        speech_timestamps = get_speech_timestamps(
            audio=waveform, vad_options=self.options
        )

        _speech_timestamps_list = [
            {"start": ts["start"], "end": ts["end"]} for ts in speech_timestamps
        ]

        if group_timestamps:
            speech_timestamps_list = self.group_timestamps(_speech_timestamps_list)
        else:
            speech_timestamps_list = _speech_timestamps_list

        return speech_timestamps_list, waveform

    def group_timestamps(
        self, timestamps: List[dict], threshold: Optional[float] = 3.0
    ) -> List[List[dict]]:
        """
        Group timestamps based on a threshold.

        Args:
            timestamps (List[dict]): List of timestamps.
            threshold (float, optional): Threshold to use for grouping. Defaults to 3.0.

        Returns:
            List[List[dict]]: List of grouped timestamps.
        """
        grouped_segments = [[]]

        for i in range(len(timestamps)):
            if (
                i > 0
                and (timestamps[i]["start"] - timestamps[i - 1]["end"]) > threshold
            ):
                grouped_segments.append([])

            grouped_segments[-1].append(timestamps[i])

        return grouped_segments

    def save_audio(self, filepath: str, audio: torch.Tensor) -> None:
        """
        Save audio tensor to file.

        Args:
            filepath (str): Path to save the audio file.
            audio (torch.Tensor): Audio tensor.
        """
        torchaudio.save(
            filepath, audio.unsqueeze(0), self.sample_rate, bits_per_sample=16
        )

def read_audio(filepath: str, sample_rate: int = 16000) -> Tuple[torch.Tensor, float]:
    """
    Read an audio file and return the audio tensor.

    Args:
        filepath (str): Path to the audio file.
        sample_rate (int): The sample rate of the audio file. Defaults to 16000.

    Returns:
        Tuple[torch.Tensor, float]: The audio tensor and the audio duration.
    """
    wav, sr = torchaudio.load(filepath)

    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)

    if sr != sample_rate:
        transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        wav = transform(wav)
        sr = sample_rate

    audio_duration = float(wav.shape[1]) / sample_rate

    return wav.squeeze(0), audio_duration

def sr2s(v: int) -> float:
    """
    Convert milliseconds to seconds.

    Args:
        v (int): Value in milliseconds.

    Returns:
        float: Value in seconds.
    """
    return v / 16000

In [51]:
waveform, _ = read_audio("./mono_file.wav")

vad_service = VadService()

speech_ts, _ = vad_service(waveform, True)
for ts in speech_ts:
    _ts = ts[0]
    print(
        f"Start: {sr2s(_ts['start'])}, End: {sr2s(_ts['end'])}"
    )

42
Start: 12.514, End: 12.99
Start: 13.25, End: 14.366
Start: 15.042, End: 16.862
Start: 17.762, End: 18.814
Start: 19.426, End: 20.766
Start: 21.666, End: 24.83
Start: 26.178, End: 29.886
Start: 30.786, End: 33.022
Start: 34.146, End: 37.214
Start: 38.338, End: 40.318
Start: 41.218, End: 42.782
Start: 43.682, End: 44.318
Start: 45.73, End: 46.494
Start: 47.778, End: 50.366
Start: 51.106, End: 52.926
Start: 53.954, End: 55.582
Start: 55.682, End: 56.926
Start: 57.954, End: 60.286
Start: 61.154, End: 64.254
Start: 65.026, End: 67.614
Start: 68.418, End: 68.99
Start: 69.922, End: 71.55
Start: 72.578, End: 75.838
Start: 76.61, End: 77.918
Start: 78.562, End: 79.454
Start: 79.746, End: 81.086
Start: 82.05, End: 83.902
Start: 84.738, End: 86.462
Start: 87.586, End: 90.782
Start: 91.746, End: 96.542
Start: 97.73, End: 98.27
Start: 99.586, End: 100.03
Start: 100.162, End: 100.862
Start: 101.794, End: 103.454
Start: 104.898, End: 107.486
Start: 108.226, End: 109.854
Start: 114.274, End: 114.75

## Segmentation

In [2]:
import math

def get_subsegments(segment_start: float, segment_end: float, window: float, shift: float) -> List[List[float]]:
    """
    Return a list of subsegments based on the segment start and end time and the window and shift length.

    Args:
        segment_start (float): Segment start time.
        segment_end (float): Segment end time.
        window (float): Window length.
        shift (float): Shift length.

    Returns:
        List[List[float]]: List of subsegments with start time and duration.
    """
    start = segment_start
    duration = segment_end - segment_start
    base = math.ceil((duration - window) / shift)
    
    subsegments: List[List[float]] = []
    slices = 1 if base < 0 else base + 1
    for slice_id in range(slices):
        end = start + window

        if end > segment_end:
            end = segment_end

        subsegments.append([start, end - start])

        start = segment_start + (slice_id + 1) * shift

    return subsegments

In [3]:
def _run_segmentation(
    vad_outputs: List[dict],
    window: float,
    shift: float,
    min_subsegment_duration: float = 0.05,
) -> List[dict]:
    """"""
    scale_segment = []
    for segment in vad_outputs:
        segment_start, segment_end = sr2s(segment["start"]), sr2s(segment["end"])
        subsegments = get_subsegments(segment_start, segment_end, window, shift)

        for subsegment in subsegments:
            start, duration = subsegment
            if duration > min_subsegment_duration:
                scale_segment.append({"offset": start, "duration": duration})

    return scale_segment

In [4]:
from nemo.collections.asr.models import EncDecSpeakerLabelModel

from torch.cuda.amp import autocast
from torch.utils.data import Dataset

speaker_model = EncDecSpeakerLabelModel.from_pretrained(
    model_name="titanet_large", map_location=None
)


class AudioSegmentDataset(Dataset):
    def __init__(self, waveform: torch.Tensor, segments: List[dict], sample_rate=16000) -> None:
        self.waveform = waveform
        self.segments = segments
        self.sample_rate = sample_rate

    def __len__(self) -> int:
        return len(self.segments)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        segment_info = self.segments[idx]
        offset_samples = int(segment_info["offset"] * self.sample_rate)
        duration_samples = int(segment_info["duration"] * self.sample_rate)

        segment = self.waveform[offset_samples:offset_samples + duration_samples]

        return segment, torch.tensor(segment.shape[0]).long()


def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]):
    """"""
    _, audio_lengths = zip(*batch)

    has_audio = audio_lengths[0] is not None
    fixed_length = int(max(audio_lengths))

    audio_signal, new_audio_lengths = [], []
    for sig, sig_len in batch:
        if has_audio:
            sig_len = sig_len.item()
            chunck_len = sig_len - fixed_length

            if chunck_len < 0:
                repeat = fixed_length // sig_len
                rem = fixed_length % sig_len
                sub = sig[-rem:] if rem > 0 else torch.tensor([])
                rep_sig = torch.cat(repeat * [sig])
                sig = torch.cat((rep_sig, sub))
            new_audio_lengths.append(torch.tensor(fixed_length))

            audio_signal.append(sig)

    if has_audio:
        audio_signal = torch.stack(audio_signal)
        audio_lengths = torch.stack(new_audio_lengths)
    else:
        audio_signal, audio_lengths = None, None

    return audio_signal, audio_lengths


def _extract_embeddings(waveform: torch.Tensor, scale_segments: List[dict]):
    """
    This method extracts speaker embeddings from segments passed through manifest_file
    Optionally you may save the intermediate speaker embeddings for debugging or any use. 
    """
    all_embs = torch.empty([0])

    dataset = AudioSegmentDataset(waveform, scale_segments)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=False, collate_fn=collate_fn
    )

    for batch in dataloader:
        _batch = [x.to(speaker_model.device) for x in batch]
        audio_signal, audio_signal_len = _batch

        with autocast():
            _, embeddings = speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
            embeddings = embeddings.view(-1, embeddings.shape[-1])
            all_embs = torch.cat((all_embs, embeddings.cpu().detach()), dim=0)
        del _batch, audio_signal, audio_signal_len, embeddings

    embeddings, time_stamps = [], []
    for i, segment in enumerate(scale_segments):
        if i == 0:
            embeddings = all_embs[i].view(1, -1)
        else:
            embeddings = torch.cat((embeddings, all_embs[i].view(1, -1)))

        time_stamps.append([segment['offset'], segment['duration']])

    return embeddings, time_stamps

[NeMo W 2023-07-31 07:04:19 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.


[NeMo I 2023-07-31 07:04:19 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.
[NeMo I 2023-07-31 07:04:19 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo
[NeMo I 2023-07-31 07:04:19 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-07-31 07:04:22 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json
    sample_rate: 16000
    labels: null
    batch_size: 64
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: scatter
    augmentor:
      noise:
        manifest_path: /manifests/noise/rir_noise_manifest.json
        prob: 0.5
        min_snr_db: 0
        max_snr_db: 15
      speed:
        prob: 0.5
        sr: 16000
        resample_type: kaiser_fast
        min_speed_rate: 0.95
        max_speed_rate: 1.05
    num_workers: 15
    pin_memory: true
    
[NeMo W 2023-07-31 07:04:22 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method 

[NeMo I 2023-07-31 07:04:22 features:291] PADDING: 16
[NeMo I 2023-07-31 07:04:24 save_restore_connector:249] Model EncDecSpeakerLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.


## Clustering

In [5]:
from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering


def get_contiguous_stamps(stamps: list):
    """
    Return contiguous time stamps
    """
    contiguous_stamps = []
    for i in range(len(stamps) - 1):
        start, end, speaker = stamps[i]
        next_start, next_end, next_speaker = stamps[i + 1]

        if end > next_start:
            avg = (next_start + end) / 2.0
            stamps[i + 1] = (avg, next_end, next_speaker)
            contiguous_stamps.append((start, avg, speaker))
        else:
            contiguous_stamps.append((start, end, speaker))

    start, end, speaker = stamps[-1]
    contiguous_stamps.append((start, end, speaker))

    return contiguous_stamps


def merge_stamps(stamps: list):
    """
    Merge time stamps of the same speaker.
    """
    overlap_stamps = []
    for i in range(len(stamps) - 1):
        start, end, speaker = stamps[i]
        next_start, next_end, next_speaker = stamps[i + 1]

        if end == next_start and speaker == next_speaker:
            stamps[i + 1] = (start, next_end, next_speaker)
        else:
            overlap_stamps.append((start, end, speaker))

    start, end, speaker = stamps[-1]
    overlap_stamps.append((start, end, speaker))

    return overlap_stamps


def perform_clustering(embs_and_timestamps, clustering_params):
    """
    Performs spectral clustering on embeddings with time stamps generated from VAD output.
    """
    speaker_clustering = SpeakerClustering(cuda=True)

    base_scale_idx = embs_and_timestamps["multiscale_segment_counts"].shape[0] - 1
    cluster_labels = speaker_clustering.forward_infer(
        embeddings_in_scales=embs_and_timestamps["embeddings"],
        timestamps_in_scales=embs_and_timestamps["timestamps"],
        multiscale_segment_counts=embs_and_timestamps["multiscale_segment_counts"],
        multiscale_weights=embs_and_timestamps["multiscale_weights"],
        oracle_num_speakers=-1,
        max_num_speakers=int(clustering_params["max_num_speakers"]),
        max_rp_threshold=float(clustering_params["max_rp_threshold"]),
        sparse_search_volume=int(clustering_params["sparse_search_volume"]),
    )

    del embs_and_timestamps
    torch.cuda.empty_cache()

    timestamps = speaker_clustering.timestamps_in_scales[base_scale_idx]
    cluster_labels = cluster_labels.cpu().numpy()
    if len(cluster_labels) != timestamps.shape[0]:
        raise ValueError("Mismatch of length between cluster_labels and timestamps.")

    clustering_labels = []
    for idx, label in enumerate(cluster_labels):
        stt, end = timestamps[idx]
        clustering_labels.append((float(stt), float(stt + end), int(label)))

    return clustering_labels

## Mapping between embeddings and timestamps

In [6]:
from statistics import mode
from typing import List

import numpy as np
import torch


def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Calculate the mapping between the base scale and other scales. A segment from a longer scale is
    repeatedly mapped to a segment from a shorter scale or the base scale.

    Args:
        timestamps_in_scales (list):
            List containing timestamp tensors for each scale.
            Each tensor has dimensions of (Number of base segments) x 2.

    Returns:
        session_scale_mapping_list (list):
            List containing argmin arrays indexed by scale index.
    """
    scale_list = list(range(len(timestamps_in_scales)))
    segment_anchor_list = [torch.mean(timestamps_in_scales[scale_idx], dim=1) for scale_idx in scale_list]

    base_scale_idx = max(scale_list)
    base_scale_anchor = segment_anchor_list[base_scale_idx]
    base_scale_anchor = base_scale_anchor.view(-1, 1)

    session_scale_mapping_list = []
    for scale_idx in scale_list:
        curr_scale_anchor = segment_anchor_list[scale_idx].view(1, -1)
        distance = torch.abs(curr_scale_anchor - base_scale_anchor)
        argmin_mat = torch.argmin(distance, dim=1)
        session_scale_mapping_list.append(argmin_mat)

    return session_scale_mapping_list


def assign_labels_to_longer_segs(clustering_labels: list, session_scale_mapping_list: list, scale_n: int):
    """
    In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale).
    To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns
    clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the
    base-scale and non-base-scales.

    Args:
        base_clus_label_dict (dict):
            Dictionary containing clustering results for base-scale segments. Indexed by `uniq_id` string.
        session_scale_mapping_dict (dict):
            Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string.

    Returns:
        all_scale_clus_label_dict (dict):
            Dictionary containing clustering labels of all scales. Indexed by scale_index in integer format.

    """
    base_scale_clus_label = np.array([x[-1] for x in clustering_labels])
    
    all_scale_clus_label_dict = {}
    all_scale_clus_label_dict[scale_n - 1] = base_scale_clus_label

    for scale_index, scale_mapping_tensor in enumerate(session_scale_mapping_list[:-1]):
        new_clus_label = []
        max_index = max(scale_mapping_tensor)

        for seg_idx in range(max_index + 1):
            if seg_idx in scale_mapping_tensor:
                seg_clus_label = mode(base_scale_clus_label[scale_mapping_tensor == seg_idx])
            else:
                seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1]

            new_clus_label.append(seg_clus_label)

        all_scale_clus_label_dict[scale_index] = new_clus_label

    return all_scale_clus_label_dict


# Check https://github.com/NVIDIA/NeMo/blob/2cc09425aba3e9b3cfdba43a3188eaef58227055/nemo/collections/asr/models/msdd_models.py#L756
def get_cluster_avg_embs(
    emb_scale_seq_dict: dict,
    clustering_labels: list,
    session_scale_mapping_list: list,
    scale_n: int,
    max_num_speakers: int,
):
    """
    MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker)
    and each scale.

    Args:
        emb_scale_seq_dict (dict):
            Dictionary containing embedding sequence for each scale. Keys are scale index in integer.
        clus_labels (list):
            Clustering results from clustering diarizer including all the sessions provided in input manifest files.
        session_scale_mapping_dict (list):
            List containing argmin arrays indexed by scale index.

    Returns:
        emb_sess_avg_dict (dict):
            Dictionary containing speaker mapping information and cluster-average speaker embedding vector.
            Each session-level dictionary is indexed by scale index in integer.
        output_clus_label_dict (dict):
            Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys.
    """    
    embeddings_session_average_dict = {}

    all_scale_clus_label_dict = assign_labels_to_longer_segs(
        clustering_labels, session_scale_mapping_list, scale_n
    )
    
    for scale_index, embeddings_tensor in emb_scale_seq_dict.items():
        clustering_labels_list = all_scale_clus_label_dict[scale_index]
        speaker_set = set(clustering_labels_list)

        clustering_labels_tensor = torch.Tensor(clustering_labels_list)
        average_embeddings = torch.zeros(embeddings_tensor[0].shape[0], max_num_speakers)
        for speaker_idx in speaker_set:
            selected_embeddings = embeddings_tensor[clustering_labels_tensor == speaker_idx]
            average_embeddings[:, speaker_idx] = torch.mean(selected_embeddings, dim=0)

        embeddings_session_average_dict[scale_index] = average_embeddings

    return embeddings_session_average_dict


## MSDD Inference

In [49]:
from itertools import combinations
from typing import Dict


class AudioMSDDDataset(Dataset):
    def __init__(
        self,
        emb_sess_avg_dict: Dict[str, torch.Tensor],
        emb_scale_seq_dict: Dict[str, torch.Tensor],
        clustering_labels: Dict[str, torch.Tensor],
        sess_scale_mapping_list: List[torch.Tensor],
        scale_n: int,
    ) -> None:
        self.emb_dict = emb_sess_avg_dict
        self.emb_seq = emb_scale_seq_dict
        self.clus_label_list = clustering_labels
        self.sess_scale_mapping = sess_scale_mapping_list
        self.scale_n = scale_n

        self.clus_speaker_digits = sorted(list(set([x[-1] for x in self.clus_label_list])))
        if len(self.clus_speaker_digits) <= 2:
            self.speaker_combinations = [(0, 1)]
        else:
            self.speaker_combinations = [x for x in combinations(self.clus_speaker_digits, 2)]

    def __len__(self) -> int:
        return 1

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        _avg_embs = torch.stack(
            [
                self.emb_dict[scale_index] 
                for scale_index in range(self.scale_n)
            ]
        )  # (scale_n, num_segments, max_num_speakers)

        selected_speakers = torch.tensor(self.speaker_combinations).flatten()
        avg_embs = _avg_embs[:, :, selected_speakers]
        

        if avg_embs.shape[2] > 2:
            raise ValueError(
                f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {2}"
            )

        feats = []
        for scale_index in range(self.scale_n):
            repeat_mat = self.sess_scale_mapping[scale_index]
            feats.append(self.emb_seq[scale_index][repeat_mat, :])

        features = torch.stack(feats).permute(1, 0, 2)
        features_length = features.shape[0]

        targets = torch.zeros(features_length, 2)

        return features, features_length, targets, avg_embs


def msdd_infer_collate_fn(batch):
    """
    Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings.

    Args:
        batch (tuple):
            Batch tuple containing feats, feats_len, targets and ms_avg_embs.
    Returns:
        feats (torch.tensor):
            Collated speaker embedding with unified length.
        feats_len (torch.tensor):
            The actual length of each embedding sequence without zero padding.
        targets (torch.tensor):
            Groundtruth Speaker label for the given input embedding sequence.
        ms_avg_embs (torch.tensor):
            Cluster-average speaker embedding vectors.
    """

    packed_batch = list(zip(*batch))
    _, feats_len, targets, _ = packed_batch
    max_audio_len = max(feats_len)
    max_target_len = max([x.shape[0] for x in targets])

    feats_list, flen_list, targets_list, ms_avg_embs_list = [], [], [], []
    for feature, feat_len, target, ivector in batch:
        flen_list.append(feat_len)
        ms_avg_embs_list.append(ivector)

        if feat_len < max_audio_len:
            feats_list.append(
                torch.nn.functional.pad(feature, (0, 0, 0, 0, 0, max_audio_len - feat_len))
            )
            targets_list.append(
                torch.nn.functional.pad(target, (0, 0, 0, max_target_len - target.shape[0]))
            )
        else:
            targets_list.append(target.clone().detach())
            feats_list.append(feature.clone().detach())

    return (
        torch.stack(feats_list),  # Features
        torch.tensor(flen_list),  # Features length
        torch.stack(targets_list),  # Targets
        torch.stack(ms_avg_embs_list),  # Cluster-average embeddings
    )

In [47]:
from nemo.collections.asr.models import EncDecDiarLabelModel
from omegaconf import OmegaConf


msdd_cfg = OmegaConf.create({
    "model_path": "diar_msdd_telephonic",
    "parameters": {
        "use_speaker_model_from_ckpt": True,
        "infer_batch_size": 25,
        "sigmoid_threshold": [0.7],
        "seq_eval_mode": False,
        "split_infer": True,
        "diar_window_length": 50,
        "overlap_infer_spk_limit": 5,
    }
})
# msdd_model = EncDecDiarLabelModel.from_config_dict(msdd_cfg)
msdd_model = EncDecDiarLabelModel.from_pretrained(model_name=msdd_cfg.model_path)
msdd_model.eval()

[NeMo I 2023-07-31 08:26:51 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2023-07-31 08:26:51 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2023-07-31 08:26:51 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-07-31 08:26:52 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2023-07-31 08:26:52 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2023-07-31 08:26:52 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple

[NeMo I 2023-07-31 08:26:52 features:291] PADDING: 16
[NeMo I 2023-07-31 08:26:52 features:291] PADDING: 16
[NeMo I 2023-07-31 08:26:53 save_restore_connector:249] Model EncDecDiarLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.


EncDecDiarLabelModel(
  (preprocessor): AudioToMelSpectrogramPreprocessor(
    (featurizer): FilterbankFeatures()
  )
  (msdd): MSDD_module(
    (softmax): Softmax(dim=2)
    (cos_dist): CosineSimilarity()
    (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
    (conv): ModuleList(
      (0): ConvLayer(
        (cnn): Sequential(
          (0): Conv2d(1, 16, kernel_size=(15, 1), stride=(1, 1))
          (1): ReLU()
          (2): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
        )
      )
      (1): ConvLayer(
        (cnn): Sequential(
          (0): Conv2d(1, 16, kernel_size=(16, 1), stride=(1, 1))
          (1): ReLU()
          (2): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
        )
      )
    )
    (conv_bn): ModuleList(
      (0-1): 2 x BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    )
    (conv_to_linear): Linear(in_features

## Map MSDD + Clustering

In [123]:
def get_overlap_stamps(contiguous_stamps: List[str], overlap_speaker_index: List[str]):
    """
    Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are
    created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`.

    Args:
        cont_stamps (list):
            Non-overlapping (single speaker per segment) diarization output in string format.
            Each line contains the start and end time of segments and corresponding speaker labels.
        ovl_spk_idx (list):
            List containing segment index of the estimated overlapped speech. The start and end of segments are based on the
            single-speaker (i.e., non-overlap-aware) RTTM generation.
    Returns:
        total_ovl_cont_list (list):
            Rendered diarization output in string format. Each line contains the start and end time of segments and
            corresponding speaker labels. This format is identical to `cont_stamps`.
    """
    overlap_speaker_contiguous_list = [[] for _ in range(len(overlap_speaker_index))]
    
    for speaker_index in range(len(overlap_speaker_index)):
        for index, segment in enumerate(contiguous_stamps):
            start, end, _ = segment
            if index in overlap_speaker_index[speaker_index]:
                overlap_speaker_contiguous_list[speaker_index].append((start, end, speaker_index))

    total_overlap_contiguous_list = []

    for overlap_contiguous_list in overlap_speaker_contiguous_list:
        if len(overlap_contiguous_list) > 0:
            total_overlap_contiguous_list.extend(merge_stamps(overlap_contiguous_list))

    return total_overlap_contiguous_list


def generate_speaker_timestamps(
    clustering_labels: List[Union[float, int]],
    msdd_preds: torch.Tensor,
    threshold: float = 0.7,
    overlap_infer_speaker_limit: int = 5,
    max_overlap_speakers: int = 2,
) -> Tuple[List[str], List[str]]:
    '''
    Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker
    labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for
    every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring.

    Args:
        clus_labels (list):
            List containing integer-valued speaker clustering results.
        msdd_preds (list):
            List containing tensors of the predicted sigmoid values.
            Each tensor has shape of: (Session length, estimated number of speakers).
        params:
            Parameters for generating RTTM output and evaluation. Parameters include:
                infer_overlap (bool): If False, overlap-speech will not be detected.
                use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output
                                         is used for constructing output RTTM files.
                overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed.
                use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated
                                           number of speakers.
                max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2.
                threshold (float): Sigmoid threshold for MSDD output.

    Returns:
        maj_labels (list):
            List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels.
            Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...]
        ovl_labels (list):
            List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels.
            Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`.
            Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...]
    '''
    estimated_num_of_spks = msdd_preds.shape[-1]
    overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)]
    infer_overlap = estimated_num_of_spks < int(overlap_infer_speaker_limit)

    main_speaker_lines = []
    _threshold = threshold - (estimated_num_of_spks - 2) * (threshold - 1) / (
        overlap_infer_speaker_limit - 2
    )

    for segment_index, cluster_label in enumerate(clustering_labels):
        speaker_for_segment = (msdd_preds[0, segment_index] > _threshold).int().tolist()
        softmax_predictions = msdd_preds[0, segment_index]

        main_speaker_index = torch.argmax(msdd_preds[0, segment_index]).item()

        if sum(speaker_for_segment) > 1 and infer_overlap:
            index_array = torch.argsort(softmax_predictions, descending=True)

            for overlap_speaker_index in index_array[: max_overlap_speakers].tolist():
                if overlap_speaker_index != int(main_speaker_index):
                    overlap_speaker_list[overlap_speaker_index].append(segment_index)

        main_speaker_lines.append((cluster_label[0], cluster_label[1], main_speaker_index))

    contiguous_stamps = get_contiguous_stamps(main_speaker_lines)
    main_labels = merge_stamps(contiguous_stamps)

    overlap_labels = get_overlap_stamps(contiguous_stamps, overlap_speaker_list)

    return main_labels, overlap_labels


def make_rttm_with_overlap(
    clustering_labels: List[Union[float, int]],
    msdd_preds: torch.Tensor,
):
    """
    """
    main_labels, overlap_labels = generate_speaker_timestamps(clustering_labels, msdd_preds)

    # _hypothesis_labels = main_labels + overlap_labels
    _hypothesis_labels = main_labels
    hypothesis_labels = sorted(_hypothesis_labels, key=lambda x: x[0])

    return hypothesis_labels

## Real Diarization process

In [127]:
max_num_speakers = 8
window_lengths, shift_lengths, multiscale_weights = (
    [1.5, 1.25, 1.0, 0.75, 0.5],
    [0.75, 0.625, 0.5, 0.375, 0.25],
    [1, 1, 1, 1, 1],
)
scale_dict = {k: (w, s) for k, (w, s) in enumerate(zip(window_lengths, shift_lengths))}

# VAD
waveform, _ = read_audio("./mono_file.wav")
vad_service = VadService()

vad_outputs, _ = vad_service(waveform, False)

# Segmentation
all_embeddings, all_timestamps, all_segment_indexes = [], [], []

scales = scale_dict.items()
for _, (window, shift) in scales:
    scale_segments = _run_segmentation(vad_outputs, window, shift)

    _embeddings, _timestamps = _extract_embeddings(waveform, scale_segments)

    if len(_embeddings) != len(_timestamps):
        raise ValueError("Mismatch of counts between embedding vectors and timestamps")

    all_embeddings.append(_embeddings)
    all_segment_indexes.append(_embeddings.shape[0])
    all_timestamps.append(torch.tensor(_timestamps))

multiscale_embeddings_and_timestamps = {
    "embeddings": torch.cat(all_embeddings, dim=0),
    "timestamps": torch.cat(all_timestamps, dim=0),
    "multiscale_segment_counts": torch.tensor(all_segment_indexes),
    "multiscale_weights": torch.tensor([1, 1, 1, 1, 1]).unsqueeze(0).float(),
}

# Clustering
clustering_params = dict(
    oracle_num_speakers=False,
    max_num_speakers=max_num_speakers,
    enhanced_count_thres=80,
    max_rp_threshold=0.25,
    sparse_search_volume=30,
    maj_vote_spk_count=False,
)
clustering_labels = perform_clustering(
    embs_and_timestamps=multiscale_embeddings_and_timestamps,
    clustering_params=clustering_params,
)

# Mapping between embeddings and timestamps on different scales
split_index = multiscale_embeddings_and_timestamps["multiscale_segment_counts"].tolist()
embeddings_in_scales = list(torch.split(
    multiscale_embeddings_and_timestamps["embeddings"], split_index, dim=0
))
timestamps_in_scales = list(torch.split(
    multiscale_embeddings_and_timestamps["timestamps"], split_index, dim=0
))
session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)

scale_mapping_argmat, emb_scale_seq_dict = {}, {}
for scale_idx in range(len(session_scale_mapping_list)):
    mapping_argmat = session_scale_mapping_list[scale_idx]
    scale_mapping_argmat[scale_idx] = mapping_argmat

    emb_scale_seq_dict[scale_idx] = embeddings_in_scales[scale_idx]

emb_sess_avg_dict = get_cluster_avg_embs(
    emb_scale_seq_dict, clustering_labels, session_scale_mapping_list, len(scale_dict), max_num_speakers
)

# MSDD algorithm
preds_list, targets_list, signal_lengths_list = [], [], []
dataset = AudioMSDDDataset(
    emb_sess_avg_dict=emb_sess_avg_dict,
    emb_scale_seq_dict=emb_scale_seq_dict,
    sess_scale_mapping_list=session_scale_mapping_list,
    clustering_labels=clustering_labels,
    scale_n=len(scale_dict),
)

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=1,
    collate_fn=msdd_infer_collate_fn,
    drop_last=False,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
)

for batch in dataloader:
    signals, signal_lengths, _, emb_vectors = batch

    # Convert data to float16
    signals = signals.half().to(msdd_model.device)
    signal_lengths = signal_lengths.half().to(msdd_model.device)
    emb_vectors = emb_vectors.half().to(msdd_model.device)

    with autocast():
        _preds, scale_weights = msdd_model.forward_infer(
            input_signal=signals,
            input_signal_length=signal_lengths,
            emb_vectors=emb_vectors,
            targets=None,
        )
        _preds = _preds.cpu().detach()
        scale_weights = scale_weights.cpu().detach()

    max_pred_length = max(_preds.shape[1], 0)
    preds = torch.zeros(_preds.shape[0], max_pred_length, _preds.shape[2])
    targets = torch.zeros(_preds.shape[0], max_pred_length, _preds.shape[2])

    preds[:, : _preds.shape[1], :] = _preds

all_hypothesis = make_rttm_with_overlap(clustering_labels, preds)

contiguous_cluster = get_contiguous_stamps(clustering_labels)
last_cluster = merge_stamps(contiguous_cluster)

# print(len(all_hypothesis), len(last_cluster))
# for h, c in zip(all_hypothesis, last_cluster):
#     start, end, speaker = h
#     print(f"{h} | {c}")

0 1.5 0.75
1 1.25 0.625
2 1.0 0.5
3 0.75 0.375
4 0.5 0.25
