# Speaker Clustering: Torch Scripted Module 

Provide the NeMo path to `NEMO_BRANCH_PATH`.

In [4]:
import sys
# NEMO_BRANCH_PATH = '/your/path/to/diar_torch/NeMo/'
NEMO_BRANCH_PATH = '/home/taejinp/projects/diar_torch/NeMo/'
sys.path.insert(0, NEMO_BRANCH_PATH)
import nemo
print("Check NeMo PATH:", nemo.__path__)


Check NeMo PATH: ['/home/taejinp/projects/diar_torch/NeMo/nemo']


In [5]:
import torch
import time

In [6]:
from nemo.collections.asr.parts.utils.nmesc_clustering import SpeakerClustering

Download an example input dictionary file `uniq_embs_and_timestamps`:
https://drive.google.com/file/d/1249CTH6FgFbioBY1KkPOuA996LfbgqP_/view?usp=sharing

Please save it to your local path such as:   
`example_file_path = "/home/taejinp/Downloads/uniq_embs_and_timestamps.pt"`

This file has been created using scale lengths of :   
`[1.5, 1.25, 1.0, 0.75, 0.5]`  
and shift length of :  
`[0.75, 0.75, 0.5, 0.375, 0.25]`  
Scale indexes are:   
`[0, 1, 2, 3, 4]`  

Base scale is the finest (shortest) scale which is also the unit of decision.   
In this example, base scale index is `4`.  
In this example, base scale has length of 0.5 second and shift length (hop length) of 0.25 second.  

`multiscale_segment_counts` variable is needed for splitting into multiscale tensors.  
`embeddings_in_scales` contains concatenated embeddings of 5 scales.   
`timestamps_in_scales` contains concatenated timestamps of 5 scales.   

jit scripted module `speaker_clustering` splits the input tensors into `scale_n` number of tensors before run clustering. 
Each of scale contains torch.tensors with different sizes. Check out the following example.

In [7]:
example_file_path = "/home/taejinp/Downloads/uniq_embs_and_timestamps.pt"
uniq_embs_and_timestamps = torch.load(example_file_path)

# Python Dictionary indexed (keys) by integer, values are torch.Tensor
multiscale_segment_counts = uniq_embs_and_timestamps['multiscale_segment_counts']
embeddings_in_scales = uniq_embs_and_timestamps['embeddings']
timestamps_in_scales = uniq_embs_and_timestamps['time_stamps']

# Multiscale segment counts for each scale
# Dimension: (Number of index-2 (3rd) scale segmensts) x 2 (start and end time stamps)
print(f"Segment counts of {len(multiscale_segment_counts)} scales")
print(multiscale_segment_counts)

# Dimension: (Number of index-4 (5th) scale segmensts) x (embedding dimension, 192 in this case)
print(type(embeddings_in_scales[4]))
print(embeddings_in_scales[4].shape)

# Dimension: (Number of index-4 (5th) scale segmensts) x 2 (start and end time stamps)
print(type(timestamps_in_scales[4]))
print(timestamps_in_scales[4].shape)


# Dimension: (Number of index-2 (3rd) scale segmensts) x (embedding dimension, 192 in this case)
print(type(embeddings_in_scales[2]))
print(embeddings_in_scales[2].shape)

# Dimension: (Number of index-2 (3rd) scale segmensts) x 2 (start and end time stamps)
print(type(timestamps_in_scales[2]))
print(timestamps_in_scales[2].shape)



Segment counts of 5 scales
tensor([ 553,  609,  788, 1045, 1577])
<class 'torch.Tensor'>
torch.Size([192])
<class 'torch.Tensor'>
torch.Size([2])
<class 'torch.Tensor'>
torch.Size([192])
<class 'torch.Tensor'>
torch.Size([2])


In [8]:
# Setup a multiscale weight vector. Equal weights
device = torch.device("cuda")
multiscale_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).unsqueeze(0).to(device)
multiscale_weights.shape

torch.Size([1, 5])

Now, create a `SpeakerClustering` class instance and convert it to torch.jit.script module. This will create recursive script module since all the sub-fucntions used in this class is all torch.jit.script-decorated. 

First module is non-parallelized instance and the second one is parallelized. `parallelism=False` or `True`.

In [15]:
# sparse_search_volume=30
# sparse_search_volume=15
sparse_search_volume=5
max_num_speaker=8
max_rp_threshold=0.15

