VoiceChat EA STT training reproducible features#15558
VoiceChat EA STT training reproducible features#15558ankitapasad wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
…ization, clean-up token ID init, and corresponding tests Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
| import os | ||
|
|
||
| import pytest | ||
| import torch |
Check notice
Code scanning / CodeQL
Unused import Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| from nemo.collections.common.tokenizers import AutoTokenizer | ||
| from nemo.collections.speechlm2.data.duplex_stt_dataset import DuplexSTTDataset | ||
| from nemo.collections.speechlm2.data.utils import get_pad_id |
Check notice
Code scanning / CodeQL
Unused import Note test
| train_batch = train_ds[cuts] | ||
| val_batch = val_ds[cuts] | ||
|
|
||
| train_targets = train_batch["audio_data"]["target_tokens"] |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| # Force aligner should be created but never called during validation | ||
| val_ds.force_aligner = MagicMock() | ||
| val_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| # Mock the force aligner to avoid loading wav2vec2 | ||
| train_ds.force_aligner = MagicMock() | ||
| train_ds.force_aligner.batch_force_align_user_audio.side_effect = lambda cuts, **kwargs: cuts | ||
| train_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| - is_mcq_cut_train / is_mcq_cut_val / is_asr_cut | ||
| """ | ||
|
|
||
| import pytest |
Check notice
Code scanning / CodeQL
Unused import Note test
What does this PR do ?
Adds following features to the dataset class to support VoiceChat EA STT training and fine-tuning
Collection: speechlm2
Usage
# Add a code snippet demonstrating how to use thisPR Type:
If you haven't finished some of the above items you can still open "Draft" PR.