[ASR] Add optimization util for linear sum assignment algorithm#6349
[ASR] Add optimization util for linear sum assignment algorithm#6349tango4j merged 37 commits intoNVIDIA-NeMo:mainfrom
Conversation
…diarization Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
for more information, see https://pre-commit.ci
nithinraok
left a comment
There was a problem hiding this comment.
minor review. Will do thorough review tomorrow.
Very neat improvement, need to understand better from my end.
| laplacian = laplacian.float().to(device) | ||
| else: | ||
| laplacian = laplacian.float().to(torch.device('cpu')) | ||
| laplacian = laplacian.float() |
|
|
||
|
|
||
| def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')) -> torch.Tensor: | ||
| def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device) -> torch.Tensor: |
There was a problem hiding this comment.
why cuda and device? Isn't only one sufficient
There was a problem hiding this comment.
This was added long back because there are users setting cuda=True but device=cpu.
This is adding some flexibility to avoid errors on such cases.
If we need to remove this, lt requires a speparate PR since this involves whole diarization pipeline.
| laplacian = laplacian.float().to(torch.device('cpu')) | ||
| laplacian = laplacian.float() |
There was a problem hiding this comment.
same here. laplacian.float() twice
| stacked = np.hstack((enc_P, enc_Q)) | ||
| cost = -1 * linear_kernel(stacked.T)[spk_count:, :spk_count] | ||
| row_ind, col_ind = linear_sum_assignment(cost) | ||
| PandQ_list: List[int] = [int(x.item()) for x in PandQ] |
There was a problem hiding this comment.
minor: mentioning dtype in variable name need to be avoided
There was a problem hiding this comment.
Makes sense since types are strictly annotated for jit script functions.
Fixed.
| marked (Tensor): 2D matrix containing the marked zeros. | ||
| """ | ||
|
|
||
| def __init__(self, cost_matrix): |
There was a problem hiding this comment.
minor, mention the dtype of cost_matrix here. Isn;t it necessary for jit scripting?
There was a problem hiding this comment.
If there is no type annotation, jit compiler think of it as torch.Tensor.
So in general if it is not torch.Tensor, type annotation is needed.
Added type annotations
| if cost_matrix.shape[1] < cost_matrix.shape[0]: | ||
| cost_matrix = cost_matrix.T | ||
| transposed = True | ||
| else: | ||
| transposed = False |
There was a problem hiding this comment.
why extra transposed variable, Use the same col < row condition below?
There was a problem hiding this comment.
This followed the original implementation in scipy.
If we don't use transposed variable, we need to create another variable to indicate that foo = cost_matrix.shape[1] < cost_matrix.shape[0].
| # Copyright (c) 2008 Brian M. Clapper <bmc@clapper.org>, Gael Varoquaux | ||
| # Author: Brian M. Clapper, Gael Varoquaux | ||
| # License: 3-clause BSD | ||
|
|
There was a problem hiding this comment.
Do we have only one optimization algorithm yet? Thinking if we should move other funcs to this file as well
There was a problem hiding this comment.
I think we can add other algorithms below this. (I mentioned "Linear Sum Assignment solver")
The copyright in the beginning of the code is the convention in the most of the project so I followed
nemo/collections/asr/metrics/der.py
Outdated
| for label in ref_labels: | ||
| start, end, speaker = label.split() | ||
| start, end = float(start), float(end) | ||
| # If the current [start, end] interval is latching the last prediction time |
There was a problem hiding this comment.
Changed the expression (Checked by Elena)
Signed-off-by: Taejin Park <tango4j@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Taejin Park <tango4j@gmail.com>
…fix/clus_spk_util_jit
…IA-NeMo#6349) * [ASR] Add optimization utils for cpWER, diarization training, online diarization Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed GPU/CPU issues for clustering Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed unreachable state Signed-off-by: Taejin Park <tango4j@gmail.com> * resolved jit script compile error for lsa algorithm Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed errors and bugs, checked tests Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed docstrings Signed-off-by: Taejin Park <tango4j@gmail.com> * Update changes on test files Signed-off-by: Taejin Park <tango4j@gmail.com> * Refactored functions Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding docstrings for the functions in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed wrong docstrings in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed a wrong docstring Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed np.array input to Tensor for LSA solver in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Added code-QL issues and unit-tests for der.py functions Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed print line in der.py Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed code QL redundant comparison Signed-off-by: Taejin Park <tango4j@gmail.com> * Fixed code QL issue Signed-off-by: Taejin Park <tango4j@gmail.com> * Added License for the reference code Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added full license text of the original code Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reflected comments Signed-off-by: Taejin Park <tango4j@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reflected review comments Signed-off-by: Taejin Park <tango4j@gmail.com> --------- Signed-off-by: Taejin Park <tango4j@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
What does this PR do ?
LSA problem solver is needed for the following tasks in NeMo:
(1) Permutation Invariant Loss (PIL) for diarization model training
(2) Label permutation matching for online speaker diarzation
(3) Concatenated minimum-permutation Word Error Rate (cp-WER) calculation
What is LSA solver algorithm? Google OR-tools LSA Solver
The NeMo linear_sum_assignment function is compared with
scipy.optimization.linear_sum_assingment.In the unit-test for NeMo LSA solver, the result is compared with the
scipyversion oflinear_sum_assignment.Removing
@torch.jit.scriptdecorator in speaker_utils.py since it creates type-errors when the code is not used for production purpose.Instead, all
torch.jit.scriptrequired classes and functions are tested intest_diar_utils.py.Take a look at these tests for checking jit_script = [True/False] and cuda = [True/False] (testing total 4 combinations)
Also refactored some of the functions in online diarization
online_clustering.py.Added a couple of functions in
der.pyfor online diarization DER calculation.der.py.Collection: [ASR]
Changelog
nemo/collections/asr/metrics/der.py
: replaced scipy LSA solver to NeMo LSA solver in
calculate_session_cpWERfunction.: Added two functions for online diarization evaluations:
get_partial_ref_labelsandget_online_DER_stats.nemo/collections/asr/models/online_diarizer.py
: Made
_perform_online_clusteringfunction simpler by movingget_reduced_matandmatch_labelsinto online clustering function.nemo/collections/asr/parts/utils/offline_clustering.py
: Added
laplacian = laplacian.float().to(torch.device('cpu'))to avoid jit-scripted module uses GPU even when CPU is specified or vice-versa. This behavior is always tested/checked intest_diar_utils.py.nemo/collections/asr/parts/utils/online_clustering.py
: replaced scipy LSA solver to NeMo LSA solver in
get_lsa_speaker_mappingfunction.: Modified the docstrings of
update_speaker_history_bufferto make the example easier.nemo/collections/asr/parts/utils/optimization_utils.py
: Fully torch-jit-scriptable, linear sum assignment problem solver class and function were added.
nemo/collections/asr/parts/utils/speaker_utils.py
: Removed
@torch.jit.scriptdecorators since this creates unnecessary warning messages and type related errors when used without scripting.tests/collections/asr/test_diar_metrics.py
: Added unit-tests for the newly added function
get_partial_ref_labelsandget_online_DER_stats.tests/collections/asr/test_diar_utils.py
: Added tests for offline clustering and online clustering for many different cases including:
[jit-script=True, cuda=True],
[jit-script=True, cuda=False],
[jit-script=False, cuda=True],
[jit-script=False, cuda=False] cases
which is using the torch-jit-scripted NeMo linear_sum_assignment function.
Usage
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.