# Single thread, WITHOUT parallelism for searching the p-value parameter.
speaker_clustering_singthrd = SpeakerClustering(
            max_num_speaker=max_num_speaker,
            max_rp_threshold=max_rp_threshold,
            sparse_search_volume=sparse_search_volume,
            multiscale_weights=multiscale_weights,
            parallelism=False,
            cuda=True)
scripted_singthrd = torch.jit.script(speaker_clustering_singthrd).to(device)

# Multi thread, WITH parallelism for searching the p-value parameter.
speaker_clustering_multhrd = SpeakerClustering(
            max_num_speaker=max_num_speaker,
            max_rp_threshold=max_rp_threshold,
            sparse_search_volume=sparse_search_volume,
            multiscale_weights=multiscale_weights,
            parallelism=True,
            cuda=True)
scripted_multhrd = torch.jit.script(speaker_clustering_multhrd).to(device)


Now, run the speaker clustering model with the following line. 

- Input is all `torch.tensor` type. It will be split into each scale in the scripted module.
- You can check the clustered labels and estimated number of speakers. The output is also torch.tensor format.

Check out the speed gain from parallelism=True. The bigger the size of `sparse_search_volume`, the more the speed gain is. For example `sparse_search_volume=30`, it has approximately 30% speed gain.

In [16]:
start1 = time.time()
cluster_labels = scripted_singthrd.forward(
    embeddings_in_scales,
    timestamps_in_scales,
    multiscale_segment_counts,
    oracle_num_speakers=-1,
    )
print(f"\nSingle Thread ETA: {(time.time()-start1):.3f} sec")
print("cluster labels:", cluster_labels)
print("Set of speakers", set(cluster_labels.cpu().numpy().tolist()))

start2 = time.time()
cluster_labels = scripted_multhrd.forward(
    embeddings_in_scales,
    timestamps_in_scales,
    multiscale_segment_counts,
    oracle_num_speakers=-1,
    )
print(f"\nMulti Thread ETA: {(time.time()-start2):.3f} sec")

print("cluster labels:", cluster_labels)
print("Set of speakers", set(cluster_labels.cpu().numpy().tolist()))


Single Thread ETA: 1.052 sec
cluster labels: tensor([0, 1, 1,  ..., 0, 0, 0], device='cuda:0')
Set of speakers {0, 1}

Multi Thread ETA: 0.863 sec
cluster labels: tensor([0, 1, 1,  ..., 0, 0, 0], device='cuda:0')
Set of speakers {0, 1}


