In [1]:
import argparse
import IPython.display as ipd
import json
import librosa
import os

# To prevent the path from becoming corrupted when this cell is executed more than once.
try:
    path
except:
    path = "../"
    os.chdir(path)
    
import phonemizer
import random
from scipy.io.wavfile import write
import torch
import torchaudio
from tqdm import tqdm
from transformers import HubertModel

from unitspeech.unitspeech import UnitSpeech
from unitspeech.duration_predictor import DurationPredictor
from unitspeech.encoder import Encoder
from unitspeech.speaker_encoder.ecapa_tdnn import ECAPA_TDNN_SMALL
from unitspeech.text import cleaned_text_to_sequence, phonemize, symbols
from unitspeech.textlesslib.textless.data.speech_encoder import SpeechEncoder
from unitspeech.util import HParams, fix_len_compatibility, intersperse, process_unit, generate_path, sequence_mask
from unitspeech.vocoder.env import AttrDict
from unitspeech.vocoder.meldataset import mel_spectrogram
from unitspeech.vocoder.models import BigVGAN

import soundfile as sf

from unitspeech.util import (
    fix_len_compatibility,
    save_plot,
    sequence_mask,
)

  from .autonotebook import tqdm as notebook_tqdm


## Hyperparameters

In [24]:
# PREPARE ARGUMENTS, CHECKPOINT PATH

# reference audio path for finetuning
logdir = "notebooks/PS-logdir/"
reference_path = "reference-speech.wav"

# pretrained model path (Follow README)
encoder_path = "unitspeech/checkpoints/unit_encoder.pt"
decoder_path = "unitspeech/checkpoints/pretrained_decoder.pt"
speaker_encoder_path = "unitspeech/speaker_encoder/checkpts/speaker_encoder.pt"
finetune_config_path = "unitspeech/checkpoints/finetune.json"

# Arguments for finetuning

# If the voice is highly unique, increasing the number of iterations can be helpful. 
# However, excessively large iteration numbers can lead to a degradation in pronunciation. 
# We recommend starting with 500 iterations and, if the results are unsatisfactory, gradually increasing the number of iterations.
n_iters = 500
learning_rate = 2e-5
fp16_run = False

with open(finetune_config_path, "r") as f:
    data = f.read()
finetune_config = json.loads(data)

hps_finetune = HParams(**finetune_config)

segment_size = fix_len_compatibility(
    hps_finetune.train.out_size_second * hps_finetune.data.sampling_rate // hps_finetune.data.hop_length,
    len(hps_finetune.decoder.dim_mults) - 1
)

num_units = hps_finetune.data.n_units
print(f"Num of units: {num_units}")

Num of units: 1000


## Pretrained Model Checkpoints

In [3]:
# Vocoder
print('Initializing Vocoder...')
with open(hps_finetune.train.vocoder_config_path) as f:
    h = AttrDict(json.load(f))
vocoder = BigVGAN(h)
vocoder.load_state_dict(torch.load(hps_finetune.train.vocoder_ckpt_path, map_location=lambda loc, storage: loc)['generator'])
_ = vocoder.cuda().eval()
vocoder.remove_weight_norm()

# Speaker Encoder for extracting speaker embedding
print('Initializing Speaker Encoder...')
spk_embedder = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
state_dict = torch.load(speaker_encoder_path, map_location=lambda storage, loc: storage)
spk_embedder.load_state_dict(state_dict['model'], strict=False)
_ = spk_embedder.cuda().eval()

# Unit Extractor for extraction unit and duration, which are used for finetuning
print('Initializing Unit Extracter...')
dense_model_name = "mhubert-base-vp_en_es_fr"
quantizer_name, vocab_size = "kmeans", 1000

unit_extractor = SpeechEncoder.by_name(
    dense_model_name=dense_model_name,
    quantizer_model_name=quantizer_name,
    vocab_size=vocab_size,
    deduplicate=True,
    need_f0=False
)
_ = unit_extractor.cuda().eval()

