# Speaker Clustering: Torch Scripted Module 

Provide the NeMo path to `NEMO_BRANCH_PATH`.

In [1]:
import sys
# NEMO_BRANCH_PATH = '/your/path/to/diar_torch/NeMo/'
NEMO_BRANCH_PATH = '/home/taejinp/projects/_streaming_mulspk_asr/NeMo/'
sys.path.insert(0, NEMO_BRANCH_PATH)
import nemo
print("Check NeMo PATH:", nemo.__path__)


Check NeMo PATH: ['/home/taejinp/projects/_streaming_mulspk_asr/NeMo/nemo']


In [2]:
import torch
import time

In [3]:
from nemo.collections.asr.parts.utils.online_clustering import OnlineSpeakerClustering
from nemo.collections.asr.parts.utils.speaker_utils import OnlineSegmentor
import nemo
print("Check NeMo PATH:", nemo.__path__)

[NeMo W 2023-02-10 16:57:53 optimizers:55] Apex was not found. Using the lamb or fused_adam optimizer will error out.
[NeMo W 2023-02-10 16:57:53 experimental:27] Module <class 'nemo.collections.asr.models.audio_to_audio_model.AudioToAudioModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-02-10 16:57:54 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
    
[NeMo W 2023-02-10 16:57:54 experimental:27] Module <class 'nemo.collections.asr.data.audio_to_audio.BaseAudioDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-02-10 16:57:54 experimental:27] Module <class 'nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset'> is experimental, not ready for production and is not fully supported. Use at your own ris

Check NeMo PATH: ['/home/taejinp/projects/_streaming_mulspk_asr/NeMo/nemo']


In [4]:
online_segmentor = OnlineSegmentor(sample_rate=16000)

# Export and save torch jit script module.
online_segmentor = torch.jit.script(online_segmentor)
torch.jit.save(online_segmentor, 'online_segmentor.pt')
online_segmentor = torch.jit.load('online_segmentor.pt')

Setup a toy audio signal to check if online segmentor is working without an issue.

In [34]:
n_secs = 7
sample_rate = 16000
chunk_len = 1.0
buffer_pad_half = 2.0
buffer_len_sec = 5.0
device = torch.device("cuda")

signal_source = torch.rand(sample_rate*(2*n_secs),).to(device)
online_segmentor = OnlineSegmentor(sample_rate=sample_rate)


segment_raw_audio = []
segment_range_ts = []
segment_indexes = []

window = 0.5
shift = 0.25

# """
# Frame is in the middle of the buffer.
# |___Buffer___[___________]____________|
# |____________[   Frame   ]____________|
# | <- buffer start
# |____________| <- frame start
# """


We assumee that we have 5 second buffer.
in the middle of that buffer, we have 1 second frame (also called chunk in NeMo)
In the very first segmentation, 
self.buffer_start = 0
self.buffer_end = 5
self.frame_start = 2

And these values will be increased by 1 second per step.
This 5 second buffer behaves like a Queue: 
new segment comes from the right and we delete the leftmost 1 second. 

Now let's run a loop that simulates incoming `audio_buffer` from the source signal `signal_source`.

In [38]:
# simulation of segmentation mechanism
audio_buffer = torch.zeros_like(signal_source[sample_rate*0: int(sample_rate*(0+buffer_len_sec))])

for k in range(n_secs):
    print(f"======= index {k} =====")
    
    # VAD is assuming that all signal is speech. from k to (k+5)
    # This is in "seconds"
    vad_timestamps = torch.tensor([[k, (k+buffer_len_sec)]])
    
    # Fetch the time-series samples from signal source.
    # This is in "number of samples, integer".
    update_len = int(sample_rate * chunk_len)
    
    # We get 1 second samples (16000 samples) from the source audio
    incoming_chunk = signal_source[sample_rate*k: int(sample_rate*(k+chunk_len))]
    audio_buffer[:-update_len] = audio_buffer[update_len:].clone()
    audio_buffer[-update_len:] = incoming_chunk

    print("audio buffer shape", audio_buffer.shape, len(audio_buffer))

    # [ Important! ] at every step, we need to feed frame start, buffer start and buffer end
    # This is implemented in _transfer_timestamps_to_segmentor() function in NeMo
    # This is in "seconds"
    online_segmentor.frame_start = (k+buffer_pad_half)*1.0
    online_segmentor.buffer_start = k*1.0
    online_segmentor.buffer_end = (k+buffer_len_sec)*1.0
    
    print("frame_start:", online_segmentor.frame_start, 
          "buffer_start:", online_segmentor.buffer_start, 
          "buffer_end:", online_segmentor.buffer_end)
    
    audio_sigs, segment_ranges, range_inds = online_segmentor.run_online_segmentation(
        audio_buffer=audio_buffer,
        vad_timestamps=vad_timestamps,
        segment_raw_audio=segment_raw_audio,
        segment_range_ts=segment_range_ts,
        segment_indexes=segment_indexes,
        window=window,
        shift=shift,
    )
    print("segment ranges time stamps", segment_ranges)

    
    # We do this scale-for-scale in the real implementation.
    segment_raw_audio = audio_sigs # Saves time-series signal to the online_diarizer memory
    segment_range_ts = segment_ranges # Saves segment start-end time to the online diarizer memory
    segment_indexes = range_inds # Saves segment index to the online diarizer memory
    
