# small model

In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T

import torch.nn as nn
import torch.nn.functional as F

import numpy as np


2.6.0+cu124


In [2]:
# model_name='B0'
# train_type='ft_lm'
# dataset='vox2'

model_name='b2' # ~b2
train_type='ptn'
dataset='vox2'

torch.hub.set_dir('/data/proj/voice/redimnet/models')

original_model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

Using cache found in /data/proj/voice/redimnet/models/IDRnD_ReDimNet_master


/data/proj/voice/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


In [3]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 72, 134]              --
│    └─Sequential: 2-1                                       [1, 72, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 72, 134]              --
├─ReDimNet: 1-2                                              [1, 1152, 134]            --
│    └─Sequential: 2-2                                       [1, 1152, 134]            --
│    │    └─Conv2d: 3-4                                      [1, 16, 72, 134]          160
│    │    └─LayerNorm: 3-5                                   [1, 16, 72, 134]          32
│   

we can see MelSpectrogram inside the model ; lets take it outside the model;


In [4]:
for name, module in original_model.named_children():
    print(name, "=>", module)

backbone => ReDimNet(
  (stem): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LayerNorm(C=(16,), data_format=channels_first, eps=1e-06)
    (2): to1d()
  )
  (stage0): Sequential(
    (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
    (1): to2d(f=72,c=16)
    (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
    (3): ConvBlock2d(
      (conv_block): ConvNeXtLikeBlock(
        (dwconvs): ModuleList(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=4)
        )
        (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): GELU(approximate='none')
        (pwconv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (4): ConvBlock2d(
      (conv_block): ConvNeXtLikeBlock(
        (dwconvs): ModuleList(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=4)
        )
        (norm): BatchNorm2d(16, eps=1

In [5]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)
model_no_mel.eval()



ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(16,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=72,c=16)
      (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ConvNeXtLikeBlock(
          (dwconvs): ModuleList(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=4)
          )
          (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): GELU(approximate='none')
          (pwconv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (4): ConvBlock2d(
        (conv_block): ConvNeXtLikeBlock(
          (dwconvs): ModuleList(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=

In [6]:
mel = original_model.spec(torch.randn(1, 32000))  # input: fake waveform
print(mel.shape)

torch.Size([1, 72, 134])


In [7]:
example_input = torch.randn(1, 1, 72, 134)
summary(model_no_mel, input_data=example_input)

Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetNoMel                                                [1, 192]                  --
├─ReDimNet: 1-1                                              [1, 1152, 134]            --
│    └─Sequential: 2-1                                       [1, 1152, 134]            --
│    │    └─Conv2d: 3-1                                      [1, 16, 72, 134]          160
│    │    └─LayerNorm: 3-2                                   [1, 16, 72, 134]          32
│    │    └─to1d: 3-3                                        [1, 1152, 134]            --
│    └─Sequential: 2-2                                       [1, 1152, 134]            --
│    │    └─weigth1d: 3-4                                    [1, 1152, 134]            (1)
│    │    └─to2d: 3-5                                        [1, 16, 72, 134]          --
│    │    └─Conv2d: 3-6                                      [1, 16, 72, 134]          272
│ 

## Utility Function for WAV -> MelSpectrogram

In [8]:

def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate: int = 16000,
    n_fft: int = 512,
    hop_length: int = 160,
    n_mels: int = 72,       ## 72 for vox2 ;  60 for B0
    f_min: float = 20.0,
    f_max: float = 8000.0,
    preemphasis_alpha: float = 0.97
):
    """
    Reproduces the main logic of 'NormalizeAudio', 'PreEmphasis',
    and 'MelSpectrogram' from the 'MelBanks' layer.
    """

    # 1) NormalizeAudio
    waveform = waveform / (waveform.abs().max() + 1e-8)

    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    # fix first sample
    waveform_preemph[:, 0] = waveform[:, 0]

    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)  # shape: [channel=1, n_mels, time_frames]

    # Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    return log_mel  # shape: [1, n_mels, time_frames]

In [9]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sr = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # (b) Convert to log-mel
    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Now log_mel shape is [1, n_mels, time_frames].

    # (c) Model expects a batch, so add batch dim => [B=1, 1, n_mels, time_frames]
    log_mel = log_mel.unsqueeze(0)
    print("log_mel shape:", log_mel.shape)


    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [10]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()

def cosine_similarity_numpys(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """
    Compute cosine similarity between two vectors of shape (D,) or (1, D).
    """
    # If shape is (1, D), flatten to (D,)
    v1 = emb1.flatten()
    v2 = emb2.flatten()

    # dot product
    dot = np.dot(v1, v2)
    # norms
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)

    # Add a small epsilon in case of very small norms
    sim = dot / (norm1 * norm2 + 1e-8)
    return sim


In [11]:
embed1 = example_inference("test00.wav")
embed2 = example_inference("test01.wav")
embed3 = example_inference("test2.wav")

print(f"Similarity: {cosine_similarity(embed1, embed2)}")
print(f"Similarity: {cosine_similarity(embed1, embed3)}")


log_mel shape: torch.Size([1, 1, 72, 219])
Embedding shape: torch.Size([1, 192])
log_mel shape: torch.Size([1, 1, 72, 200])
Embedding shape: torch.Size([1, 192])
log_mel shape: torch.Size([1, 1, 72, 5057])




Embedding shape: torch.Size([1, 192])
Similarity: 0.868979275226593
Similarity: 0.49476510286331177


## store

In [12]:
import onnx

def export_to_onnx(model, onnx_path="ReDimNet_no_mel.onnx"):
    model.eval()
    
    # Create a dummy input with shape matching [B=1, 1, n_mels=72, time_frames=200] (example)
    dummy_input = torch.randn(1, 1, 72, 200)
    
    ##  fixed-length segments 
    # torch.onnx.export(
    #     model,
    #     dummy_input,
    #     onnx_path,
    #     input_names=["log_mel"],
    #     output_names=["embedding"],
    #     opset_version=11
    # )
    # print("Exported to", onnx_path)
    
    ## Export with Dynamic Axis
    torch.onnx.export(
        model_no_mel,                 # your model
        dummy_input,                  # e.g. shape [1, 1, 60, 200]
        "ReDimNet_no_mel.onnx",
        input_names=["log_mel"],
        output_names=["embedding"],
        opset_version=19,
        dynamic_axes={
            "log_mel": {0: "batch_size", 3: "time_frames"},
            "embedding": {0: "batch_size"}
        }
    )
    print("Exported to", onnx_path)

# Example usage
export_to_onnx(model_no_mel)

Exported to ReDimNet_no_mel.onnx


In [13]:
!ls -lah ReDimNet_no_mel.onnx

-rw-rw-r-- 1 vlad vlad 20M May 30 17:30 ReDimNet_no_mel.onnx


## verify

In [14]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!


In [15]:
def waveform_to_logmel(
    waveform: torch.Tensor,
    sample_rate=16000,
    n_fft=512,
    hop_length=160,
    n_mels=72,         # match whatever your model expects
    f_min=20.0,
    f_max=8000.0,
    preemphasis_alpha=0.97
):
    # 1) Normalize
    waveform = waveform / (waveform.abs().max() + 1e-8)
    # 2) PreEmphasis
    shifted = torch.roll(waveform, shifts=1, dims=1)
    waveform_preemph = waveform - preemphasis_alpha * shifted
    waveform_preemph[:, 0] = waveform[:, 0]
    # 3) MelSpectrogram
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        power=2.0,
        center=False
    )
    mel_spec = mel_transform(waveform_preemph)
    # 4) Log scale
    log_mel = torch.log(mel_spec + 1e-6)
    return log_mel  # shape: [1, n_mels, frames]

In [16]:
import onnxruntime as ort

def run_inference_onnx(onnx_path, wav_path):
    """
    Loads an audio file, converts to log-mel, and runs inference
    in an ONNX session. Returns the embedding as a NumPy array.
    """
    #######################################
    # 1) Load your ONNX model
    #######################################
    # (Optional) onnx.checker to confirm it’s valid
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"Loaded and checked ONNX model from: {onnx_path}")

    # Create an inference session
    session = ort.InferenceSession(onnx_path)

    # Usually we retrieve the first input & output name
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    #######################################
    # 2) Load audio, get log-mel
    #######################################
    waveform, sr = torchaudio.load(wav_path)
    # If multi-channel, downmix:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    log_mel = waveform_to_logmel(waveform, sample_rate=sr)
    # Insert a batch dimension => shape [B, 1, n_mels, frames]
    log_mel = log_mel.unsqueeze(0)  # => [1, 1, n_mels, time_frames]

    #######################################
    # 3) ONNX Inference
    #######################################
    # Convert to NumPy for ONNX runtime
    log_mel_np = log_mel.cpu().numpy()
    # Run inference
    outputs = session.run([output_name], {input_name: log_mel_np})
    # outputs is a list; typically we want the first item
    embedding = outputs[0]  # shape is [1, embedding_dim]

    print("Embedding shape:", embedding.shape)
    # print("Embedding data:\n", embedding)
    return embedding

In [17]:
onnx_model_path = "ReDimNet_no_mel.onnx"

embed1 = run_inference_onnx(onnx_model_path, "test00.wav")
embed2 = run_inference_onnx(onnx_model_path, "test01.wav")
embed3 = run_inference_onnx(onnx_model_path, "test2.wav")

print(f"Similarity: {cosine_similarity_numpys(embed1, embed2)}")
print(f"Similarity: {cosine_similarity_numpys(embed1, embed3)}")


Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Embedding shape: (1, 192)
Loaded and checked ONNX model from: ReDimNet_no_mel.onnx
Embedding shape: (1, 192)
Similarity: 0.8689795136451721
Similarity: 0.49476543068885803


## cal data


* fake cal data ; TBD

In [18]:
import numpy as np
for i in range(100):
    dummy = np.random.rand(1, 1, 60, 134).astype('float32')
    np.save(f'cal/sample_{i}.npy', dummy)

with open('cal/dataset.txt', 'w') as f:
    for i in range(100):
        f.write(f'sample_{i}.npy\n')