In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T
import torch
import torch.nn.functional as F

2.6.0+cu124


## test files
 * espeak "Hello, this is a test." -w test00.wav
 * espeak "this is the same voice" -w test01.wav

 * test2.wav - something random from the internet;

In [2]:


model_name='B0'
train_type='ft_lm'
dataset='vox2'

torch.hub.set_dir('/data/proj/voice/redimnet/models')

model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

model.eval()

Using cache found in /data/proj/voice/redimnet/models/IDRnD_ReDimNet_master


/data/proj/voice/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


ReDimNetWrap(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_sta

In [3]:
from torchinfo import summary
summary(model, input_size=(1, 32000))

  with torch.cuda.amp.autocast(enabled=False):


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

In [4]:
def extract_speaker_embedding(wav_path, target_sample_rate=16000, target_length=32000):
    """
    Extracts a speaker embedding from a given WAV file using the ReDimNet model.
    
    Parameters:
    - model: The ReDimNet model
    - wav_path: Path to the WAV file
    - target_sample_rate: Sample rate the model expects (default: 16kHz)
    - target_length: Number of samples the model expects (default: 32000 = 2 sec @ 16kHz)
    
    Returns:
    - speaker_embedding: The extracted speaker embedding as a PyTorch tensor
    """
    # Load audio file
    waveform, sample_rate = torchaudio.load(wav_path)
    
    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    # Ensure the waveform has exactly `target_length` samples
    if waveform.shape[1] < target_length:
        # Pad with zeros if too short
        pad_size = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_size))
        print(f"Padding waveform to {target_length} samples.")
    else:
        # Trim if too long
        waveform = waveform[:, :target_length]
        print(f"Trimming waveform to {target_length} samples.")
    
    # Ensure correct shape (batch_size, num_samples)
    print(f"waveform Sample Shape: {waveform.shape} ; type : {type(waveform)}")
    
    # Extract speaker embedding
    with torch.no_grad():
        speaker_embedding = model(waveform)
        
    print(f"Speaker Embedding Shape: {speaker_embedding.shape} ; type : {type(speaker_embedding)}")  # Expected: (1, embedding_dim)
    
    return speaker_embedding


In [5]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()


In [6]:

embed1 = extract_speaker_embedding(wav_path='testRob1.wav')
embed2 = extract_speaker_embedding(wav_path='testRob2.wav')
embed3 = extract_speaker_embedding(wav_path='testme1.wav')
embed4 = extract_speaker_embedding(wav_path="testme2.wav")


Padding waveform to 32000 samples.
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>
Padding waveform to 32000 samples.
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>
Padding waveform to 32000 samples.
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>
Trimming waveform to 32000 samples.
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>


In [7]:
print(f"Similarity (robot to robot): {cosine_similarity(embed1, embed2)}")
print(f"Similarity (robot to me   ): {cosine_similarity(embed1, embed3)}")
print(f"Similarity (me 1 to me 2  ): {cosine_similarity(embed3, embed4)}")

Similarity (robot to robot): 0.765476644039154
Similarity (robot to me   ): -0.04124457761645317
Similarity (me 1 to me 2  ): 0.30778083205223083