# Check out the segments from online segmentor module
print("Final segment indexes", segment_indexes)
print("Final segment ranges time stamps", segment_range_ts)


audio buffer shape torch.Size([80000]) 80000
frame_start: 2.0 buffer_start: 0.0 buffer_end: 5.0
segment ranges time stamps [[0.0, 0.5], [0.25, 0.75], [0.5, 1.0], [0.75, 1.25], [1.0, 1.5], [1.25, 1.75], [1.5, 2.0]]
audio buffer shape torch.Size([80000]) 80000
frame_start: 3.0 buffer_start: 1.0 buffer_end: 6.0
segment ranges time stamps [[0.0, 0.5], [0.25, 0.75], [0.5, 1.0], [0.75, 1.25], [1.0, 1.5], [1.25, 1.75], [1.5, 2.0], [3.0, 3.5], [3.25, 3.75], [3.5, 4.0], [3.75, 4.25], [4.0, 4.5], [4.25, 4.75], [4.5, 5.0]]
audio buffer shape torch.Size([80000]) 80000
frame_start: 4.0 buffer_start: 2.0 buffer_end: 7.0
segment ranges time stamps [[0.0, 0.5], [0.25, 0.75], [0.5, 1.0], [0.75, 1.25], [1.0, 1.5], [1.25, 1.75], [1.5, 2.0], [3.0, 3.5], [3.25, 3.75], [3.5, 4.0], [3.75, 4.25], [4.0, 4.5], [4.25, 4.75], [4.5, 5.0], [4.75, 5.25], [5.0, 5.5], [5.25, 5.75], [5.5, 6.0]]
audio buffer shape torch.Size([80000]) 80000
frame_start: 5.0 buffer_start: 3.0 buffer_end: 8.0
segment ranges time stamps [[0

Now that we checked segmentor, let's work on online clustering module.

In [39]:
online_clus = OnlineSpeakerClustering(
    max_num_speakers=4,
    max_rp_threshold=0.15,
    sparse_search_volume=5,
    history_buffer_size=100,
    current_buffer_size=100,
)

# Export and save torch jit script module.
online_clus = torch.jit.script(online_clus)
torch.jit.save(online_clus, 'online_clus.pt')
online_clus = torch.jit.load('online_clus.pt')

The following is a script to test clustering algorithm with a toy data.
We can quickly check that scripted module is working without a problem.

In [40]:
from nemo.collections.asr.data.audio_to_label import repeat_signal
from nemo.collections.asr.parts.utils.offline_clustering import (
    get_scale_interpolated_embs,
    getCosAffinityMatrix,
    split_input_data,
)
from nemo.collections.asr.parts.utils.online_clustering import (
    OnlineSpeakerClustering,
    get_closest_embeddings,
    get_merge_quantity,
    get_minimal_indices,
    merge_vectors,
    run_reducer,
    stitch_cluster_labels,
)
from nemo.collections.asr.parts.utils.speaker_utils import (
    check_ranges,
    get_new_cursor_for_update,
    get_online_subsegments_from_buffer,
    get_speech_labels_for_update,
    get_subsegments,
    get_target_sig,
    merge_float_intervals,
    merge_int_intervals,
)

import nemo
print("Check NeMo PATH:", nemo.__path__)


import numpy as np

def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim):
    """Generate a set of artificial orthogonal embedding vectors from random numbers
    """
    gaus = torch.randn(emb_dim, emb_dim)
    _svd = torch.linalg.svd(gaus)
    orth = _svd[0] @ _svd[2]
    orth_embs = orth[:total_spks]
    # Assert orthogonality
    assert torch.abs(getCosAffinityMatrix(orth_embs) - torch.diag(torch.ones(total_spks))).sum() < 1e-4
    return orth_embs


def generate_toy_data(
    n_spks=2,
    spk_dur=3,
    emb_dim=192,
    perturb_sigma=0.0,
    ms_window=[1.5, 1.0, 0.5],
    ms_shift=[0.75, 0.5, 0.25],
    torch_seed=0,
):
    """Generate a toy data to test clustering algorithms
    """
    torch.manual_seed(torch_seed)
    spk_timestamps = [(spk_dur * k, spk_dur) for k in range(n_spks)]
    emb_list, seg_list = [], []
    multiscale_segment_counts = [0 for _ in range(len(ms_window))]
    ground_truth = []
    random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim)
    for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)):
        for spk_idx, (offset, dur) in enumerate(spk_timestamps):
            segments_stt_dur = get_subsegments(offset=offset, window=window, shift=shift, duration=dur)
            segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur]
            emb_cent = random_orthogonal_embs[spk_idx, :]
            emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim)
            seg_list.extend(segments)
            emb_list.append(emb)
            multiscale_segment_counts[scale_idx] += emb.shape[0]

            if scale_idx == len(multiscale_segment_counts) - 1:
                ground_truth.extend([spk_idx] * emb.shape[0])

    emb_tensor = torch.concat(emb_list)
    multiscale_segment_counts = torch.tensor(multiscale_segment_counts)
    segm_tensor = torch.tensor(seg_list)
    multiscale_weights = torch.ones(len(ms_window)).unsqueeze(0)
    ground_truth = torch.tensor(ground_truth)
    return emb_tensor, segm_tensor, multiscale_segment_counts, multiscale_weights, spk_timestamps, ground_truth


