Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding long-form audio speaker diarization (clustering) class and functions #7737

Merged
merged 40 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3923de4
Adding long-form audio clustering for diarization
tango4j Oct 14, 2023
9829949
Adding unit test changes
tango4j Oct 16, 2023
f9c6141
Merge branch 'NVIDIA:main' into long_clus
tango4j Oct 16, 2023
26c61c4
Added tests for torch jit script
tango4j Oct 16, 2023
51904c6
Merge branch 'long_clus' of https://github.com/tango4j/NeMo into long…
tango4j Oct 16, 2023
67883f5
Added variable value checking line
tango4j Oct 17, 2023
15ab8cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
8f537f8
Added needed params to all yamls
tango4j Oct 17, 2023
dcaf06a
Consolidated long-form and short-form clustering methods
tango4j Oct 18, 2023
c739609
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2023
7ac2ecc
Merge remote-tracking branch 'origin' into long_clus
tango4j Oct 18, 2023
f62184a
Merged latest main and updated speaker utils
tango4j Oct 18, 2023
e7ce447
Fixed code formatting error in speaker_utils.py
tango4j Oct 18, 2023
f8ac688
Some minor fixes for doc-strings
tango4j Oct 18, 2023
31f57d2
Removed unnecessary comments
tango4j Oct 18, 2023
e810223
Merge branch 'main' into long_clus
stevehuang52 Oct 20, 2023
3a5b4f2
Merge branch 'main' into long_clus
tango4j Oct 20, 2023
e60c16a
Refelcted comments and made changes
tango4j Oct 26, 2023
319d2d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2023
381f6c1
Merge branch 'main' into long_clus
tango4j Oct 26, 2023
57bec0e
Minor changes on typos and comments
tango4j Oct 26, 2023
72869a6
Minor changes on typos and comments
tango4j Oct 26, 2023
4871767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2023
6bbc0a8
Merge branch 'main' into long_clus
tango4j Oct 26, 2023
b1124bb
Fixes for code QL
tango4j Oct 26, 2023
976121b
Merge branch 'main' into long_clus
tango4j Oct 26, 2023
20eb34a
Fixed docstring errors
tango4j Oct 26, 2023
afa7434
Merge branch 'long_clus' of https://github.com/tango4j/NeMo into long…
tango4j Oct 26, 2023
14774b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2023
5e5c896
Merge branch 'main' into long_clus
tango4j Oct 27, 2023
fe756d4
Merge branch 'main' into long_clus
tango4j Oct 30, 2023
4657c07
Reflected the second batch of comments
tango4j Nov 1, 2023
696c559
Merge branch 'long_clus' of https://github.com/tango4j/NeMo into long…
tango4j Nov 1, 2023
c90bce8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2023
e41acee
Updating all yamls for inference
tango4j Nov 1, 2023
bf7fe44
Merge branch 'long_clus' of https://github.com/tango4j/NeMo into long…
tango4j Nov 1, 2023
cd299ec
Added None-checker to forward to prevent type errors
tango4j Nov 1, 2023
2db9779
Merge branch 'main' into long_clus
tango4j Nov 1, 2023
22dee0a
Merge branch 'main' into long_clus
tango4j Nov 2, 2023
1416307
Merge branch 'main' into long_clus
nithinraok Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ diarizer:
shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
sub_cluster_n: 16 # Number of sub-clusters for each sub-window for long-form audio clustering.
unit_window_len: 10000 # Window length for the unit of chunking. Control according to the GPU-memory capacity.

clustering:
parameters:
Expand All @@ -52,7 +54,9 @@ diarizer:
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
sparse_search_volume: 10 # The higher the number, the more values will be examined with more time.
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.

sub_cluster_n: 50 # Number of sub-clusters for each sub-window for long-form audio clustering.
unit_window_len: 10000 # Window length for the unit of chunking. Control according to the GPU-memory capacity.
nithinraok marked this conversation as resolved.
Show resolved Hide resolved

msdd_model:
model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ diarizer:
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
sub_cluster_n: 50 # Number of sub-clusters for each sub-window for long-form audio clustering.
unit_window_len: 10000 # Window length for the unit of chunking. Control according to the GPU-memory capacity.

msdd_model:
model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
Expand Down Expand Up @@ -88,5 +90,4 @@ diarizer:
arpa_language_model: null # Provide a KenLM language model in .arpa format.
min_number_of_words: 3 # Min number of words for the left context.
max_number_of_words: 10 # Max number of words for the right context.
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.

logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ diarizer:
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
sub_cluster_n: 50 # Number of sub-clusters for each sub-window for long-form audio clustering.
unit_window_len: 10000 # Window length for the unit of chunking. Control according to the GPU-memory capacity.

msdd_model:
model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
Expand Down Expand Up @@ -88,5 +90,4 @@ diarizer:
arpa_language_model: null # Provide a KenLM language model in .arpa format.
min_number_of_words: 3 # Min number of words for the left context.
max_number_of_words: 10 # Max number of words for the right context.
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.

logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
76 changes: 37 additions & 39 deletions nemo/collections/asr/parts/utils/offline_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def __init__(
maj_vote_spk_count: bool = False,
parallelism: bool = False,
cuda: bool = False,
device=None,
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Clustering method for speaker diarization based on cosine similarity.
Expand Down Expand Up @@ -1149,7 +1150,10 @@ def __init__(
self.maj_vote_spk_count: bool = maj_vote_spk_count
self.embeddings_in_scales: List[torch.Tensor] = [torch.Tensor(0)]
self.timestamps_in_scales: List[torch.Tensor] = [torch.Tensor(0)]
self.device = torch.device("cuda") if self.cuda else torch.device("cpu")
if device is None or type(device) is not torch.device:
self.device = torch.device("cuda") if self.cuda else torch.device("cpu")
else:
self.device = device

def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor:
"""
Expand Down Expand Up @@ -1209,64 +1213,58 @@ def forward_infer(
kmeans_random_trials: int = 1,
) -> torch.LongTensor:
"""
Calculate affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best
p-value and perform spectral clustering based on the estimated p-value and the calculated affinity matrix.
Calculate the affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best
p-value, and perform spectral clustering based on the estimated p-value and the calculated affinity matrix.

Caution:
For the sake of compatibility with libtorch, python boolean `False` is replaced with `torch.LongTensor(-1)`.
For compatibility with libtorch, python boolean `False` has been replaced with `torch.LongTensor(-1)`.

Args:
Dict containing following keys associated with tensors.
embeddings (Tensor):
Concatenated Torch tensor containing embeddings in multiple scales
This tensor has dimensions of (Number of base segments) x (Embedding Dimension)
timestamps (Tensor):
Concatenated Torch tensor containing timestamps in multiple scales.
This tensor has dimensions of (Total number of segments all scales) x 2
embeddings_in_scales (Tensor):
List containing concatenated Torch tensor embeddings across multiple scales.
The length of the list is equal to the number of scales.
Each tensor has dimensions of (Number of base segments) x (Embedding Dimension).
timestamps_in_scales (Tensor):
List containing concatenated Torch tensor timestamps across multiple scales.
The length of the list is equal to the number of scales.
Each tensor has dimensions of (Total number of segments across all scales) x 2.
Example:
>>> timestamps_in_scales = \
>>> torch.tensor([0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]])

multiscale_segment_counts (LongTensor):
Concatenated Torch tensor containing number of segments per each scale
This tensor has dimensions of (Number of scales)
>>> timestamps_in_scales[0] = \
torch.Tensor([[0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]])
multiscale_segment_counts (torch.LongTensor):
A Torch tensor containing the number of segments for each scale.
The tensor has dimensions of (Number of scales).
Example:
>>> multiscale_segment_counts = torch.LongTensor([31, 52, 84, 105, 120])

multiscale_weights (Tensor):
Multi-scale weights that are used when affinity scores are merged.
multiscale_weights (torch.Tensor):
Multi-scale weights used when merging affinity scores.
Example:
>>> multiscale_weights = torch.tensor([1.4, 1.3, 1.2, 1.1, 1.0])

oracle_num_speakers (int):
The number of speakers in a session from the reference transcript
The number of speakers in a session as given by the reference transcript.
max_num_speakers (int):
The upper bound for the number of speakers in each session
The upper bound for the number of speakers in each session.
max_rp_threshold (float):
Limits the range of parameter search.
Clustering performance can vary depending on this range.
Default is 0.15.
The clustering performance can vary based on this range.
The default value is 0.15.
enhanced_count_thres (int):
For the short audio recordings, clustering algorithm cannot
accumulate enough amount of speaker profile for each cluster.
Thus, function `getEnhancedSpeakerCount` employs anchor embeddings
(dummy representations) to mitigate the effect of cluster sparsity.
enhanced_count_thres = 80 is recommended.
For shorter audio recordings, the clustering algorithm might not accumulate enough speaker profiles for each cluster.
Thus, the function `getEnhancedSpeakerCount` uses anchor embeddings (dummy representations) to mitigate the effects of cluster sparsity.
A value of 80 is recommended for `enhanced_count_thres`.
sparse_search_volume (int):
Number of p_values we search during NME analysis.
Default is 30. The lower the value, the faster NME-analysis becomes.
Lower than 20 might cause a poor parameter estimation.
The number of p_values considered during NME analysis.
The default is 30. Lower values speed up the NME-analysis but might lead to poorer parameter estimations. Values below 20 are not recommended.
fixed_thres (float):
If fixed_thres value is provided, NME-analysis process will be skipped.
This value should be optimized on a development set to obtain a quality result.
Default is None and performs NME-analysis to estimate the threshold.
If a `fixed_thres` value is provided, the NME-analysis process will be skipped.
This value should be optimized on a development set for best results.
By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold.
kmeans_random_trials (int):
Number of random trials for initializing k-means clustering. More trials
will result in a more stable clustering result. Default is 1.
The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1.

Returns:
Y (LongTensor):
Speaker labels for the segments in the given input embeddings.
Speaker labels for the segments in the provided input embeddings.
"""
self.embeddings_in_scales, self.timestamps_in_scales = split_input_data(
embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts
Expand Down