In [1]:
import torch
print(torch.__version__)

import torchaudio
import torchaudio.transforms as T
import torch
import torch.nn.functional as F

2.6.0+cu124


## test files
 * espeak "Hello, this is a test." -w test00.wav
 * espeak "this is the same voice" -w test01.wav

 * test2.wav - something random from the internet;

In [2]:


model_name='M' # ~b3-b4 size
train_type='ft_mix'
dataset='vb2+vox2+cnc'

torch.hub.set_dir('/data/deep/redimnet/models')

model = torch.hub.load('IDRnD/ReDimNet', 'ReDimNet', 
                       model_name=model_name, 
                       train_type=train_type, 
                       dataset=dataset)

model.eval()

Using cache found in /data/deep/redimnet/models/IDRnD_ReDimNet_master


/data/deep/redimnet/models/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


ReDimNetWrap(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(24,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=72,c=24)
      (2): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
          (conv1pw): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
          (conv2pw): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [None]:
from torchinfo import summary
summary(model, input_size=(1, 32000))

In [3]:
def extract_speaker_embedding(wav_path, target_sample_rate=16000, target_length=32000):
    """
    Extracts a speaker embedding from a given WAV file using the ReDimNet model.
    
    Parameters:
    - model: The ReDimNet model
    - wav_path: Path to the WAV file
    - target_sample_rate: Sample rate the model expects (default: 16kHz)
    - target_length: Number of samples the model expects (default: 32000 = 2 sec @ 16kHz)
    
    Returns:
    - speaker_embedding: The extracted speaker embedding as a PyTorch tensor
    """
    # Load audio file
    waveform, sample_rate = torchaudio.load(wav_path)
    
    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # Resample if needed
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    # Ensure the waveform has exactly `target_length` samples
    if waveform.shape[1] < target_length:
        # Pad with zeros if too short
        pad_size = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_size))
    else:
        # Trim if too long
        waveform = waveform[:, :target_length]
    
    # Ensure correct shape (batch_size, num_samples)
    print(f"waveform Sample Shape: {waveform.shape} ; type : {type(waveform)}")
    
    # Extract speaker embedding
    with torch.no_grad():
        speaker_embedding = model(waveform)
        
    print(f"Speaker Embedding Shape: {speaker_embedding.shape} ; type : {type(speaker_embedding)}")  # Expected: (1, embedding_dim)
    
    return speaker_embedding


In [4]:
# Compute similarity between two embeddings
def cosine_similarity(embedding1, embedding2):
    return F.cosine_similarity(embedding1, embedding2).item()


In [5]:

embed1 = extract_speaker_embedding(wav_path='test00.wav')
embed2 = extract_speaker_embedding(wav_path='test01.wav')
embed3 = extract_speaker_embedding(wav_path='test2.wav')
    

waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>


  with torch.cuda.amp.autocast(enabled=False):


Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>
waveform Sample Shape: torch.Size([1, 32000]) ; type : <class 'torch.Tensor'>
Speaker Embedding Shape: torch.Size([1, 192]) ; type : <class 'torch.Tensor'>


In [6]:
print(f"Similarity: {cosine_similarity(embed1, embed2)}")
print(f"Similarity: {cosine_similarity(embed1, embed3)}")

Similarity: 0.8862158060073853
Similarity: 0.17124216258525848
