# ReDimNet REf Vs NoMel

In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import copy

2.6.0+cu124


# orig Model

In [2]:
model_name='B0'
# train_type='ft_lm'
train_type='ptn'
dataset='vox2'

torch.hub.set_dir('/data/proj/voice/redimnet/models')

model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)
model.eval()

Using cache found in /data/proj/voice/redimnet/models/IDRnD_ReDimNet_master


/data/proj/voice/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


ReDimNetWrap(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_sta

In [3]:
from torchinfo import summary
summary(model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

## WAVE PRE

In [4]:
def extract_speaker_embedding(wav_path, target_sample_rate=16000, target_length=32000):
    """
    Extracts a speaker embedding from a given WAV file using the ReDimNet model.
    
    Parameters:
    - model: The ReDimNet model
    - wav_path: Path to the WAV file
    - target_sample_rate: Sample rate the model expects (default: 16kHz)
    - target_length: Number of samples the model expects (default: 32000 = 2 sec @ 16kHz)
    
    Returns:
    - speaker_embedding: The extracted speaker embedding as a PyTorch tensor
    """
    # Load audio file
    waveform, sample_rate = torchaudio.load(wav_path)
    
    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    # Ensure the waveform has exactly `target_length` samples
    if waveform.shape[1] < target_length:
        # Pad with zeros if too short
        pad_size = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_size))
        print(f"Padding waveform to {target_length} samples.")
    else:
        # Trim if too long
        waveform = waveform[:, :target_length]
        print(f"Trimming waveform to {target_length} samples.")
    
    # Ensure correct shape (batch_size, num_samples)
    print(f"waveform Sample Shape: {waveform.shape} ; type : {type(waveform)}")
    
    # Extract speaker embedding
    with torch.no_grad():
        speaker_embedding = model(waveform)
        
    print(f"Speaker Embedding Shape: {speaker_embedding.shape} ; type : {type(speaker_embedding)}")  # Expected: (1, embedding_dim)
    
    return speaker_embedding


In [5]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()


In [6]:
embed_orig = extract_speaker_embedding(wav_path='test000.wav')




Trimming waveform to 32000 samples.
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>


# NOMEL

In [7]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(model)
model_no_mel.eval()



ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

### inspect spec

In [8]:
with torch.no_grad():
    wav, sr = torchaudio.load('test000.wav')
    if sr != 16000:
        wav = T.Resample(sr, 16000)(wav)
    emb_orig  = model(wav)
    
    mel_ref = model.spec(wav)           # <-- COMPLETE pipeline
    print("wav shape:", wav.shape)
    print("mel_ref shape:", mel_ref.shape)  
    print("emb_orig shape:", emb_orig.shape) 
    
    print("===============================")
    mel_ref_for_nomel = mel_ref.unsqueeze(1)
    emb_nomel = model_no_mel(mel_ref_for_nomel)
    print("mel for no mel shape:", mel_ref_for_nomel.shape)  # should be [1, 1, n_mels, time_frames]
    print("emb_nomel shape:", emb_nomel.shape)  # should be same as emb_orig
    
    print(F.cosine_similarity(emb_nomel, emb_orig))   # should be 0.999-ish

wav shape: torch.Size([1, 293699])
mel_ref shape: torch.Size([1, 60, 1224])
emb_orig shape: torch.Size([1, 192])
mel for no mel shape: torch.Size([1, 1, 60, 1224])
emb_nomel shape: torch.Size([1, 192])
tensor([1.0000])


In [9]:
# ------------------------------------------------------------------
#  ReDimNet front-end settings (taken from the IDRnD repo defaults)
#    • 16 kHz audio
#    • pre-emphasis α = 0.97
#    • 25 ms window  (400 samples)
#    • 15 ms hop     (240 samples)  ➜ 134 frames for a 2-s clip
#    • 60 Mel bins, 20 Hz → 8 kHz
# ------------------------------------------------------------------
_PREEMPH  = 0.97
_SR       = 16_000
_N_FFT    = 512
_WIN_LEN  = 400
_HOP      = 240
_N_MELS   = 60
_F_MIN    = 20.0
_F_MAX    = 7600.0
_EPS      = 1e-6            # numerical stability


