# 1 - Import Stuff

In [1]:
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import soundfile

from torch import Tensor
from typing import List, Tuple, Optional

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit


# 2 - Define some basic functions

load_untts_checkpoint()

load_waveglow_checkpoint()

text_preprocess()

In [2]:
def load_untts_checkpoint(path):
    print("Loading UnTTS... ")
    from CookieTTS._2_ttm.untts.model import load_model
    
    assert os.path.exists(path), f"untts checkpoint file at '{path}' does not exist."
    checkpoint_dict = torch.load(path)
    speaker_id_lookup = checkpoint_dict['speaker_id_lookup']
    hparams           = checkpoint_dict['hparams']
    iteration         = checkpoint_dict['iteration']
    
    model_dict        = checkpoint_dict['state_dict']
    model = load_model(hparams)
    model.load_state_dict(model_dict)
    _ = model.cuda().eval().half()
    print("Loaded UnTTS!")
    print(f"This UnTTS model has been trained for {iteration} Iterations.")
    return model, hparams, speaker_id_lookup


def load_waveglow_checkpoint(vocoder_path, config_fpath):
    import json
    
    # Load config file
    with open(config_fpath) as f:
        data = f.read()
    config = json.loads(data)
    train_config = config["train_config"]
    data_config = config["data_config"]
    if 'preempthasis' not in data_config.keys():
        data_config['preempthasis'] = 0.0
    if 'use_logvar_channels' not in data_config.keys():
        data_config['use_logvar_channels'] = False
    if 'load_hidden_from_disk' not in data_config.keys():
        data_config['load_hidden_from_disk'] = False
    if not 'iso226_empthasis' in data_config.keys():
        data_config["iso226_empthasis"] = False
    dist_config = config["dist_config"]
    data_config['n_mel_channels'] = config["waveglow_config"]['n_mel_channels'] if 'n_mel_channels' in config["waveglow_config"].keys() else 160
    vocoder_config = {
        **config["waveglow_config"],
        'win_length': data_config['win_length'],
        'hop_length': data_config['hop_length'],
        'preempthasis': data_config['preempthasis'],
        'n_mel_channels': data_config["n_mel_channels"],
        'use_logvar_channels': data_config["use_logvar_channels"],
        'load_hidden_from_disk': data_config["load_hidden_from_disk"],
        'iso226_empthasis': data_config["iso226_empthasis"]
    }
    print(vocoder_config)
    print(f"Config File from '{config_fpath}' successfully loaded.")
    
    # initialize model
    print("intializing WaveGlow model... ", end="")
    from CookieTTS._4_mtw.waveglow.efficient_model_ax import WaveGlow
    waveglow = WaveGlow(**vocoder_config)#.cuda()
    print("Done!")

    print(f"{sum(p.numel() for p in waveglow.parameters())/1e6:.3f}M Parameters")

    # load checkpoint from file
    print("loading WaveGlow checkpoint... ", end="")
    checkpoint = torch.load(vocoder_path)
    waveglow.load_state_dict(checkpoint['model']) # and overwrite initialized weights with checkpointed weights
    waveglow.half()       # and convert to half precision (2x speed)
    waveglow.cuda().eval()# move to GPU
    if hasattr(waveglow, 'remove_weightnorm'):
        waveglow.remove_weightnorm() # and remove weightnorm (1.3x speed)
    print("Done!")

    print("initializing Denoiser... ", end="")
    cond_channels = vocoder_config['n_mel_channels']*(vocoder_config['use_logvar_channels']+1)
    denoiser = None#Denoiser(waveglow, n_mel_channels=cond_channels)
    print("Done!")
    vocoder_iters = checkpoint['iteration']
    print(f"WaveGlow trained for {vocoder_iters} iterations")
    speaker_lookup = checkpoint['speaker_lookup'] # ids lookup
    training_sigma = train_config['sigma']

    print("Clearing CUDA Cache... ", end='')
    del checkpoint
    torch.cuda.empty_cache()
    print("Done!")

    import gc # prints currently alive Tensors and Variables  # - And fixes the memory leak? I guess poking the leak with a stick is the answer for now.
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                pass#print(type(obj), obj.size())
        except:
            pass
    
    return waveglow, denoiser, training_sigma, speaker_lookup, vocoder_config

