In [1]:
import typing as tp
from pathlib import Path
import re
import math
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from pydub import AudioSegment
from common.text import cmudict
from fastpitch.model import FastPitch
from hifigan.models import Generator, Denoiser
from common.text.text_processing import get_text_processing


device = torch.device('cuda:0')


def load_fast_pitch(ckpt_path: Path) -> FastPitch:
    ckpt_data = torch.load(ckpt_path)
    ckpt_config = ckpt_data.get('config')
    model_config = ckpt_config
    model = FastPitch(**model_config)
    model.forward = model.infer
    sd = ckpt_data['state_dict']
    sd = {re.sub('^module\.', '', k): v for k, v in sd.items()}
    status = model.load_state_dict(sd, strict=False)
    model.eval()
    return model


def load_hifigan(ckpt_path: Path) -> Generator:
    ckpt_data = torch.load(ckpt_path)
    ckpt_config = ckpt_data.get('config')
    model_config = ckpt_config
    model = Generator(model_config)
    sd = ckpt_data['generator']
    sd = {re.sub('^module\.', '', k): v for k, v in sd.items()}
    status = model.load_state_dict(sd, strict=False)
    model.remove_weight_norm()
    model.eval()
    return model


def load_cmudict():
    cmudict_path = 'cmudict/cmudict-0.7b'
    heteronyms_path = 'cmudict/heteronyms'
    cmudict.initialize(cmudict_path, heteronyms_path)


text_processor = get_text_processing(symbol_set='english_basic',
                                     text_cleaners=['english_cleaners_v2'], 
                                     p_arpabet=1.0)


def prepare_batch(texts: tp.List[tp.List[str]],
                  voice: str) -> tp.Tuple[tp.Dict[str, torch.Tensor], tp.List[int]]:
    text_list = []
    refs_to_origin = []
    for idx, text_segments in enumerate(texts):
        for text_segment in text_segments:
            text_list.append(text_processor.encode_text(text_segment))
            refs_to_origin.append(idx)
        
    batch = {'text': text_list}
    batch['text'] = [torch.LongTensor(text) for text in batch['text']]
    batch['text_lens'] = [text_tensor.size(0) for text_tensor in batch['text']]
    for f in batch:
        if f == 'text':
            batch[f] = pad_sequence(batch[f], batch_first=True)
        elif f == 'mel':
            batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1)
        elif f == 'pitch':
            batch[f] = pad_sequence(batch[f], batch_first=True)
        if type(batch[f]) is torch.Tensor:
            batch[f] = batch[f].to(device)
    return batch, refs_to_origin
    

fast_pitch_model = load_fast_pitch('output/FastPitch_checkpoint_760.pt')
hifigan_model = load_hifigan(Path('/home/server2/weights') / 'hifigan' / 'hifigan_gen_checkpoint_10000_ft.pt')
load_cmudict()
denoising_strength = 0.01
fast_pitch_model = fast_pitch_model.to(device)
hifigan_model = hifigan_model.to(device)
denoiser = Denoiser(hifigan_model, device=device) if denoising_strength > 0 else None
samplerate = 22050
hop_length = 256
batch_size = 8


def do_tts(texts: str, voice: str, durations: tp.List[tp.Optional[float]]) -> tp.List[AudioSegment]:
    max_wav_value = 2 ** 15
    fade_out = 10
    
    batch, refs_to_origin = prepare_batch(texts, voice)
    with torch.no_grad():
        mel, mel_lens = [], []
        for part_idx in range(math.ceil(batch['text'].shape[0] / batch_size)):
            part_text = batch['text'][part_idx * batch_size:(part_idx+1) * batch_size]
            part_mel, part_mel_lens, *_ = fast_pitch_model(part_text, pace=1.0)
            mel += [m for m in part_mel]
            mel_lens += [ml for ml in part_mel_lens]
        
        out_durations = [0 for _ in range(len(texts))]
        for ref_to_origin, mel_len_item in zip(refs_to_origin, mel_lens):
            out_durations[ref_to_origin] += (mel_len_item.item() * hop_length) / samplerate
            
        paces = [None for _ in range(len(texts))]
        for idx, (duration, out_duration) in enumerate(zip(durations, out_durations)):
            if duration is None:
                paces[idx] = None
                continue
            paces[idx] = out_duration / duration
            paces[idx] = math.ceil(paces[idx] * samplerate / hop_length) * hop_length / samplerate
        
        mel_, mel_lens_ = [], []
        for idx, (text_item, ref_to_origin) in enumerate(zip(batch['text'], refs_to_origin)):
            pace = paces[ref_to_origin]
            if pace is None:
                mel_.append(mel[idx])
                mel_lens_.append(mel_lens[idx])
                continue
            part_mel, part_mel_lens, *_ = fast_pitch_model(text_item.unsqueeze(0), pace=pace)
            part_mel.squeeze_(0)
            part_mel_lens.squeeze_(0)
            mel_.append(part_mel)
            mel_lens_.append(part_mel_lens)
        mel, mel_lens = mel_, mel_lens_
        
        audios = []
        for mel_item in mel:
            audio = hifigan_model(mel_item.unsqueeze(0)).float().squeeze(1)
            audios.append(audio)
        
        audios_ = []
        for audio, mel_len_item in zip(audios, mel_lens):
            if denoiser is not None:
                audio = denoiser(audio.float(), denoising_strength)
            audio = audio.squeeze(1).squeeze(0) * max_wav_value
            audio = audio[:mel_len_item.item() * hop_length]
            audios_.append(audio)
        audios = audios_
    
    audio_segments = [AudioSegment.silent(0) for _ in texts]
    for audio, ref_to_origin in zip(audios, refs_to_origin):
        if audio.shape[0] == 0:
            audio = AudioSegment.silent(0)
            audio_segments[ref_to_origin] += audio
            continue
        
        if fade_out:
            fade_len = fade_out * hop_length
            if audio.shape[-1] >= fade_len:
                fade_w = torch.linspace(1.0, 0.0, fade_len)
                audio[-fade_len:] *= fade_w.to(audio.device)

        audio = audio.cpu().numpy().astype(np.int16)
        audio = AudioSegment(audio.tobytes(), 
                             frame_rate=samplerate,
                             sample_width=audio.dtype.itemsize, 
                             channels=audio.shape[0] if len(audio.shape) > 1 else 1)
        audio_segments[ref_to_origin] += audio
    
    return audio_segments

LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1991.)
  return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))


HiFi-GAN: Removing weight norm.


In [2]:
from pathlib import Path
import IPython.display as ipd

temp_folder = Path('temp')
temp_folder.mkdir(exist_ok=True)

audio = do_tts(["Young Belgian couple cried a lot when their one-year-old baby Emin slept with them " +
                "they realised that he only got quiet when he slept in another room with his dogs. "],
               voice=['lj'],
               durations=[None])[0]

audio.export(temp_folder / 'test.mp3')
ipd.display(ipd.Audio(filename=temp_folder / 'test.mp3'))

  audio = audio[:mel_len_item.item() * hop_length]