def test_online_speaker_clustering(n_spks, total_sec, buffer_size, sigma, seed):
    """Test online speaker clustering algorithm with toy data parameters
    """
    step_per_frame = 2
    spk_dur = total_sec / n_spks
    em, ts, mc, _, _, gt = generate_toy_data(n_spks, spk_dur=spk_dur, perturb_sigma=sigma, torch_seed=seed)
    em_s, ts_s = split_input_data(em, ts, mc)

    emb_gen = em_s[-1]
    segment_indexes = ts_s[-1]
    if torch.cuda.is_available():
        emb_gen, segment_indexes = emb_gen.to("cuda"), segment_indexes.to("cuda")
        cuda = True
    else:
        cuda = False

    history_buffer_size = buffer_size
    current_buffer_size = buffer_size

    online_clus = OnlineSpeakerClustering(
        max_num_speakers=8,
        max_rp_threshold=0.15,
        sparse_search_volume=30,
        history_buffer_size=history_buffer_size,
        current_buffer_size=current_buffer_size,
    )
    n_frames = int(emb_gen.shape[0] / step_per_frame)
    evaluation_list = []

    # Simulate online speaker clustering
    for frame_index in range(n_frames):
        curr_emb = emb_gen[0 : (frame_index + 1) * step_per_frame]
        base_segment_indexes = np.arange(curr_emb.shape[0])
        
        curr_emb = torch.tensor(curr_emb)
        base_segment_indexes = torch.tensor(base_segment_indexes)

        # Call clustering function
        merged_clus_labels = online_clus.forward(curr_emb=curr_emb, 
                                                 base_segment_indexes=base_segment_indexes, 
                                                 max_num_speakers=4,
                                                 max_rp_threshold=0.15,
                                                 enhanced_count_thres=40,
                                                 sparse_search_volume=5,
                                                 frame_index=frame_index, cuda=cuda)

        gt = gt.to(merged_clus_labels.device)
        
        # Fix permutation to evaluatae clustering labels
        merged_clus_labels = stitch_cluster_labels(Y_old=gt[: len(merged_clus_labels)], Y_new=merged_clus_labels)

        evaluation_list.extend(list(merged_clus_labels == gt[: len(merged_clus_labels)]))
        cumul_label_acc = sum(evaluation_list) / len(evaluation_list)
        print(f"Running Cumulative Label Acc. index-{frame_index} {100*cumul_label_acc.item():.4f}% Acc.")
        
    assert online_clus.is_online
    cumul_label_acc = sum(evaluation_list) / len(evaluation_list)
    print("Final cumulative label accuracy", cumul_label_acc)




Check NeMo PATH: ['/home/taejinp/projects/_streaming_mulspk_asr/NeMo/nemo']


Test if online speaker clustering works without error. 

In [41]:
test_online_speaker_clustering(n_spks=2, 
                               total_sec=30, 
                               buffer_size=40, 
                               sigma=0.05, 
                               seed=0)

      curr_emb = torch.tensor(curr_emb)
    


Running Cumulative Label Acc. index-0 100.0000% Acc.
Running Cumulative Label Acc. index-1 100.0000% Acc.
Running Cumulative Label Acc. index-2 100.0000% Acc.
Running Cumulative Label Acc. index-3 100.0000% Acc.
Running Cumulative Label Acc. index-4 100.0000% Acc.
Running Cumulative Label Acc. index-5 100.0000% Acc.
Running Cumulative Label Acc. index-6 100.0000% Acc.
Running Cumulative Label Acc. index-7 100.0000% Acc.
Running Cumulative Label Acc. index-8 100.0000% Acc.
Running Cumulative Label Acc. index-9 100.0000% Acc.
Running Cumulative Label Acc. index-10 100.0000% Acc.
Running Cumulative Label Acc. index-11 100.0000% Acc.
Running Cumulative Label Acc. index-12 100.0000% Acc.
Running Cumulative Label Acc. index-13 100.0000% Acc.
Running Cumulative Label Acc. index-14 100.0000% Acc.
Running Cumulative Label Acc. index-15 100.0000% Acc.
Running Cumulative Label Acc. index-16 100.0000% Acc.
Running Cumulative Label Acc. index-17 100.0000% Acc.
Running Cumulative Label Acc. index-18

If overall accuracy is close to 1.0, then online clustering algorithm is working well. It gets affected by buffer size. Bigger buffer size leads to better performance but takes longer time.