In [17]:
#hparams.py

"""Hyper parameters."""
__author__ = 'Erdene-Ochir Tuguldur'


class HParams:
    """Hyper parameters"""

    disable_progress_bar = False  # set True if you don't want the progress bar in the console

    logdir = "logdir"  # log dir where the checkpoints and tensorboard files are saved

    # audio.py options, these values are from https://github.com/Kyubyong/dc_tts/blob/master/hyperparams.py
    reduction_rate = 4  # melspectrogram reduction rate, don't change because SSRN is using this rate
    n_fft = 2048 # fft points (samples)
    n_mels = 80  # Number of Mel banks to generate
    power = 1.5  # Exponent for amplifying the predicted magnitude
    n_iter = 50  # Number of inversion iterations
    preemphasis = .97
    max_db = 100
    ref_db = 20
    sr = 22050  # Sampling rate
    frame_shift = 0.0125  # seconds
    frame_length = 0.05  # seconds
    hop_length = int(sr * frame_shift)  # samples. =276.
    win_length = int(sr * frame_length)  # samples. =1102.
    max_N = 180  # Maximum number of characters.
    max_T = 210  # Maximum number of mel frames.

    e = 128  # embedding dimension
    d = 256  # Text2Mel hidden unit dimension
    c = 512+128  # SSRN hidden unit dimension

    dropout_rate = 0.05  # dropout

    # Text2Mel network options
    text2mel_lr = 0.005  # learning rate
    text2mel_max_iteration = 3000  # max train step
    text2mel_weight_init = 'none'  # 'kaiming', 'xavier' or 'none'
    text2mel_normalization = 'layer'  # 'layer', 'weight' or 'none'
    text2mel_basic_block = 'gated_conv'  # 'highway', 'gated_conv' or 'residual'

    # SSRN network options
    ssrn_lr = 0.0005  # learning rate
    ssrn_max_iteration = 1500  # max train step
    ssrn_weight_init = 'kaiming'  # 'kaiming', 'xavier' or 'none'
    ssrn_normalization = 'weight'  # 'layer', 'weight' or 'none'
    ssrn_basic_block = 'residual'  # 'highway', 'gated_conv' or 'residual'


In [18]:
#audio.py

"""These methods are copied from https://github.com/Kyubyong/dc_tts/"""

import os
import copy
import librosa
import scipy.io.wavfile
import numpy as np
from tqdm import tqdm
from scipy import signal
hp = HParams()

def spectrogram2wav(mag):
    '''# Generate wave file from linear magnitude spectrogram
    Args:
      mag: A numpy array of (T, 1+n_fft//2)
    Returns:
      wav: A 1-D numpy array.
    '''
    # transpose
    mag = mag.T

    # de-noramlize
    mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db

    # to amplitude
    mag = np.power(10.0, mag * 0.05)

    # wav reconstruction
    wav = griffin_lim(mag ** hp.power)

    # de-preemphasis
    wav = signal.lfilter([1], [1, -hp.preemphasis], wav)

    # trim
    wav, _ = librosa.effects.trim(wav)

    return wav.astype(np.float32)


def griffin_lim(spectrogram):
    '''Applies Griffin-Lim's raw.'''
    X_best = copy.deepcopy(spectrogram)
    for i in range(hp.n_iter):
        X_t = invert_spectrogram(X_best)
        est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length)
        phase = est / np.maximum(1e-8, np.abs(est))
        X_best = spectrogram * phase
    X_t = invert_spectrogram(X_best)
    y = np.real(X_t)

    return y


def invert_spectrogram(spectrogram):
    '''Applies inverse fft.
    Args:
      spectrogram: [1+n_fft//2, t]
    '''
    return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann")


