In [1]:
from conf.hydra_config import (
    TrainingUnitEncoderConfig_STEP1,
)
import logging
import torch


cfg = TrainingUnitEncoderConfig_STEP1
device = torch.device("cuda" if torch.cuda.is_available() and cfg.train.on_GPU else "cpu")

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# Normalization of Mel-Spectrogram

# Speaker embeddings

In [16]:
from unitspeech.speaker_encoder.ecapa_tdnn import ECAPA_TDNN


spkr_embedder = ECAPA_TDNN(feat_dim=cfg.spkr_embedder.feat_dim,
                        channels=cfg.spkr_embedder.channels,
                        emb_dim=cfg.spkr_embedder.spk_emb_dim,
                        feat_type=cfg.spkr_embedder.feat_type,
                        sr=cfg.spkr_embedder.sr,
                        feature_selection=cfg.spkr_embedder.feature_selection,
                        update_extract=cfg.spkr_embedder.update_extract,
                        config_path=cfg.spkr_embedder.config_path).to(device).eval()

state_dict = torch.load(cfg.spkr_embedder.checkpoint,
                        map_location=lambda loc, storage: loc)
spkr_embedder.load_state_dict(state_dict["model"], strict=False)

Using cache found in /home/astanea/.cache/torch/hub/s3prl_s3prl_main
  from .autonotebook import tqdm as notebook_tqdm
2024-04-17 21:07:29 | INFO | s3prl.util.download | Requesting URL: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
2024-04-17 21:07:29 | INFO | s3prl.util.download | Using URL's local file: /home/astanea/.cache/s3prl/download/f2d5200177fd6a33b278b7b76b454f25cd8ee866d55c122e69fccf6c7467d37d.wavlm_large.pt
2024-04-17 21:07:36 | INFO | s3prl.upstream.wavlm.WavLM | WavLM Config: {'extractor_mode': 'layer_norm', 'encoder_layers': 24, 'encoder_embed_dim': 1024, 'encoder_ffn_embed_dim': 4096, 'encoder_attention_heads': 16, 'activation_fn': 'gelu', 'layer_norm_first': True, 'conv_feature_layers': '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2', 'conv_bias': False, 'feature_grad_mult': 1.0, 'normalize': True, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.0, 'dropout_input': 0.0, 'dropout_features': 0.0, '

_IncompatibleKeys(missing_keys=[], unexpected_keys=['loss_calculator.projection.weight'])

In [17]:
from unitspeech.util import parse_filelist
import os
import librosa
import torchaudio
from functools import lru_cache


@lru_cache(maxsize=None)
def get_resampler(orig_sr, target_sr):
    return torchaudio.transforms.Resample(orig_sr, target_sr).cuda()

@lru_cache(maxsize=None)
def load_and_process_wav(filepath, device):
    wav, sr = librosa.load(filepath)
    wav = torch.FloatTensor(wav).to(device)
    if sr != 16_000:
        resample_fn = get_resampler(sr, 16000)
        wav = resample_fn(wav)
    return wav

def save_mean_emb(crnt_spkr_mean, crnt_speaker, all_mean_embs, dataset_name):
    all_mean_embs[int(crnt_speaker)] = crnt_spkr_mean.unsqueeze(1)
    os.makedirs(f"resources/{dataset_name}/speaker_embs", exist_ok=True)
    torch.save(crnt_spkr_mean, f"resources/{dataset_name}/speaker_embs/{crnt_speaker}.pt")

In [20]:
def get_mean_spkr_embs(filelist_path, dataset_name):
    print(f"Loading filelist from {filelist_path}")
    filelist = parse_filelist(filelist_path, split_char='|')

    global all_mean_embs
    crnt_speaker = -1
    num_samples = 0

    crnt_spkr_mean = torch.zeros(cfg.spkr_embedder.spk_emb_dim).unsqueeze(0).to(device)
    is_first_sample = True

    for idx, line in enumerate(filelist, start=1):
        filepath, text, spk_id = line[0], line[1], line[2]
        if idx % 10 == 0 or idx == 0:
            print(f"Processing line ({idx}|{len(filelist)})")

        if crnt_speaker == -1:
            print(f"First speaker was: {spk_id}")
            crnt_speaker = spk_id

        if spk_id != crnt_speaker: # New speaker detected
            print(f"Number of sample: {num_samples}")
            save_mean_emb(crnt_spkr_mean, crnt_speaker, all_mean_embs, dataset_name=dataset_name)
            # Reset for new speaker ID
            crnt_spkr_mean = torch.zeros(cfg.spkr_embedder.spk_emb_dim).unsqueeze(0).to(device)
            is_first_sample = True
            num_samples = 0
            crnt_speaker = spk_id

            print(f"Speaker change to {spk_id} | current line = {idx}")

        wav = load_and_process_wav(filepath, device)
        emb = spkr_embedder(wav.unsqueeze(0))
        if is_first_sample:
            crnt_spkr_mean = emb
            is_first_sample = False
        else:
            crnt_spkr_mean = torch.mean(torch.stack([crnt_spkr_mean, emb]), dim=0)
        num_samples += 1
    return all_mean_embs

## Libri-TTS

- From all the speakers in the Libri-TTS dataset, we extract the speaker embeddings using the pre-trained speaker encoder model.
- Each speaker embedding is a 256-dimensional vector. The vector contains the mean value of the embeddings of all the utterances of the speaker.
- Speaker ID's range from 1 to 256

In [None]:
from conf.hydra_config import LibriTTSConfig as dataset_cfg

all_mean_embs = {}
all_mean_embs = get_mean_spkr_embs(dataset_cfg.train_filelist_path,
                   dataset_cfg.name)

In [21]:
torch.save(all_mean_embs, f"resources/{dataset_cfg.name}/speaker_embs/speaker_embs.pt")

## LJ-Speech

In [None]:
from conf.hydra_config import LJSPeechConfig as dataset_cfg

all_mean_embs = {}
all_mean_embs = get_mean_spkr_embs(dataset_cfg.train_filelist_path,
                   dataset_cfg.name)

In [None]:
torch.save(all_mean_embs, f"resources/{dataset_cfg.name}/speaker_embs/speaker_embs.pt")