# Define basic model with noMel


===============================================================

* build new noMel model based on base line
* run voice through the model and compare with baseline
* store to onnx (full32) and compare torch with onnx; 

===============================================================

In [1]:
%load_ext autoreload
%autoreload 2
## our utils
from utils.common_import import *
from utils.test_all_voices import *


2.6.0+cu124


* load related modules (import once by running ref notebooks and hide outputs)

In [2]:
%%capture --no-display          
import my_utils as myUtils
from play1_setBase_line_B0 import original_model,base_line_embedding

## New Model

we can see MelSpectrogram inside the model ; lets take it outside the model;


In [3]:
for name, module in original_model.named_children():
    print(name, "=>", module)

backbone => ReDimNet(
  (stem): Sequential(
    (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
    (2): to1d()
  )
  (stage0): Sequential(
    (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
    (1): to2d(f=60,c=10)
    (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
    (3): ConvBlock2d(
      (conv_block): ResBasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      

* default MODEL with MelSpectrogram outside the model

In [4]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)
model_no_mel.eval()



ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

# TORCH SIDE

In [5]:
def torch_inference(wav_path: str):
    # (a) Load audio
    waveform, sample_rate = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    target_sample_rate = 16000  # Force to 16kHz as per model requirements
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    # (b) Convert to log-mel
    log_mel = myUtils.waveform_to_logmel(waveform)
    print('feeding logmel shape:', log_mel.shape)

    # (c) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

run test:

In [6]:
torch_embedding = test_all_voices(
    extract_speaker_embedding_function = torch_inference,
    cosine_similarity_function = myUtils.cosine_similarity
)

Input waveform shape: torch.Size([1, 32000])
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 25776])
Padding log_mel from 108 to 134 frames
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 23570])
Padding log_mel from 99 to 134 frames
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 32000])
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 32000])
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 28126])
Padding log_mel from 118 to 134 frames
feeding logmel shape: torch.Size([1, 1, 60, 134])
Embedding shape: torch.Size([1, 192])
Input waveform shape: torch.Size([1, 32000])
feeding logmel shape: torch.Size([1, 1, 6

## compare to baseline

* test embedding compare of voice in the currnet model with baseline model:

In [7]:
print(f"Similarity embde0: {myUtils.cosine_similarity(base_line_embedding['embed0'], torch_embedding['embed0'])}")
print(f"Similarity embde1: {myUtils.cosine_similarity(base_line_embedding['embed1'], torch_embedding['embed1'])}")
print(f"Similarity embde2: {myUtils.cosine_similarity(base_line_embedding['embed2'], torch_embedding['embed2'])}")
print(f"Similarity embde3: {myUtils.cosine_similarity(base_line_embedding['embed3'], torch_embedding['embed3'])}")
print(f"Similarity embde4: {myUtils.cosine_similarity(base_line_embedding['embed4'], torch_embedding['embed4'])}")
print(f"Similarity embde5: {myUtils.cosine_similarity(base_line_embedding['embed5'], torch_embedding['embed5'])}")
print(f"Similarity embde6: {myUtils.cosine_similarity(base_line_embedding['embed6'], torch_embedding['embed6'])}")

Similarity embde0: 1.0
Similarity embde1: 0.9607586860656738
Similarity embde2: 0.9563478231430054
Similarity embde3: 0.9997662901878357
Similarity embde4: 0.9997462034225464
Similarity embde5: 0.9846920967102051
Similarity embde6: 0.9999998807907104


# ONNX SIDE

In [8]:
myUtils.export_to_onnx(model_no_mel,onnx_path = "ReDimNet_no_mel.onnx")
!ls -lah ReDimNet_no_mel.onnx

Exported to ReDimNet_no_mel.onnx
-rw-rw-r-- 1 vlad vlad 4.1M Jun 18 18:32 ReDimNet_no_mel.onnx


### store half

In [9]:
# fp16_net = copy.deepcopy(model_no_mel).half().eval()
# fp16_dummy = dummy_input = torch.randn(1, 1, 60, 134).half()

# #  fixed-length segments 
# torch.onnx.export(
#    fp16_net,
#    fp16_dummy,
#    "ReDimNet_no_mel_half.onnx",
#    input_names=["log_mel"],
#    output_names=["embedding"],
#    opset_version=13
# )

PyTorch's .half() is unreliable for full model precision control in ONNX. Instead:

In [10]:
myUtils.restore_in_half_precision('ReDimNet_no_mel.onnx','ReDimNet_no_mel_fp16.onnx')

Converted ReDimNet_no_mel.onnx to half precision and saved as ReDimNet_no_mel_fp16.onnx




## verify

In [11]:
onnx_path = "ReDimNet_no_mel.onnx"
# onnx_model_path = "ReDimNet_no_mel_fp16.onnx"

In [12]:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")

ONNX model is valid!


In [13]:
def inference_onnx(wav_path):
    """
    Loads an audio file, converts to log-mel, and runs inference
    in an ONNX session. Returns the embedding as a NumPy array.
    """
    print("===================================================")
    print("===========   run_inference_onnx   ================")
    print("===================================================")
    #######################################
    # 1) Load your ONNX model
    #######################################
    # (Optional) onnx.checker to confirm it’s valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Loaded and checked ONNX model from: {onnx_path}")

    # Create an inference session
    session = ort.InferenceSession(onnx_path)

    # Usually we retrieve the first input & output name
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    #######################################
    # 2) Load audio, get log-mel
    #######################################
    print("loading audio from:", wav_path)
    waveform, sample_rate = torchaudio.load(wav_path)
    print(f"...Waveform rate {sample_rate}  ; shape : {waveform.shape}")

    
    # If multi-channel, downmix:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
        
    # Resample if needed
    target_sample_rate=16000
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
        # save resampled waveform to files with suffix "_resampled_16.wav"
        # torchaudio.save(wav_path.replace(".wav", "_resampled_16.wav"), waveform, target_sample_rate)

    log_mel = myUtils.waveform_to_logmel(waveform)

    #######################################
    # 3) ONNX Inference
    #######################################
    # Convert to NumPy for ONNX runtime
    log_mel_np = log_mel.cpu().numpy()
    
    ## save log_mel_np to file with suffix "_logmel.npy" to check later
    print("logmelshape : ", log_mel_np.shape)
    log_mel_fp16 = log_mel_np.astype(np.float16)  # → half precision
    np.save(wav_path.replace(".wav", "_logmel.npy"), log_mel_fp16 )
    
    # Run inference
    outputs = session.run([output_name], {input_name: log_mel_np})
    # outputs is a list; typically we want the first item
    embedding = outputs[0]  # shape is [1, embedding_dim]

    # print("Embedding[10]: ", embedding[0:10])  # Print the 10th element of the embedding
    print("Embedding shape:", embedding.shape)
    # print("Embedding data:\n", embedding)
    return embedding


In [14]:
onnx_embedding = test_all_voices(
    extract_speaker_embedding_function = inference_onnx,
    cosine_similarity_function = myUtils.cosine_similarity_numpys
)

Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
loading audio from: /data/proj/voice/redimnet/wrkB0/utils/../audio/test000.wav
...Waveform rate 16000  ; shape : torch.Size([1, 293699])
Input waveform shape: torch.Size([1, 32000])
logmelshape :  (1, 1, 60, 134)
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
loading audio from: /data/proj/voice/redimnet/wrkB0/utils/../audio/testRob1.wav
...Waveform rate 22050  ; shape : torch.Size([1, 35522])
Input waveform shape: torch.Size([1, 25776])
Padding log_mel from 108 to 134 frames
logmelshape :  (1, 1, 60, 134)
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
loading audio from: /data/proj/voice/redimnet/wrkB0/utils/../audio/testRob2.wav
...Waveform rate 22050  ; shape : torch.Size([1, 32482])
Input waveform shape: torch.Size([1, 23570])
Padding log_mel from 99 to 134 frames
logmelshape :  (1, 1, 60, 134)
Embedding shape: (1, 192)
Loaded and checked ONNX model from:

### compare onnx with torch

In [15]:
print(f"Similarity embde0: {myUtils.cosine_similarity_numpys(torch_embedding['embed0'], onnx_embedding['embed0'])}")
print(f"Similarity embde1: {myUtils.cosine_similarity_numpys(torch_embedding['embed1'], onnx_embedding['embed1'])}")
print(f"Similarity embde2: {myUtils.cosine_similarity_numpys(torch_embedding['embed2'], onnx_embedding['embed2'])}")
print(f"Similarity embde3: {myUtils.cosine_similarity_numpys(torch_embedding['embed3'], onnx_embedding['embed3'])}")
print(f"Similarity embde4: {myUtils.cosine_similarity_numpys(torch_embedding['embed4'], onnx_embedding['embed4'])}")
print(f"Similarity embde5: {myUtils.cosine_similarity_numpys(torch_embedding['embed5'], onnx_embedding['embed5'])}")
print(f"Similarity embde6: {myUtils.cosine_similarity_numpys(torch_embedding['embed6'], onnx_embedding['embed6'])}")

Similarity embde0: 1.0000001192092896
Similarity embde1: 0.9999999403953552
Similarity embde2: 1.0
Similarity embde3: 1.0
Similarity embde4: 1.0
Similarity embde5: 1.0
Similarity embde6: 1.0


### compare onnx with base line

In [16]:
print(f"Similarity embde0: {myUtils.cosine_similarity_numpys(base_line_embedding['embed0'], onnx_embedding['embed0'])}")
print(f"Similarity embde1: {myUtils.cosine_similarity_numpys(base_line_embedding['embed1'], onnx_embedding['embed1'])}")
print(f"Similarity embde2: {myUtils.cosine_similarity_numpys(base_line_embedding['embed2'], onnx_embedding['embed2'])}")
print(f"Similarity embde3: {myUtils.cosine_similarity_numpys(base_line_embedding['embed3'], onnx_embedding['embed3'])}")
print(f"Similarity embde4: {myUtils.cosine_similarity_numpys(base_line_embedding['embed4'], onnx_embedding['embed4'])}")
print(f"Similarity embde5: {myUtils.cosine_similarity_numpys(base_line_embedding['embed5'], onnx_embedding['embed5'])}")
print(f"Similarity embde6: {myUtils.cosine_similarity_numpys(base_line_embedding['embed6'], onnx_embedding['embed6'])}")

Similarity embde0: 1.0000001192092896
Similarity embde1: 0.9607586860656738
Similarity embde2: 0.9563480615615845
Similarity embde3: 0.9997662305831909
Similarity embde4: 0.9997462034225464
Similarity embde5: 0.9846920371055603
Similarity embde6: 0.9999998211860657