def get_spectrograms(fpath):
    '''Parse the wave file in `fpath` and
    Returns normalized melspectrogram and linear spectrogram.
    Args:
      fpath: A string. The full path of a sound file.
    Returns:
      mel: A 2d array of shape (T, n_mels) and dtype of float32.
      mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.
    '''
    # Loading sound file
    y, sr = librosa.load(fpath, sr=hp.sr)

    # Trimming
    y, _ = librosa.effects.trim(y)

    # Preemphasis
    y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])

    # stft
    linear = librosa.stft(y=y,
                          n_fft=hp.n_fft,
                          hop_length=hp.hop_length,
                          win_length=hp.win_length)

    # magnitude spectrogram
    mag = np.abs(linear)  # (1+n_fft//2, T)

    # mel spectrogram
    #mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels)  # (n_mels, 1+n_fft//2)
    mel_basis = librosa.filters.mel(sr=22050, n_fft=2048, n_mels=128)
    mel = np.dot(mel_basis, mag)  # (n_mels, t)

    # to decibel
    mel = 20 * np.log10(np.maximum(1e-5, mel))
    mag = 20 * np.log10(np.maximum(1e-5, mag))

    # normalize
    mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
    mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)

    # Transpose
    mel = mel.T.astype(np.float32)  # (T, n_mels)
    mag = mag.T.astype(np.float32)  # (T, 1+n_fft//2)

    return mel, mag


def save_to_wav(mag, filename):
    """Generate and save an audio file from the given linear spectrogram using Griffin-Lim."""
    wav = spectrogram2wav(mag)
    scipy.io.wavfile.write(filename, hp.sr, wav)


def preprocess(dataset_path, speech_dataset):
    """Preprocess the given dataset."""
    wavs_path = os.path.join(dataset_path, 'wavs')
    mels_path = os.path.join(dataset_path, 'mels')
    if not os.path.isdir(mels_path):
        os.mkdir(mels_path)
    mags_path = os.path.join(dataset_path, 'mags')
    if not os.path.isdir(mags_path):
        os.mkdir(mags_path)

    for fname in tqdm(speech_dataset.fnames):
        mel, mag = get_spectrograms(os.path.join(wavs_path, '%s.wav' % fname))

        t = mel.shape[0]
        # Marginal padding for reduction shape sync.
        num_paddings = hp.reduction_rate - (t % hp.reduction_rate) if t % hp.reduction_rate != 0 else 0
        mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant")
        mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant")
        # Reduction
        mel = mel[::hp.reduction_rate, :]

        np.save(os.path.join(mels_path, '%s.npy' % fname), mel)
        np.save(os.path.join(mags_path, '%s.npy' % fname), mag)


In [5]:
#ljspeech.py

"""Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/"""
import os
import re
import codecs
import unicodedata
import numpy as np

from torch.utils.data import Dataset

#vocab = "PE abcdefghijklmnopqrstuvwxyz'.?"  # P: Padding, E: EOS.
vocab = "PE অআইঈউঊঋএঐওঔা ি ী ু ূ ৃ ে ৈ ো ৌক খ গ ঘ ঙ চ ছ জ ঝ ঞ ট ঠ ড ঢ ণত থ দ ধ ন প ফ ব ভমযরলশষসহড়ঢ়য়ৎংঃঁ্ঽ‍্য‍  ‍্র'"
#vocab = "PE অআইঈউঊঋএঐওঔা ি ী ু ূ ৃ ে ৈ ো ৌক খ গ ঘ ঙ চ ছ জ ঝ ঞ ট ঠ ড ঢ ণত থ দ ধ ন প ফ ব ভমযরলশষসহড়ঢ়য়ৎংঃঁ্ঽ‍্য‍  ‍্র'.?"
char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}


def text_normalize(text):
    text = ''.join(char for char in unicodedata.normalize('NFD', text)
                   if unicodedata.category(char) != 'Mn')  # Strip accents
    #print(text)
    #text = text.lower()
    text = re.sub("[^{}]()".format(vocab), " ", text)
    text = re.sub("[ ]+", " ", text)
    return text

def read_metadata(metadata_file):
    fnames, text_lengths, texts = [], [], []
    transcript = os.path.join(metadata_file)
    #transcript = "/content/drive/My Drive/TTS_B/datasets/LJSpeech-1.1/line_index.tsv"
    lines = codecs.open(transcript, 'r', 'utf-8').readlines()
    for line in lines:
        fname, text = line.strip().split("\t")

        fnames.append(fname)

        text = text_normalize(text) + "E"  # E: EOS
        text = [char2idx[char] for char in text]
        text_lengths.append(len(text))
        texts.append(np.array(text, np.float32))

    return fnames, text_lengths, texts


