In [1]:
import pickle as pkl
import tensorflow as tf
from model import *
import numpy as np
import os
from tqdm import tqdm 
from sklearn.model_selection import train_test_split

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
embedding_dim = 512
number_of_conv = 5
encoder_filters = 512
kernel_size =  5
encoder_rnn_unit = 256
zoneout = 0.1
dropout_rate = 0.5

prenet_dim = 256
n_mel_channels = 80
decoder_rnn_unit = 1024
linear_unit = n_mel_channels
attention_location_filters = 32
attention_dim = 128
attention_kernel_size = 34

postnet_embedding_dim = 128

In [4]:
text_file = ['./text/' + file_name for file_name in os.listdir('./text/')]
wav_file = ['./wav/' + file_name for file_name in os.listdir('./wav/')]
gate_file = ['./gate/' + file_name for file_name in os.listdir('./gate/')]

In [5]:
text = []
wav = []
gate = []
for text_fn, wav_fn, gate_fn in tqdm(zip(text_file[:100], wav_file[:100], gate_file[:100])):
    with open(text_fn, 'rb') as f:
        text.append(pkl.load(f))
    with open(wav_fn, 'rb') as f:
        wav.append(pkl.load(f).T)
    with open(gate_fn, 'rb') as f:
        gate.append(pkl.load(f))
        
with open('./vocab.pkl', 'rb') as f:
    vocab = pkl.load(f)
    
text = np.array(text)
wav = np.array(wav)
gate = np.array(gate)
inp_vocab_size = len(vocab)

100it [00:00, 699.24it/s]


In [6]:
wav.shape

(100, 810, 80)

In [7]:
train_input, dev_input, train_mel, dev_mel = train_test_split(text, wav, test_size = 0.1, random_state = 255)
train_gate, dev_gate = train_test_split(gate, test_size = 0.1, random_state = 255)

In [8]:
print(train_input.shape, dev_input.shape)

(90, 187) (10, 187)


In [9]:
optimizer = tf.keras.optimizers.Adam()
def mean_abs_error(x,y):
    '''
    x : model's predictions (B, T)
    y : label mel output (B, T)
    '''
    return tf.reduce_mean(tf.abs(x-y))
gate_loss = tf.keras.losses.BinaryCrossentropy()

def loss_function(pred, target):
    mel_output, post_mel_output, pred_gate = pred
    mel_target, target_gate = target
    
    mae = mean_abs_error(mel_output, mel_target) + mean_abs_error(post_mel_output, mel_target)
    bloss =  tf.cast(gate_loss(pred_gate, target_gate), tf.float32)
    return mae + bloss

In [10]:
mean_abs_error(tf.zeros((128,32,128)), tf.ones((128,32,128)))

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

In [11]:
encoder = Encoder(input_vocab_size = inp_vocab_size,
                  embedding_dim = embedding_dim,
                  num_of_conv_layer = number_of_conv,
                  filters = encoder_filters,
                  kernel_size = kernel_size,
                  rnn_unit = encoder_rnn_unit,
                  zoneout_prob = zoneout,
                  dropout_rate = dropout_rate)

decoder = Decoder(prenet_dim1 = prenet_dim, 
                  prenet_dim2 = prenet_dim, 
                  n_mel_channels = n_mel_channels, 
                  rnn_unit = decoder_rnn_unit, 
                  linear_unit = linear_unit, 
                  attention_location_filters = attention_location_filters, 
                  attention_dim = attention_dim, 
                  attention_loc_kernel_size = attention_kernel_size, 
                  dropout=dropout_rate)

postnet = PostNet(n_mel_channels=n_mel_channels,
                 postnet_embedding_dim = postnet_embedding_dim,
                 kernel_size = kernel_size,
                 dropout_rate = dropout_rate)

In [12]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder,
                                 postnet = postnet)

In [17]:
def train_step(batch_input, batch_target, batch_gate, encoder_hidden):
    loss = 0
    with tf.GradientTape() as tape:
        enc_output = encoder(batch_input, encoder_hidden)
        print(enc_output.shape)
        lstm1_hidden, lstm2_hidden = decoder.get_initialize(batch_target.shape[0])
        
        dec_input = tf.zeros((batch_target.shape[0], 1, batch_target.shape[-1]), tf.float32)
        attention_weights_cum = tf.zeros((batch_target.shape[0], batch_input.shape[1]))
        # initialize attention weights cum (batch_size, encoder_output_length)
        print(attention_weights_cum.shape)
        gate_output = []
        mel_output = []
        for i in range(batch_target.shape[1]-1):
            dec_input, stop_token_pred, lstm1_hidden, lstm2_hidden, attention_weights = decoder(dec_input,
                                                                             enc_output,
                                                                             lstm1_hidden,
                                                                             lstm2_hidden,
                                                                             attention_weights_cum)
            attention_weights_cum += attention_weights # attention culmulative
            gate_output.append(stop_token_pred)
            mel_output.append(dec_input)
            dec_input = tf.expand_dims(batch_target[:,i,:], 1) # new input of mel spectrogram (Batch_size, 1, mel_spectrogram dim)
        mel_output = tf.stack(mel_output, axis = 1)
        mel_output = tf.reshape(mel_output, (mel_output.shape[0], -1, mel_output.shape[-1]))
        # mel_output : mel_spectrogram (batch_size, mel_spectrogram(time dim), mel_spectrogram(mel dim))
        gate_output = tf.stack(gate_output, axis = -1)
        gate_output = tf.reshape(gate_output, (gate_output.shape[0], -1))
        # gate_output : stop token prediction : (batch_size, mel_spectrogram(time dim))
        print('mel_output :',mel_output.shape)
        print('gate_output :',gate_output.shape)
        postnet_output = postnet(mel_output)
        loss += loss_function(pred = (mel_output, postnet_output, gate_output),
                              target = (batch_target[:,:-1, :], batch_gate[...,:-1]))
    variables = encoder.trainable_variables + decoder.trainable_variables
    
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    
    return loss

In [18]:
train_step(tf.convert_to_tensor(train_input[:10]), 
           tf.convert_to_tensor(train_mel[:10]), 
           tf.convert_to_tensor(train_gate[:10]), encoder.get_initialization(10))

(10, 187, 512)
(10, 187)
mel_output : (10, 809, 80)
gate_output : (10, 809)


<tf.Tensor: shape=(), dtype=float32, numpy=0.46722493>