In [1]:
import models
import torch
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt

from tacotron2_common.utils import load_wav_to_torch
import tacotron2_common.layers as layers
import json


import umap

import dllogger as DLLogger
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def checkpoint_from_distributed(state_dict):
    """
    Checks whether checkpoint was generated by DistributedDataParallel. DDP
    wraps model in additional "module.", it needs to be unwrapped for single
    GPU inference.
    :param state_dict: model's state dict
    """
    ret = False
    for key, _ in state_dict.items():
        if key.find('module.') != -1:
            ret = True
            break
    return ret


def unwrap_distributed(state_dict):
    """
    Unwraps model from DistributedDataParallel.
    DDP wraps model in additional "module.", it needs to be removed for single
    GPU inference.
    :param state_dict: model's state dict
    """
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace('module.', '')
        new_state_dict[new_key] = value
    return new_state_dict


def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run,
                         forward_is_infer=False, jittable=False):
    model_parser = models.model_parser(model_name, parser, add_help=False)
    model_args, _ = model_parser.parse_known_args()

    model_config = models.get_model_config(model_name, model_args)
    model = models.get_model(model_name, model_config, cpu_run=cpu_run,
                             forward_is_infer=forward_is_infer,
                             jittable=jittable)

    if checkpoint is not None:
        if cpu_run:
            state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
        else:
            state_dict = torch.load(checkpoint)['state_dict']
        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        model.load_state_dict(state_dict, strict=False)

    if model_name == "WaveGlow":
        model = model.remove_weightnorm(model)

    model.eval()

    if fp16_run:
        model.half()

    return model


# taken from tacotron2/data_function.py:TextMelCollate.__call__
def pad_sequences(batch):
    # Right zero-pad all one-hot text sequences to max input length
    input_lengths, ids_sorted_decreasing = torch.sort(
        torch.LongTensor([len(x) for x in batch]),
        dim=0, descending=True)
    max_input_len = input_lengths[0]

    text_padded = torch.LongTensor(len(batch), max_input_len)
    text_padded.zero_()
    for i in range(len(ids_sorted_decreasing)):
        text = batch[ids_sorted_decreasing[i]]
        text_padded[i, :text.size(0)] = text

    return text_padded, input_lengths


def prepare_input_sequence(texts, cpu_run=False):
    emotions = {}
    d = []
    for i,text in enumerate(texts):
        flag = True
        while flag:
            flag = False
            for emo in tag2ref.keys():
                pos = text.find(emo)
                if pos != -1:
                    text = text.replace(emo, '', 1)
                    if emo in emotions:
                        emotions[emo].append(pos)
                    else:
                        emotions[emo] = [pos]
                    flag = True

        d.append(torch.IntTensor(
            text_to_sequence(text, ['english_cleaners'])[:]))

    text_padded, input_lengths = pad_sequences(d)
    if not cpu_run:
        text_padded = text_padded.cuda().long()
        input_lengths = input_lengths.cuda().long()
    else:
        text_padded = text_padded.long()
        input_lengths = input_lengths.long()

    return text_padded, input_lengths, emotions

class MelLoader():
    def __init__(self, text_cleaners, max_wav_value, sampling_rate, filter_length, hop_length, win_length, n_mel_channels, mel_fmin, mel_fmax, segment_length=None):
        self.text_cleaners = text_cleaners
        self.max_wav_value = max_wav_value
        self.sampling_rate = sampling_rate
        self.segment_length = segment_length
        self.stft = layers.TacotronSTFT(
            filter_length, hop_length, win_length,
            n_mel_channels, sampling_rate, mel_fmin, mel_fmax)

    def get_mel(self, filename):
        audio, sampling_rate = load_wav_to_torch(filename)
        if sampling_rate != self.stft.sampling_rate:
            raise ValueError("{} {} SR doesn't match target {} SR".format(
                sampling_rate, self.stft.sampling_rate))
        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)

        return melspec
    
    def get_mel_audio(self, filename):
        audio, sampling_rate = load_wav_to_torch(filename)
        if sampling_rate != self.stft.sampling_rate:
            raise ValueError("{} {} SR doesn't match target {} SR".format(
                sampling_rate, self.stft.sampling_rate))
        
        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = torch.randint(0, max_audio_start + 1, size=(1,)).item()
            audio_segment = audio[audio_start:audio_start+self.segment_length]
        else:
            audio_segment = torch.nn.functional.pad(
                audio, (0, self.segment_length - audio.size(0)), 'constant').data
        audio_segment = audio_segment / self.max_wav_value

        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        
        return melspec, audio_segment