def get_test_data(sentences, max_n):
    normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences]  # text normalization, E: EOS
    texts = np.zeros((len(normalized_sentences), max_n + 1), np.float32)
    for i, sent in enumerate(normalized_sentences):
        texts[i, :len(sent)] = [char2idx[char] for char in sent]
    return texts


class LJSpeech(Dataset):
    def __init__(self, keys, dir_name='bn_bd'):
        self.keys = keys
        self.path = os.path.join("/home/nipu/ml", dir_name)
        self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, "line_index.tsv"))

    def slice(self, start, end):
        self.fnames = self.fnames[start:end]
        self.text_lengths = self.text_lengths[start:end]
        self.texts = self.texts[start:end]

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):
        data = {}
        if 'texts' in self.keys:
            data['texts'] = self.texts[index]
        if 'mels' in self.keys:
            # (39, 80)
            data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index]))
        if 'mags' in self.keys:
            # (39, 80)
            data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index]))
        if 'mel_gates' in self.keys:
            data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int)  # TODO: because pre processing!
        if 'mag_gates' in self.keys:
            data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int)  # TODO: because pre processing!
        return data


he librosa.filters.mel() function creates a Mel filter-bank. This produces a linear transformation matrix to project FFT bins onto Mel-frequency bins.

The Mel scale is a quasi-logarithmic function of acoustic frequency designed such that perceptually similar pitch intervals (e.g. octaves) appear equal in width over the full hearing range. This makes it a useful scale for representing and analyzing audio signals, as it more closely matches the way humans perceive sound.

Mel filter-banks are commonly used in a variety of audio processing tasks, such as speech recognition, music information retrieval, and automatic speaker recognition.

The librosa.filters.mel() function takes the following arguments:

sr: The sampling rate of the incoming signal (in Hz).
n_fft: The number of FFT components.
n_mels: The number of Mel bands to generate.
fmin: The lowest frequency (in Hz).
fmax: The highest frequency (in Hz). If None, use fmax = sr / 2.0.
htk: Whether to use the HTK formula instead of Slaney.
norm: The type of normalization to apply to the filters. Can be None, slaney, or a number.
dtype: The data type of the output basis.
The function returns a NumPy array of shape (n_mels, 1 + n_fft / 2), which represents the Mel filter-bank.

Here is an example of how to use the librosa.filters.mel() function:

Python
import librosa

# Create a Mel filter-bank with 128 Mel bands
melfb = librosa.filters.mel(sr=22050, n_fft=2048, n_mels=128)

# Compute the Mel spectrogram of an audio signal
audio_data, sr = librosa.load('audio.wav')
mel_spectrogram = librosa.feature.melspectrogram(audio_data, sr=sr, S=melfb)
Use code with caution. Learn more
The mel_spectrogram variable will now contain a NumPy array of shape (n_mels, n_frames), which represents the Mel spectrogram of the audio signal.

Mel filter-banks are a powerful tool for audio processing, and the librosa.filters.mel() function makes it easy to create them in Python.

In [6]:
import os
import re
import codecs
import unicodedata
import numpy as np
import pandas as pd

from torch.utils.data import Dataset

In [7]:
vocab = "PE অআইঈউঊঋএঐওঔা ি ী ু ূ ৃ ে ৈ ো ৌক খ গ ঘ ঙ চ ছ জ ঝ ঞ ট ঠ ড ঢ ণত থ দ ধ ন প ফ ব ভমযরলশষসহড়ঢ়য়ৎংঃঁ্ঽ‍্য‍  ‍্র'"
char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = {idx: char for idx, char in enumerate(vocab)}

