# FastPitch Adapter Finetuning

This notebook is designed to provide a guide on how to run FastPitch Adapter Finetuning Pipeline. It contains the following sections:
1. **Transform pre-trained FastPitch checkpoint to adapter-compatible checkpoint**
2. **Fine-tune FastPitch on adaptation data**: fine-tune pre-trained multi-speaker FastPitch for a new speaker
* Dataset Preparation: download dataset and extract manifest files. (duration more than 15 mins)
* Preprocessing: add absolute audio paths in manifest and extract Supplementary Data.
* Training: fine-tune frozen multispeaker FastPitch with trainable adapters.
3. **Fine-tune HiFiGAN on adaptation data**: fine-tune a vocoder for the fine-tuned multi-speaker FastPitch
* Dataset Preparation: extract mel-spectrograms from fine-tuned FastPitch.
* Training: fine-tune HiFiGAN with fine-tuned adaptation data.
4. **Inference**: generate speech from adpated FastPitch
* Load Model: load pre-trained multi-speaker FastPitch with **fine-tuned adapters**.
* Output Audio: generate audio files.

# License

> Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
> 
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
> You may obtain a copy of the License at
> 
>     http://www.apache.org/licenses/LICENSE-2.0
> 
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

In [None]:
"""
You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.
Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies# .
"""
# # If you're using Colab and not running locally, uncomment and run this cell.
# BRANCH = 'main'
# !apt-get install sox libsndfile1 ffmpeg
# !pip install wget unidecode pynini==2.1.4 scipy==1.7.3
# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

# # Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,
# # comment out the below lines and set `code_dir` to your local path.
code_dir = 'NeMoTTS' 
!git clone https://github.com/NVIDIA/NeMo.git {code_dir}

In [None]:
!wandb login #PASTE_WANDB_APIKEY_HERE

In [None]:
# .nemo files for your pre-trained FastPitch and HiFiGAN
pretrained_fastpitch_checkpoint = ""
finetuned_hifigan_on_multispeaker_checkpoint = ""

In [None]:
sample_rate = 44100
# Store all manifest and audios
data_dir = 'NeMoTTS_dataset'
# Store all supplementary files
supp_dir = "NeMoTTS_sup_data"
# Store all training logs
logs_dir = "NeMoTTS_logs"
# Store all mel-spectrograms for vocoder training
mels_dir = "NeMoTTS_mels"

In [None]:
import os
import json
import shutil
import nemo
import torch
import numpy as np

from pathlib import Path
from tqdm import tqdm

In [None]:
os.makedirs(code_dir, exist_ok=True)
code_dir = os.path.abspath(code_dir)
os.makedirs(data_dir, exist_ok=True)
data_dir = os.path.abspath(data_dir)
os.makedirs(supp_dir, exist_ok=True)
supp_dir = os.path.abspath(supp_dir)
os.makedirs(logs_dir, exist_ok=True)
logs_dir = os.path.abspath(logs_dir)
os.makedirs(mels_dir, exist_ok=True)
mels_dir = os.path.abspath(mels_dir)

# 1. Transform pre-trained checkpoint to adapter-compatible checkpoint

In [None]:
from nemo.collections.tts.models import FastPitchModel
from nemo.core import adapter_mixins
from omegaconf import DictConfig, OmegaConf, open_dict

In [None]:
def update_model_config_to_support_adapter(config) -> DictConfig:
    with open_dict(config):
        enc_adapter_metadata = adapter_mixins.get_registered_adapter(config.input_fft._target_)
        if enc_adapter_metadata is not None:
            config.input_fft._target_ = enc_adapter_metadata.adapter_class_path

        dec_adapter_metadata = adapter_mixins.get_registered_adapter(config.output_fft._target_)
        if dec_adapter_metadata is not None:
            config.output_fft._target_ = dec_adapter_metadata.adapter_class_path

        pitch_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.pitch_predictor._target_)
        if pitch_predictor_adapter_metadata is not None:
            config.pitch_predictor._target_ = pitch_predictor_adapter_metadata.adapter_class_path

        duration_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.duration_predictor._target_)
        if duration_predictor_adapter_metadata is not None:
            config.duration_predictor._target_ = duration_predictor_adapter_metadata.adapter_class_path

        aligner_adapter_metadata = adapter_mixins.get_registered_adapter(config.alignment_module._target_)
        if aligner_adapter_metadata is not None:
            config.alignment_module._target_ = aligner_adapter_metadata.adapter_class_path

    return config

In [None]:
model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint)
model.cfg = update_model_config_to_support_adapter(model.cfg)
model.save_to('Pretrained-FastPitch.nemo')
shutil.copyfile(finetuned_hifigan_on_multispeaker_checkpoint, "Pretrained-HifiGan.nemo")

