# small model

In [1]:
import torch
print(torch.__version__)
from torchsummary import summary


import torchaudio
import torchaudio.transforms as T

import torch.nn as nn
import torch.nn.functional as F

import numpy as np


2.6.0+cu124


In [2]:
model_name='B0'
train_type='ft_lm'
dataset='vox2'

torch.hub.set_dir('/data/deep/redimnet/models')

original_model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

Using cache found in /data/deep/redimnet/models/IDRnD_ReDimNet_master


/data/deep/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


*  ReDimNetWrap expects raw 16 kHz mono audio, exactly 32 000 samples

In [3]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

we can see MelSpectrogram inside the model ; lets take it outside the model;


In [4]:
for name, module in original_model.named_children():
    print(name, "=>", module)

backbone => ReDimNet(
  (stem): Sequential(
    (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
    (2): to1d()
  )
  (stage0): Sequential(
    (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
    (1): to2d(f=60,c=10)
    (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
    (3): ConvBlock2d(
      (conv_block): ResBasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      

# create new model

In [5]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        
        ## ???? TESTTT
        # self.backbone.stage0[6] = nn.Identity()
        # self.backbone.stage1[8] = nn.Identity()
        self.backbone.stage2[8] = nn.Identity()
        self.backbone.stage3[9] = nn.Identity()
        self.backbone.stage4[7] = nn.Identity()
        
        
        # Replace ASTP with RKNN-safe version:
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        print("Backbone output shape:", x.shape)  # ADD THIS LINE
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)



In [6]:
model_no_mel.eval()  # <- this line is critical!
dummy = torch.randn(1, 1, 60, 200)
model_no_mel(dummy)

Backbone output shape: torch.Size([1, 600, 200])


tensor([[ 4.3170e-01, -3.8658e+00, -5.2461e-01, -2.5823e+00,  2.6725e+00,
         -1.0786e+00, -3.7065e-01,  1.8964e+00,  8.7669e-02,  4.0358e+00,
          4.7946e+00, -2.2716e+00, -2.1222e+00, -6.3212e+00, -5.5915e+00,
          1.5326e+00, -4.5874e+00, -4.8371e-01,  1.3838e+00, -2.5561e+00,
          1.4633e+00,  2.0870e+00,  4.1246e+00,  4.5188e+00, -1.4006e-01,
         -2.1214e-01,  2.1208e+00, -2.3666e+00,  9.8855e-01, -2.6717e+00,
         -3.7827e+00, -4.5494e-01, -2.1038e+00,  1.7126e+00,  6.2170e+00,
          2.7851e+00, -1.2936e+00, -1.7676e+00,  4.1724e+00, -1.9700e+00,
         -7.7161e-01,  5.9109e+00,  1.2310e+00, -5.6595e+00,  8.3462e+00,
          6.9083e-02, -4.8812e+00,  1.0612e-01,  2.9877e+00,  2.8394e+00,
         -2.6745e+00, -3.8817e+00, -5.7622e-01,  7.0919e+00, -7.7118e+00,
         -4.2552e+00, -6.1277e+00, -1.8548e+00, -2.3931e+00,  1.9095e+00,
         -2.5748e+00, -3.9921e+00,  2.7086e+00,  1.7395e+00,  1.3864e+00,
          2.2903e+00,  3.4666e+00,  4.

## layres debug

In [7]:
for name, module in model_no_mel.named_modules():
    if isinstance(module, nn.LayerNorm):
        print("❌ Still has LayerNorm at:", name)

❌ Still has LayerNorm at: backbone.stage0.6.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage0.6.tcm.4.final_layer_norm
❌ Still has LayerNorm at: backbone.stage1.8.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage1.8.tcm.4.final_layer_norm


In [8]:
print("stage0.6 =", model_no_mel.backbone.stage0[6])

stage0.6 = TimeContextBlock1d(
  (red_dim_conv): Sequential(
    (0): Conv1d(600, 20, kernel_size=(1,), stride=(1,))
    (1): LayerNorm(C=(20,), data_format=channels_first, eps=1e-06)
  )
  (tcm): Sequential(
    (0): ConvNeXtLikeBlock(
      (dwconvs): ModuleList(
        (0): Conv1d(20, 20, kernel_size=(7,), stride=(1,), padding=same, groups=20)
      )
      (norm): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): GELU(approximate='none')
      (pwconv1): Conv1d(20, 20, kernel_size=(1,), stride=(1,))
    )
    (1): ConvNeXtLikeBlock(
      (dwconvs): ModuleList(
        (0): Conv1d(20, 20, kernel_size=(19,), stride=(1,), padding=same, groups=20)
      )
      (norm): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): GELU(approximate='none')
      (pwconv1): Conv1d(20, 20, kernel_size=(1,), stride=(1,))
    )
    (2): ConvNeXtLikeBlock(
      (dwconvs): ModuleList(
        (0): Conv1d(20, 20, kerne

## info

In [9]:
model_no_mel.eval()


ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [10]:
summary(model_no_mel, (1, 1, 60, 200))


Backbone output shape: torch.Size([1, 600, 200])


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetNoMel                                                [1, 192]                  --
├─ReDimNet: 1-1                                              [1, 600, 200]             --
│    └─Sequential: 2-1                                       [1, 600, 200]             --
│    │    └─Conv2d: 3-1                                      [1, 10, 60, 200]          100
│    │    └─LayerNorm: 3-2                                   [1, 10, 60, 200]          20
│    │    └─to1d: 3-3                                        [1, 600, 200]             --
│    └─Sequential: 2-2                                       [1, 600, 200]             --
│    │    └─weigth1d: 3-4                                    [1, 600, 200]             (1)
│    │    └─to2d: 3-5                                        [1, 10, 60, 200]          --
│    │    └─Conv2d: 3-6                                      [1, 10, 60, 200]          110
│ 

## Utility Function for WAV -> MelSpectrogram

In [11]:

def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate: int = 16000,
    n_fft: int = 512,
    hop_length: int = 160,
    n_mels: int = 60,       ## 72 for vox2 ;  60 for B0
    f_min: float = 20.0,
    f_max: float = 8000.0,
    preemphasis_alpha: float = 0.97
):
    """
    Reproduces the main logic of 'NormalizeAudio', 'PreEmphasis',
    and 'MelSpectrogram' from the 'MelBanks' layer.
    """

    # 1) NormalizeAudio
    waveform = waveform / (waveform.abs().max() + 1e-8)

    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    # fix first sample
    waveform_preemph[:, 0] = waveform[:, 0]

    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)  # shape: [channel=1, n_mels, time_frames]

    # Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    return log_mel  # shape: [1, n_mels, time_frames]

In [12]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sr = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # (b) Convert to log-mel
    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Now log_mel shape is [1, n_mels, time_frames].

    # (c) Model expects a batch, so add batch dim => [B=1, 1, n_mels, time_frames]
    log_mel = log_mel.unsqueeze(0)

    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [13]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()

def cosine_similarity_numpys(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """
    Compute cosine similarity between two vectors of shape (D,) or (1, D).
    """
    # If shape is (1, D), flatten to (D,)
    v1 = emb1.flatten()
    v2 = emb2.flatten()

    # dot product
    dot = np.dot(v1, v2)
    # norms
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)

    # Add a small epsilon in case of very small norms
    sim = dot / (norm1 * norm2 + 1e-8)
    return sim


In [14]:
embed1 = example_inference("testRob1.wav")
embed2 = example_inference("testRob2.wav")
embed3 = example_inference("testme1.wav")
embed4 = example_inference("testme2.wav")


Backbone output shape: torch.Size([1, 600, 219])
Embedding shape: torch.Size([1, 192])
Backbone output shape: torch.Size([1, 600, 200])
Embedding shape: torch.Size([1, 192])




Backbone output shape: torch.Size([1, 600, 525])
Embedding shape: torch.Size([1, 192])
Backbone output shape: torch.Size([1, 600, 993])
Embedding shape: torch.Size([1, 192])


In [15]:
print(f"Similarity (robot to robot): {cosine_similarity_numpys(embed1, embed2)}")
print(f"Similarity (robot to me   ): {cosine_similarity_numpys(embed1, embed3)}")
print(f"Similarity (me 1 to me 2  ): {cosine_similarity_numpys(embed3, embed4)}")

Similarity (robot to robot): 0.9700368642807007
Similarity (robot to me   ): 0.7867892980575562
Similarity (me 1 to me 2  ): 0.9767839908599854


  dot = np.dot(v1, v2)


## store

In [16]:
import onnx

def export_to_onnx(model, onnx_path="ReDimNet_no_mel.onnx"):
    model.eval()
    
    # Create a dummy input with shape matching [B=1, 1, n_mels=72, time_frames=200] (example)
    dummy_input = torch.randn(1, 1, 60, 200)
    
    #  fixed-length segments 
    torch.onnx.export(
        model_no_mel,
        dummy_input,
        onnx_path,
        input_names=["log_mel"],
        output_names=["embedding"],
        opset_version=11
    )
    print("Exported to", onnx_path)
    
    # ## Export with Dynamic Axis
    # torch.onnx.export(
    #     model_no_mel,                 
    #     dummy_input,                 
    #     onnx_path,
    #     input_names=["log_mel"],
    #     output_names=["embedding"],
    #     opset_version=19,
    #     dynamic_axes={
    #         "log_mel": {0: "batch_size", 3: "time_frames"},
    #         "embedding": {0: "batch_size"}
    #     }
    # )
    # print("Exported to", onnx_path)

# Example usage
export_to_onnx(model_no_mel)

Backbone output shape: torch.Size([1, 600, 200])
Exported to ReDimNet_no_mel.onnx


In [17]:
!ls -lah ReDimNet_no_mel.onnx

-rw-rw-r-- 1 vlad vlad 3.0M Jun 11 08:26 ReDimNet_no_mel.onnx


## verify

In [18]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!


In [19]:
print(type(model_no_mel))
print(isinstance(model_no_mel, nn.Module))

<class '__main__.ReDimNetNoMel'>
True


In [20]:
type(onnx_model)

onnx.onnx_ml_pb2.ModelProto

In [21]:
def pad_or_crop_logmel(log_mel, target_frames=200):
    """
    Ensures log_mel is shaped [1, n_mels, target_frames] by:
    - Padding with zeros on the right if too short
    - Center-cropping if too long
    """
    B, M, T = log_mel.shape
    if T < target_frames:
        pad_amt = target_frames - T
        log_mel = F.pad(log_mel, (0, pad_amt))  # pad at end
        print(f"Padding log_mel from {T} to {target_frames} frames")
    elif T > target_frames:
        start = (T - target_frames) // 2
        log_mel = log_mel[:, :, start:start + target_frames]
        print(f"Cropping log_mel from {T} to {target_frames} frames")
    return log_mel


def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate=16000,
    n_fft=512,
    hop_length=160,
    n_mels=60,         # match whatever your model expects
    f_min=20.0,
    f_max=8000.0,
    preemphasis_alpha=0.97,
    target_frames=200
):
    """
    Converts waveform -> normalized, preemphasized -> log-mel spectrogram
    Returns shape: [1, 1, n_mels, target_frames]
    """
    # 1) Normalize
    waveform = waveform / (waveform.abs().max() + 1e-8)
    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    waveform_preemph[:, 0] = waveform[:, 0]
    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)
    # 4) Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    
    # 5) Pad/crop to fixed number of frames
    log_mel = pad_or_crop_logmel(log_mel, target_frames=target_frames)

    print("Log-mel shape:", log_mel.shape)  # Expect [1, 1, 60, 200]
    return log_mel

In [22]:
import onnxruntime as ort

def run_inference_onnx(onnx_path, wav_path):
    """
    Loads an audio file, converts to log-mel, and runs inference
    in an ONNX session. Returns the embedding as a NumPy array.
    """
    #######################################
    # 1) Load your ONNX model
    #######################################
    # (Optional) onnx.checker to confirm it’s valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Loaded and checked ONNX model from: {onnx_path}")

    # Create an inference session
    session = ort.InferenceSession(onnx_path)

    # Usually we retrieve the first input & output name
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    #######################################
    # 2) Load audio, get log-mel
    #######################################
    waveform, sr = torchaudio.load(wav_path)
    # If multi-channel, downmix:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Insert a batch dimension => shape [B, 1, n_mels, frames]
    log_mel = log_mel.unsqueeze(0)  # => [1, 1, n_mels, time_frames]

    #######################################
    # 3) ONNX Inference
    #######################################
    # Convert to NumPy for ONNX runtime
    log_mel_np = log_mel.cpu().numpy()
    # Run inference
    outputs = session.run([output_name], {input_name: log_mel_np})
    # outputs is a list; typically we want the first item
    embedding = outputs[0]  # shape is [1, embedding_dim]

    print("Embedding shape:", embedding.shape)
    # print("Embedding data:\n", embedding)
    return embedding

In [23]:
onnx_model_path = "ReDimNet_no_mel.onnx"

embed1 = run_inference_onnx(onnx_model_path, "testRob1.wav")
embed2 = run_inference_onnx(onnx_model_path, "testRob2.wav")
embed3 = run_inference_onnx(onnx_model_path, "testme1.wav")
embed4 = run_inference_onnx(onnx_model_path, "testme2.wav")






Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Cropping log_mel from 219 to 200 frames
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Cropping log_mel from 525 to 200 frames
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Cropping log_mel from 993 to 200 frames
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)


In [24]:
print(f"Similarity (robot to robot): {cosine_similarity_numpys(embed1, embed2)}")
print(f"Similarity (robot to me   ): {cosine_similarity_numpys(embed1, embed3)}")
print(f"Similarity (me 1 to me 2  ): {cosine_similarity_numpys(embed3, embed4)}")

Similarity (robot to robot): 0.9725348949432373
Similarity (robot to me   ): 0.7354221343994141
Similarity (me 1 to me 2  ): 0.9561303853988647


## cal fake data

* run in rknn docker:

``` bash

import os
import numpy as np
import torch

# Directory for calibration inputs
os.makedirs("calib_npy", exist_ok=True)

# Create 100 dummy log-mel tensors
for i in range(10):
    log_mel = torch.randn(1, 1, 60, 200).numpy().astype(np.float32)
    np.save(f"calib_npy/sample_{i}.npy", log_mel)

# Write dataset.txt listing all paths
with open("dataset.txt", "w") as f:
    for i in range(10):
        f.write(f"calib_npy/sample_{i}.npy\n")


````

## converts

* python convert.py ReDimNet_no_mel.onnx  rk3588 fp ReDimNet_no_mel.rknn 
* python convert.py ReDimNet_no_mel.onnx  rv1106 i8  ReDimNet_no_mel.rknn