In [8]:
def text_normalize(text):
    text = ''.join(char for char in unicodedata.normalize('NFKC', text) # NFKC = 
                   if unicodedata.category(char) != 'Mn')  # Strip accents
    # print(text)
    text = re.sub("[^{}]()".format(vocab), " ", text)
    text = re.sub("[ ()]+", " ", text)
    return text

In [9]:
ss = text_normalize("আমার,++ নাম :(শামীম) ক্কা? কক্ষ মুখ্য পত্র, বন্ধু")
print(unicodedata.normalize('NFKC', "আমার,++ নাম :(শামীম) ক্কা? কক্ষ মুখ্য পত্র, বন্ধু"))

আমার,++ নাম :(শামীম) ক্কা? কক্ষ মুখ্য পত্র, বন্ধু


In [10]:
lines = codecs.open('bn_bd/line_index.tsv', 'r', 'utf-8').readlines()
len(lines[1])
line = lines[5]
print(line)

ban_00737_00107291991	কেয়া ডেভেলপারস দেশের বিভিন্ন স্থানে স্থাপনা তৈরি করে থাকে



In [11]:
label, text = line.strip().split("\t")
print(label.strip(), "\n", text.strip())

ban_00737_00107291991 
 কেয়া ডেভেলপারস দেশের বিভিন্ন স্থানে স্থাপনা তৈরি করে থাকে


In [12]:
# TTS_Bn
import os
import sys
import torch
from os.path import exists, join, expanduser


In [13]:
datasets_path = '/home/nipu/ml'
dataset_path = os.path.join(datasets_path, 'bn_bd')
print(dataset_path)

if os.path.isdir(dataset_path) and False:
  print("BN dataset folder already exists")


/home/nipu/ml/bn_bd


In [15]:
print("pre processing...")
lj_speech = LJSpeech([])
# print(lj_speech)
# print(dataset_path, lj_speech)
preprocess(dataset_path, lj_speech)

pre processing...


100%|██████████| 1891/1891 [01:29<00:00, 21.05it/s]


In [21]:
__author__ = 'Erdene-Ochir Tuguldur'
__all__ = ['E', 'D', 'C', 'HighwayBlock', 'GatedConvBlock', 'ResidualBlock']

#layers

import torch.nn as nn
import torch.nn.functional as F

# from hparams import HParams as hp


class LayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        """Layer Norm."""
        super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # PyTorch LayerNorm seems to be expect (B, T, C)
        y = super(LayerNorm, self).forward(x)
        y = y.permute(0, 2, 1)  # reverse
        return y


class D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, weight_init='none', normalization='weight', nonlinearity='linear'):
        """1D Deconvolution."""
        super(D, self).__init__()
        self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size,
                                         stride=2,  # paper: stride of deconvolution is always 2
                                         dilation=dilation)

        if normalization == 'weight':
            self.deconv = nn.utils.weight_norm(self.deconv)
        elif normalization == 'layer':
            self.layer_norm = LayerNorm(out_channels)

        self.nonlinearity = nonlinearity
        if weight_init == 'kaiming':
            nn.init.kaiming_normal_(self.deconv.weight, mode='fan_out', nonlinearity=nonlinearity)
        elif weight_init == 'xavier':
            nn.init.xavier_uniform_(self.deconv.weight, nn.init.calculate_gain(nonlinearity))

    def forward(self, x, output_size=None):
        y = self.deconv(x, output_size=output_size)
        if hasattr(self, 'layer_norm'):
            y = self.layer_norm(y)
        y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True)
        if self.nonlinearity == 'relu':
            y = F.relu(y, inplace=True)
        return y


