This is a noteboook used to generate the speaker embeddings with the  GE2E model.

In [1]:
import sys 
sys.path.insert(0, "../")

In [2]:
from utils.audio_processor import WrapperAudioProcessor as AudioProcessor 
from utils.generic_utils import load_config
import librosa
import os
import numpy as np
import torch
from glob import glob
from tqdm import tqdm

Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit


In [3]:
#Download encoder Checkpoint
#!wget https://github.com/Edresson/GE2E-Speaker-Encoder/releases/download/checkpoints/checkpoint-voicefilter-seungwonpark.pt -O embedder.pt

In [4]:
# speaker_encoder parameters 
num_mels = 40
n_fft = 512
emb_dim = 256
lstm_hidden = 768
lstm_layers = 3
window = 80
stride = 40

checkpoint_dir = "embedder.pt"

In [5]:
import torch
import torch.nn as nn

class LinearNorm(nn.Module):
    def __init__(self, lstm_hidden, emb_dim):
        super(LinearNorm, self).__init__()
        self.linear_layer = nn.Linear(lstm_hidden, emb_dim)

    def forward(self, x):
        return self.linear_layer(x)


class SpeakerEncoder(nn.Module):
    def __init__(self, num_mels, lstm_layers, lstm_hidden, window, stride):
        super(SpeakerEncoder, self).__init__()
        self.lstm = nn.LSTM(num_mels, lstm_hidden,
                            num_layers=lstm_layers,
                            batch_first=True)
        self.proj = LinearNorm(lstm_hidden, emb_dim)
        self.num_mels = num_mels
        self.lstm_layers = lstm_layers
        self.lstm_hidden = lstm_hidden
        self.window = window
        self.stride = stride

    def forward(self, mel):
        # (num_mels, T)
        mels = mel.unfold(1, self.window, self.stride) # (num_mels, T', window)
        mels = mels.permute(1, 2, 0) # (T', window, num_mels)
        x, _ = self.lstm(mels) # (T', window, lstm_hidden)
        x = x[:, -1, :] # (T', lstm_hidden), use last frame only
        x = self.proj(x) # (T', emb_dim)
        x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim)
        x = x.sum(0) / x.size(0) # (emb_dim), average pooling over time frames
        return x


In [6]:
embedder = SpeakerEncoder(num_mels, lstm_layers, lstm_hidden, window, stride).cuda()
chkpt_embed = torch.load(checkpoint_dir)
embedder.load_state_dict(chkpt_embed)
embedder.eval()

SpeakerEncoder(
  (lstm): LSTM(40, 768, num_layers=3, batch_first=True)
  (proj): LinearNorm(
    (linear_layer): Linear(in_features=768, out_features=256, bias=True)
  )
)

In [7]:
# Set constants
DATA_ROOT_PATH = '../../../LibriSpeech/voicefilter-open-fiel-ao-paper-data/'
TRAIN_DATA = os.path.join(DATA_ROOT_PATH, 'train')
TEST_DATA = os.path.join(DATA_ROOT_PATH, 'test')
glob_re_wav_emb = '*-ref_emb.wav'
glob_re_emb = '*-emb.pt'

In [8]:
# load ap compativel with speaker encoder
config = {"backend":"voicefilter", "mel_spec": False,  "audio_len": 3, 
          "voicefilter":{"n_fft": 1200,"num_mels":40,"num_freq": 601,"sample_rate": 16000,"hop_length": 160,
                         "win_length": 400,"min_level_db": -100.0, "ref_level_db": 20.0, "preemphasis": 0.97,
                         "power": 1.5, "griffin_lim_iters": 60}}
ap = AudioProcessor(config)

In [9]:
os.listdir(TEST_DATA)

