# Speaker Clustering: Torch Scripted Module Example

Provide the NeMo path to `NEMO_BRANCH_PATH`.

In [1]:
import sys
# NEMO_BRANCH_PATH = '/your/path/to/diar_torch/NeMo/'
NEMO_BRANCH_PATH = '/home/taejinp/projects/diar_torch/NeMo/'
sys.path.insert(0, NEMO_BRANCH_PATH)
import nemo
print("Check NeMo PATH:", nemo.__path__)


Check NeMo PATH: ['/home/taejinp/projects/diar_torch/NeMo/nemo']


In [2]:
import torch
import time

In [3]:
from nemo.collections.asr.parts.utils.nmesc_clustering_export import SpeakerClustering

[NeMo W 2022-07-14 18:04:49 optimizers:55] Apex was not found. Using the lamb or fused_adam optimizer will error out.
    


Download an example input dictionary file `uniq_scale_dict`:
https://drive.google.com/file/d/1gQ7pqKnHk4v9zt52ECkJBeQL-aN38pEI/view?usp=sharing

Please save it to your local path such as:   
`example_file_path = "/home/taejinp/Downloads/uniq_scale_dict.pt"`

This file has been created using scale lengths of :   
`[1.5, 1.25, 1.0, 0.75, 0.5]`  
and shift length of :  
`[0.75, 0.75, 0.5, 0.375, 0.25]`  

Scale indexes are:   
`[0, 1, 2, 3, 4]`  

Base scale is the finest (shortest) scale which is also the unit of decision.   
In this example, base scale index is `4`.  
In this example, base scale has length of 0.5 second and shift length (hop length) of 0.25 second.  

Each of scale contains torch.tensors with different sizes. Check out the following example.

In [4]:
example_file_path = "/home/taejinp/Downloads/uniq_scale_dict.pt"
uniq_scale_dict = torch.load(example_file_path)


# It contains integer scale indexes
print("uniq_scale_dict keys:\n", uniq_scale_dict.keys())

# Each scale key contains (1) embeddings (2) time_stamps
print("Base scale contents\n", uniq_scale_dict[4].keys())

# Dimension: (Number of index-4 (5th) scale segmensts) x (embedding dimension, 192 in this case)
print(type(uniq_scale_dict[4]['embeddings']))
print(uniq_scale_dict[4]['embeddings'].shape)

# Dimension: (Number of index-4 (5th) scale segmensts) x 2 (start and end time stamps)
print(type(uniq_scale_dict[4]['time_stamps']))
print(uniq_scale_dict[4]['time_stamps'].shape)


# Dimension: (Number of index-2 (3rd) scale segmensts) x (embedding dimension, 192 in this case)
print(type(uniq_scale_dict[2]['embeddings']))
print(uniq_scale_dict[2]['embeddings'].shape)

# Dimension: (Number of index-2 (3rd) scale segmensts) x 2 (start and end time stamps)
print(type(uniq_scale_dict[2]['time_stamps']))
print(uniq_scale_dict[2]['time_stamps'].shape)

uniq_scale_dict keys:
 dict_keys([0, 1, 2, 3, 4])
Base scale contents
 dict_keys(['embeddings', 'time_stamps'])
<class 'torch.Tensor'>
torch.Size([1577, 192])
<class 'torch.Tensor'>
torch.Size([1577, 2])
<class 'torch.Tensor'>
torch.Size([788, 192])
<class 'torch.Tensor'>
torch.Size([788, 2])


In [5]:
# Setup a multiscale weight vector. Equal weights
device = torch.device("cuda")
multiscale_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).unsqueeze(0).to(device)
multiscale_weights.shape

torch.Size([1, 5])

Now, create a `SpeakerClustering` class instance and convert it to torch.jit.script module. This will create recursive script module since all the sub-fucntions used in this class is all torch.jit.script-decorated. 

First module is non-parallelized instance and the second one is parallelized. `parallelism=False` or `True`.

In [33]:
device = torch.device("cuda")

sparse_search_volume=30
max_num_speaker=8
max_rp_threshold=0.15

# Single thread, WITHOUT parallelism for searching the p-value parameter.
speaker_clustering_singthrd = SpeakerClustering(
            max_num_speaker=max_num_speaker,
            max_rp_threshold=max_rp_threshold,
            sparse_search_volume=sparse_search_volume,
            multiscale_weights=multiscale_weights,
            parallelism=False,
            cuda=True)
scripted_singthrd = torch.jit.script(speaker_clustering_singthrd)

# Multi thread, WITH parallelism for searching the p-value parameter.
speaker_clustering_multhrd = SpeakerClustering(
            max_num_speaker=max_num_speaker,
            max_rp_threshold=max_rp_threshold,
            sparse_search_volume=sparse_search_volume,
            multiscale_weights=multiscale_weights,
            parallelism=True,
            cuda=True)
scripted_multhrd = torch.jit.script(speaker_clustering_multhrd)


Now, run the speaker clustering model with the following line. You can check the clustered labels and estimated number of speakers. The output is also torch.tensor format.

Check out the speed gain from parallelism=True. The bigger the size of `sparse_search_volume`, the more the speed gain is. For example `sparse_search_volume=30`, it has approximately 30% speed gain.

In [34]:
start1 = time.time()
cluster_labels = scripted_singthrd.forward(
    uniq_scale_dict,
    oracle_num_speakers=-1,
    )
print(f"\nSingle Thread ETA: {(time.time()-start1):.3f} sec")
print("cluster labels:", cluster_labels)
print("Set of speakers", set(cluster_labels.cpu().numpy().tolist()))

start2 = time.time()
cluster_labels = scripted_multhrd.forward(
    uniq_scale_dict,
    oracle_num_speakers=-1,
    )
print(f"\nMulti Thread ETA: {(time.time()-start2):.3f} sec")

print("cluster labels:", cluster_labels)
print("Set of speakers", set(cluster_labels.cpu().numpy().tolist()))


Single Thread ETA: 4.728 sec
cluster labels: tensor([0, 1, 1,  ..., 0, 0, 0], device='cuda:0')
Set of speakers {0, 1}

Multi Thread ETA: 2.555 sec
cluster labels: tensor([0, 1, 1,  ..., 0, 0, 0], device='cuda:0')
Set of speakers {0, 1}
