# small model

In [1]:
import torch
print(torch.__version__)
from torchsummary import summary


import torchaudio
import torchaudio.transforms as T

import torch.nn as nn
import torch.nn.functional as F

import numpy as np


2.6.0+cu124


In [2]:
model_name='B0'
train_type='ft_lm'
dataset='vox2'

torch.hub.set_dir('/data/deep/redimnet/models')

original_model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

Using cache found in /data/deep/redimnet/models/IDRnD_ReDimNet_master


/data/deep/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


*  ReDimNetWrap expects raw 16 kHz mono audio, exactly 32 000 samples

In [3]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

we can see MelSpectrogram inside the model ; lets take it outside the model;


In [4]:
for name, module in original_model.named_children():
    print(name, "=>", module)

backbone => ReDimNet(
  (stem): Sequential(
    (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
    (2): to1d()
  )
  (stage0): Sequential(
    (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
    (1): to2d(f=60,c=10)
    (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
    (3): ConvBlock2d(
      (conv_block): ResBasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
        (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
        (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      

# create new model

In [5]:
class MyLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        # LayerNorm across last dimension by default
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta


def replace_layernorm(module):
    for name, child in module.named_children():
        if isinstance(child, nn.LayerNorm):
            print(f"Replacing {name} with MyLayerNorm")
            # Replace with ONNX-friendly version using the same shape
            setattr(module, name, MyLayerNorm(child.normalized_shape, child.eps))
        else:
            replace_layernorm(child)  # Recurse into children

In [6]:
class ASTP(nn.Module):
    """
    Safe version of Adaptive Temporal Pooling for RKNN.
    Replaces problematic ReduceSum, LayerNorm, or std-based normalization.

    Input: [B, C, T] — e.g. after Conv1d from [B, C, F, T] → flatten → linear
    Output: [B, final_dim]
    """
    def __init__(self, in_channels=1800, hidden_channels=128, out_channels=600):
        super().__init__()
        self.linear1 = nn.Conv1d(in_channels, hidden_channels, kernel_size=1)
        self.linear2 = nn.Conv1d(hidden_channels, out_channels, kernel_size=1)

    # def forward(self, x):
    #     # Input: [B, C, T]
    #     x = self.linear1(x)           # [B, H, T]
    #     x = F.relu(x)                 # or GELU if supported
    #     x = F.adaptive_avg_pool1d(x, 1)  # [B, H, 1]
    #     x = self.linear2(x)           # [B, C, 1]
    #     x = x.view(x.size(0), -1)     # [B, C]
    #     return x
    
    def forward(self, x):
        print("  ASTP input:", x.shape)
        x = self.linear1(x)
        print("  after linear1:", x.shape)
        x = F.relu(x)
        x = F.adaptive_avg_pool1d(x, 1)
        print("  after avgpool:", x.shape)
        x = self.linear2(x)
        print("  after linear2:", x.shape)
        x = x.view(x.size(0), -1)
        print("  final output:", x.shape)
        return x

In [7]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        
        # Replace all LayerNorms inside the backbone
        replace_layernorm(original_wrap.backbone)
        
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        
        ## ???? TESTTT
        self.backbone.stage0[6] = nn.Identity()
        self.backbone.stage1[8] = nn.Identity()
        self.backbone.stage2[8] = nn.Identity()
        self.backbone.stage3[9] = nn.Identity()
        self.backbone.stage4[7] = nn.Identity()
        
        
        # Replace ASTP with RKNN-safe version:
        # self.pool = original_wrap.pool
        self.pool = ASTP(
            in_channels=600,    # use your actual input channels to ASTP
            hidden_channels=128,
            out_channels=600
        )
        
        # self.bn = original_wrap.bn
        self.bn = nn.BatchNorm1d(600)
        
        # self.linear = original_wrap.linear
        self.linear = nn.Linear(600, 192)

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        print("Backbone output shape:", x.shape)  # ADD THIS LINE
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)



Replacing layer_norm with MyLayerNorm
Replacing final_layer_norm with MyLayerNorm
Replacing layer_norm with MyLayerNorm
Replacing final_layer_norm with MyLayerNorm
Replacing layer_norm with MyLayerNorm
Replacing final_layer_norm with MyLayerNorm
Replacing layer_norm with MyLayerNorm
Replacing final_layer_norm with MyLayerNorm
Replacing layer_norm with MyLayerNorm
Replacing final_layer_norm with MyLayerNorm


In [8]:
model_no_mel.eval()  # <- this line is critical!
dummy = torch.randn(1, 1, 60, 200)
model_no_mel(dummy)

Backbone output shape: torch.Size([1, 600, 200])
  ASTP input: torch.Size([1, 600, 200])
  after linear1: torch.Size([1, 128, 200])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])


tensor([[ 1.1409e-02, -4.9671e-02,  8.9055e-03,  5.2458e-02,  1.4925e-02,
          4.4221e-02, -3.7378e-02, -4.1301e-02,  1.2994e-02, -1.3779e-03,
         -5.0147e-02, -9.6620e-02, -2.3057e-02,  3.9165e-02,  1.5018e-02,
         -4.9623e-02, -4.9700e-02, -1.5956e-02, -3.1548e-02, -1.9897e-02,
          9.6088e-03,  3.4397e-02,  1.2584e-02, -3.0539e-03,  8.3607e-02,
         -4.3442e-02, -3.6923e-02, -5.3017e-02,  3.2926e-02, -3.1123e-03,
          3.4564e-02, -3.4400e-02,  7.0450e-03,  1.7432e-03,  3.0956e-02,
          7.8028e-02,  1.3909e-03,  6.5870e-02, -1.0717e-02, -1.9195e-02,
          5.9924e-03,  8.0474e-03,  1.5320e-02,  4.1094e-02, -3.1212e-02,
         -7.5346e-02,  9.9045e-03, -1.1411e-01,  1.6414e-02,  2.7966e-03,
         -6.5407e-02,  2.5251e-03, -2.3489e-02, -2.3065e-02,  4.2556e-02,
         -6.2216e-02, -5.2939e-02,  4.1639e-02, -2.0117e-02,  5.5216e-02,
         -2.0438e-02, -4.6595e-02, -3.4434e-02,  2.8327e-02,  3.0313e-02,
          1.6845e-02,  7.2600e-02,  6.

In [9]:
for name, module in model_no_mel.named_modules():
    if isinstance(module, nn.LayerNorm):
        print("❌ Still has LayerNorm at:", name)

## info

In [10]:
print("stage4.7 =", model_no_mel.backbone.stage4[7])

stage4.7 = Identity()


In [11]:
model_no_mel.eval()


ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [12]:
summary(model_no_mel, (1, 1, 60, 200))


Backbone output shape: torch.Size([1, 600, 200])
  ASTP input: torch.Size([1, 600, 200])
  after linear1: torch.Size([1, 128, 200])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])


Layer (type:depth-idx)                        Output Shape              Param #
ReDimNetNoMel                                 [1, 192]                  --
├─ReDimNet: 1-1                               [1, 600, 200]             --
│    └─Sequential: 2-1                        [1, 600, 200]             --
│    │    └─Conv2d: 3-1                       [1, 10, 60, 200]          100
│    │    └─LayerNorm: 3-2                    [1, 10, 60, 200]          20
│    │    └─to1d: 3-3                         [1, 600, 200]             --
│    └─Sequential: 2-2                        [1, 600, 200]             --
│    │    └─weigth1d: 3-4                     [1, 600, 200]             (1)
│    │    └─to2d: 3-5                         [1, 10, 60, 200]          --
│    │    └─Conv2d: 3-6                       [1, 10, 60, 200]          110
│    │    └─ConvBlock2d: 3-7                  [1, 10, 60, 200]          440
│    │    └─ConvBlock2d: 3-8                  [1, 10, 60, 200]          440
│    │    └─to1

## Utility Function for WAV -> MelSpectrogram

In [13]:

def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate: int = 16000,
    n_fft: int = 512,
    hop_length: int = 160,
    n_mels: int = 60,       ## 72 for vox2 ;  60 for B0
    f_min: float = 20.0,
    f_max: float = 8000.0,
    preemphasis_alpha: float = 0.97
):
    """
    Reproduces the main logic of 'NormalizeAudio', 'PreEmphasis',
    and 'MelSpectrogram' from the 'MelBanks' layer.
    """

    # 1) NormalizeAudio
    waveform = waveform / (waveform.abs().max() + 1e-8)

    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    # fix first sample
    waveform_preemph[:, 0] = waveform[:, 0]

    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)  # shape: [channel=1, n_mels, time_frames]

    # Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    return log_mel  # shape: [1, n_mels, time_frames]

In [14]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sr = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # (b) Convert to log-mel
    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Now log_mel shape is [1, n_mels, time_frames].

    # (c) Model expects a batch, so add batch dim => [B=1, 1, n_mels, time_frames]
    log_mel = log_mel.unsqueeze(0)

    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [15]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()

def cosine_similarity_numpys(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """
    Compute cosine similarity between two vectors of shape (D,) or (1, D).
    """
    # If shape is (1, D), flatten to (D,)
    v1 = emb1.flatten()
    v2 = emb2.flatten()

    # dot product
    dot = np.dot(v1, v2)
    # norms
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)

    # Add a small epsilon in case of very small norms
    sim = dot / (norm1 * norm2 + 1e-8)
    return sim


In [16]:
embed1 = example_inference("test00.wav")
embed2 = example_inference("test01.wav")
embed3 = example_inference("test02.wav")

print(f"Similarity: {cosine_similarity(embed1, embed2)}")
print(f"Similarity: {cosine_similarity(embed1, embed3)}")


Backbone output shape: torch.Size([1, 600, 219])
  ASTP input: torch.Size([1, 600, 219])
  after linear1: torch.Size([1, 128, 219])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])
Embedding shape: torch.Size([1, 192])
Backbone output shape: torch.Size([1, 600, 200])
  ASTP input: torch.Size([1, 600, 200])
  after linear1: torch.Size([1, 128, 200])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])