class C(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, causal=False, weight_init='none', normalization='weight', nonlinearity='linear'):
        """1D convolution.
        The argument 'causal' indicates whether the causal convolution should be used or not.
        """
        super(C, self).__init__()
        self.causal = causal
        if causal:
            self.padding = (kernel_size - 1) * dilation
        else:
            self.padding = (kernel_size - 1) * dilation // 2

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                              stride=1,  # paper: 'The stride of convolution is always 1.'
                              padding=self.padding, dilation=dilation)

        if normalization == 'weight':
            self.conv = nn.utils.weight_norm(self.conv)
        elif normalization == 'layer':
            self.layer_norm = LayerNorm(out_channels)

        self.nonlinearity = nonlinearity
        if weight_init == 'kaiming':
            nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity=nonlinearity)
        elif weight_init == 'xavier':
            nn.init.xavier_uniform_(self.conv.weight, nn.init.calculate_gain(nonlinearity))

    def forward(self, x):
        y = self.conv(x)
        padding = self.padding
        if self.causal and padding > 0:
            y = y[:, :, :-padding]

        if hasattr(self, 'layer_norm'):
            y = self.layer_norm(y)
        y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True)
        if self.nonlinearity == 'relu':
            y = F.relu(y, inplace=True)
        return y


class E(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(E, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)

    def forward(self, x):
        return self.embedding(x)


class HighwayBlock(nn.Module):
    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'):
        """Highway Network like layer: https://arxiv.org/abs/1505.00387
        The input and output shapes remain same.
        Args:
            d: input channel
            k: kernel size
            delta: dilation
            causal: causal convolution or not
        """
        super(HighwayBlock, self).__init__()
        self.d = d
        self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization)

    def forward(self, x):
        L = self.C(x)
        H1 = L[:, :self.d, :]
        H2 = L[:, self.d:, :]
        sigH1 = F.sigmoid(H1)
        return sigH1 * H2 + (1 - sigH1) * x


class GatedConvBlock(nn.Module):
    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'):
        """Gated convolutional layer: https://arxiv.org/abs/1612.08083
        The input and output shapes remain same.
        Args:
            d: input channel
            k: kernel size
            delta: dilation
            causal: causal convolution or not
        """
        super(GatedConvBlock, self).__init__()
        self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal,
                   weight_init=weight_init, normalization=normalization)
        self.glu = nn.GLU(dim=1)

    def forward(self, x):
        L = self.C(x)
        return self.glu(L) + x


class ResidualBlock(nn.Module):
    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight',
                 widening_factor=2):
        """Residual block: https://arxiv.org/abs/1512.03385
        The input and output shapes remain same.
        Args:
            d: input channel
            k: kernel size
            delta: dilation
            causal: causal convolution or not
        """
        super(ResidualBlock, self).__init__()
        self.C1 = C(in_channels=d, out_channels=widening_factor * d, kernel_size=k, dilation=delta, causal=causal,
                    weight_init=weight_init, normalization=normalization, nonlinearity='relu')
        self.C2 = C(in_channels=widening_factor * d, out_channels=d, kernel_size=k, dilation=delta, causal=causal,
                    weight_init=weight_init, normalization=normalization, nonlinearity='relu')

    def forward(self, x):
        return self.C2(self.C1(x)) + x


In [23]:
#ssrn.py

"""
Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention
https://arxiv.org/abs/1710.08969

SSRN Network.
"""
__author__ = 'Erdene-Ochir Tuguldur'
__all__ = ['SSRN']

import torch.nn as nn
import torch.nn.functional as F



def Conv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):
    return C(in_channels, out_channels, kernel_size, dilation, causal=False,
             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity)


def DeConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):
    return D(in_channels, out_channels, kernel_size, dilation,
             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity)


def BasicBlock(d, k, delta):
    if hp.ssrn_basic_block == 'gated_conv':
        return GatedConvBlock(d, k, delta, causal=False,
                              weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization)
    elif hp.ssrn_basic_block == 'highway':
        return HighwayBlock(d, k, delta, causal=False,
                            weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization)
    else:
        return ResidualBlock(d, k, delta, causal=False,
                             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization,
                             widening_factor=1)