Save torch.jit.script module with `torch.jit.save`. Check out [this page](https://pytorch.org/docs/stable/generated/torch.jit.save.html) where `torch.jit.save` is explained as below.

> Save an offline version of this module for use in a separate process. The saved module serializes all of the methods, submodules, parameters, and attributes of this module. It can be loaded into the C++ API using torch::jit::load(filename) or into the Python API with torch.jit.load.


In [None]:
torch.jit.save(scripted_multhrd, 'speaker_clustering_multithread.pt')

# Speaker Clustering: Speed Test for Torch Scripted Module 

`addAnchorEmb()` Is a function that creates dummy speaker embedding. Let's create dummy embeddings to measure the speed of speaker clustering.

In [None]:
def addAnchorEmb(emb: torch.Tensor, anchor_sample_n: int, anchor_spk_n: int, sigma: float):
    emb_dim = 192
    std_org = torch.std(emb, dim=0)
    new_emb_list = []
    for _ in range(anchor_spk_n):
        emb_m = torch.tile(torch.randn(1, emb_dim), (anchor_sample_n, 1))
        emb_noise = torch.randn(anchor_sample_n, emb_dim).T
        emb_noise = torch.matmul(
            torch.diag(std_org), emb_noise / torch.max(torch.abs(emb_noise), dim=0)[0].unsqueeze(0)
        ).T
        emb_gen = emb_m + sigma * emb_noise
        new_emb_list.append(emb_gen)

#     new_emb_list.append(emb)
    new_emb_np = torch.vstack(new_emb_list)
    return new_emb_np

scale_n = 1
mat_size = 2**13

anchor_spk_n = 2

embeddings = addAnchorEmb(embeddings_in_scales, anchor_sample_n=int(mat_size/anchor_spk_n), anchor_spk_n=anchor_spk_n, sigma=1)
print("embeddings:\n", embeddings.shape)
embeddings_in_scales_gen = embeddings.tile((scale_n, 1))
print("embeddings_in_scales_gen.shape \n", embeddings_in_scales_gen.shape)

timestamps = torch.tensor([[float(stt/2), float(stt/2+1)] for stt in range(mat_size)])
timestamps_in_scales_gen = timestamps.tile((scale_n, 1))
print(embeddings_in_scales_gen)
print(timestamps_in_scales_gen)
print("timestamps_in_scales_gen.shape \n", timestamps_in_scales_gen.shape)
multiscale_segment_counts_gen = torch.tensor([mat_size for x in range(scale_n)])
print(multiscale_segment_counts_gen)


Measure the speed of speaker clustering jit scripted module. 

In [None]:
# Setup a multiscale weight vector. Equal weights
device = torch.device("cuda")
multiscale_weights = torch.tensor([1.0 for _ in range(scale_n)]).unsqueeze(0).to(device)

sparse_search_volume=10
max_num_speaker=8
max_rp_threshold=0.25
scale_n = 1
anchor_spk_n = 2
repeat_n = 10
eta_list_single = []
eta_list_mutli = []
mat_size_list = []
size_list=[4,5,6,7,8,9,10,11,12,13]

for pos in size_list:
    mat_size = int(2 ** pos)
    mat_size_list.append(mat_size)
    embeddings = addAnchorEmb(embeddings_in_scales, anchor_sample_n=int(mat_size/anchor_spk_n), anchor_spk_n=anchor_spk_n, sigma=1)
    embeddings_in_scales_gen = embeddings.tile((scale_n, 1))
    timestamps = torch.tensor([[float(stt/2), float(stt/2+1)] for stt in range(mat_size)])
    timestamps_in_scales_gen = timestamps.tile((scale_n, 1))
    multiscale_segment_counts_gen = torch.tensor([mat_size for x in range(scale_n)])


    # Single thread, WITHOUT parallelism for searching the p-value parameter.
    speaker_clustering_singthrd = SpeakerClustering(
                max_num_speaker=max_num_speaker,
                max_rp_threshold=max_rp_threshold,
                sparse_search_volume=sparse_search_volume,
                multiscale_weights=multiscale_weights,
                parallelism=False,
                cuda=True)
    scripted_singthrd = torch.jit.script(speaker_clustering_singthrd).to(device)

    print("Running segment volume with: \n", multiscale_segment_counts_gen)
    # Multi thread, WITH parallelism for searching the p-value parameter.
    speaker_clustering_multhrd = SpeakerClustering(
                max_num_speaker=max_num_speaker,
                max_rp_threshold=max_rp_threshold,
                sparse_search_volume=sparse_search_volume,
                multiscale_weights=multiscale_weights,
                parallelism=True,
                cuda=True)
    scripted_multhrd = torch.jit.script(speaker_clustering_multhrd).to(device)


    start_gen1 = time.time()
    for _ in range(repeat_n):
        cluster_labels = scripted_singthrd.forward(
            embeddings_in_scales_gen,
            timestamps_in_scales_gen,
            multiscale_segment_counts_gen,
            oracle_num_speakers=-1,
            )
    eta1 = (time.time()-start_gen1)/repeat_n
    print(f"\nSingle Thread ETA: {eta1:.3f} sec")
    eta_list_single.append(eta1)
    
    start_gen2 = time.time()
    for _ in range(repeat_n):
        cluster_labels = scripted_multhrd.forward(
            embeddings_in_scales_gen,
            timestamps_in_scales_gen,
            multiscale_segment_counts_gen,
            oracle_num_speakers=-1,
            )
    eta2 = (time.time()-start_gen2)/repeat_n
    print(f"\nMulti Thread ETA: {eta2:.3f} sec")
    eta_list_mutli.append(eta2)
    
    print("\n\n")

print("eta_list_single\n")
print(eta_list_single)
print("eta_list_multi\n")
print(eta_list_mutli)
print("mat_size_list\n")
print(mat_size_list)

Save the result in csv file format.

In [None]:
import csv
   
# field names 
fields = ['Matrix size', 'Single Thread', 'Multi Thread']
    
# data rows of csv file 
mat_size_list.insert(0,fields[0])
eta_list_single.insert(0, fields[1])
eta_list_mutli.insert(0, fields[2])

rows = [mat_size_list,
        eta_list_single,
        eta_list_mutli]
print(rows)
print(mat_size_list, eta_list_single, eta_list_mutli)
path = f"/home/taejinp/gdrive/result_data/clustering_speed_searchvol{sparse_search_volume}.csv"

with open(path, 'w') as f:
    # using csv.writer method from CSV package
    write = csv.writer(f)
    write.writerows(rows)