Embedding shape: torch.Size([1, 192])
Backbone output shape: torch.Size([1, 600, 1833])
  ASTP input: torch.Size([1, 600, 1833])
  after linear1: torch.Size([1, 128, 1833])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])
Embedding shape: torch.Size([1, 192])
Similarity: 0.9917510151863098
Similarity: 0.9893390536308289


## store

In [17]:
import onnx

def export_to_onnx(model, onnx_path="ReDimNet_no_mel.onnx"):
    model.eval()
    
    # Create a dummy input with shape matching [B=1, 1, n_mels=72, time_frames=200] (example)
    dummy_input = torch.randn(1, 1, 60, 200)
    
    #  fixed-length segments 
    torch.onnx.export(
        model_no_mel,
        dummy_input,
        onnx_path,
        input_names=["log_mel"],
        output_names=["embedding"],
        opset_version=11
    )
    print("Exported to", onnx_path)
    
    # ## Export with Dynamic Axis
    # torch.onnx.export(
    #     model_no_mel,                 
    #     dummy_input,                 
    #     onnx_path,
    #     input_names=["log_mel"],
    #     output_names=["embedding"],
    #     opset_version=19,
    #     dynamic_axes={
    #         "log_mel": {0: "batch_size", 3: "time_frames"},
    #         "embedding": {0: "batch_size"}
    #     }
    # )
    # print("Exported to", onnx_path)

