# NeMo ASR to ONNX Conversion for Triton Inference Server

## NeMo ASR ONNX Export

This notebook guides you step-by-step through converting a pretrained NeMo ASR model (`nvidia/stt_en_fastconformer_ctc_large`) into optimized ONNX format modules:

- **Preprocessor**: converts audio signal to Mel Spectrogram
- **ASR Acoustic Model**: generates logits from Mel Spectrogram
- **CTC Decoder**: extracts text from logits (kept in PyTorch due to limitations)

These modules are designed to be deployed independently using Triton Inference Server. Preprocessor and CTC decoder will use the PyTorch backend in Triton.

In [3]:
# Install dependencies (uncomment if needed)
# !pip install git+https://github.com/NVIDIA/NeMo.git@v2.2.0rc3#egg=nemo_toolkit[asr]
# !pip install onnxruntime-gpu==1.19.0 soundfile psutil

In [4]:
import numpy as np
import torch
import soundfile as sf
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor

## Step 1: Download and Initialize ASR Model

In [7]:
### Device cuda or cpu
device = 'cuda'
model_name = "nvidia/stt_en_fastconformer_ctc_large"
asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(
    model_name=model_name, map_location='cuda'
)
asr_model.eval();

[NeMo I 2025-03-25 12:14:49 mixins:181] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2025-03-25 12:14:50 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 1
    shuffle: true
    num_workers: 8
    pin_memory: true
    use_start_end_token: false
    trim_silence: false
    max_duration: 20
    min_duration: 0.1
    is_tarred: false
    tarred_audio_filepaths: null
    shuffle_n: 2048
    bucketing_strategy: fully_randomized
    bucketing_batch_size: null
    
[NeMo W 2025-03-25 12:14:50 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 32
    shuffle: false
    num_workers: 8
    pin_m

[NeMo I 2025-03-25 12:14:50 features:305] PADDING: 0
[NeMo I 2025-03-25 12:14:52 save_restore_connector:275] Model EncDecCTCModelBPE was successfully restored from /root/.cache/huggingface/hub/models--nvidia--stt_en_fastconformer_ctc_large/snapshots/5a84a7a3bee8d9bd414c6719ddfea7bc723e3961/stt_en_fastconformer_ctc_large.nemo.


In [8]:
!wget https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav

--2025-03-25 12:15:46--  https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav
Resolving dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)... 52.219.108.66, 52.219.176.138, 52.219.98.98, ...
Connecting to dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)|52.219.108.66|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 237964 (232K) [audio/wav]
Saving to: ‘2086-149220-0033.wav’


2025-03-25 12:15:48 (185 KB/s) - ‘2086-149220-0033.wav’ saved [237964/237964]



## Step 2: Preprocessor Module
Due to current limitations, we keep the Preprocessor as a PyTorch module.

In [30]:
preprocessor = AudioToMelSpectrogramPreprocessor(features=80)
preprocessor.to(device)

# Test preprocessor
audio, sr = sf.read('2086-149220-0033.wav')
audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)
print(processed_signal.shape, processed_signal_length.shape)

[NeMo I 2025-03-25 12:26:53 features:305] PADDING: 16
torch.Size([1, 80, 752]) torch.Size([1])


## Step 3: Export ASR Acoustic Model to ONNX

In [31]:
class InferenceSTTRu(torch.nn.Module):
    def __init__(self, model_inference):
        super().__init__()
        self.asr_model = model_inference

    def forward(self, processed_signal):
        return self.asr_model.forward_for_export(processed_signal)

stt_module = InferenceSTTRu(asr_model)
stt_module.eval();

In [32]:
with torch.no_grad():
    torch.onnx.export(
        stt_module,
        processed_signal,
        'model.onnx',
        export_params=True,
        input_names=["signal"],
        output_names=["output"],
        dynamic_axes={
            "signal": {0: "batch_size", 2: "sequence_length"},
            "output": {0: "batch_size", 1: "sequence_length"},
        },
    )

## Step 4: CTC Decoder
Due to current limitations, we keep the CTC decoder as a PyTorch module.

In [36]:
asr_model.decoding.strategy = 'greedy_batch'
ctc_decoder = asr_model.decoding.ctc_decoder_predictions_tensor

## Step 5: ONNX Model Inference Test

In [49]:
import onnxruntime
import psutil

session_options = onnxruntime.SessionOptions()
session_options.intra_op_num_threads = psutil.cpu_count(logical=True)
session_options.log_severity_level = 1
providers = ["CUDAExecutionProvider"]  # Change to CUDA if GPU is available

ort_session = onnxruntime.InferenceSession('model.onnx', session_options, providers);

2025-03-25 12:33:29.550425027 [I:onnxruntime:, inference_session.cc:583 TraceSessionOptions] Session Options {  execution_mode:0 execution_order:DEFAULT enable_profiling:0 optimized_model_filepath:"" enable_mem_pattern:1 enable_mem_reuse:1 enable_cpu_mem_arena:1 profile_file_prefix:onnxruntime_profile_ session_logid: session_log_severity_level:1 session_log_verbosity_level:0 max_num_graph_transformation_steps:10 graph_optimization_level:3 intra_op_param:OrtThreadPoolParams { thread_pool_size: 12 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str:  set_denormal_as_zero: 0 } inter_op_param:OrtThreadPoolParams { thread_pool_size: 0 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str:  set_denormal_as_zero: 0 } use_per_session_threads:1 thread_pool_allow_spinning:1 use_deterministic_compute:0 config_options: {  } }
2025-03-25 12:33:29.550521408 [I:onnxruntime:, inference_session.cc:491 ConstructorCommon] Creating an

2025-03-25 12:33:30.047512114 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasGeluFusion modified: 0 with status: OK
2025-03-25 12:33:30.048279810 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer SkipLayerNormFusion modified: 0 with status: OK
2025-03-25 12:33:30.049007836 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion modified: 0 with status: OK
2025-03-25 12:33:30.050179773 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer QuickGeluFusion modified: 1 with status: OK
2025-03-25 12:33:30.065821989 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasSoftmaxFusion modified: 0 with status: OK
2025-03-25 12:33:30.066459221 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer BiasDropoutFusion modified: 0 with status: OK
2025-03-25 12:33:30.067121876 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer MatMulScaleFusion modified: 0 with status: OK
2025-03-25 12:33:30

In [51]:
audio, sr = sf.read('2086-149220-0033.wav')

audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)

output = ort_session.run(None, {"signal": processed_signal.cpu().numpy()})

pred_text = ctc_decoder(torch.from_numpy(output[0]))[0].text
print(pred_text)

well i don't wish to see it any more observed phoebe turning away her eyes it is certainly very like the old portrait


## Results

In [60]:
text, pred_text

("well i don't wish to see it any more observed phoebe turning away her eyes it is certainly very like the old portrait",
 "well i don't wish to see it any more observed phoebe turning away her eyes it is certainly very like the old portrait")

## Performance Comparison

Performance measured using %%timeit:

- **Original PyTorch inference:**


In [None]:
%%timeit -n 10 -r 10
text = asr_model.transcribe(['2086-149220-0033.wav'])[0].text

### 96 ms ± 9.2 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)

- **ONNX optimized inference:**

In [None]:
%%timeit -n 10 -r 10
audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)

output = ort_session.run(None, {"signal": processed_signal.cpu().numpy()})

pred_text = ctc_decoder(torch.from_numpy(output[0]))[0].text

### 16.7 ms ± 1.41 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)

#### The optimized ONNX inference significantly outperforms the original PyTorch inference, delivering roughly **6x speed-up**.

x