class MeasureTime():
    def __init__(self, measurements, key, cpu_run=False):
        self.measurements = measurements
        self.key = key
        self.cpu_run = cpu_run

    def __enter__(self):
        if not self.cpu_run:
            torch.cuda.synchronize()
        self.t0 = time.perf_counter()

    def __exit__(self, exc_type, exc_value, exc_traceback):
        if not self.cpu_run:
            torch.cuda.synchronize()
        self.measurements[self.key] = time.perf_counter() - self.t0

In [3]:
tacotron_path = 'output/checkpoint_Tacotron2_6420.pt'

class Args(argparse.Namespace):
    input='phrases/phrase.txt'
    output='output/'
    suffix=''
    tacotron2=tacotron_path
    waveglow='checkpoints/waveglow_1076430_14000_amp'
    sigma_infer=0.9
    denoising_strength=0.01
    sampling_rate=22050
    fp16=False
    cpu=False
    log_file='nvlog.json'
    include_warmup=False
    stft_hop_length=256
    ref_path='/home/madusov/vkr/data/ssw_esd_ljspeech_22050/wavs/5_50_2.wav'
    
class Args2(argparse.Namespace):
    input='phrases/phrase.txt'
    output='output/'
    suffix=''
    tacotron2=tacotron_path
    waveglow='checkpoints/waveglow_1076430_14000_amp'
    sigma_infer=0.9
    denoising_strength=0.01, 
    sampling_rate=22050, 
    fp16=False, 
    cpu=False 
    log_file='nvlog.json'
    include_warmup=False 
    stft_hop_length=256 
    ref_path='/home/madusov/vkr/data/ssw_esd_ljspeech_22050/wavs/5_50_2.wav'
    mask_padding=False
    n_mel_channels=80 
    n_symbols=148
    symbols_embedding_dim=512
    encoder_kernel_size=5
    encoder_n_convolutions=3
    encoder_embedding_dim=512
    n_frames_per_step=1
    decoder_rnn_dim=1024
    prenet_dim=256
    max_decoder_steps=2000
    gate_threshold=0.5
    p_attention_dropout=0.1
    p_decoder_dropout=0.1
    decoder_no_early_stopping=False
    attention_rnn_dim=1024
    attention_dim=128
    attention_location_n_filters=32
    attention_location_kernel_size=31 
    postnet_embedding_dim=512
    postnet_kernel_size=5
    postnet_n_convolutions=5

In [7]:
args = Args()
args_model = Args2()

In [5]:
log_file = os.path.join(args.output, args.log_file)
DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_file),
                        StdOutBackend(Verbosity.VERBOSE)])

In [6]:
for k,v in vars(args).items():
    DLLogger.log(step="PARAMETER", data={k:v})
DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})

args.segment_length = 50000
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
                                 args.fp16, args.cpu, forward_is_infer=True)

# load tacotron
model_config = models.get_model_config('Tacotron2', args_model)
model = models.get_model('Tacotron2', model_config, cpu_run=cpu_run,
                         forward_is_infer=forward_is_infer,
                         jittable=jittable)

if checkpoint is not None:
    if cpu_run:
        state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
    else:
        state_dict = torch.load(checkpoint)['state_dict']
    if checkpoint_from_distributed(state_dict):
        state_dict = unwrap_distributed(state_dict)

    model.load_state_dict(state_dict, strict=True)



waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
                                args.fp16, args.cpu, forward_is_infer=True,
                                jittable=True)
denoiser = Denoiser(waveglow)
if not args.cpu:
    denoiser.cuda()

waveglow.make_ts_scriptable()
jitted_waveglow = torch.jit.script(waveglow)
# jitted_tacotron2 = torch.jit.script(tacotron2)

texts = []
try:
    f = open(args.input, 'r')
    texts = f.readlines()