class SSRN(nn.Module):
    def __init__(self, c=hp.c, f=hp.n_mels, f_prime=(1 + hp.n_fft // 2)):
        """Spectrogram super-resolution network.
        Args:
            c: SSRN dim
            f: Number of mel bins
            f_prime: full spectrogram dim
        Input:
            Y: (B, f, T) predicted melspectrograms
        Outputs:
            Z_logit: logit of Z
            Z: (B, f_prime, 4*T) full spectrograms
        """
        super(SSRN, self).__init__()
        self.layers = nn.Sequential(
            Conv(f, c, 1, 1),

            BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),

            DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),
            DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),

            Conv(c, 2 * c, 1, 1),

            BasicBlock(2 * c, 3, 1), BasicBlock(2 * c, 3, 1),

            Conv(2 * c, f_prime, 1, 1),

            # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'),
            # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'),
            BasicBlock(f_prime, 1, 1),

            Conv(f_prime, f_prime, 1, 1)
        )

    def forward(self, x):
        Z_logit = self.layers(x)
        Z = F.sigmoid(Z_logit)
        return Z_logit, Z

In [25]:
#text2mel.py

"""
Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention
https://arxiv.org/abs/1710.08969

Text2Mel Network.
"""
__author__ = 'Erdene-Ochir Tuguldur'
__all__ = ['Text2Mel']

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F



def Conv(in_channels, out_channels, kernel_size, dilation, causal=False, nonlinearity='linear'):
    return C(in_channels, out_channels, kernel_size, dilation, causal=causal,
             weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, nonlinearity=nonlinearity)


def BasicBlock(d, k, delta, causal=False):
    if hp.text2mel_basic_block == 'gated_conv':
        return GatedConvBlock(d, k, delta, causal=causal,
                              weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization)
    elif hp.text2mel_basic_block == 'highway':
        return HighwayBlock(d, k, delta, causal=causal,
                            weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization)
    else:
        return ResidualBlock(d, k, delta, causal=causal,
                             weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization,
                             widening_factor=2)


def CausalConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):
    return Conv(in_channels, out_channels, kernel_size, dilation, causal=True, nonlinearity=nonlinearity)


def CausalBasicBlock(d, k, delta):
    return BasicBlock(d, k, delta, causal=True)


class TextEnc(nn.Module):

    def __init__(self, vocab, e=hp.e, d=hp.d):
        """Text encoder network.
        Args:
            vocab: vocabulary
            e: embedding dim
            d: Text2Mel dim
        Input:
            L: (B, N) text inputs
        Outputs:
            K: (B, d, N) keys
            V: (N, d, N) values
        """
        super(TextEnc, self).__init__()
        self.d = d
        self.embedding = E(len(vocab), e)

        self.layers = nn.Sequential(
            Conv(e, 2 * d, 1, 1, nonlinearity='relu'),
            Conv(2 * d, 2 * d, 1, 1),

            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27),
            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27),

            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 1),

            BasicBlock(2 * d, 1, 1), BasicBlock(2 * d, 1, 1)
        )

    def forward(self, x):
        out = self.embedding(x)
        out = out.permute(0, 2, 1)  # change to (B, e, N)
        out = self.layers(out)  # (B, 2*d, N)
        K = out[:, :self.d, :]  # (B, d, N)
        V = out[:, self.d:, :]  # (B, d, N)
        return K, V


class AudioEnc(nn.Module):
    def __init__(self, d=hp.d, f=hp.n_mels):
        """Audio encoder network.
        Args:
            d: Text2Mel dim
            f: Number of mel bins
        Input:
            S: (B, f, T) melspectrograms
        Output:
            Q: (B, d, T) queries
        """
        super(AudioEnc, self).__init__()
        self.layers = nn.Sequential(
            CausalConv(f, d, 1, 1, nonlinearity='relu'),
            CausalConv(d, d, 1, 1, nonlinearity='relu'),
            CausalConv(d, d, 1, 1),

            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),
            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),

            CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 3),
        )

    def forward(self, x):
        return self.layers(x)