In [3]:
import json
from CookieTTS.utils.text import text_to_sequence
from CookieTTS.utils.text.ARPA import ARPA

# torchMoji imports
import csv
import numpy as np
import os
from CookieTTS.utils.torchmoji.sentence_tokenizer import SentenceTokenizer
from CookieTTS.utils.torchmoji.model_def import torchmoji_feature_encoding
from CookieTTS.utils.torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
from CookieTTS.utils.dataset.utils import load_filepaths_and_text

class preprocess():
    def __init__(self, hparams, dict_path, speaker_ids):
        self.text_cleaners = hparams.text_cleaners
        self.arpa = ARPA(dict_path)
        self.speaker_ids = speaker_ids
        self.start_token = hparams.start_token
        self.stop_token = hparams.stop_token
        
        print(f'Tokenizing using dictionary from {VOCAB_PATH}')
        maxlen = 128
        with open(VOCAB_PATH, 'r') as f:
            vocabulary = json.load(f)
        self.tm_tokenizer = SentenceTokenizer(vocabulary, maxlen)
        del vocabulary
        
        print('Loading TorchMoji model from {}.'.format(PRETRAINED_PATH))
        self.tm_model = torchmoji_feature_encoding(PRETRAINED_PATH)
        self.tm_embedding = hparams.torchMoji_attDim
    
    def get_item(self, texts, use_arpa):
        if type(texts) == str:
            texts = [texts,]
        B = len(texts)
        
        text_tensors = []
        for text in texts:
            text = f'{self.start_token}{text}{self.stop_token}'
            text = self.arpa.get(text) if use_arpa else text
            text = self.get_text(text) # convert text into tensor representation
            text_tensors.append(text)
        del texts
        
        max_length = max([len(textt) for textt in text_tensors])
        
        texts = torch.zeros(B, max_length).long() # LongTensor[B, enc_T]
        for i, text in enumerate(text_tensors):
            texts[i, :text.shape[0]] = text
        return texts
    
    def get_text(self, text):
        text_norm = torch.LongTensor(text_to_sequence(text, self.text_cleaners))
        return text_norm
    
    def get_speaker_id(self, ext_speaker_ids):
        if (type(ext_speaker_ids) == str) or (type(ext_speaker_ids) == int):
            if int(ext_speaker_ids) in self.speaker_ids:
                speaker_ids = self.speaker_ids[int(ext_speaker_ids)]
            else:
                speaker_ids = 0
            speaker_ids =  torch.LongTensor([speaker_ids])
            # [1]
        else:
            speaker_ids = torch.LongTensor([self.speaker_ids[int(ext_speaker_id)] for ext_speaker_id in ext_speaker_ids])
            # [B]
        return speaker_ids
    
    def get_sylps(self, sylps):
        if type(sylps) == int:
            sylps = float(sylps)
        if sylps is None:
            pass
        elif type(sylps) == torch.Tensor:
            pass
        elif type(sylps) == float:
            sylps = torch.FloatTensor([sylps,])
        else:
            sylps = [float(x) for x in sylps]
            sylps = torch.FloatTensor(sylps)
        return sylps
    
    def get_torchmoji_hidden(self, texts, BATCH_SIZE=64):
        if type(texts) == str:
            texts = [texts,]
        B = len(texts)
        
        torchmoji_hdn = torch.zeros(B, int(self.tm_embedding))
        
        for i in range(0, B, BATCH_SIZE):
            texts_batch = texts[i:i+BATCH_SIZE]
            texts_batch = [x.split("|")[1] if ("|" in x and len(x.split("|"))==2 ) else x for x in texts_batch]
            
            tokenized, _, _ = self.tm_tokenizer.tokenize_sentences(texts_batch)# [B, Embed]
            embedding = self.tm_model(tokenized)# returns np array [B, Embed]
            torchmoji_hdn[i:i+BATCH_SIZE] = torch.from_numpy(embedding)
        
        return torchmoji_hdn
    
    def __call__(self,
            texts,
            ext_speaker_ids,
            sylps=None, # List of str or List of float or FloatTensor[B]
            use_arpa=True,
            ):
        # get torchMoji hidden state for each text
        torchmoji_hdn = self.get_torchmoji_hidden(texts)# -> FloatTensor[B, torchMoji_dim]
        
        # convert from texts to Tensors for model input
        texts       = self.get_item(texts, use_arpa)# -> LongTensor[B, enc_T]
        
        # convert external speaker ids to internal
        speaker_ids = self.get_speaker_id(ext_speaker_ids)# -> IntTensor[B]
        
        # ensure sylps is FloatTensor[B] or None if different type
        sylps       = self.get_sylps(sylps)
        
        return texts, speaker_ids, torchmoji_hdn, sylps

