Skip to content

Commit

Permalink
Updates + fixes + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rayhane-mamah committed Mar 20, 2018
1 parent 81b657d commit c5e48a0
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 39 deletions.
3 changes: 2 additions & 1 deletion tacotron/datasets/feeder.py
Expand Up @@ -28,7 +28,8 @@ def __init__(self, coordinator, metadata_filename, hparams):
self._datadir = os.path.dirname(metadata_filename)
with open(metadata_filename, encoding='utf-8') as f:
self._metadata = [line.strip().split('|') for line in f]
hours = sum([int(x[1]) for x in self._metadata]) * hparams.frame_shift_ms / (3600 * 1000)
frame_shift_ms = hparams.hop_size / hparams.sample_rate
hours = sum([int(x[1]) for x in self._metadata]) * frame_shift_ms / (3600)
log('Loaded metadata for {} examples ({:.2f} hours)'.format(len(self._metadata), hours))

# Create placeholders for inputs and targets. Don't specify batch size because we want
Expand Down
5 changes: 4 additions & 1 deletion tacotron/griffin_lim_synthesis_example.ipynb
Expand Up @@ -4,6 +4,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
Expand All @@ -27,7 +28,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"wav = inv_mel_spectrogram(mel_spectro.T)\n",
Expand Down
1 change: 1 addition & 0 deletions tacotron/hparams.py
Expand Up @@ -44,6 +44,7 @@
enc_conv_channels=512, #number of encoder convolutions filters for each layer
encoder_lstm_units=256, #number of lstm units for each direction (forward and backward)

smoothing=False, #Whether to smooth the attention normalization function
attention_dim = 128, #dimension of attention space
attention_filters = 32, #number of attention convolution filters
attention_kernel = (31, ), #kernel size of attention convolution
Expand Down
6 changes: 3 additions & 3 deletions tacotron/models/Architecture_wrappers.py
Expand Up @@ -157,15 +157,15 @@ def __call__(self, inputs, state):
#Compute the attention (context) vector and alignments using
#the top layer hidden state as query vector
#and previous alignments to extract location features
#Based on Luong et Al. (2015):
#Based on Luong et Al. (2015) for the top layer choice:
#https://arxiv.org/pdf/1508.04025.pdf
first_lstm_state, last_lstm_state = state.cell_state
attention_inputs = last_lstm_state.h
last_hidden_state = last_lstm_state.h

previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
context_vector, alignments, _ = _compute_attention(self._attention_mechanism,
attention_inputs,
last_hidden_state,
previous_alignments,
attention_layer=None)

Expand Down
47 changes: 45 additions & 2 deletions tacotron/models/attention.py
Expand Up @@ -37,11 +37,32 @@ def _location_sensitive_score(W_query, W_fil, W_keys):
num_units = W_query.shape[-1].value or array_ops.shape(W_query)[-1]

v_a = tf.get_variable(
'v_a', shape=[num_units], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
'v_a', shape=[num_units], dtype=tf.float32)

return tf.reduce_sum(v_a * tf.tanh(W_keys + W_query + W_fil), axis=2)

def _smoothing_normalization(e):
"""Applies a smoothing normalization function instead of softmax
Introduced in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.
#######################################################################
Smoothing normalization function
a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
#######################################################################
Args:
e: matrix [batch_size, max_time(memory_time)]: expected to be energy (score)
values of an attention mechanism
Returns:
matrix [batch_size, max_time]: [0, 1] normalized alignments with possible
attendance to multiple memory time steps.
"""
return tf.nn.sigmoid(e) / tf.reduce_sum(tf.nn.sigmoid(e), axis=-1, keepdims=True)


