Skip to content

Commit

Permalink
move AED chunked infer script (#9367)
Browse files Browse the repository at this point in the history
Signed-off-by: stevehuang52 <heh@nvidia.com>
  • Loading branch information
stevehuang52 authored Jun 4, 2024
1 parent 63833cd commit df49143
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
8 changes: 6 additions & 2 deletions examples/asr/asr_chunked_inference/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Streaming / Buffered ASR
# Streaming / Buffered / Chunked ASR

Contained within this directory are scripts to perform streaming or buffered inference of audio files using CTC / Transducer ASR models.
Contained within this directory are scripts to perform streaming or buffered inference of audio files using CTC / Transducer ASR models, and chunked inference for MultitaskAED models (e.g., "nvidia/canary-1b").

## Difference between streaming and buffered ASR

Expand All @@ -9,3 +9,7 @@ While we primarily showcase the defaults of these models in buffering mode, note
If you reduce your chunk size, the latency for your first prediction is reduced, and the model appears to predict the text with shorter delay. On the other hand, since the amount of information in the chunk is reduced, it causes higher WER.

On the other hand, if you increase your chunk size, then the delay between spoken sentence and the transcription increases (this is buffered ASR). While the latency is increased, you are able to obtain more accurate transcripts since the model has more context to properly transcribe the text.

## Chunked Inference

For MultitaskAED models, we provide a script to perform chunked inference. This script will split the input audio into non-overlapping chunks and perform inference on each chunk. The script will then concatenate the results to provide the final transcript.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class TranscriptionConfig:

# Chunked configs
chunk_len_in_secs: float = 40.0 # Chunk length in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
model_stride: int = (
8 # Model downsampling factor, 8 for Citrinet and FasConformer models and 4 for Conformer models.
)

# Decoding strategy for MultitaskAED models
decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
Expand Down Expand Up @@ -209,7 +211,12 @@ def autocast(*args, **kwargs):
with autocast(dtype=amp_dtype):
with torch.no_grad():
hyps = get_buffered_pred_feat_multitaskAED(
frame_asr, model_cfg.preprocessor, model_stride_in_secs, asr_model.device, manifest, filepaths,
frame_asr,
model_cfg.preprocessor,
model_stride_in_secs,
asr_model.device,
manifest,
filepaths,
)

output_filename, pred_text_attr_name = write_transcription(
Expand Down

0 comments on commit df49143

Please sign in to comment.