In [1]:
import nemo
import nemo.collections.asr as nemo_asr
import pytorch_lightning as pl
from omegaconf import DictConfig
import pathlib
import nemo.collections.asr as nemo_asr
import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt
import re

import copy
import json
import os
import tempfile
from math import ceil
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import ChainDataset
from tqdm.auto import tqdm

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
from nemo.utils import logging
from nemo.collections.tts.models import FastPitchHifiGanE2EModel

import IPython.display as ipd
import librosa
import librosa.display
import numpy as np
import torch
from matplotlib import pyplot as plt
%matplotlib inline

# Reduce logging messages for this notebook
from nemo.utils import logging
logging.setLevel(logging.ERROR)

from nemo.collections.tts.models import FastPitchModel
from nemo.collections.tts.models import HifiGanModel
from nemo.collections.tts.helpers.helpers import regulate_len

import torch.nn.functional as F


[NeMo W 2021-12-20 05:32:28 optimizers:47] Apex was not found. Using the lamb optimizer will error out.
    
[NeMo W 2021-12-20 05:32:28 nmse_clustering:54] Using eigen decomposition from scipy, upgrade torch to 1.9 or higher for faster clustering
################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################

      '"sox" backend is being deprecated. '
    
[NeMo W 2021-12-20 05:32:28 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text_dali._AudioTextDALIDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.


In [2]:
try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML

config_path = 'stt_en_citrinet_256_gamma_0_25'
config_name = 'model_config.yaml'
yaml = YAML(typ='safe')

with open(os.path.join(config_path, config_name)) as f:
    config = yaml.load(f)

config['tokenizer']['dir'] = 'citrinet_tokenizer/tokenizer_spe_unigram_v1024'
config['tokenizer']['type'] = 'bpe'

# config['train_ds']['manifest_filepath']="../../datasets/LJSpeech-1.1/train_manifest.json"
config['train_ds']['manifest_filepath']="../../datasets/an4/train_manifest.json"
config['train_ds']['batch_size'] = 1
config['train_ds']['num_workers'] = 12
config['train_ds']['pin_memory'] = True

# config['validation_ds']['manifest_filepath']="../../datasets/LJSpeech-1.1/test_manifest.json"
config['validation_ds']['manifest_filepath']="../../datasets/an4/test_manifest.json"
config['validation_ds']['batch_size'] = 1
config['validation_ds']['num_workers'] = 12
config['validation_ds']['pin_memory'] = True

config['spec_augment']['freq_masks'] = 0
config['spec_augment']['time_masks'] = 0
config['optim']['lr'] = 0.01
config['optim']['name'] = 'novograd'
config['optim']['betas'] = [0.8, 0.25]
config['optim']['weight_decay'] = 0.001
config['optim']['sched']['warmup_steps']=1000
config['optim']['sched']['min_lr'] = 0.00001

config['tokenizer']['model_path'] = 'stt_en_citrinet_256_gamma_0_25/3d20ebb793c84a64a20c7ad26fc64d62_tokenizer.model'
config['tokenizer']['vocab_path'] = 'stt_en_citrinet_256_gamma_0_25/df5191f216004f10a268c44e90fdb63f_vocab.txt'
config['tokenizer']['spe_tokenizer_vocab'] = 'stt_en_citrinet_256_gamma_0_25/b774eaac83804907843607272fde21a4_tokenizer.vocab'

In [3]:
# asr_model = nemo_asr.models.EncDecCTCModelBPE(cfg=DictConfig(config))
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained('stt_en_citrinet_256_gamma_0_25') 


In [4]:
fastpitch = FastPitchModel.from_pretrained("tts_en_fastpitch")
hifigan = HifiGanModel.from_pretrained("tts_hifigan")

In [5]:
fastpitch.cfg['train_ds']['sample_rate'] = 16000
fastpitch.cfg['preprocessor']['sample_rate'] = 16000
fastpitch.cfg['validation_ds']['sample_rate'] = 16000

In [6]:
hifigan.cfg['preprocessor']['sample_rate'] = 16000

In [7]:
asr_model.setup_training_data(train_data_config=config['train_ds'])
asr_model.setup_test_data(test_data_config=config['validation_ds'])


In [8]:
def fastpitch_training_step(self, audio, audio_lens, text, text_lens, durs, pitch, batch_idx=1):
    mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)

    mels_pred, _, _, log_durs_pred, pitch_pred, attn_soft, attn_logprob, attn_hard, attn_hard_dur, pitch = self(
        text=text,
        durs=durs,
        pitch=pitch,
        speaker=None,
        pace=1.0,
        spec=mels if self.learn_alignment else None,
        attn_prior=None,
        mel_lens=spec_len,
        input_lens=text_lens,
    )
    if durs is None:
        durs = attn_hard_dur

    mel_loss = self.mel_loss(spect_predicted=mels_pred, spect_tgt=mels)
    dur_loss = self.duration_loss(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
    loss = mel_loss + dur_loss

    pitch_loss = self.pitch_loss(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
    loss += pitch_loss

    return loss

def hifigan_training_step(self, batch, batch_idx, optimizer_idx):
    # if in finetune mode the mels are pre-computed using a
    # spectrogram generator
    if self.input_as_mel:
        audio, audio_len, audio_mel = batch
    # else, we compute the mel using the ground truth audio
    else:
        audio, audio_len = batch
        # mel as input for generator
        audio_mel, _ = self.audio_to_melspec_precessor(audio, audio_len)

    # mel as input for L1 mel loss
    audio_trg_mel, _ = self.trg_melspec_fn(audio, audio_len)
    audio = audio.unsqueeze(1)

    audio_pred = self.generator(x=audio_mel)
    audio_pred_mel, _ = self.trg_melspec_fn(audio_pred.squeeze(1), audio_len)

    # train discriminator
#     self.optim_d.zero_grad()
    mpd_score_real, mpd_score_gen, _, _ = self.mpd(y=audio, y_hat=audio_pred.detach())
    loss_disc_mpd, _, _ = self.discriminator_loss(
        disc_real_outputs=mpd_score_real, disc_generated_outputs=mpd_score_gen
    )
    msd_score_real, msd_score_gen, _, _ = self.msd(y=audio, y_hat=audio_pred.detach())
    loss_disc_msd, _, _ = self.discriminator_loss(
        disc_real_outputs=msd_score_real, disc_generated_outputs=msd_score_gen
    )
    loss_d = loss_disc_msd + loss_disc_mpd
    
    return loss_d
#     self.manual_backward(loss_d)
#     self.optim_d.step()

#     # train generator
#     self.optim_g.zero_grad()
#     loss_mel = F.l1_loss(audio_pred_mel, audio_trg_mel)
#     _, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd(y=audio, y_hat=audio_pred)
#     _, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd(y=audio, y_hat=audio_pred)
#     loss_fm_mpd = self.feature_loss(fmap_r=fmap_mpd_real, fmap_g=fmap_mpd_gen)
#     loss_fm_msd = self.feature_loss(fmap_r=fmap_msd_real, fmap_g=fmap_msd_gen)
#     loss_gen_mpd, _ = self.generator_loss(disc_outputs=mpd_score_gen)
#     loss_gen_msd, _ = self.generator_loss(disc_outputs=msd_score_gen)
#     loss_g = loss_gen_msd + loss_gen_mpd + loss_fm_msd + loss_fm_mpd + loss_mel * self.l1_factor
#     self.manual_backward(loss_g)
#     self.optim_g.step()

def citrinet_training_step(self, batch, batch_nb):
    signal, signal_len, transcript, transcript_len = batch
    
    log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

    loss_value = self.loss(
        log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
    )

    return loss_value

# for batch in asr_model.test_dataloader():
#     signal, signal_len, transcript, transcript_len = batch
        
#     target = asr_model._wer.ctc_decoder_predictions_tensor(
#         transcript, predictions_len=transcript_len, return_hypotheses=False,
#     )
    
#     parsed = fastpitch.parse(target[0])
    
#     spec, _, durs_pred, _, pitch_pred, *_ = fastpitch.eval()(text=parsed,
#                                                              durs=None, pitch=None, speaker=None, pace=1.0)
        
#     loss = fastpitch_training_step(fastpitch.cpu().train(), signal, signal_len, parsed,
#                                    torch.tensor([len(parsed)]), durs_pred, pitch_pred, batch_idx=1).cpu()
    
#     hifigan.input_as_mel = True
#     loss_2 = hifigan_training_step(hifigan.cpu().train(), batch=[signal, signal_len, spec], batch_idx=1, optimizer_idx=1)
    
#     loss_asr = citrinet_training_step(asr_model.cpu().train(), batch=batch, batch_nb=1)
    
#     print(loss, loss_2, loss_asr)
#     break

In [15]:
# Dual Transformation: text -> predicted_speech -> predicted_text -> loss
def dt_text_training_step(batch, batch_idx, optimizer_idx=None):
    if optimizer_idx is None:
        optimizer_idx = batch_idx
    
    _, _, transcript, transcript_len = batch
        
    target = asr_model._wer.ctc_decoder_predictions_tensor(
        transcript, predictions_len=transcript_len, return_hypotheses=False,
    )
    
    parsed = fastpitch.parse(target[0]).cpu()
    
    spec, _, durs_pred, _, pitch_pred, *_ = fastpitch.cpu().eval()(text=parsed,
                                                             durs=None, pitch=None, speaker=None, pace=1.0)
        
    fastpitch.eval()(text=parsed, durs=durs_pred, pitch=pitch_pred, speaker=None, pace=1.0)
    
    signal = hifigan.cpu().eval().convert_spectrogram_to_audio(spec=spec).to('cpu')
    signal_len = torch.tensor([len(signal[0])])
        
    loss = citrinet_training_step(asr_model.cpu().train(),
                                  batch=[signal, signal_len, transcript, transcript_len], batch_nb=batch_idx)

    fastpitch_training_step(fastpitch.cpu().train(), signal, signal_len, parsed,
                                   torch.tensor([len(parsed)]), durs_pred, pitch_pred, batch_idx=batch_idx).cpu()

    hifigan.input_as_mel = True
    hifigan_training_step(hifigan.cpu().train(),
                                   batch=[signal, signal_len, spec], batch_idx=batch_idx, optimizer_idx=optimizer_idx)
    
    return loss

In [16]:
# Dual Transformation: speech -> predicted_text -> predicted_speech -> loss
def dt_speech_training_step(batch, batch_idx, optimizer_idx=None):
    if optimizer_idx is None:
        optimizer_idx = batch_idx
        
    signal, signal_len, _, _ = batch

    logits, logits_len, greedy_predictions = asr_model.eval().cpu()(input_signal=signal,
                                                                    input_signal_length=signal_len)
    current_hypotheses = asr_model._wer.ctc_decoder_predictions_tensor(
        greedy_predictions, predictions_len=logits_len
    )
    
    transcript = fastpitch.parse(current_hypotheses[0])
    transcript_len = torch.tensor([len(transcript[0])]) 
        
    spec, _, durs_pred, _, pitch_pred, *_ = fastpitch.eval()(text=transcript, durs=None, pitch=None, speaker=None, pace=1.0)
    gt = train_batch[0]
    pred = hifigan.cpu().eval().convert_spectrogram_to_audio(spec=spec.cpu()).to('cpu')
                 
    citrinet_training_step(asr_model.cpu().train(),
                                  batch=[signal, signal_len, transcript, transcript_len], batch_nb=batch_idx)

    fastpitch_training_step(fastpitch.cpu().train(), signal, signal_len, transcript,
                                   transcript_len.cpu(), durs_pred, pitch_pred, batch_idx=batch_idx).cpu()

    hifigan.input_as_mel = True
    hifigan_training_step(hifigan.cpu().train(),
                                   batch=[signal, signal_len, spec], batch_idx=batch_idx, optimizer_idx=optimizer_idx)    
    
    return ((signal[0][:min(len(signal[0]), len(pred[0]))] - 
             pred[0][:min(len(signal[0]), len(pred[0]))].detach().numpy()) ** 2).sum()
    

tensor(74.8404, grad_fn=<MeanBackward0>)


In [34]:
def flip_batch(batch):
    for i in range(len(batch)):
        for j in range(len(batch[i])):
            batch[i][j] = torch.flip(batch[i][j], dims=[0])
            
    return batch

In [None]:
for step, train_batch in enumerate(asr_model.train_dataloader()):
    loss = 0
    
    loss += dt_text_training_step(train_batch, step)
    loss += dt_text_training_step(flip_batch(train_batch), step)
    
    loss += dt_speech_training_step(train_batch, step)
    loss += dt_text_training_step(flip_batch(train_batch), step)
    
    ###
    # Add loss for paired data
    ###

    ###
    # Add loss for Denoising Auto-Encoder
    ###

    # print(step, loss)
    