# ONNX model session tests (quantized models)

This notebook tests individual ONNX sessions (preprocessor, encoder, decoder_joint) using small synthetic inputs and runs each inference in a subprocess with a timeout to detect hangs. Edit the model paths below to point to your quantized encoder, decoder_joint and preprocessor ONNX files.

Workflow:
1. Print available providers
2. Inspect model inputs/outputs
3. Build small dummy inputs (replace dynamic dims with small constants)
4. Run each model in a subprocess with timeout and print status (ok/error/timeout)

In [1]:
# Imports and helpers
import time
import numpy as np
import onnxruntime
from pathlib import Path
from multiprocessing import Process, Queue

def available_providers():
    return onnxruntime.get_available_providers()

def _small_constant_for_dim(name, idx):
    # heuristics for picking small sizes for dynamic dims
    lname = name.lower()
    if 'length' in lname or 'len' in lname:
        return 1
    if 'signal' in lname or 'audio' in lname or 'samples' in lname:
        return 16000  # one second at 16k
    # common feature dims
    if idx == 1:
        return 257
    # time axis fallback
    return 10

def build_dummy_inputs_from_session(sess):
    inputs = {}
    for inp in sess.get_inputs():
        shape = inp.shape
        name = inp.name
        arr_shape = []
        for i, dim in enumerate(shape):
            if dim is None:
                arr_shape.append(_small_constant_for_dim(name, i))
            else:
                arr_shape.append(int(dim))
        dtype = np.float32 if 'float' in inp.type.lower() or 'tensor(float' in inp.type.lower() else np.int64
        if dtype == np.float32:
            inputs[name] = np.random.randn(*arr_shape).astype(np.float32)
        else:
            # lengths/indices
            inputs[name] = np.ones(arr_shape, dtype=np.int64)
    return inputs

def run_model_in_subprocess(model_path, input_dict, providers=None, timeout=10, use_serialized=False):
    """
    Load the model inside a subprocess and run session.run to detect hangs.
    Returns dict: {status: 'ok'|'error'|'timeout', result: ...}
    """
    def worker(q, model_path, input_dict, providers, use_serialized):
        try:
            import onnxruntime, onnx
            sess_options = onnxruntime.SessionOptions()
            sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
            # prefer file path when possible; some quantized models work via file
            if use_serialized:
                m = onnx.load(model_path)
                sess = onnxruntime.InferenceSession(m.SerializeToString(), providers=providers, sess_options=sess_options)
            else:
                sess = onnxruntime.InferenceSession(str(model_path), providers=providers, sess_options=sess_options)
            start = time.time()
            out = sess.run(None, input_dict)
            duration = time.time() - start
            # return lightweight metadata (shapes)
            out_meta = [{
                'idx': i,
                'shape': getattr(o, 'shape', np.array(o).shape),
                'dtype': str(type(o))
            } for i, o in enumerate(out)]
            q.put(('ok', {'duration': duration, 'out_meta': out_meta}))
        except Exception as e:
            q.put(('err', repr(e)))

    q = Queue()
    p = Process(target=worker, args=(q, str(model_path), input_dict, providers, use_serialized))
    p.start()
    p.join(timeout)
    if p.is_alive():
        p.terminate()
        return {'status': 'timeout'}
    if q.empty():
        return {'status': 'no_result'}
    status, payload = q.get()
    if status == 'ok':
        return {'status': 'ok', 'result': payload}
    return {'status': 'error', 'error': payload}

In [2]:
# Configure your model paths here
BASE = Path('.')
preprocessor_path = BASE / 'models' / 'preprocessor-stt_de_fastconformer_hybrid_large_pc.onnx'
encoder_path = BASE / 'quantized_models' / 'encoder-stt_de_fastconformer_hybrid_large_pc_qint8_not_per_channel.onnx'
decoder_joint_path = BASE / 'models' / 'decoder_joint-stt_de_fastconformer_hybrid_large_pc.onnx'

print('Available providers:', available_providers())
print('Preprocessor exists?', preprocessor_path.exists())
print('Encoder exists?', encoder_path.exists())
print('Decoder-joint exists?', decoder_joint_path.exists())

Available providers: ['AzureExecutionProvider', 'CPUExecutionProvider']
Preprocessor exists? True
Encoder exists? True
Decoder-joint exists? True


