# NeMo ASR to ONNX Conversion for Triton Inference Server

## NeMo ASR ONNX Export

This notebook guides you step-by-step through converting a pretrained NeMo ASR model (`nvidia/stt_en_fastconformer_ctc_large`) into optimized ONNX format modules:

- **Preprocessor**: converts audio signal to Mel Spectrogram (kept in PyTorch due to limitations)
- **ASR Acoustic Model**: generates logits from Mel Spectrogram
- **CTC Decoder**: extracts text from logits (kept in PyTorch due to limitations)

These modules are designed to be deployed independently using Triton Inference Server. Preprocessor and CTC decoder will use the PyTorch backend in Triton.

In [None]:
# Install dependencies (uncomment if needed)
# !pip install git+https://github.com/NVIDIA/NeMo.git@v2.2.0rc3#egg=nemo_toolkit[asr]
# !pip install onnxruntime-gpu==1.19.0 soundfile psutil

In [None]:
import numpy as np
import torch
import soundfile as sf
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor

## Step 1: Download and Initialize ASR Model

In [None]:
### Device cuda or cpu
device = 'cuda'
model_name = "nvidia/stt_en_fastconformer_ctc_large"
asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(
    model_name=model_name, map_location='cuda'
)
asr_model.eval();

In [None]:
!wget https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav

## Step 2: Preprocessor Module
Due to current limitations, we keep the Preprocessor as a PyTorch module.

In [None]:
preprocessor = AudioToMelSpectrogramPreprocessor(features=80)
preprocessor.to(device)

# Test preprocessor
audio, sr = sf.read('2086-149220-0033.wav')
audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)
print(processed_signal.shape, processed_signal_length.shape)

## Step 3: Export ASR Acoustic Model to ONNX

In [None]:
class InferenceSTTEn(torch.nn.Module):
    def __init__(self, model_inference):
        super().__init__()
        self.asr_model = model_inference

    def forward(self, processed_signal):
        return self.asr_model.forward_for_export(processed_signal)

stt_module = InferenceSTTEn(asr_model)
stt_module.eval();

In [None]:
with torch.no_grad():
    torch.onnx.export(
        stt_module,
        processed_signal,
        'model.onnx',
        export_params=True,
        input_names=["signal"],
        output_names=["output"],
        dynamic_axes={
            "signal": {0: "batch_size", 2: "sequence_length"},
            "output": {0: "batch_size", 1: "sequence_length"},
        },
    )

## Step 4: CTC Decoder
Due to current limitations, we keep the CTC decoder as a PyTorch module.

In [None]:
asr_model.decoding.strategy = 'greedy_batch'
ctc_decoder = asr_model.decoding.ctc_decoder_predictions_tensor

## Step 5: ONNX Model Inference Test

In [None]:
import onnxruntime
import psutil

session_options = onnxruntime.SessionOptions()
session_options.intra_op_num_threads = psutil.cpu_count(logical=True)
session_options.log_severity_level = 1
providers = ["CUDAExecutionProvider"]  # Change to CUDA if GPU is available

ort_session = onnxruntime.InferenceSession('model.onnx', session_options, providers);

In [None]:
### Pipeline with ONNX model
audio, sr = sf.read('2086-149220-0033.wav')

audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)

output = ort_session.run(None, {"signal": processed_signal.cpu().numpy()})

pred_text = ctc_decoder(torch.from_numpy(output[0]))[0].text
print(pred_text)

In [None]:
text = asr_model.transcribe(['2086-149220-0033.wav'])[0].text

## Results

In [None]:
print(f"NeMo text: {text}")
print(f"ONNX text: {pred_text}")

## Performance Comparison

Performance measured using %%timeit:

- **Original PyTorch inference:**


In [None]:
%%timeit -n 10 -r 10
text = asr_model.transcribe(['2086-149220-0033.wav'])[0].text

### 96 ms ± 9.2 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)

- **ONNX optimized inference:**

In [None]:
%%timeit -n 10 -r 10
audio_array = np.array([audio])
audio_signal = torch.from_numpy(audio_array).to(device)
audio_signal_len = torch.tensor([audio_signal.shape[1]] * audio_array.shape[0]).to(device)

processed_signal, processed_signal_length = preprocessor(input_signal = audio_signal, 
                                                         length= audio_signal_len)

output = ort_session.run(None, {"signal": processed_signal.cpu().numpy()})

pred_text = ctc_decoder(torch.from_numpy(output[0]), decoder_lengths=None)
pred_text = [i.text for i in pred_text]

### 16.7 ms ± 1.41 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)

#### The optimized ONNX inference significantly outperforms the original PyTorch inference, delivering roughly **6x speed-up**.

## Step 6: Packaging for Triton

Follow instructions for Triton deployment:

- [Triton Inference Server Documentation](https://github.com/triton-inference-server/server)
- [Convert ONNX model to TensorRT](NeMo_convert_ONNX_to_TensorRT.md)


## Next Steps
- Deploy modules to Triton.
- Benchmark performance.
- Integrate modules into your application pipeline.