# Example usage
export_to_onnx(model_no_mel)

Backbone output shape: torch.Size([1, 600, 200])
  ASTP input: torch.Size([1, 600, 200])
  after linear1: torch.Size([1, 128, 200])
  after avgpool: torch.Size([1, 128, 1])
  after linear2: torch.Size([1, 600, 1])
  final output: torch.Size([1, 600])
Exported to ReDimNet_no_mel.onnx


In [18]:
!ls -lah ReDimNet_no_mel.onnx

-rw-rw-r-- 1 vlad vlad 1.7M Jun 10 13:51 ReDimNet_no_mel.onnx


## verify

In [19]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!


In [20]:
print(type(model_no_mel))
print(isinstance(model_no_mel, nn.Module))

<class '__main__.ReDimNetNoMel'>
True


In [21]:
type(onnx_model)

onnx.onnx_ml_pb2.ModelProto

In [22]:
def pad_or_crop_logmel(log_mel, target_frames=200):
    """
    Ensures log_mel is shaped [1, n_mels, target_frames] by:
    - Padding with zeros on the right if too short
    - Center-cropping if too long
    """
    B, M, T = log_mel.shape
    if T < target_frames:
        pad_amt = target_frames - T
        log_mel = F.pad(log_mel, (0, pad_amt))  # pad at end
    elif T > target_frames:
        start = (T - target_frames) // 2
        log_mel = log_mel[:, :, start:start + target_frames]
    return log_mel