class AudioDec(nn.Module):
    def __init__(self, d=hp.d, f=hp.n_mels):
        """Audio decoder network.
        Args:
            d: Text2Mel dim
            f: Number of mel bins
        Input:
            R_prime: (B, 2d, T) [V*Attention, Q] paper says: "we found it beneficial in our pilot study."
        Output:
            Y: (B, f, T)
        """
        super(AudioDec, self).__init__()
        self.layers = nn.Sequential(
            CausalConv(2 * d, d, 1, 1),

            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),

            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 1),

            # CausalConv(d, d, 1, 1, nonlinearity='relu'),
            # CausalConv(d, d, 1, 1, nonlinearity='relu'),
            CausalBasicBlock(d, 1, 1),
            CausalConv(d, d, 1, 1, nonlinearity='relu'),

            CausalConv(d, f, 1, 1)
        )

    def forward(self, x):
        return self.layers(x)


class Text2Mel(nn.Module):
    def __init__(self, vocab, d=hp.d):
        """Text to melspectrogram network.
        Args:
            vocab: vocabulary
            d: Text2Mel dim
        Input:
            L: (B, N) text inputs
            S: (B, f, T) melspectrograms
        Outputs:
            Y_logit: logit of Y
            Y: predicted melspectrograms
            A: (B, N, T) attention matrix
        """
        super(Text2Mel, self).__init__()
        self.d = d
        self.text_enc = TextEnc(vocab)
        self.audio_enc = AudioEnc()
        self.audio_dec = AudioDec()

    def forward(self, L, S, monotonic_attention=False):
        K, V = self.text_enc(L)
        Q = self.audio_enc(S)
        A = torch.bmm(K.permute(0, 2, 1), Q) / np.sqrt(self.d)

        if monotonic_attention:
            # TODO: vectorize instead of loops
            B, N, T = A.size()
            for i in range(B):
                prva = -1  # previous attention
                for t in range(T):
                    _, n = torch.max(A[i, :, t], 0)
                    if not (-1 <= n - prva <= 3):
                        A[i, :, t] = -2 ** 20  # some small numbers
                        A[i, min(N - 1, prva + 1), t] = 1
                    _, prva = torch.max(A[i, :, t], 0)

        A = F.softmax(A, dim=1)
        R = torch.bmm(V, A)
        R_prime = torch.cat((R, Q), 1)
        Y_logit = self.audio_dec(R_prime)
        Y = F.sigmoid(Y_logit)
        return Y_logit, Y, A


In [27]:
"""Wrapper class for logging into the TensorBoard and comet.ml"""
__author__ = 'Erdene-Ochir Tuguldur'
__all__ = ['Logger']

import os
from tensorboardX import SummaryWriter




class Logger(object):

    def __init__(self, dataset_name, model_name):
        self.model_name = model_name
        self.project_name = "%s-%s" % (dataset_name, self.model_name)
        self.logdir = os.path.join(hp.logdir, self.project_name)
        self.writer = SummaryWriter(log_dir=self.logdir)

    def log_step(self, phase, step, loss_dict, image_dict):
        if phase == 'train':
            if step % 50 == 0:
                # self.writer.add_scalar('lr', get_lr(), step)
                # self.writer.add_scalar('%s-step/loss' % phase, loss, step)
                for key in sorted(loss_dict):
                    self.writer.add_scalar('%s-step/%s' % (phase, key), loss_dict[key], step)

            if step % 1000 == 0:
                for key in sorted(image_dict):
                    self.writer.add_image('%s/%s' % (self.model_name, key), image_dict[key], step)

    def log_epoch(self, phase, step, loss_dict):
        for key in sorted(loss_dict):
            self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key], step)




In [30]:
import sys
import time
import argparse
from tqdm import *

import torch
import torch.nn.functional as F

# project imports
# from models import SSRN
# from hparams import HParams as hp
# from logger import Logger
# from utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint
# from datasets.data_loader import SSRNDataLoader

In [None]:
torch.set_grad_enabled(False)
text2mel = Text2Mel(vocab)
#print(text2mel)
text2mel.load_state_dict(torch.load("/content/drive/My Drive/TTSBn/logdir/ljspeech-text2mel/step-010K.pth")['state_dict'])
text2mel = text2mel.eval()
ssrn = SSRN()
ssrn.load_state_dict(torch.load("/content/drive/My Drive/TTSBn/logdir/ljspeech-ssrn/step-005K.pth")['state_dict'])
ssrn = ssrn.eval()