Initializing Vocoder...
Removing weight norm...
Initializing Speaker Encoder...


Using cache found in /home/astanea/.cache/torch/hub/s3prl_s3prl_main
2024-05-20 12:56:12 | INFO | s3prl.util.download | Requesting URL: https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
2024-05-20 12:56:12 | INFO | s3prl.util.download | Using URL's local file: /home/astanea/.cache/s3prl/download/f2d5200177fd6a33b278b7b76b454f25cd8ee866d55c122e69fccf6c7467d37d.wavlm_large.pt
2024-05-20 12:56:26 | INFO | s3prl.upstream.wavlm.WavLM | WavLM Config: {'extractor_mode': 'layer_norm', 'encoder_layers': 24, 'encoder_embed_dim': 1024, 'encoder_ffn_embed_dim': 4096, 'encoder_attention_heads': 16, 'activation_fn': 'gelu', 'layer_norm_first': True, 'conv_feature_layers': '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2', 'conv_bias': False, 'feature_grad_mult': 1.0, 'normalize': True, 'dropout': 0.0, 'attention_dropout': 0.0, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.0, 'dropout_input': 0.0, 'dropout_features': 0.0, 'mask_length': 10, 'mask_prob': 0.8, 'mask_selectio

Initializing Unit Extracter...


2024-05-20 12:56:55 | INFO | fairseq.tasks.hubert_pretraining | current directory is /home/astanea/dev/UnitSpeech
2024-05-20 12:56:55 | INFO | fairseq.tasks.hubert_pretraining | HubertPretrainingTask Config {'_name': 'hubert_pretraining', 'data': '/checkpoint/annl/s2st/data/voxpopuli/mHuBERT/en_es_fr', 'fine_tuning': False, 'labels': ['km'], 'label_dir': '/checkpoint/wnhsu/experiments/hubert/kmeans/mhubert_vp_en_es_fr_it2_400k/en_es_fr.layer9.km500', 'label_rate': 50.0, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 250000, 'min_sample_size': 32000, 'single_target': False, 'random_crop': True, 'pad_audio': False}
2024-05-20 12:56:55 | INFO | fairseq.models.hubert.hubert | HubertModel Config: {'_name': 'hubert', 'label_rate': 50.0, 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': transformer, 'dropout':

## Load and process reference speech

- Load reference audio sample to do **One-shot speaker adaptation** of the Diffusion-based decoder
- Normalize mel spectrogram to [-1 1] interval
- Encode units and durations which will be used to finetune de Decoder to the target speaker voice

In [25]:
# Preprocess the reference audio in a format suitable for fine-tuning
wav, sr = librosa.load(reference_path)
wav = torch.FloatTensor(wav).unsqueeze(0)
mel = mel_spectrogram(wav, hps_finetune.data.n_fft, hps_finetune.data.n_feats, hps_finetune.data.sampling_rate, hps_finetune.data.hop_length,
                      hps_finetune.data.win_length, hps_finetune.data.mel_fmin, hps_finetune.data.mel_fmax, center=False)

save_plot(mel.squeeze().cpu(), f'{logdir}original_mel-UNSCALED.png',
          title="Reference mel - UNSCALED") 


with torch.no_grad():
    reference_audio = vocoder.forward(mel.cuda()).cpu().squeeze().clamp(-1, 1).numpy()

# Load the normalization parameters for mel-spectrogram normalization.
mel_min = mel.min(-1, keepdim=True)[0]
mel_max = mel.max(-1, keepdim=True)[0]

# normalize mel-spectrogram in range [-1, 1]
mel = (mel - mel_min) / (mel_max - mel_min) * 2 - 1 
save_plot(mel.squeeze().cpu(), f'{logdir}original_mel-SCALED.png',
          title="Reference mel - SCALED") 

mel = mel.cuda()

# Speaker embedder expects 16KHz audio samples
resample_fn = torchaudio.transforms.Resample(sr, 16000).cuda()
wav = resample_fn(wav.cuda())
spk_emb = spk_embedder(wav)

spk_emb = spk_emb / spk_emb.norm()

# Extract the units and unit durations to be used for fine-tuning.
encoded = unit_extractor(wav.to("cuda")) # => units with f_unit freq: 16Khz 

# Upsample unit and durations from f_unit to f_mel
unit, duration = process_unit(encoded, hps_finetune.data.sampling_rate, hps_finetune.data.hop_length)



In [26]:
print("Referece synthesisez speech:")
ipd.display(ipd.Audio(reference_audio, rate=sr)) # Decoder generates spectrograms which were samples with sr=22050
# Save to file
sf.write(f'{logdir}synthesized-reference.wav', reference_audio, sr)

Referece synthesisez speech:


In [27]:
# Initialize model and optimizer
unit_encoder = Encoder(
    n_vocab=num_units,
    n_feats=hps_finetune.data.n_feats,
    **hps_finetune.encoder
)

unit_encoder_dict = torch.load(encoder_path, map_location=lambda loc, storage: loc)
unit_encoder.load_state_dict(unit_encoder_dict['model'])
_ = unit_encoder.cuda().eval()

unitspeech = UnitSpeech(
    n_feats=hps_finetune.data.n_feats,
    **hps_finetune.decoder
)

decoder_dict = torch.load(decoder_path, map_location=lambda loc, storage: loc)
unitspeech.load_state_dict(decoder_dict['model'])
_ = unitspeech.cuda().train()

# NOTE: During fine-tunning we say the rest of params are frozen because we only update the decoder params in the optimizer
optimizer = torch.optim.Adam(params=unitspeech.parameters(), lr=learning_rate)

if fp16_run:
    scaler = torch.cuda.amp.GradScaler()

In [28]:
# Reshape the input to match the dimensions and convert it to a PyTorch tensor.
unit = unit.unsqueeze(0).cuda()
duration = duration.unsqueeze(0).cuda()
mel = mel.cuda()

print(f"Unit shape: {unit.shape}")
print(f"Duration shape: {duration.shape}")
print(f"Mel shape: {mel.shape} \n")

unit_lengths = torch.LongTensor([unit.shape[-1]]).cuda()
mel_lengths = torch.LongTensor([mel.shape[-1]]).cuda()
spk_emb = spk_emb.cuda().unsqueeze(1)

print(f"Unit lengths: {unit_lengths}")
print(f"Mel lengths: {mel_lengths}")
print(f"Speaker embedding shape: {spk_emb.shape}")

# Prepare unit encoder output for finetuning
with torch.no_grad():
    cond_x, x, x_mask = unit_encoder(unit, unit_lengths)

#     save_plot(cond_x.squeeze().cpu(), f'{logdir}Unit condx.png',
#           title="Reference mel - UNSCALED") 
#     save_plot(x.squeeze().cpu(), f'{logdir}Unit x.png',
#           title="Reference mel - UNSCALED") 
    print(f"\ncond_x shape: {cond_x.shape}")
    print(f"x shape: {x.shape}")
    print(f"x_mask shape: {x_mask.shape}")

mel_max_length = mel.shape[-1]
mel_mask = sequence_mask(mel_lengths, mel_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * mel_mask.unsqueeze(2)

print(f"\nMel max length: {mel_max_length}")
print(f"mel_mask shape: {mel_mask.shape}")
print(f"attn_mask shape: {attn_mask.shape}")

attn = generate_path(duration, attn_mask.squeeze(1))
print(f"\nattn shape: {attn.shape}")

Unit shape: torch.Size([1, 189])
Duration shape: torch.Size([1, 189])
Mel shape: torch.Size([1, 80, 429]) 

Unit lengths: tensor([189], device='cuda:0')
Mel lengths: tensor([429], device='cuda:0')
Speaker embedding shape: torch.Size([1, 1, 256])

cond_x shape: torch.Size([1, 80, 189])
x shape: torch.Size([1, 192, 189])
x_mask shape: torch.Size([1, 1, 189])

Mel max length: 429
mel_mask shape: torch.Size([1, 1, 429])
attn_mask shape: torch.Size([1, 1, 189, 429])

attn shape: torch.Size([1, 189, 429])


- Adapt the diffusion decoder to the target speaker voice using the reference audio sample using the diffusion loss

In [29]:
def fine_tune(cond_x, y, y_mask, y_lengths, y_max_length, attn, spk_emb, segment_size, n_feats, decoder):
    if y_max_length < segment_size:
        pad_size = segment_size - y_max_length
        y = torch.cat([y, torch.zeros_like(y)[:, :, :pad_size]], dim=-1)
        y_mask = torch.cat([y_mask, torch.zeros_like(y_mask)[:, :, :pad_size]], dim=-1)

    max_offset = (y_lengths - segment_size).clamp(0)
    offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
    out_offset = torch.LongTensor([
        torch.tensor(random.choice(range(start, end)) if end > start else 0)
        for start, end in offset_ranges
    ]).to(y_lengths)

    attn_cut = torch.zeros(attn.shape[0], attn.shape[1], segment_size, dtype=attn.dtype, device=attn.device)
    y_cut = torch.zeros(y.shape[0], n_feats, segment_size, dtype=y.dtype, device=y.device)
    y_cut_lengths = []
    for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
        y_cut_length = segment_size + (y_lengths[i] - segment_size).clamp(None, 0)
        y_cut_lengths.append(y_cut_length)
        cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
        y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
        attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
    y_cut_lengths = torch.LongTensor(y_cut_lengths)
    y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)

    if y_cut_mask.shape[-1] < segment_size:
        y_cut_mask = torch.nn.functional.pad(y_cut_mask, (0, segment_size - y_cut_mask.shape[-1]))

    attn = attn_cut
    y = y_cut
    y_mask = y_cut_mask

    # Align encoded text with mel-spectrogram and get cond_y segment
    cond_y = torch.matmul(attn.squeeze(1).transpose(1, 2).contiguous(), cond_x.transpose(1, 2).contiguous())
    cond_y = cond_y.transpose(1, 2).contiguous()
    cond_y = cond_y * y_mask

    # Compute loss of score-based decoder
    diff_loss, xt = decoder.compute_loss(y, y_mask, cond_y, spk_emb=spk_emb)

    return diff_loss

# Text-to-Speech

In [30]:
# Please provide the transcript you would like to synthesize with the desired voice. 
# We recommend entering the transcript in normalized text format.
# text = "If we are lucky, once in a lifetime we might achieve greatness"
text = "Does the quick brown fox jump over the lazy dog?"

# The text gradient scale is responsible for pronunciation and audio quality. 
# The default value is 1, and increasing the value improves pronunciation accuracy but may reduce speaker similarity. 
# We recommend starting with 0 and gradually increasing it if the pronunciation is not satisfactory.
# text_gradient_scale = 1.0
text_gradient_scale = 0.0


# The speaker gradient scale is responsible for speaker similarity. 
# Increasing the value enhances speaker similarity but may slightly degrade pronunciation and audio quality. 
# For unique voices, we recommend using a larger value for the speaker gradient scale.
# spk_gradient_scale = 2.0
spk_gradient_scale = 0.0


# We have confirmed that our duration predictor is not accurately following the duration of the reference audio as expected.
# As a result, while the reference audio's tone and speaking style are well adapted, there are differences in speech rate. 
# To address this issue, we use the "length_scale" argument as in Grad-TTS to mitigate the discrepancy.
# If the value of "length_scale" is greater than 1, the speech rate will be slower. 
# Conversely, if the value is less than 1, the speech rate will be faster.
length_scale = 1.0

# The number of diffusion steps during sampling refers to the number of iterations performed to improve audio quality.
# Generally, larger values lead to better audio quality but slower sampling speeds. 
# Conversely, smaller values allow for faster sampling but may result in lower audio quality.
# We recommend using a value of 50 for this parameter.
diffusion_step = 50 # 100

In [31]:
# Load modules for one-shot text-to-speech
text_encoder_path = "unitspeech/checkpoints/text_encoder.pt"
duration_predictor_path = "unitspeech/checkpoints/duration_predictor.pt"
tts_config_path = "unitspeech/checkpoints/text-to-speech.json"

with open(tts_config_path, "r") as f:
    data = f.read()
tts_config = json.loads(data)

hps_tts = HParams(**tts_config)

global_phonemizer = phonemizer.backend.EspeakBackend(
    language='en-us', preserve_punctuation=True, with_stress=True, language_switch="remove-flags"
)

In [32]:
# Initialize & load model
text_encoder = Encoder(
    n_vocab=len(symbols) + 1,
    n_feats=hps_tts.data.n_feats,
    **hps_tts.encoder
)

text_encoder_dict = torch.load(text_encoder_path, map_location=lambda loc, storage: loc)
text_encoder.load_state_dict(text_encoder_dict['model'])
_ = text_encoder.cuda().eval()

duration_predictor = DurationPredictor(
    **hps_tts.duration_predictor
)

duration_predictor_dict = torch.load(duration_predictor_path, map_location=lambda loc, storage: loc)
duration_predictor.load_state_dict(duration_predictor_dict['model'])
_ = duration_predictor.cuda().eval()

_ = unitspeech.cuda().eval()

- Phonemize the input text. The text encoder recognizes phonemes and uses a transformer architecture to encode features used as condition for the decoder

In [33]:
phoneme = phonemize(text, global_phonemizer)
phoneme = cleaned_text_to_sequence(phoneme)
phoneme = intersperse(phoneme, len(symbols))  # add a blank token, whose id number is len(symbols)
phoneme = torch.LongTensor(phoneme).cuda().unsqueeze(0)
phoneme_lengths = torch.LongTensor([phoneme.shape[-1]]).cuda()

phoneme.shape, phoneme_lengths

(torch.Size([1, 115]), tensor([115], device='cuda:0'))

## Compare original with finetuned results

### Original

In [39]:
logdir_original = logdir + "ORIGINAL_TTS/"
with torch.no_grad():
    y_enc, y_dec, _attn = unitspeech.execute_text_to_speech(
        phoneme=phoneme,
        phoneme_lengths=phoneme_lengths,
        spk_emb=spk_emb,
        text_encoder=text_encoder,
        duration_predictor=duration_predictor,
        num_downsamplings_in_unet=len(hps_tts.decoder.dim_mults) - 1,
        diffusion_steps=diffusion_step,
        length_scale=length_scale,
        text_gradient_scale=text_gradient_scale,
        spk_gradient_scale=spk_gradient_scale,
    )
    save_plot(y_enc.squeeze().cpu(), f"{logdir_original}encoder-output.png", title="Encoder output")
    save_plot(y_dec.squeeze().cpu(), f"{logdir_original}decoder-output-SCALED.png", title="Decoder output")
    save_plot(_attn.squeeze().cpu(), f"{logdir_original}MAS-attention-alignment.png", title="MAS alignment")

    y_dec = (y_dec + 1) / 2 * (mel_max.to(y_dec.device) - mel_min.to(y_dec.device)) + mel_min.to(y_dec.device)
    save_plot(y_dec.squeeze().cpu(), f"{logdir_original}decoder-output.png", title="Decoder output")
    synthesized_audio = vocoder.forward(y_dec).cpu().squeeze().clamp(-1, 1).numpy()  # (60160)

In [40]:
print('Generated audio')
ipd.display(ipd.Audio(synthesized_audio, rate=sr)) # Decoder generates spectrograms which were samples with sr=22050
sf.write(f'{logdir_original}generated.wav', synthesized_audio, sr)

Generated audio


### ADAPTATION: Finetune the decoder to the target speaker

In [41]:
# Finetune the decoder
for _ in tqdm(range(n_iters)):
    cond_x = cond_x.detach()
    mel = mel.detach()
    mel_mask = mel_mask.detach()
    mel_lengths = mel_lengths.detach()
    spk_emb = spk_emb.detach()
    attn = attn.detach()

    unitspeech.zero_grad()
      
    with torch.cuda.amp.autocast(enabled=fp16_run):
        diff_loss = fine_tune(cond_x, mel, mel_mask, mel_lengths, mel_max_length, attn, spk_emb, segment_size, hps_finetune.data.n_feats, unitspeech)

    loss = sum([diff_loss])

    if fp16_run:
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        _ = torch.nn.utils.clip_grad_norm_(unitspeech.parameters(), max_norm=1)
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        _ = torch.nn.utils.clip_grad_norm_(unitspeech.parameters(), max_norm=1)
        optimizer.step()

100%|██████████| 500/500 [01:49<00:00,  4.55it/s]


In [54]:
finetune_original = logdir + "FINETUNE_TTS/"
with torch.no_grad():
    y_enc, y_dec, _attn = unitspeech.execute_text_to_speech(
        phoneme=phoneme,
        phoneme_lengths=phoneme_lengths,
        spk_emb=spk_emb,
        text_encoder=text_encoder,
        duration_predictor=duration_predictor,
        num_downsamplings_in_unet=len(hps_tts.decoder.dim_mults) - 1,
        diffusion_steps=diffusion_step,
        length_scale=length_scale,
        text_gradient_scale=text_gradient_scale,
        spk_gradient_scale=spk_gradient_scale,
    )
    save_plot(y_enc.squeeze().cpu(), f"{finetune_original}encoder-output.png", title="Encoder output")
    save_plot(y_dec.squeeze().cpu(), f"{finetune_original}decoder-output-SCALED.png", title="Decoder output")
    save_plot(_attn.squeeze().cpu(), f"{finetune_original}MAS-attention-alignment.png", title="MAS alignment")

    y_dec = (y_dec + 1) / 2 * (mel_max.to(y_dec.device) - mel_min.to(y_dec.device)) + mel_min.to(y_dec.device)
    save_plot(y_dec.squeeze().cpu(), f"{finetune_original}decoder-output.png", title="Decoder output")
    synthesized_audio = vocoder.forward(y_dec).cpu().squeeze().clamp(-1, 1).numpy()  # (60160)

In [55]:
print('Generated audio')
ipd.display(ipd.Audio(synthesized_audio, rate=sr)) # Decoder generates spectrograms which were samples with sr=22050
sf.write(f'{finetune_original}generated.wav', synthesized_audio, sr)

Generated audio


In [56]:
unitspeech.text_uncon

Parameter containing:
tensor([[[-0.0256],
         [ 0.0366],
         [ 0.0836],
         [ 0.1462],
         [ 0.1891],
         [ 0.1528],
         [ 0.1260],
         [ 0.1020],
         [ 0.1140],
         [ 0.1088],
         [ 0.1122],
         [ 0.1031],
         [ 0.0954],
         [ 0.0740],
         [ 0.0598],
         [ 0.0297],
         [ 0.0179],
         [ 0.0009],
         [-0.0201],
         [-0.0285],
         [-0.0450],
         [-0.0383],
         [-0.0687],
         [-0.0645],
         [-0.0831],
         [-0.0888],
         [-0.0901],
         [-0.0977],
         [-0.1022],
         [-0.0972],
         [-0.1014],
         [-0.1015],
         [-0.0948],
         [-0.0941],
         [-0.1133],
         [-0.0739],
         [-0.0748],
         [-0.0839],
         [-0.0726],
         [-0.0707],
         [-0.0813],
         [-0.0779],
         [-0.0681],
         [-0.0808],
         [-0.0792],
         [-0.0929],
         [-0.0779],
         [-0.0705],
         [-0.0815]

: 

In [45]:
unitspeech.spk_uncon

torch.Size([1, 1, 256])