pretrained_fastpitch_checkpoint = os.path.abspath("Pretrained-FastPitch.nemo")
finetuned_hifigan_on_multispeaker_checkpoint = os.path.abspath("Pretrained-HifiGan.nemo")
#     state = torch.load(pretrained_fastpitch_checkpoint)
#     state['hyper_parameters']['cfg'] = update_model_config_to_support_adapter(state['hyper_parameters']['cfg'])
#     torch.save(state, pretrained_fastpitch_checkpoint)

# 2. Fine-tune FastPitch on adaptation data

## a. Data Preparation
For our tutorial, we use small part of VCTK dataset with a new target speaker (p267). Usually, the audios should have total duration more than 15 mintues.

In [None]:
!cd {data_dir} && wget https://vctk-subset.s3.amazonaws.com/vctk_subset.tar.gz && tar zxf vctk_subset.tar.gz

In [None]:
manidir = f"{data_dir}/vctk_subset"
!ls {manidir}

In [None]:
train_manifest = os.path.abspath(os.path.join(manidir, 'train.json'))
valid_manifest = os.path.abspath(os.path.join(manidir, 'dev.json'))

## b. Preprocessing

### Add absolute file path in manifest
We use absoluate path for audio_filepath to get the audio during training.

In [None]:
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest

In [None]:
train_datas = read_manifest(train_manifest)
for m in train_datas: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))
write_manifest(train_manifest, train_datas)

valid_datas = read_manifest(valid_manifest)
for m in valid_datas: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))
write_manifest(valid_manifest, valid_datas)

### Extract Supplementary Data

As mentioned in the [FastPitch and MixerTTS training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MixerTTS_Training.ipynb) - To accelerate and stabilize our training, we also need to extract pitch for every audio, estimate pitch statistics (mean, std, min, and max). To do this, all we need to do is iterate over our data one time, via `extract_sup_data.py` script.

Note: This is an optional step, if skipped, it will be automatically executed within the first epoch of training FastPitch.

In [None]:
!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \
    manifest_filepath={train_manifest} \
    sup_data_path={supp_dir} \
    dataset.sample_rate={sample_rate} \
    dataset.n_fft=2048 \
    dataset.win_length=2048 \
    dataset.hop_length=512

After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.
```bash
PITCH_MEAN=175.48513793945312, PITCH_STD=42.3786735534668
PITCH_MIN=65.4063949584961, PITCH_MAX=270.8517761230469
```

In [None]:
!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \
    manifest_filepath={valid_manifest} \
    sup_data_path={supp_dir} \
    dataset.sample_rate={sample_rate} \
    dataset.n_fft=2048 \
    dataset.win_length=2048 \
    dataset.hop_length=512

## c. Training

In [None]:
phoneme_dict_path = os.path.abspath(os.path.join(code_dir, "scripts", "tts_dataset_files", "cmudict-0.7b_nv22.10"))
heteronyms_path = os.path.abspath(os.path.join(code_dir, "scripts", "tts_dataset_files", "heteronyms-052722"))

# Copy and Paste the PITCH_MEAN and PITCH_STD from previous steps (train_manifest) to overide pitch_mean and pitch_std configs below.
PITCH_MEAN=175.48513793945312
PITCH_STD=42.3786735534668

### Important notes
* `+init_from_ptl_ckpt`: initialize with a multi-speaker FastPitch checkpoint
* `~model.speaker_encoder.lookup_module`: remove the pre-trained looked-up speaker embedding
* Other optional arguments based on your preference:
    * batch_size
    * exp_manager
    * trainer

In [None]:
# Normally 100 epochs
!cd {code_dir} && python examples/tts/fastpitch_finetune_adapters.py \
--config-name=fastpitch_align_44100_adapter.yaml \
+init_from_nemo_model={pretrained_fastpitch_checkpoint} \
train_dataset={train_manifest} \
validation_datasets={valid_manifest} \
sup_data_types="['align_prior_matrix', 'pitch', 'speaker_id', 'reference_audio']" \
sup_data_path={supp_dir} \
pitch_mean={PITCH_MEAN} \
pitch_std={PITCH_STD} \
~model.speaker_encoder.lookup_module \
model.train_ds.dataloader_params.batch_size=8 \
model.validation_ds.dataloader_params.batch_size=8 \
model.optim.name=adam \
model.optim.lr=2e-4 \
~model.optim.sched \
exp_manager.exp_dir={logs_dir} \
+exp_manager.create_wandb_logger=True \
+exp_manager.wandb_logger_kwargs.name="tutorial-FastPitch-finetune-adaptation" \
+exp_manager.wandb_logger_kwargs.project="NeMo" \
+exp_manager.checkpoint_callback_params.save_top_k=-1 \
trainer.max_epochs=10 \
trainer.check_val_every_n_epoch=10 \
trainer.log_every_n_steps=1 \
trainer.devices=1 \
trainer.strategy=ddp \
trainer.precision=32