except:
    print("Could not read file")
    sys.exit(1)

with open('config.json') as f:
    audio_config = json.load(f)

loader = MelLoader(text_cleaners=['english_cleaners'], 
                   max_wav_value=audio_config['audio']['max-wav-value'], 
                   sampling_rate=audio_config['audio']['sampling-rate'], 
                   filter_length=audio_config['audio']['filter-length'], 
                   hop_length=audio_config['audio']['hop-length'], 
                   win_length=audio_config['audio']['win-length'], 
                   n_mel_channels=80, 
                   mel_fmin=audio_config['audio']['mel-fmin'], 
                   mel_fmax=audio_config['audio']['mel-fmax'],
                   segment_length=args.segment_length)
# ref_mel = loader.get_mel(args.ref_path)
# ref_mel = ref_mel.unsqueeze(0)

# load emotion mels
for emo, path in tag2ref.items():
    emo_mel, emo_audio_segment = loader.get_mel_audio(path)
    emo_mel = emo_mel.unsqueeze(0)

    if not args.cpu:
        emo_mel = emo_mel.to('cuda')

    tag2ref[emo] = {'mel': emo_mel, 'audio': emo_audio_segment}

# if not args.cpu:
#     ref_mel = ref_mel.to('cuda')


# if args.include_warmup:
#     sequence = torch.randint(low=0, high=148, size=(1,50)).long()
#     input_lengths = torch.IntTensor([sequence.size(1)]).long()
#     if not args.cpu:
#         sequence = sequence.cuda()
#         input_lengths = input_lengths.cuda()
#     for i in range(3):
#         with torch.no_grad():
#             mel, mel_lengths, _ = tacotron2(sequence, input_lengths)
#             _ = jitted_waveglow(mel)

measurements = {}

sequences_padded, input_lengths, emotions = prepare_input_sequence(texts, args.cpu)

for emo, positions in emotions.items():
    emotions[emo] = {'pos': positions, 'mel': tag2ref[emo]['mel'], 'audio': tag2ref[emo]['audio']}

if '<NEUTRAL>' not in emotions:
    emotions['<NEUTRAL>'] = {'pos': [], 'mel': tag2ref['<NEUTRAL>']['mel'], 'audio': tag2ref['<NEUTRAL>']['audio']}




DLL 2025-05-04 20:40:44.584222 - PARAMETER model_name : Tacotron2_PyT 


NameError: name 'parser' is not defined

In [None]:
with torch.no_grad(), MeasureTime(measurements, "tacotron2_time", args.cpu):
    mel, mel_lengths, alignments = tacotron2(sequences_padded, input_lengths, emotions)

with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu):
    audios = jitted_waveglow(mel, sigma=args.sigma_infer)
    audios = audios.float()
with torch.no_grad(), MeasureTime(measurements, "denoiser_time", args.cpu):
    audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

print("Stopping after",mel.size(2),"decoder steps")
tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']
waveglow_infer_perf = audios.size(0)*audios.size(1)/measurements['waveglow_time']

DLLogger.log(step=0, data={"tacotron2_items_per_sec": tacotron2_infer_perf})
DLLogger.log(step=0, data={"tacotron2_latency": measurements['tacotron2_time']})
DLLogger.log(step=0, data={"waveglow_items_per_sec": waveglow_infer_perf})
DLLogger.log(step=0, data={"waveglow_latency": measurements['waveglow_time']})
DLLogger.log(step=0, data={"denoiser_latency": measurements['denoiser_time']})
DLLogger.log(step=0, data={"latency": (measurements['tacotron2_time']+measurements['waveglow_time']+measurements['denoiser_time'])})

for i, audio in enumerate(audios):

    plt.imshow(alignments[i].float().data.cpu().numpy().T, aspect="auto", origin="lower")
    figure_path = os.path.join(args.output,"alignment_"+str(i)+args.suffix+".png")
    plt.savefig(figure_path)

    audio = audio[:mel_lengths[i]*args.stft_hop_length]
    audio = audio/torch.max(torch.abs(audio))
    audio_path = os.path.join(args.output,"audio_"+str(i)+args.suffix+".wav")
    write(audio_path, args.sampling_rate, audio.cpu().numpy())

DLLogger.flush()