In [3]:
# Test preprocessor model
if preprocessor_path.exists():
    try:
        sess = onnxruntime.InferenceSession(str(preprocessor_path))
        print('Preprocessor inputs:')
        for inp in sess.get_inputs():
            print(' ', inp.name, inp.shape, inp.type)
        print('Preprocessor outputs:')
        for out in sess.get_outputs():
            print(' ', out.name, out.shape, out.type)

        dummy = build_dummy_inputs_from_session(sess)
        print('Running preprocessor in subprocess with timeout=10s...')
        res = run_model_in_subprocess(preprocessor_path, dummy, providers=None, timeout=10)
        print('Preprocessor result:', res)
    except Exception as e:
        print('Failed to inspect preprocessor session:', repr(e))
else:
    print('Preprocessor model file not found; update preprocessor_path')

Preprocessor inputs:
  input_signal ['input_signal_dynamic_axes_1', 'input_signal_dynamic_axes_2'] tensor(float)
  length ['length_dynamic_axes_1'] tensor(int64)
Preprocessor outputs:
  processed_signal ['Castprocessed_signal_dim_0', 80, 'Castprocessed_signal_dim_2'] tensor(float)
  processed_length ['length_dynamic_axes_1'] tensor(int64)
Failed to inspect preprocessor session: ValueError("invalid literal for int() with base 10: 'input_signal_dynamic_axes_1'")


In [4]:
# Test encoder model
if encoder_path.exists():
    try:
        sess_enc = onnxruntime.InferenceSession(str(encoder_path))
        print('Encoder inputs:')
        for inp in sess_enc.get_inputs():
            print(' ', inp.name, inp.shape, inp.type)
        print('Encoder outputs:')
        for out in sess_enc.get_outputs():
            print(' ', out.name, out.shape, out.type)

        enc_dummy = build_dummy_inputs_from_session(sess_enc)
        # If preprocessor produced features, you would feed them here; we just test encoder standalone
        print('Running encoder in subprocess with timeout=10s...')
        res_enc = run_model_in_subprocess(encoder_path, enc_dummy, providers=None, timeout=10)
        print('Encoder result:', res_enc)
    except Exception as e:
        print('Failed to inspect encoder session:', repr(e))
else:
    print('Encoder model file not found; update encoder_path')

Failed to inspect encoder session: NotImplemented("[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ConvInteger(10) node with name '/pre_encode/conv/conv.0/Conv_quant'")


In [None]:
# Test decoder_joint model (single-step)
if decoder_joint_path.exists():
    try:
        sess_dec = onnxruntime.InferenceSession(str(decoder_joint_path))
        print('Decoder-joint inputs:')
        for inp in sess_dec.get_inputs():
            print(' ', inp.name, inp.shape, inp.type)
        print('Decoder-joint outputs:')
        for out in sess_dec.get_outputs():
            print(' ', out.name, out.shape, out.type)

        dec_dummy = build_dummy_inputs_from_session(sess_dec)
        # Ensure common names for targets/length exist: if not, create minimal ones
        names = [i.name for i in sess_dec.get_inputs()]
        if any('targets' in n.lower() for n in names) and 'targets' not in dec_dummy:
            # try to fill a target placeholder
            first_input = sess_dec.get_inputs()[0]
            batch = 1
            dec_dummy.setdefault('targets', np.zeros((batch,1), dtype=np.int32))
        print('Running decoder-joint in subprocess with timeout=10s...')
        res_dec = run_model_in_subprocess(decoder_joint_path, dec_dummy, providers=None, timeout=10)
        print('Decoder-joint result:', res_dec)
    except Exception as e:
        print('Failed to inspect decoder-joint session:', repr(e))
else:
    print('Decoder-joint model file not found; update decoder_joint_path')

Next steps / troubleshooting tips:
- If one model times out, try re-running that model with different providers: CPU first (['CPUExecutionProvider']) and, if available, GPU providers.
- Try toggling session options: lower graph_optimization_level or disable optimizations to see if quantized ops cause problems.
- If a subprocess shows timeout but no exception, try running the same session.run in a fresh Python script (outside notebook) to reproduce.
- If preprocessor runs but encoder hangs, feed the produced preprocessor outputs to encoder to confirm exact input shapes.
- Enable ORT profiling in a separate invocation: sess_options.enable_profiling = True to capture runtime traces.