['005549-target.wav',
 '004185-target.pt',
 '000492-ref_emb.wav',
 '004125-mixed.wav',
 '001272-mixed.wav',
 '004410-ref_emb.wav',
 '004563-target.pt',
 '005316-target.pt',
 '002196-mixed.pt',
 '005060-target.wav',
 '001233-target.wav',
 '005520-mixed.pt',
 '002855-mixed.pt',
 '005006-mixed.pt',
 '001233-target.pt',
 '005390-mixed.wav',
 '002673-target.pt',
 '004728-ref_emb.wav',
 '004270-mixed.wav',
 '002485-mixed.wav',
 '003695-mixed.wav',
 '002860-ref_emb.wav',
 '003389-mixed.pt',
 '000135-target.pt',
 '002448-target.wav',
 '000128-ref_emb.wav',
 '000328-target.pt',
 '002688-target.pt',
 '001498-target.wav',
 '004841-target.wav',
 '001116-target.wav',
 '002646-target.wav',
 '002876-mixed.wav',
 '001722-ref_emb.wav',
 '002453-ref_emb.wav',
 '000378-ref_emb.wav',
 '004537-ref_emb.wav',
 '000542-mixed.pt',
 '001553-mixed.pt',
 '004414-mixed.pt',
 '002534-mixed.pt',
 '001932-mixed.pt',
 '002633-target.pt',
 '003422-ref_emb.wav',
 '001864-target.pt',
 '000666-ref_emb.wav',
 '001519-mixed

In [12]:
#Preprocess dataset
train_files = sorted(glob(os.path.join(TRAIN_DATA, glob_re_wav_emb)))
test_files = sorted(glob(os.path.join(TEST_DATA, glob_re_wav_emb)))

if len(train_files) == 0 or len(test_files):
    print("check train and test path files not in directory")
files  = train_files+test_files

for i in tqdm(range(len(files))):
    try:
        wave_file_path = files[i]
        wav_file_name = os.path.basename(wave_file_path)
        # Extract Embedding
        with open(wave_file_path, 'r') as f:
            LB_wave_file_path = f.readline().strip()
        emb_wav, _ = librosa.load(os.path.join('../',LB_wave_file_path), sr=16000)
        mel = torch.from_numpy(ap.get_mel(emb_wav)).cuda()
        #print(mel.shape)
        file_embedding = embedder(mel).cpu().detach().numpy()
    except:
        # if is not possible extract embedding because wav lenght is very small
        file_embedding = np.array([0]) # its make a error in training
        print("Embedding reference is very sort")
    output_name = wave_file_path.replace(glob_re_wav_emb.replace('*',''),'')+glob_re_emb.replace('*','')
    torch.save(torch.from_numpy(file_embedding.reshape(-1)), output_name)

  0%|          | 2/259819 [00:00<4:37:07, 15.63it/s]

check train and test path files not in directory


  2%|▏         | 4812/259819 [04:44<4:11:11, 16.92it/s]


KeyboardInterrupt: 

In [14]:
#file_embedding

array([-9.30385757e-03,  5.41789979e-02,  9.06330049e-02,  6.22695163e-02,
        2.20813379e-02,  8.56699888e-04, -2.18634997e-02, -5.71655575e-03,
        2.30765045e-02,  6.84375130e-03,  5.21903634e-02,  2.06298940e-02,
       -1.68322830e-03,  2.84168012e-02, -4.99293301e-03, -2.33015195e-02,
        2.07366832e-02, -4.94350716e-02,  6.57151267e-02,  3.51618789e-02,
       -4.54175659e-02, -2.13632826e-02, -2.04372802e-03, -3.49387787e-02,
       -5.92319779e-02,  5.97215481e-02,  3.63193341e-02, -1.33465361e-02,
       -4.25689854e-03,  8.81210715e-03, -9.37813297e-02,  2.60864999e-02,
       -7.97466747e-03, -2.80721323e-03, -6.61637560e-02, -6.85963454e-03,
        1.76966004e-02, -2.56214738e-02, -9.70787834e-03, -4.84729782e-02,
        7.53615517e-03,  1.79857463e-02, -1.18213892e-01, -4.07752581e-02,
       -9.56033915e-02, -1.29147740e-02,  6.45934939e-02, -1.45624071e-01,
       -6.98695146e-03, -6.02206700e-02,  3.54489870e-02, -4.85779420e-02,
       -1.34030664e-02, -