In [None]:
import os
import sys
import numpy as np
import IPython.display

import librosa
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict

# load other modules --> repo root path
sys.path.insert(0, "../")

import torch
from utils import text, audio
from utils.logging import Logger
from params.params import Params as hp
from modules.tacotron2 import Tacotron
from dataset.dataset import TextToSpeechDataset, TextToSpeechDatasetCollection

In [None]:
def remove_dataparallel_prefix(state_dict): 
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    return new_state_dict

In [None]:
def build_model(checkpoint):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = torch.load(checkpoint, map_location=device)
    hp.load_state_dict(state['parameters'])

    model = Tacotron()
    model_dict = model.state_dict()
    pretrained_dict = remove_dataparallel_prefix(state['model'])
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    
    model_dict.update(pretrained_dict) 
    model.load_state_dict(model_dict) 
    model.to(device)
    
    return model

In [None]:
def inference(model, inputs):
    
    inputs = [l.rstrip().split('|') for l in inputs if l]

    spectrograms = []
    for i in inputs:
        t = torch.LongTensor(text.to_sequence(i[0], use_phonemes=hp.use_phonemes))
        l = torch.LongTensor([hp.languages.index(i[2])]) if hp.multi_language else None
        s = torch.LongTensor([hp.unique_speakers.index(i[1])]) if hp.multi_speaker else None

        if torch.cuda.is_available(): 
            t = t.cuda(non_blocking=True)
            if l: l = l.cuda(non_blocking=True)
            if s: s = s.cuda(non_blocking=True)
                
        spectrograms.append(model.inference(t, speaker=s, language=l).cpu().detach().numpy())

    return spectrograms

# Synthesis

In [None]:
#checkpoint = "../checkpoints/FRGE2B_loss-129-0.085"
checkpoint = "../checkpoints/pretrain/CSS-ACU_loss-99-0.145"

In [None]:
torch.load(checkpoint, map_location="cpu")['parameters']

In [None]:
model = build_model(checkpoint)
model.eval();
print(hp.encoder_type)
print(hp.encoder_dimension)
print(hp.languages)

In [None]:
inputs = ["erlauben sie bitte, dass ich mich kurz vorstelle. ich heiße jana novakova.||hungarian",
          "les socialistes et les républicains sont venus apporter leurs voix à la majorité pour ce texte.||spanish",
          "and when the first municipal authority of our land will be no longer subjected to the reproach||german"]

In [None]:
generated_spectrograms = inference(model, inputs)

In [None]:
for i, s in enumerate(generated_spectrograms):
    s = audio.denormalize_spectrogram(s, not hp.predict_linear)
    w = audio.inverse_spectrogram(s, not hp.predict_linear)
    a = IPython.display.Audio(data=w, rate=hp.sample_rate)
    IPython.display.display(a)

## Chaning language embedding

In [None]:
from modules.layers import ConstantEmbedding

In [None]:
embedding = model._decoder._language_embedding.weight[1,:] # .mean(dim=0)
model._decoder._language_embedding = ConstantEmbedding(embedding)

In [None]:
generated_spectrograms = inference(model, inputs)

In [None]:
for i, s in enumerate(generated_spectrograms):
    s = audio.denormalize_spectrogram(s, not hp.predict_linear)
    w = audio.inverse_spectrogram(s, not hp.predict_linear)
    a = IPython.display.Audio(data=w, rate=hp.sample_rate)
    IPython.display.display(a)