def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate=16000,
    n_fft=512,
    hop_length=160,
    n_mels=60,         # match whatever your model expects
    f_min=20.0,
    f_max=8000.0,
    preemphasis_alpha=0.97,
    target_frames=200
):
    """
    Converts waveform -> normalized, preemphasized -> log-mel spectrogram
    Returns shape: [1, 1, n_mels, target_frames]
    """
    # 1) Normalize
    waveform = waveform / (waveform.abs().max() + 1e-8)
    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    waveform_preemph[:, 0] = waveform[:, 0]
    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)
    # 4) Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    
    # 5) Pad/crop to fixed number of frames
    log_mel = pad_or_crop_logmel(log_mel, target_frames=target_frames)

    print("Log-mel shape:", log_mel.shape)  # Expect [1, 1, 60, 200]
    return log_mel

In [23]:
import onnxruntime as ort

def run_inference_onnx(onnx_path, wav_path):
    """
    Loads an audio file, converts to log-mel, and runs inference
    in an ONNX session. Returns the embedding as a NumPy array.
    """
    #######################################
    # 1) Load your ONNX model
    #######################################
    # (Optional) onnx.checker to confirm it’s valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Loaded and checked ONNX model from: {onnx_path}")

    # Create an inference session
    session = ort.InferenceSession(onnx_path)

    # Usually we retrieve the first input & output name
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    #######################################
    # 2) Load audio, get log-mel
    #######################################
    waveform, sr = torchaudio.load(wav_path)
    # If multi-channel, downmix:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Insert a batch dimension => shape [B, 1, n_mels, frames]
    log_mel = log_mel.unsqueeze(0)  # => [1, 1, n_mels, time_frames]

    #######################################
    # 3) ONNX Inference
    #######################################
    # Convert to NumPy for ONNX runtime
    log_mel_np = log_mel.cpu().numpy()
    # Run inference
    outputs = session.run([output_name], {input_name: log_mel_np})
    # outputs is a list; typically we want the first item
    embedding = outputs[0]  # shape is [1, embedding_dim]

    print("Embedding shape:", embedding.shape)
    # print("Embedding data:\n", embedding)
    return embedding

In [24]:
onnx_model_path = "ReDimNet_no_mel.onnx"

embed1 = run_inference_onnx(onnx_model_path, "test00.wav")
embed2 = run_inference_onnx(onnx_model_path, "test01.wav")
embed3 = run_inference_onnx(onnx_model_path, "testme.wav")

print(f"Similarity: {cosine_similarity_numpys(embed1, embed2)}")
print(f"Similarity: {cosine_similarity_numpys(embed1, embed3)}")


Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Log-mel shape: torch.Size([1, 60, 200])
Embedding shape: (1, 192)
Similarity: 0.9922904372215271
Similarity: 0.9639715552330017


