# ReDimNet REf Vs NoMel

In [None]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import copy

# orig Model

In [None]:
model_name='B0'
# train_type='ft_lm'
train_type='ptn'
dataset='vox2'

torch.hub.set_dir('/data/proj/voice/redimnet/models')

model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)
model.eval()

In [None]:
from torchinfo import summary
summary(model, input_size=(1, 32000))

## WAVE PRE

In [None]:
def extract_speaker_embedding(wav_path, target_sample_rate=16000, target_length=32000):
    """
    Extracts a speaker embedding from a given WAV file using the ReDimNet model.
    
    Parameters:
    - model: The ReDimNet model
    - wav_path: Path to the WAV file
    - target_sample_rate: Sample rate the model expects (default: 16kHz)
    - target_length: Number of samples the model expects (default: 32000 = 2 sec @ 16kHz)
    
    Returns:
    - speaker_embedding: The extracted speaker embedding as a PyTorch tensor
    """
    # Load audio file
    waveform, sample_rate = torchaudio.load(wav_path)
    
    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    # Ensure the waveform has exactly `target_length` samples
    if waveform.shape[1] < target_length:
        # Pad with zeros if too short
        pad_size = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_size))
        print(f"Padding waveform to {target_length} samples.")
    else:
        # Trim if too long
        waveform = waveform[:, :target_length]
        print(f"Trimming waveform to {target_length} samples.")
    
    # Ensure correct shape (batch_size, num_samples)
    print(f"waveform Sample Shape: {waveform.shape} ; type : {type(waveform)}")
    
    # Extract speaker embedding
    with torch.no_grad():
        speaker_embedding = model(waveform)
        
    print(f"Speaker Embedding Shape: {speaker_embedding.shape} ; type : {type(speaker_embedding)}")  # Expected: (1, embedding_dim)
    
    return speaker_embedding


In [None]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()


In [None]:
embed_orig = extract_speaker_embedding(wav_path='test000.wav')




# NOMEL

In [None]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(model)
model_no_mel.eval()



In [None]:
??? put here redimnet_logmel code ???

In [None]:
def example_inference(wav_path: str):
    # (a) Load audio
    waveform, sample_rate = torchaudio.load(wav_path)  # shape: [channels, time]
    # If stereo, select one channel, or average:
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    target_sample_rate = 16000  # Force to 16kHz as per model requirements
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    # (b) Convert to log-mel
    log_mel = redimnet_logmel(waveform)
    print('feeding logmel shape:', log_mel.shape)

    # (d) Forward pass
    with torch.no_grad():
        embedding = model_no_mel(log_mel)  # shape typically [1, 192] or so

    print("Embedding shape:", embedding.shape)
    #print("Embedding:", embedding)
    return embedding

In [None]:
embed_nomel = example_inference("test000.wav")

# compare

In [None]:
print(f"Similarity (robot to robot): {cosine_similarity(embed_nomel, embed_orig)}")

In [None]:
with torch.no_grad():
    # ➊ let the original model compute its own f-bank
    wav, _ = torchaudio.load("test000.wav")
    mel_ref = model.spec.torchfbank(wav)          # -> [B, 60, 134]

    # ➋ pass *exactly the same tensor* to the wrapper
    emb_from_nomel = model_no_mel(mel_ref.unsqueeze(1))   # add channel dim

    # ➌ compare
    emb_from_orig  = model(wav)
    print(F.cosine_similarity(emb_from_nomel, emb_from_orig))