In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T

import torch.nn as nn
import torch.nn.functional as F

import numpy as np

import onnx
import onnxruntime as ort



2.6.0+cu124


In [2]:
onnx_path = "ReDimNet_no_mel.onnx"
test_wave_file = "test00.wav"

## help functions

In [3]:
def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate=16000,
    n_fft=512,
    hop_length=160,
    n_mels=60,         # match whatever your model expects
    f_min=20.0,
    f_max=8000.0,
    preemphasis_alpha=0.97
):
    # 1) Normalize
    waveform = waveform / (waveform.abs().max() + 1e-8)
    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    waveform_preemph[:, 0] = waveform[:, 0]
    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)
    # 4) Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    return log_mel  # shape: [1, n_mels, frames]

## load model

In [4]:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(f"Loaded and checked ONNX model from: {onnx_path}")

Loaded and checked ONNX model from: ReDimNet_no_mel.onnx


## Inference Test

In [5]:

# Create an inference session
session = ort.InferenceSession(onnx_path)

# Usually we retrieve the first input & output name
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

#######################################
# 2) Load audio, get log-mel
#######################################
waveform, sr = torchaudio.load(test_wave_file)
# If multi-channel, downmix:
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

log_mel = waveform_to_logmel(waveform, sample_rate=sr)
# Insert a batch dimension => shape [B, 1, n_mels, frames]
log_mel = log_mel.unsqueeze(0)  # => [1, 1, n_mels, time_frames]

#######################################
# 3) ONNX Inference
#######################################
# Convert to NumPy for ONNX runtime
log_mel_np = log_mel.cpu().numpy()
# Run inference
outputs = session.run([output_name], {input_name: log_mel_np})
# outputs is a list; typically we want the first item
embedding = outputs[0]  # shape is [1, embedding_dim]

In [7]:
print(embedding.shape)

(1, 192)