# Singleton MelSpectrogram so we build the kernel only once
_mel_layer = T.MelSpectrogram(
    sample_rate=_SR,
    n_fft=_N_FFT,
    win_length=_WIN_LEN,
    hop_length=_HOP,
    f_min=_F_MIN,
    f_max=_F_MAX,
    n_mels=_N_MELS,
    power=2.0,               # the original uses power-spec → log10 later
    center=True,
    pad_mode="reflect",
    window_fn=torch.hamming_window
)

def _pre_emphasis(wave: torch.Tensor, alpha: float = _PREEMPH) -> torch.Tensor:
    """y[n] = x[n] − α·x[n−1] (first sample unchanged)."""
    y = wave.clone()
    y[:, 1:] = y[:, 1:] - alpha * y[:, :-1]
    return y


@torch.no_grad()
def redimnet_logmel(wave: torch.Tensor) -> torch.Tensor:
    """
    Parameters
    ----------
    wave : Tensor [B', T] | [1, T]
        16-kHz mono waveform already trimmed / padded (32 000 samples for 2 s).

    Returns
    -------
    log_mel : Tensor [B', 1, 60, frames]
        Bit-exact front-end output expected by `model_no_mel`.
    """
    # Make sure we always have a batch dimension
    if wave.dim() == 1:      # (T,) → (1, T)
        wave = wave.unsqueeze(0)
    elif wave.dim() == 2 and wave.shape[0] > 1:
        raise ValueError("Input must be mono; got multi-channel tensor.")

    # ➊ pre-emphasis
    wave = _pre_emphasis(wave.float())

    # ➋ Mel power-spectrogram
    mel = _mel_layer(wave)
    mel = torch.log(mel + 1e-6)          # → [B, 60, frames]

    # ➌ log-scale (natural or log10 both work – log10 matches repo)
    mel = mel - mel.mean(dim=-1, keepdim=True)

    # ➍ add the dummy channel dim expected by Conv2d stem
    mel = mel.unsqueeze(1)                # → [B, 1, 60, frames]

    return mel

In [10]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sample_rate = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    target_sample_rate = 16000  # Force to 16kHz as per model requirements
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    # (b) Convert to log-mel
    log_mel = redimnet_logmel(waveform)
    print('feeding logmel shape:', log_mel.shape)

    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [11]:
embed_nomel = example_inference("test000.wav")

feeding logmel shape: torch.Size([1, 1, 60, 1224])
Embedding shape: torch.Size([1, 192])


# compare

In [12]:
print(f"Similarity (robot to robot): {cosine_similarity(embed_nomel, embed_orig)}")

Similarity (robot to robot): 0.5167194604873657


In [13]:
with torch.no_grad():
    wav, sr = torchaudio.load('testRob1.wav')
    if sr != 16000:
        wav = T.Resample(sr, 16000)(wav)

    feats  = redimnet_logmel(wav)      # [B, 1, 60, 134·N]
    emb1   = model_no_mel(feats)

    emb2   = model(wav)                      # original pipeline

print('cos-sim:', F.cosine_similarity(emb1, emb2).item())

cos-sim: 1.0


In [14]:
def embed_nomel_from_file(fname: str) -> torch.Tensor:
    wav, sr = torchaudio.load(fname)
    if wav.shape[0] > 1:                # stereo → mono
        wav = wav.mean(0, keepdim=True)
    if sr != 16_000:
        wav = T.Resample(sr, 16_000)(wav)
    lm = redimnet_logmel(wav)
    with torch.no_grad():
        return model_no_mel(lm)

# --- sanity check against the original model --------------------------
model.eval();  model_no_mel.eval()

wav_raw, sr = torchaudio.load("testRob1.wav")
if sr != 16_000:
    wav_raw = T.Resample(sr, 16_000)(wav_raw)

emb_orig  = model(wav_raw)                 # full pipeline
emb_nomel = embed_nomel_from_file("testRob1.wav")

print("cos-sim:", F.cosine_similarity(emb_nomel, emb_orig).item())
# → cos-sim: 0.9995  (≈ identity)


cos-sim: 1.0


In [15]:
mel_ref = model.spec(wav)           # [B, 60, F]
mel_new = redimnet_logmel(wav).squeeze(1)
print("max |Δ| =", (mel_ref - mel_new).abs().max())


max |Δ| = tensor(0.)