In [None]:
# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/adapters.pt
last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / "FastPitch").iterdir() if i.is_dir()]))[-1] / "checkpoints"
finetuned_adapter_checkpoint = list(last_checkpoint_dir.glob('adapters.pt'))[0]
print(finetuned_adapter_checkpoint)

# 4. Fine-tune HiFiGAN on adaptation data

## a. Dataset Preparation
Generate mel-spectrograms for HiFiGAN training.

In [None]:
from nemo.collections.tts.parts.utils.tts_dataset_utils import BetaBinomialInterpolator
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from collections import defaultdict
import random
random.seed(100)

In [None]:
def gen_spectrogram(index, manifest, speaker_to_index):
    
    record = manifest[index]
    audio_file = record["audio_filepath"]
    
    if '.wav' in audio_file:
        save_path = os.path.abspath(os.path.join(mels_dir, audio_file.split("/")[-1].replace(".wav", ".npy")))
    
    if '.flac' in audio_file:
        save_path = os.path.abspath(os.path.join(mels_dir, audio_file.split("/")[-1].replace(".flac", ".npy")))
    
    if os.path.exists(save_path):
        return save_path
    
    if "normalized_text" in record:
        text = spec_model.parse(record["normalized_text"], normalize=False)
    else:
        text = spec_model.parse(record['text'])
        
    text_len = torch.tensor(text.shape[-1], dtype=torch.long, device=spec_model.device).unsqueeze(0)
    
    audio = wave_model.process(audio_file).unsqueeze(0).to(device=spec_model.device)
    audio_len = torch.tensor(audio.shape[1]).long().unsqueeze(0).to(device=spec_model.device)
    spect, spect_len = spec_model.preprocessor(input_signal=audio, length=audio_len) 
    
    attn_prior = torch.from_numpy(beta_binomial_interpolator(spect_len.item(), text_len.item())).unsqueeze(0).to(spec_model.device)
        
    reference_pool = speaker_to_index[record["speaker"]] - set([index]) if len(speaker_to_index[record["speaker"]]) > 1 else speaker_to_index[record["speaker"]]
    reference_sample = manifest[random.sample(reference_pool, 1)[0]]
    reference_audio = wave_model.process(reference_sample["audio_filepath"]).unsqueeze(0).to(device=spec_model.device)
    reference_audio_length = torch.tensor(reference_audio.shape[1]).long().unsqueeze(0).to(device=spec_model.device)
    reference_spec, reference_spec_len = spec_model.preprocessor(input_signal=reference_audio, length=reference_audio_length)  
    
        
    with torch.no_grad():
        spectrogram = spec_model.forward(
          text=text, 
          input_lens=text_len,
          spec=spect, 
          mel_lens=spect_len, 
          attn_prior=attn_prior,
          reference_spec=reference_spec,
          reference_spec_lens=reference_spec_len,
        )[0]
    
    spec = spectrogram[0].to('cpu').numpy()
    np.save(save_path, spec)
    return save_path

In [None]:
wave_model = WaveformFeaturizer(sample_rate=sample_rate)

# Pretrained FastPitch Weights
spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint)

# Load Adapter Weights
spec_model.load_adapters(finetuned_adapter_checkpoint)
spec_model.eval().cuda()

beta_binomial_interpolator = BetaBinomialInterpolator()

In [None]:
os.makedirs(mels_dir, exist_ok=True)

# Train
train_datas = read_manifest(train_manifest)
speaker_to_index = defaultdict(list)
for i, d in enumerate(train_datas): speaker_to_index[d.get('speaker', None)].append(i)
speaker_to_index = {k: set(v) for k, v in speaker_to_index.items()}

for i, record in enumerate(tqdm(train_datas)):
    record["mel_filepath"] =  gen_spectrogram(i, train_datas, speaker_to_index)

write_manifest(train_manifest, train_datas)


# Valid
valid_datas = read_manifest(valid_manifest)
speaker_to_index = defaultdict(list)
for i, d in enumerate(valid_datas): speaker_to_index[d.get('speaker', None)].append(i)
speaker_to_index = {k: set(v) for k, v in speaker_to_index.items()}

for i, record in enumerate(tqdm(valid_datas)):
    record["mel_filepath"] =  gen_spectrogram(i, valid_datas, speaker_to_index)

write_manifest(valid_manifest, valid_datas)

## b. Training

