# Generate speaker embeddings for each one of the 5 speakers

In [1]:
import argparse
import IPython.display as ipd
import json
import librosa
import os

# To prevent the path from becoming corrupted when this cell is executed more than once.
try:
    path
except:
    path = "../"
    os.chdir(path)
    
import phonemizer
import random
from scipy.io.wavfile import write
import torch
import torchaudio
from tqdm import tqdm
from transformers import HubertModel

from unitspeech.unitspeech import UnitSpeech
from unitspeech.duration_predictor import DurationPredictor
from unitspeech.encoder import Encoder
from unitspeech.speaker_encoder.ecapa_tdnn import ECAPA_TDNN_SMALL
from unitspeech.text import cleaned_text_to_sequence, phonemize, symbols
from unitspeech.textlesslib.textless.data.speech_encoder import SpeechEncoder
from unitspeech.util import HParams, fix_len_compatibility, intersperse, process_unit, generate_path, sequence_mask
from unitspeech.vocoder.env import AttrDict
from unitspeech.vocoder.meldataset import mel_spectrogram
from unitspeech.vocoder.models import BigVGAN

from conf.hydra_config import (
    MainConfig,
)
import pandas as pd

import soundfile as sf

from unitspeech.util import (
    fix_len_compatibility,
    save_plot,
    sequence_mask,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = MainConfig
device = torch.device("cuda" if torch.cuda.is_available() and cfg.train.on_GPU else "cpu")

print(f"Running from {os.getcwd()}")
print(f"Device: {device}")

Running from /workspace/local
Device: cuda


In [6]:
column_names = ['path', 'transcript', 'speaker_id']
reference_speech_samples = pd.read_csv('reference_speech_samples.csv', delimiter="|", header=None, names=column_names)
reference_speech_samples
eval_speech_samples = pd.read_csv('evaluation.csv', delimiter="|", header=None, names=column_names)

In [7]:
reference_speech_samples["speaker_id"]

0     1
1    19
2    39
3    20
4    10
Name: speaker_id, dtype: int64

# Load and process reference speech

In [8]:
finetune_config_path = "unitspeech/checkpoints/finetune.json"
with open(finetune_config_path, "r") as f:
    data = f.read()
finetune_config = json.loads(data)


fp16_run = False
learning_rate = 2e-5
# Runtime HYPERPARAMS
num_downsamplings_in_unet = len(cfg.decoder.dim_mults) - 1
out_size = fix_len_compatibility(
    cfg.train.out_size_second * cfg.data.sampling_rate // cfg.data.hop_length, num_downsamplings_in_unet=num_downsamplings_in_unet
)

hps_finetune = HParams(**finetune_config)

segment_size = fix_len_compatibility(
    hps_finetune.train.out_size_second * hps_finetune.data.sampling_rate // hps_finetune.data.hop_length,
    len(hps_finetune.decoder.dim_mults) - 1
)

speaker_encoder_path = "/checkpoints/EVALUATION/speaker_encoder/checkpts/speaker_encoder.pt"

## Data processing modules

In [9]:
# TODO Vocoder - MISSING FILES
print('Initializing Vocoder...')
with open(hps_finetune.train.vocoder_config_path) as f:
    h = AttrDict(json.load(f))
vocoder = BigVGAN(h)
vocoder.load_state_dict(torch.load(hps_finetune.train.vocoder_ckpt_path, map_location=lambda loc, storage: loc)['generator'])
_ = vocoder.cuda().eval()
vocoder.remove_weight_norm()


# Speaker Encoder for extracting speaker embedding
print('Initializing Speaker Encoder...')
spk_embedder = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
state_dict = torch.load(speaker_encoder_path, map_location=lambda storage, loc: storage)
spk_embedder.load_state_dict(state_dict['model'], strict=False)
_ = spk_embedder.cuda().eval()


# Unit Extractor for extraction unit and duration, which are used for finetuning
print('Initializing Unit Extracter...')
unit_extractor = SpeechEncoder.by_name(dense_model_name=cfg.unit_extractor.dense_model_name,
                                        quantizer_model_name=cfg.unit_extractor.quantizer_name,
                                        vocab_size=cfg.unit_extractor.vocab_size,
                                        deduplicate=cfg.unit_extractor.deduplicate,
                                        need_f0=cfg.unit_extractor.need_f0)
_ = unit_extractor.cuda().eval()

Initializing Vocoder...


Removing weight norm...
Initializing Speaker Encoder...


Using cache found in /root/.cache/torch/hub/s3prl_s3prl_main
2024-06-19 10:06:28 | INFO | s3prl.util.download | Requesting URL: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
2024-06-19 10:06:28 | INFO | s3prl.util.download | Using URL's local file: /root/.cache/s3prl/download/f2d5200177fd6a33b278b7b76b454f25cd8ee866d55c122e69fccf6c7467d37d.wavlm_large.pt
2024-06-19 10:06:29 | INFO | s3prl.upstream.wavlm.WavLM | WavLM Config: {'extractor_mode': 'layer_norm', 'encoder_layers': 24, 'encoder_embed_dim': 1024, 'encoder_ffn_embed_dim': 4096, 'encoder_attention_heads': 16, 'activation_fn': 'gelu', 'layer_norm_first': True, 'conv_feature_layers': '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2', 'conv_bias': False, 'feature_grad_mult': 1.0, 'normalize': True, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.0, 'dropout_input': 0.0, 'dropout_features': 0.0, 'mask_length': 10, 'mask_prob': 0.8, 'mask_selection': 'static', 'm

Initializing Unit Extracter...


2024-06-19 10:06:38 | INFO | fairseq.tasks.hubert_pretraining | current directory is /workspace/local
2024-06-19 10:06:38 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': '/checkpoint/annl/s2st/data/voxpopuli/mHuBERT/en_es_fr', 'fine_tuning': False, 'labels': ['km'], 'label_dir': '/checkpoint/wnhsu/experiments/hubert/kmeans/mhubert_vp_en_es_fr_it2_400k/en_es_fr.layer9.km500', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2024-06-19 10:06:38 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout': 0.1, 'atten

## TTS modules

In [10]:
text_encoder = Encoder(
    n_vocab=cfg.encoder.n_vocab,
    n_feats=cfg.data.n_feats,
    n_channels=cfg.encoder.n_channels,
    filter_channels=cfg.encoder.filter_channels,
    n_heads=cfg.encoder.n_heads,
    n_layers=cfg.encoder.n_layers,
    kernel_size=cfg.encoder.kernel_size,
    p_dropout=cfg.encoder.p_dropout,
    window_size=cfg.encoder.window_size,
)
if not os.path.exists(cfg.encoder.checkpoint):
    raise FileNotFoundError(f"Checkpoint for encoder not found: {cfg.encoder.checkpoint}")
text_encoder_dict = torch.load(cfg.encoder.checkpoint, map_location=lambda loc, storage: loc)
text_encoder.load_state_dict(text_encoder_dict["model"])
_ = text_encoder.cuda().eval()  

In [11]:
duration_predictor = DurationPredictor(
    in_channels=cfg.duration_predictor.in_channels,
    filter_channels=cfg.duration_predictor.filter_channels,
    kernel_size=cfg.duration_predictor.kernel_size,
    p_dropout=cfg.duration_predictor.p_dropout,
    spk_emb_dim=cfg.duration_predictor.spk_emb_dim,
)
if not os.path.exists(cfg.duration_predictor.checkpoint):
    raise FileNotFoundError(f"Checkpoint for duration predictor not found: {cfg.duration_predictor.checkpoint}")
duration_predictor_dict = torch.load(cfg.duration_predictor.checkpoint, map_location=lambda loc, storage: loc)
duration_predictor.load_state_dict(duration_predictor_dict["model"])
_ = duration_predictor.cuda().eval()  

In [12]:
unit_encoder_path = "unitspeech/checkpoints/unit_encoder.pt"
unit_encoder = Encoder(n_vocab=cfg.data.n_units,
                    n_feats=cfg.data.n_feats,
                    n_channels=cfg.encoder.n_channels,
                    filter_channels=cfg.encoder.filter_channels,
                    n_heads=cfg.encoder.n_heads,
                    n_layers=cfg.encoder.n_layers,
                    kernel_size=cfg.encoder.kernel_size,
                    p_dropout=cfg.encoder.p_dropout,
                    window_size=cfg.encoder.window_size)
unit_encoder_dict = torch.load(unit_encoder_path, map_location=lambda storage, loc: storage)
unit_encoder.load_state_dict(unit_encoder_dict['model'])
_ = unit_encoder.cuda().eval()

In [13]:
# Normalization parameters for mel spectrogram
decoder_dict = torch.load(cfg.decoder.checkpoint, map_location=lambda loc, storage: loc)

mel_max = decoder_dict['mel_max']
mel_min = decoder_dict['mel_min']

global_phonemizer = phonemizer.backend.EspeakBackend(language='ro',
                                                    preserve_punctuation=True,
                                                    with_stress=True,
                                                    language_switch="remove-flags",
                                                    words_mismatch='ignore',)

In [14]:
# The text gradient scale is responsible for pronunciation and audio quality. 
# The default value is 1, and increasing the value improves pronunciation accuracy but may reduce speaker similarity. 
# We recommend starting with 0 and gradually increasing it if the pronunciation is not satisfactory.
text_gradient_scale = 1.0

# The speaker gradient scale is responsible for speaker similarity. 
# Increasing the value enhances speaker similarity but may slightly degrade pronunciation and audio quality. 
# For unique voices, we recommend using a larger value for the speaker gradient scale.
spk_gradient_scale = 1.0

# We have confirmed that our duration predictor is not accurately following the duration of the reference audio as expected.
# As a result, while the reference audio's tone and speaking style are well adapted, there are differences in speech rate. 
# To address this issue, we use the "length_scale" argument as in Grad-TTS to mitigate the discrepancy.
# If the value of "length_scale" is greater than 1, the speech rate will be slower. 
# Conversely, if the value is less than 1, the speech rate will be faster.
length_scale = 1.0

# The number of diffusion steps during sampling refers to the number of iterations performed to improve audio quality.
# Generally, larger values lead to better audio quality but slower sampling speeds. 
# Conversely, smaller values allow for faster sampling but may result in lower audio quality.
diffusion_steps = 500

In [15]:
def fine_tune(cond_x, y, y_mask, y_lengths, y_max_length, attn, spk_emb, segment_size, n_feats, decoder):
    if y_max_length < segment_size:
        pad_size = segment_size - y_max_length
        y = torch.cat([y, torch.zeros_like(y)[:, :, :pad_size]], dim=-1)
        y_mask = torch.cat([y_mask, torch.zeros_like(y_mask)[:, :, :pad_size]], dim=-1)

    max_offset = (y_lengths - segment_size).clamp(0)
    offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
    out_offset = torch.LongTensor([
        torch.tensor(random.choice(range(start, end)) if end > start else 0)
        for start, end in offset_ranges
    ]).to(y_lengths)

    attn_cut = torch.zeros(attn.shape[0], attn.shape[1], segment_size, dtype=attn.dtype, device=attn.device)
    y_cut = torch.zeros(y.shape[0], n_feats, segment_size, dtype=y.dtype, device=y.device)
    y_cut_lengths = []
    for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
        y_cut_length = segment_size + (y_lengths[i] - segment_size).clamp(None, 0)
        y_cut_lengths.append(y_cut_length)
        cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
        y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
        attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
    y_cut_lengths = torch.LongTensor(y_cut_lengths)
    y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)

    if y_cut_mask.shape[-1] < segment_size:
        y_cut_mask = torch.nn.functional.pad(y_cut_mask, (0, segment_size - y_cut_mask.shape[-1]))

    attn = attn_cut
    y = y_cut
    y_mask = y_cut_mask

    # Align encoded text with mel-spectrogram and get cond_y segment
    cond_y = torch.matmul(attn.squeeze(1).transpose(1, 2).contiguous(), cond_x.transpose(1, 2).contiguous())
    cond_y = cond_y.transpose(1, 2).contiguous()
    cond_y = cond_y * y_mask

    # Compute loss of score-based decoder
    diff_loss, xt = decoder.compute_loss(y, y_mask, cond_y, spk_emb=spk_emb)

    return diff_loss

## Finetune loop for each speaker

In [None]:
for outer_idx, outer_row in reference_speech_samples.iterrows():
    # =====================================================================================
    path = outer_row["path"]
    transcript = outer_row["transcript"]
    speaker_id = outer_row["speaker_id"]
    print(f"ID: {speaker_id}")
    # =====================================================================================
    # FINETUNE
    # Diffusion-based acoutstic model to be finetuned to the current speaker
    unitspeech = UnitSpeech(n_feats=cfg.data.n_feats,
                            dim=cfg.decoder.dim,
                            dim_mults=cfg.decoder.dim_mults,
                            beta_min=cfg.decoder.beta_min,
                            beta_max=cfg.decoder.beta_max,
                            pe_scale=cfg.decoder.pe_scale,
                            spk_emb_dim=cfg.decoder.spk_emb_dim)
    decoder_dict = torch.load(cfg.decoder.checkpoint, map_location=lambda loc, storage: loc)
    unitspeech.load_state_dict(decoder_dict['model'])
    _ = unitspeech.cuda().train()

    optimizer = torch.optim.Adam(params=unitspeech.parameters(), lr=learning_rate)
    if fp16_run:
        scaler = torch.cuda.amp.GradScaler()

    # DATA PROCESSING
    wav, sr = librosa.load(path)
    wav = torch.FloatTensor(wav).unsqueeze(0)
    mel = mel_spectrogram(
        wav,
        hps_finetune.data.n_fft,
        hps_finetune.data.n_feats,
        hps_finetune.data.sampling_rate,
        hps_finetune.data.hop_length,
        hps_finetune.data.win_length,
        hps_finetune.data.mel_fmin,
        hps_finetune.data.mel_fmax,
        center=False,
    )
    mel_max = decoder_dict['mel_max']
    mel_min = decoder_dict['mel_min']
    mel = (mel - mel_min) / (mel_max - mel_min) * 2 - 1 
    mel = mel.cuda()
    # Speaker embedder expects 16KHz audio samples
    resample_fn = torchaudio.transforms.Resample(sr, cfg.spkr_embedder.sr).cuda()
    wav = resample_fn(wav.cuda())
    spk_emb = spk_embedder(wav)
    # User speaker embeddings with norm = 1
    spk_emb = spk_emb / spk_emb.norm()
    # Extract the units and unit durations to be used for fine-tuning.
    encoded = unit_extractor(wav.to("cuda")) # => units with f_unit freq: 16Khz
    # Upsample unit and durations from f_unit to f_mel
    unit, duration = process_unit(encoded, cfg.spkr_embedder.sr, cfg.data.hop_length)
    # Reshape the input to match the dimensions and convert it to a PyTorch tensor.
    unit = unit.unsqueeze(0).cuda()
    duration = duration.unsqueeze(0).cuda()
    mel = mel.cuda()
    unit_lengths = torch.LongTensor([unit.shape[-1]]).cuda()
    mel_lengths = torch.LongTensor([mel.shape[-1]]).cuda()
    spk_emb = spk_emb.cuda().unsqueeze(1)
    # Prepare unit encoder output for finetuning
    with torch.no_grad():
        cond_x, x, x_mask = unit_encoder(unit, unit_lengths)
    mel_max_length = mel.shape[-1]
    mel_mask = sequence_mask(mel_lengths, mel_max_length).unsqueeze(1).to(x_mask)
    attn_mask = x_mask.unsqueeze(-1) * mel_mask.unsqueeze(2)
    attn = generate_path(duration, attn_mask.squeeze(1))

    # Finetune the decoder
    for _ in tqdm(range(2_000)):
        cond_x = cond_x.detach()
        mel = mel.detach()
        mel_mask = mel_mask.detach()
        mel_lengths = mel_lengths.detach()
        spk_emb = spk_emb.detach()
        attn = attn.detach()

        unitspeech.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=fp16_run):
            diff_loss = fine_tune(cond_x, mel, mel_mask, mel_lengths, mel_max_length, attn, spk_emb, segment_size, hps_finetune.data.n_feats, unitspeech)

        loss = sum([diff_loss])

        if fp16_run:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            _ = torch.nn.utils.clip_grad_norm_(unitspeech.parameters(), max_norm=1)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            _ = torch.nn.utils.clip_grad_norm_(unitspeech.parameters(), max_norm=1)
            optimizer.step()
    # =====================================================================================
    # NOW USE THE FINETUNED DECODER FOR INFERENCE OVER THE EVALUATION SAMPLES
    eval_samples_crnt_speaker = eval_speech_samples[eval_speech_samples["speaker_id"] == speaker_id]
    print(eval_samples_crnt_speaker["speaker_id"].value_counts())

    # Load the normalization parameters for mel-spectrogram normalization.
    mel_max = decoder_dict['mel_max'].cuda()
    mel_min = decoder_dict['mel_min'].cuda()
    for inner_idx, inner_row in eval_samples_crnt_speaker.iterrows():
        inner_path = inner_row["path"]
        # Get the sample name from inner_path
        crnt_sample = inner_path.split("/")[-1].split(".")[0] # no file extension
        
        inner_transcript = inner_row["transcript"]
        inner_speaker_id = inner_row["speaker_id"]
        assert inner_speaker_id == speaker_id, "Running inference on a different speaker ID than current finetuned model."
        print(f"Inference on speaker ID {inner_speaker_id}")
        print(f"\tPath: {inner_path}")
        print(f"\tTranscript: {inner_transcript}")


        phoneme = phonemize(inner_transcript, global_phonemizer)
        print(f"Running inference on: {phoneme}")
        phoneme = cleaned_text_to_sequence(phoneme)
        phoneme = intersperse(phoneme, len(symbols))  # add a blank token, whose id number is len(symbols)
        phoneme = torch.LongTensor(phoneme).cuda().unsqueeze(0)
        phoneme_lengths = torch.LongTensor([phoneme.shape[-1]]).cuda()

        with torch.no_grad():
            y_enc, y_dec, _attn = unitspeech.execute_text_to_speech(
                phoneme=phoneme,
                phoneme_lengths=phoneme_lengths,
                spk_emb=spk_emb,
                text_encoder=text_encoder,
                duration_predictor=duration_predictor,
                num_downsamplings_in_unet=num_downsamplings_in_unet,
                diffusion_steps=diffusion_steps,
                length_scale=length_scale,
                text_gradient_scale=text_gradient_scale,
                spk_gradient_scale=spk_gradient_scale,
            )
            mel_generated = ((y_dec + 1) / 2 * (mel_max - mel_min) + mel_min)
            synthesized_audio = vocoder.forward(mel_generated).cpu().squeeze().clamp(-1, 1).numpy()

            base_path  = "/workspace/local/evaluation/outputs/with-finetune_AWGN"
            write(f"{base_path}/{crnt_sample}.wav", cfg.data.sampling_rate, synthesized_audio)
        if inner_idx >= 5:
            break