# 3 - Inference Config

In [4]:
config = {
    "dict_path":      "../../dict/merged.dict.txt",
    "untts_path":    r"G:\TwiBot\CookiePPPTTS\CookieTTS\_2_ttm\untts\outdir\best_model",
    "vocoder_path":  r"H:\TTCheckpoints\waveflow\4thLargeKernels\WG_24_Flow_AEF6\best_val_model",
    "vocoder_config":r"H:\TTCheckpoints\waveflow\4thLargeKernels\WG_24_Flow_AEF6\config.json",
}

In [5]:
waveglow, _, wg_sigma, wg_speaker_ids, wg_config = load_waveglow_checkpoint(
                                            config["vocoder_path"], config["vocoder_config"])

{'n_mel_channels': 256, 'speaker_embed': 32, 'shift_spect': 0.0, 'scale_spect': 1.0, 'preceived_vol_scaling': False, 'waveflow': False, 'channel_mixing': '1x1conv', 'mix_first': True, 'n_flows': 24, 'n_group': 24, 'n_early_every': 24, 'n_early_size': 2, 'memory_efficient': 1.0, 'spect_scaling': False, 'upsample_mode': 'normal', 'WN_config': {'gated_unit': 'GTU', 'n_layers': 8, 'n_channels': 384, 'kernel_size_w': 3, 'n_layers_dilations_w': None, 'n_layers_dilations_h': 1, 'speaker_embed_dim': 0, 'rezero': False, 'transposed_conv_hidden_dim': 256, 'transposed_conv_kernel_size': [2, 3, 5], 'transposed_conv_scales': None, 'cond_layers': 1, 'cond_activation_func': 'lrelu', 'cond_out_activation_func': False, 'negative_slope': 0.5, 'cond_hidden_channels': 256, 'cond_kernel_size': 1, 'cond_padding_mode': 'zeros', 'seperable_conv': False, 'res_skip': True, 'merge_res_skip': False, 'upsample_mode': 'linear'}, 'cond_layers': 4, 'cond_activation_func': 'lrelu', 'negative_slope': 0.1, 'cond_hidden_

In [62]:
untts, un_hparams, speaker_ids = load_untts_checkpoint(config['untts_path'])

Loading UnTTS... 
Loaded UnTTS!
This UnTTS model has been trained for 47500 Iterations.


In [7]:
Process = preprocess(un_hparams, config["dict_path"], speaker_ids)

Tokenizing using dictionary from g:\twibot\cookieppptts\CookieTTS\utils/torchmoji/model/vocabulary.json
Loading TorchMoji model from g:\twibot\cookieppptts\CookieTTS\utils/torchmoji/model/pytorch_model.bin.
Loading weights for embed.weight
Loading weights for lstm_0.weight_ih_l0
Loading weights for lstm_0.weight_hh_l0
Loading weights for lstm_0.bias_ih_l0
Loading weights for lstm_0.bias_hh_l0
Loading weights for lstm_0.weight_ih_l0_reverse
Loading weights for lstm_0.weight_hh_l0_reverse
Loading weights for lstm_0.bias_ih_l0_reverse
Loading weights for lstm_0.bias_hh_l0_reverse
Loading weights for lstm_1.weight_ih_l0
Loading weights for lstm_1.weight_hh_l0
Loading weights for lstm_1.bias_ih_l0
Loading weights for lstm_1.bias_hh_l0
Loading weights for lstm_1.weight_ih_l0_reverse
Loading weights for lstm_1.weight_hh_l0_reverse
Loading weights for lstm_1.bias_ih_l0_reverse
Loading weights for lstm_1.bias_hh_l0_reverse
Loading weights for attention_layer.attention_vector
Ignoring weights fo

# 4 - Infer

In [66]:
text = """
Aeron, a human man is transported to Equestria through means unknown. After appearing in the Everfree forest, he is attacked and killed by the malicious Timber Wolves. But his story does not end there. For he is reborn as something new to Equestria. But will this new being find acceptance among the Ponies? And how will his presence change the course of events as we know them? Only time will tell!
""".replace("\n","")
speaker_id = 167
sylps = None

texts, speaker_ids, torchmoji_hdn, sylps = Process(text, speaker_id, sylps, use_arpa=False)

ttm_dict = untts.inference(texts, speaker_ids, torchmoji_hdn, sylps,
                          mel_sigma=0.99, dur_sigma=0.95, var_sigma=0.8)

print(ttm_dict['spect'].shape)

audio_batch = waveglow.infer(ttm_dict['spect'], speaker_ids=speaker_ids.cuda(), sigma=wg_sigma*0.95)
audio_batch = audio_batch.cpu().float()

import IPython.display as ipd
ipd.display(ipd.Audio(audio_batch[0].cpu().numpy(), rate=48000))

torch.set_printoptions(sci_mode=False)
char_durs   = ttm_dict['char_durs'].view(-1)
char_voiced = ttm_dict['char_voiced'].view(-1)
char_f0     = ttm_dict['char_f0'].view(-1)
char_energy = ttm_dict['char_energy'].view(-1)
for i, (char, dur, voiced, f0, energy) in enumerate(zip(text, char_durs, char_voiced, char_f0, char_energy)):
    print(f"{i:>3} | '{char}' | {(1000.*dur.item()*un_hparams.hop_length)/un_hparams.sampling_rate:4.0f}ms | {voiced.item():.1f}%voiced | {f0.item() if voiced.item() > 0.1 else 0.0:5.1f}hz | {energy*4:5.1f}RMS Energy dB")
torch.set_printoptions(sci_mode=True)

torch.Size([1, 256, 1623])


  0 | 'A' |   43ms | 0.6%voiced | 352.0hz |  57.5RMS Energy dB
  1 | 'e' |  103ms | 0.7%voiced | 361.8hz |  79.0RMS Energy dB
  2 | 'r' |   82ms | 1.0%voiced | 379.0hz |  65.8RMS Energy dB
  3 | 'o' |   77ms | 1.0%voiced | 434.5hz |  53.4RMS Energy dB
  4 | 'n' |  100ms | 0.9%voiced | 397.2hz |  64.8RMS Energy dB
  5 | ',' |   92ms | 1.0%voiced | 431.2hz |  53.3RMS Energy dB
  6 | ' ' |  110ms | 0.6%voiced | 334.5hz |  58.7RMS Energy dB
  7 | 'a' |   25ms | 0.6%voiced | 319.2hz |  63.1RMS Energy dB
  8 | ' ' |   45ms | 0.6%voiced | 295.2hz |  53.0RMS Energy dB
  9 | 'h' |   75ms | 0.7%voiced | 280.5hz |  65.3RMS Energy dB
 10 | 'u' |   72ms | 0.2%voiced | 299.8hz |  51.0RMS Energy dB
 11 | 'm' |   55ms | 0.8%voiced | 321.8hz |  50.0RMS Energy dB
 12 | 'a' |  133ms | 1.0%voiced | 331.8hz |  56.9RMS Energy dB
 13 | 'n' |   55ms | 1.0%voiced | 303.2hz |  62.0RMS Energy dB
 14 | ' ' |   62ms | 1.0%voiced | 329.2hz |  50.7RMS Energy dB
 15 | 'm' |   24ms | 1.0%voiced | 313.0hz |  61.0RMS En

---