class LocationSensitiveAttention(BahdanauAttention):
"""Impelements Bahdanau-style (cumulative) scoring function.
Expand All @@ -63,6 +84,7 @@ def __init__(self,
num_units,
memory,
memory_sequence_length=None,
smoothing=False,
name='LocationSensitiveAttention'):
"""Construct the Attention mechanism.
Args:
Expand All @@ -72,12 +94,33 @@ def __init__(self,
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
smoothing (optional): Boolean. Determines which normalization function to use.
Default normalization function (probablity_fn) is softmax. If smoothing is
enabled, we replace softmax with:
a_{i, j} = sigmoid(e_{i, j}) / sum_j(sigmoid(e_{i, j}))
Introduced in:
J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
gio, “Attention-based models for speech recognition,” in Ad-
vances in Neural Information Processing Systems, 2015, pp.
577–585.
This is mainly used if the model wants to attend to multiple inputs parts
at the same decoding step. We probably won't be using it since multiple sound
frames may depend from the same character, probably not the way around.
Note:
We still keep it implemented in case we want to test it. They used it in the
paper because they were doing speech recognitions, where one phoneme may depend from
multiple subsequent sound frames.
name: Name to use when creating ops.
"""
#Create normalization function
#Setting it to None defaults in using softmax
normalization_function = _smoothing_normalization if (smoothing == True) else None
super(LocationSensitiveAttention, self).__init__(
num_units=num_units,
memory=memory,
memory_sequence_length=memory_sequence_length,
probability_fn=normalization_function,
name=name)

self.location_convolution = tf.layers.Conv1D(filters=hparams.attention_filters,
Expand Down
38 changes: 16 additions & 22 deletions tacotron/models/modules.py
@@ -1,8 +1,7 @@
import tensorflow as tf
from .zoneout_LSTM import ZoneoutLSTMCell
from tensorflow.contrib.rnn import RNNCell, LSTMBlockCell
from tensorflow.contrib.rnn import LSTMBlockCell
from hparams import hparams
from tensorflow.python.layers import base


def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
Expand All @@ -14,15 +13,14 @@ def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
filters=channels,
kernel_size=kernel_size,
activation=None,
padding='same',
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))
padding='same')
batched = tf.layers.batch_normalization(conv1d_output, training=is_training)
activated = activation(batched)
return tf.layers.dropout(activated, rate=drop_rate, training=is_training,
name='dropout_{}'.format(scope))


class EncoderConvolutions(base.Layer):
class EncoderConvolutions:
"""Encoder convolutional layers used to find local dependencies in inputs characters.
"""
def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.relu, scope=None):
Expand Down Expand Up @@ -51,7 +49,7 @@ def __call__(self, inputs):
return x


class EncoderRNN(RNNCell):
class EncoderRNN:
"""Encoder bidirectional one layer LSTM
"""
def __init__(self, is_training, size=256, zoneout=0.1, scope=None):
Expand All @@ -70,10 +68,9 @@ def __init__(self, is_training, size=256, zoneout=0.1, scope=None):
self.scope = 'encoder_LSTM' if scope is None else scope

#Create LSTM Cell
# self._cell = ZoneoutLSTMCell(size, is_training,
# zoneout_factor_cell=zoneout,
# zoneout_factor_output=zoneout)
self._cell = LSTMBlockCell(size)
self._cell = ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
zoneout_factor_output=zoneout)

def __call__(self, inputs, input_lengths):
with tf.variable_scope(self.scope):
Expand All @@ -87,7 +84,7 @@ def __call__(self, inputs, input_lengths):
return tf.concat(outputs, axis=2) # Concat and return forward + backward outputs


class Prenet(base.Layer):
class Prenet:
"""Two fully connected layers used as an information bottleneck for the attention.
"""
def __init__(self, is_training, layer_sizes=[256, 256], activation=tf.nn.relu, scope=None):
Expand All @@ -114,7 +111,6 @@ def __call__(self, inputs):
with tf.variable_scope(self.scope):
for i, size in enumerate(self.layer_sizes):
dense = tf.layers.dense(x, units=size, activation=self.activation,
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
name='dense_{}'.format(i + 1))
#The paper discussed introducing diversity in generation at inference time
#by using a dropout of 0.5 only in prenet layers.
Expand All @@ -123,7 +119,7 @@ def __call__(self, inputs):
return x


class DecoderRNN(RNNCell):
class DecoderRNN:
"""Decoder two uni directional LSTM Cells
"""
def __init__(self, is_training, layers=2, size=1024, zoneout=0.1, scope=None):
Expand All @@ -143,10 +139,9 @@ def __init__(self, is_training, layers=2, size=1024, zoneout=0.1, scope=None):
self.scope = 'decoder_rnn' if scope is None else scope

#Create a set of LSTM layers
# self.rnn_layers = [ZoneoutLSTMCell(size, is_training,
# zoneout_factor_cell=zoneout,
# zoneout_factor_output=zoneout) for i in range(layers)]
self.rnn_layers = [LSTMBlockCell(size) for i in range(layers)]
self.rnn_layers = [ZoneoutLSTMCell(size, is_training,
zoneout_factor_cell=zoneout,
zoneout_factor_output=zoneout) for i in range(layers)]

self._cell = tf.contrib.rnn.MultiRNNCell(self.rnn_layers, state_is_tuple=True)

Expand All @@ -155,7 +150,7 @@ def __call__(self, inputs, states):
return self._cell(inputs, states)


class FrameProjection(base.Layer):
class FrameProjection:
"""Projection layer to r * num_mels dimensions or num_mels dimensions
"""
def __init__(self, shape=80, activation=None, scope=None):
Expand All @@ -177,13 +172,12 @@ def __call__(self, inputs):
#If activation==None, this returns a simple Linear projection
#else the projection will be passed through an activation function
output = tf.layers.dense(inputs, units=self.shape, activation=self.activation,
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
name='projection_{}'.format(self.scope))

return output


class StopProjection(base.Layer):
class StopProjection:
"""Projection to a scalar and through a sigmoid activation
"""
def __init__(self, is_training, shape=hparams.outputs_per_step, activation=tf.nn.sigmoid, scope=None):
Expand All @@ -204,7 +198,7 @@ def __init__(self, is_training, shape=hparams.outputs_per_step, activation=tf.nn

def __call__(self, inputs):
with tf.variable_scope(self.scope):
output = tf.layers.dense(inputs, units=self.shape, kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
output = tf.layers.dense(inputs, units=self.shape,
activation=None, name='projection_{}'.format(self.scope))

#During training, don't use activation as it is integrated inside the sigmoid_cross_entropy loss function
Expand All @@ -213,7 +207,7 @@ def __call__(self, inputs):
return self.activation(output)


class Postnet(base.Layer):
class Postnet:
"""Postnet that takes final decoder output and fine tunes it (using vision on past and future frames)
"""
def __init__(self, is_training, kernel_size=(5, ), channels=512, activation=tf.nn.tanh, scope=None):
Expand Down
6 changes: 3 additions & 3 deletions tacotron/models/tacotron.py
Expand Up @@ -41,8 +41,7 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets

# Embeddings ==> [batch_size, sequence_length, embedding_dim]
embedding_table = tf.get_variable(
'inputs_embedding', [len(symbols), hp.embedding_dim], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
'inputs_embedding', [len(symbols), hp.embedding_dim], dtype=tf.float32)
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs)


Expand All @@ -63,7 +62,8 @@ def initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets
#Attention Decoder Prenet
prenet = Prenet(is_training, layer_sizes=hp.prenet_layers, scope='decoder_prenet')
#Attention Mechanism
attention_mechanism = LocationSensitiveAttention(hp.attention_dim, encoder_outputs)
attention_mechanism = LocationSensitiveAttention(hp.attention_dim, encoder_outputs,
memory_sequence_length=input_lengths, smoothing=hp.smoothing)
#Decoder LSTM Cells
decoder_lstm = DecoderRNN(is_training, layers=hp.decoder_layers,
size=hp.decoder_lstm_units, zoneout=hp.zoneout_rate, scope='decoder_lstm')
Expand Down
8 changes: 1 addition & 7 deletions tacotron/utils/audio.py
Expand Up @@ -60,12 +60,6 @@ def _stft(y):
def _istft(y):
return librosa.istft(y, hop_length=get_hop_size())

def _stft_params():
n_fft = (hparams.num_freq - 1) * 2
hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate)
return n_fft, hop_length, win_length


# Conversions
_mel_basis = None
Expand All @@ -89,7 +83,7 @@ def _build_mel_basis():
fmin=hparams.fmin, fmax=hparams.fmax)

def _amp_to_db(x):
min_level = _db_to_amp(hparams.min_level_db)
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))

def _db_to_amp(x):
Expand Down

0 comments on commit c5e48a0

Please sign in to comment.