In [None]:
# Normally 500 epochs
!cd {code_dir} && python examples/tts/hifigan_finetune.py \
--config-name=hifigan_44100.yaml \
train_dataset={train_manifest} \
validation_datasets={valid_manifest} \
+init_from_nemo_model={finetuned_hifigan_on_multispeaker_checkpoint} \
model.train_ds.dataloader_params.batch_size=32 \
model.optim.lr=0.0001 \
model/train_ds=train_ds_finetune \
model/validation_ds=val_ds_finetune \
+trainer.max_epochs=5 \
trainer.check_val_every_n_epoch=5 \
trainer.devices=-1 \
trainer.strategy='ddp' \
trainer.precision=16 \
exp_manager.exp_dir={logs_dir} \
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name="tutorial-HiFiGAN-finetune-multispeaker" \
exp_manager.wandb_logger_kwargs.project="NeMo"

In [None]:
# e.g. NeMoTTS_logs/HifiGan/Y-M-D_H-M-S/checkpoints/HifiGan.nemo
last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / "HifiGan").iterdir() if i.is_dir()]))[-1] / "checkpoints"
finetuned_hifigan_on_adaptation_checkpoint = list(last_checkpoint_dir.glob('*.nemo'))[0]
finetuned_hifigan_on_adaptation_checkpoint

# 3. Inference

In [None]:
from nemo.collections.tts.models import HifiGanModel
import IPython.display as ipd
import matplotlib.pyplot as plt

## a. Load Model

In [None]:
wave_model = WaveformFeaturizer(sample_rate=sample_rate)

In [None]:
# FastPitch
spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint)
spec_model.load_adapters(finetuned_adapter_checkpoint)
spec_model = spec_model.eval().cuda()

In [None]:
# HiFiGAN
vocoder_model = HifiGanModel.restore_from(finetuned_hifigan_on_adaptation_checkpoint).eval().cuda()

## b. Output Audio

In [None]:
def gt_spectrogram(audio_path, wave_model, spec_gen_model):
    features = wave_model.process(audio_path, trim=False)
    audio, audio_length = features, torch.tensor(features.shape[0]).long()
    audio = audio.unsqueeze(0).to(device=spec_gen_model.device)
    audio_length = audio_length.unsqueeze(0).to(device=spec_gen_model.device)
    with torch.no_grad():
        spectrogram, spec_len = spec_gen_model.preprocessor(input_signal=audio, length=audio_length)
    return spectrogram, spec_len

def gen_spectrogram(text, spec_gen_model, reference_spec, reference_spec_lens):
    parsed = spec_gen_model.parse(text)
    with torch.no_grad():    
        spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed,                                                           
                                                          reference_spec=reference_spec, 
                                                          reference_spec_lens=reference_spec_lens)

    return spectrogram
  
def synth_audio(vocoder_model, spectrogram):    
    with torch.no_grad():  
        audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)
    if isinstance(audio, torch.Tensor):
        audio = audio.to('cpu').numpy()
    return audio

In [None]:
# Reference Audio
with open(train_manifest, "r") as f:
    for i, line in enumerate(f):
        reference_record = json.loads(line)
        break
        
# Validatation Audio
num_val = 3
val_records = []
with open(valid_manifest, "r") as f:
    for i, line in enumerate(f):
        val_records.append(json.loads(line))
        if len(val_records) >= num_val:
            break

In [None]:
for i, val_record in enumerate(val_records):
    reference_spec, reference_spec_lens = gt_spectrogram(reference_record['audio_filepath'], wave_model, spec_model)
    reference_spec = reference_spec.to(spec_model.device)
    spec_pred = gen_spectrogram(val_record['text'], spec_model,
                                reference_spec=reference_spec, 
                                reference_spec_lens=reference_spec_lens)

    audio_gen = synth_audio(vocoder_model, spec_pred)
    
    audio_ref = ipd.Audio(reference_record['audio_filepath'], rate=sample_rate)
    audio_gt = ipd.Audio(val_record['audio_filepath'], rate=sample_rate)
    audio_gen = ipd.Audio(audio_gen, rate=sample_rate)
    
    print("------")
    print(f"Text: {val_record['text']}")
    print('Reference Audio')
    ipd.display(audio_ref)
    print('Ground Truth Audio')
    ipd.display(audio_gt)
    print('Synthesized Audio')
    ipd.display(audio_gen)
    plt.imshow(spec_pred[0].to('cpu').numpy(), origin="lower", aspect="auto")
    plt.show()

In [None]:
fintuned_fastpitch = 'fastpitch.nemo'
fintuned_hifigan = 'hifigan.nemo'
spec_model.save_to(fintuned_fastpitch)
vocoder_model.save_to(fintuned_hifigan)

In [None]:
print(f"FastPitch checkpoint: {pretrained_fastpitch_checkpoint}")
print(f"Adapter checkpoint: {finetuned_adapter_checkpoint}")
print(f"HiFi-Gan checkpoint: {finetuned_hifigan_on_adaptation_checkpoint}")

In [None]:
print(f"FastPitch nemo file: {fintuned_fastpitch}")
print(f"HiFi-Gan nemo file: {fintuned_hifigan}")