In [117]:
import torch
from torch import nn
from math import sqrt

from torch.autograd import Variable
from torch.nn import functional as F



In [118]:
""" from https://github.com/keithito/tacotron """

'''
Defines the set of symbols used in text input to the model.

The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
""" from https://github.com/keithito/tacotron """

import re


valid_symbols = [
  'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
  'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
  'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
  'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
  'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
  'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
  'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
]

_valid_symbol_set = set(valid_symbols)


class CMUDict:
  '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
  def __init__(self, file_or_path, keep_ambiguous=True):
    if isinstance(file_or_path, str):
      with open(file_or_path, encoding='latin-1') as f:
        entries = _parse_cmudict(f)
    else:
      entries = _parse_cmudict(file_or_path)
    if not keep_ambiguous:
      entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
    self._entries = entries


  def __len__(self):
    return len(self._entries)


  def lookup(self, word):
    '''Returns list of ARPAbet pronunciations of the given word.'''
    return self._entries.get(word.upper())



_alt_re = re.compile(r'\([0-9]+\)')


def _parse_cmudict(file):
  cmudict = {}
  for line in file:
    if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
      parts = line.split('  ')
      word = re.sub(_alt_re, '', parts[0])
      pronunciation = _get_pronunciation(parts[1])
      if pronunciation:
        if word in cmudict:
          cmudict[word].append(pronunciation)
        else:
          cmudict[word] = [pronunciation]
  return cmudict


def _get_pronunciation(s):
  parts = s.strip().split(' ')
  for part in parts:
    if part not in _valid_symbol_set:
      return None
  return ' '.join(parts)

_pad        = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in valid_symbols]

# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet

class hps:
	################################
	# Data Parameters              #
	################################
	text_cleaners=['english_cleaners']

	################################
	# Audio                        #
	################################
	num_mels = 80
	num_freq = 1025
	sample_rate = 16000
	frame_length_ms = 50
	frame_shift_ms = 12.5
	preemphasis = 0.97
	min_level_db = -100
	ref_level_db = 20
	power = 1.5
	gl_iters = 100

	################################
	# Model Parameters             #
	################################
	n_symbols = len(symbols)
	symbols_embedding_dim = 512

	# Encoder parameters
	encoder_kernel_size = 5

	# Decoder parameters
	n_frames_per_step = 2
	decoder_rnn_dim = 1024
	prenet_dim = 256
	max_decoder_steps = 120
	gate_threshold = 0.5
	p_attention_dropout = 0.1
	p_decoder_dropout = 0.1

	# Attention parameters
	attention_rnn_dim = 1024
	attention_dim = 128

	# Location Layer parameters
	attention_location_n_filters = 32
	attention_location_kernel_size = 31

	# Mel-post processing network parameters
	postnet_embedding_dim = 512
	postnet_kernel_size = 5
	postnet_n_convolutions = 5

	################################
	# Train                        #
	################################
	is_cuda = False
	pin_mem = True
	n_workers = 8
	lr = 2e-3
	betas = (0.9, 0.999)
	eps = 1e-6
	sch = True
	sch_step = 4000
	max_iter = 1e6
	batch_size = 40
	iters_per_log = 50
	iters_per_sample = 500
	iters_per_ckpt = 1000
	weight_decay = 1e-6
	grad_clip_thresh = 1.0
	mask_padding = True
	p = 10 # mel spec loss penalty
	eg_text = 'Make America great again!'

	############# added
	iscrop = True
	encoder_embedding_dim = 384  # encoder_lstm_units
	encoder_n_convolutions = 5  # enc_conv_num_blocks

	num_init_filters= 24

	prenet_layers= [256, 256]
	decoder_layers= 2
	decoder_lstm_units= 256

	tacotron_teacher_forcing_start_decay= 29000
	tacotron_teacher_forcing_decay_steps= 130000

	T= 90 #90
	overlap= 15
	mel_overlap= 40
	mel_step_size= 240
	img_size = 96
	fps= 30


	use_lws = False
	# Mel spectrogram
	n_fft = 800  # Extra window size is filled with 0 paddings to match this parameter
	hop_size = 200  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
	win_size = 800  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)

	# M-AILABS (and other datasets) trim params (these parameters are usually correct for any
	# data, but definitely must be tuned for specific speakers)
	trim_fft_size = 512
	trim_hop_size = 128
	trim_top_db = 23

	# Mel and Linear spectrograms normalization/scaling and clipping
	signal_normalization = True
	# Whether to normalize mel spectrograms to some predefined range (following below parameters)
	allow_clipping_in_normalization = True # Only relevant if mel_normalization = True
	symmetric_mels = True
	# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
	# faster and cleaner convergence)
	max_abs_value = 4.
	# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
	# be too big to avoid gradient explosion,
	# not too small for fast convergence)
	normalize_for_wavenet = True
	# whether to rescale to [0, 1] for wavenet. (better audio quality)
	clip_for_wavenet = True
	# whether to clip [-max, max] before training/synthesizing with wavenet (better audio quality)

	# Contribution by @begeekmyfriend
	# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
	# levels. Also allows for better G&L phase reconstruction)
	preemphasize = True # whether to apply filter

	fmin = 55
	# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
	# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
	fmax = 7600  # To be increased/reduced depending on data.

	# Griffin Lim
	# Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice.
	griffin_lim_iters = 60
# Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence.
###########################################################################################################################################

In [119]:
import torch
import numpy as np


def mode(obj, model = False):
	if model and hps.is_cuda:
		obj = obj
	elif hps.is_cuda:
		obj = obj.cuda(non_blocking = hps.pin_mem)
	return obj

def to_var(tensor):
	var = torch.autograd.Variable(tensor)
	return mode(var)

def to_arr(var):
	return var.cpu().detach().numpy().astype(np.float32)

def get_mask_from_lengths(lengths, pad = False):
	max_len = torch.max(lengths).int().item()
	if pad and max_len%hps.n_frames_per_step != 0:
		max_len += hps.n_frames_per_step - max_len%hps.n_frames_per_step
		assert max_len%hps.n_frames_per_step == 0
	ids = torch.arange(0, max_len, out = torch.LongTensor(max_len))
	ids = mode(ids)
	mask = (ids < lengths.unsqueeze(1))
	return mask

In [120]:
import torch


class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal

class ConvNorm3D(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear', activation=torch.nn.ReLU, residual=False):
        super(ConvNorm3D, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.residual = residual
        self.conv3d = torch.nn.Conv3d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)
        self.batched = torch.nn.BatchNorm3d(out_channels)
        self.activation = activation()

        torch.nn.init.xavier_uniform_(
            self.conv3d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.batched.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.activation.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv3d(signal)

        batched = self.batched(conv_signal)

        if self.residual:
            batched = batched + signal
        activated = self.activation(batched)

        return activated

In [173]:

class Tacotron2Loss(nn.Module):
	def __init__(self):
		super(Tacotron2Loss, self).__init__()

	def forward(self, model_output, targets, iteration):
		mel_target, gate_target = targets[0], targets[1]
		mel_target.requires_grad = False
		gate_target.requires_grad = False
		slice = torch.arange(0, gate_target.size(1), hps.n_frames_per_step)
		gate_target = gate_target[:, slice].view(-1, 1)

		mel_out, mel_out_postnet, gate_out, _ = model_output
		gate_out = gate_out.view(-1, 1)
		p = hps.p
		mel_loss = nn.MSELoss()(p*mel_out, p*mel_target)
		mel_loss_post = nn.MSELoss()(p*mel_out_postnet, p*mel_target)
		gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)

		# added
		l1_loss = nn.L1Loss()(mel_target, mel_out)
		return mel_loss, mel_loss_post, l1_loss, gate_loss#, ((mel_loss+mel_loss_post)/(p**2)+gate_loss+l1_loss).item()


class LocationLayer(nn.Module):
	def __init__(self, attention_n_filters, attention_kernel_size,
				 attention_dim):
		super(LocationLayer, self).__init__()
		padding = int((attention_kernel_size - 1) / 2)
		self.location_conv = ConvNorm(2, attention_n_filters,
									  kernel_size=attention_kernel_size,
									  padding=padding, bias=False, stride=1,
									  dilation=1)
		self.location_dense = LinearNorm(attention_n_filters, attention_dim,
										 bias=False, w_init_gain='tanh')

	def forward(self, attention_weights_cat):
		processed_attention = self.location_conv(attention_weights_cat)
		processed_attention = processed_attention.transpose(1, 2)
		processed_attention = self.location_dense(processed_attention)
		return processed_attention


class Attention(nn.Module):
	def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
				 attention_location_n_filters, attention_location_kernel_size):
		super(Attention, self).__init__()
		self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
									  bias=False, w_init_gain='tanh')
		self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
									   w_init_gain='tanh')
		self.v = LinearNorm(attention_dim, 1, bias=False)
		self.location_layer = LocationLayer(attention_location_n_filters,
											attention_location_kernel_size,
											attention_dim)
		self.score_mask_value = -float('inf')

	def get_alignment_energies(self, query, processed_memory,
							   attention_weights_cat):
		'''
		PARAMS
		------
		query: decoder output (batch, num_mels * n_frames_per_step)
		processed_memory: processed encoder outputs (B, T_in, attention_dim)
		attention_weights_cat: cumulative and prev. att we;[ights (B, 2, max_time)

		RETURNS
		-------
		alignment (batch, max_time)
		'''

		processed_query = self.query_layer(query.unsqueeze(1))
		processed_attention_weights = self.location_layer(attention_weights_cat)
		energies = self.v(torch.tanh(
			processed_query + processed_attention_weights + processed_memory))

		energies = energies.squeeze(-1)
		return energies

	def forward(self, attention_hidden_state, memory, processed_memory,
				attention_weights_cat, mask):
		'''
		PARAMS
		------
		attention_hidden_state: attention rnn last output
		memory: encoder outputs
		processed_memory: processed encoder outputs
		attention_weights_cat: previous and cummulative attention weights
		mask: binary mask for padded data
		'''
		alignment = self.get_alignment_energies(
			attention_hidden_state, processed_memory, attention_weights_cat)

		if mask is not None:
			alignment.data.masked_fill_(mask, self.score_mask_value)

		attention_weights = F.softmax(alignment, dim=1)
		attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
		attention_context = attention_context.squeeze(1)

		return attention_context, attention_weights

class Prenet(nn.Module):
	def __init__(self, in_dim, sizes):
		super(Prenet, self).__init__()
		in_sizes = [in_dim] + sizes[:-1]
		self.layers = nn.ModuleList(
			[LinearNorm(in_size, out_size, bias=False)
			 for (in_size, out_size) in zip(in_sizes, sizes)])

	def forward(self, x):
		for linear in self.layers:
			x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
		return x


class Postnet(nn.Module):
	'''Postnet
		- Five 1-d convolution with 512 channels and kernel size 5
	'''

	def __init__(self):
		super(Postnet, self).__init__()
		self.convolutions = nn.ModuleList()

		self.convolutions.append(
			nn.Sequential(
				ConvNorm(hps.num_mels, hps.postnet_embedding_dim,
						 kernel_size=hps.postnet_kernel_size, stride=1,
						 padding=int((hps.postnet_kernel_size - 1) / 2),
						 dilation=1, w_init_gain='tanh'),
				nn.BatchNorm1d(hps.postnet_embedding_dim))
		)

		for i in range(1, hps.postnet_n_convolutions - 1):
			self.convolutions.append(
				nn.Sequential(
					ConvNorm(hps.postnet_embedding_dim,
							 hps.postnet_embedding_dim,
							 kernel_size=hps.postnet_kernel_size, stride=1,
							 padding=int((hps.postnet_kernel_size - 1) / 2),
							 dilation=1, w_init_gain='tanh'),
					nn.BatchNorm1d(hps.postnet_embedding_dim))
			)

		self.convolutions.append(
			nn.Sequential(
				ConvNorm(hps.postnet_embedding_dim, hps.num_mels,
						 kernel_size=hps.postnet_kernel_size, stride=1,
						 padding=int((hps.postnet_kernel_size - 1) / 2),
						 dilation=1, w_init_gain='linear'),
				nn.BatchNorm1d(hps.num_mels))
			)

	def forward(self, x):
		for i in range(len(self.convolutions) - 1):
			x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
		x = F.dropout(self.convolutions[-1](x), 0.5, self.training)

		return x


class Encoder(nn.Module):
	'''Encoder module:
		- Three 1-d convolution banks
		- Bidirectional LSTM
	'''
	def __init__(self):
		super(Encoder, self).__init__()

		convolutions = []
		for _ in range(hps.encoder_n_convolutions):
			conv_layer = nn.Sequential(
				ConvNorm(hps.encoder_embedding_dim,
						 hps.encoder_embedding_dim,
						 kernel_size=hps.encoder_kernel_size, stride=1,
						 padding=int((hps.encoder_kernel_size - 1) / 2),
						 dilation=1, w_init_gain='relu'),
				nn.BatchNorm1d(hps.encoder_embedding_dim))
			convolutions.append(conv_layer)
		self.convolutions = nn.ModuleList(convolutions)

		self.lstm = nn.LSTM(hps.encoder_embedding_dim,
							int(hps.encoder_embedding_dim / 2), 1,
							batch_first=True, bidirectional=True)

	def forward(self, x, input_lengths):
		for conv in self.convolutions:
			x = F.dropout(F.relu(conv(x)), 0.5, self.training)

		x = x.transpose(1, 2)

		# pytorch tensor are not reversible, hence the conversion
		input_lengths = input_lengths.cpu().numpy()
		x = nn.utils.rnn.pack_padded_sequence(
			x, input_lengths, batch_first=True)

		self.lstm.flatten_parameters()
		outputs, _ = self.lstm(x)

		outputs, _ = nn.utils.rnn.pad_packed_sequence(
			outputs, batch_first=True)

		return outputs

	def inference(self, x):
		for conv in self.convolutions:
			x = F.dropout(F.relu(conv(x)), 0.5, self.training)

		x = x.transpose(1, 2)

		self.lstm.flatten_parameters()
		outputs, _ = self.lstm(x)

		return outputs


class Encoder3D(nn.Module):
	"""Encoder module:
        - Three 3-d convolution banks
        - Bidirectional LSTM
    """

	def __init__(self, hparams):
		super(Encoder3D, self).__init__()

		self.hparams = hparams
		self.out_channel = hps.num_init_filters
		self.in_channel = 3
		convolutions = []

		for i in range(hps.encoder_n_convolutions):
			if i == 0:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.in_channel, self.out_channel,
								kernel_size=5, stride=(1, 2, 2),
								# padding=int((hparams.encoder_kernel_size - 1) / 2),
								dilation=1, w_init_gain='relu'),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True)
				)
				convolutions.append(conv_layer)
			else:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.in_channel, self.out_channel,
							   kernel_size=3, stride=(1, 2, 2),
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu'),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True)
				)
				convolutions.append(conv_layer)

			if i == hps.encoder_n_convolutions - 1:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=(1, 3, 3),
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu'))
				convolutions.append(conv_layer)

			self.in_channel = self.out_channel
			self.out_channel *= 2
		self.convolutions = nn.ModuleList(convolutions)

		self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
		                    int(hparams.encoder_embedding_dim / 2), 1,
		                    batch_first=True, bidirectional=True)

	def forward(self, x, input_lengths):
		for conv in self.convolutions:
			x = F.dropout(conv(x), 0.5, self.training)
		# for i in range(len(self.convolutions)):
		# 	if i==0 or i==1 or i ==2:
		# 		with torch.no_grad():
		# 			x = F.dropout(self.convolutions[i](x), 0.5, self.training)
		# 	else:
		# 		x = F.dropout(self.convolutions[i](x), 0.5, self.training)

		x = x.permute(0, 2, 1, 3, 4).squeeze(4).squeeze(3).contiguous()  # [bs x 90 x encoder_embedding_dim]
		print(x.size())
		# pytorch tensor are not reversible, hence the conversion
		input_lengths = input_lengths.cpu().numpy()
		# x = nn.utils.rnn.pack_padded_sequence(
		# 	x, input_lengths, batch_first=True)

		# self.lstm.flatten_parameters()
		outputs, _ = self.lstm(x)
		# print('outputs',outputs.size())
		# outputs, _ = nn.utils.rnn.pad_packed_sequence(
		# 	outputs, batch_first=True)
		# print('outputs', outputs.size())

		return outputs

	def inference(self, x):
		for conv in self.convolutions:
			x = F.dropout(conv(x), 0.5, self.training)

		x = x.permute(0, 2, 1, 3, 4).squeeze(4).squeeze(3).contiguous()
		# self.lstm.flatten_parameters()
		outputs, _ = self.lstm(x)	#x:B,T,C

		return outputs


class Decoder(nn.Module):
	def __init__(self):
		super(Decoder, self).__init__()
		self.num_mels = hps.num_mels
		self.n_frames_per_step = hps.n_frames_per_step
		self.encoder_embedding_dim = hps.encoder_embedding_dim
		self.attention_rnn_dim = hps.attention_rnn_dim
		self.decoder_rnn_dim = hps.decoder_rnn_dim
		self.prenet_dim = hps.prenet_dim
		self.max_decoder_steps = hps.max_decoder_steps
		self.gate_threshold = hps.gate_threshold
		self.p_attention_dropout = hps.p_attention_dropout
		self.p_decoder_dropout = hps.p_decoder_dropout

		self.prenet = Prenet(
			hps.num_mels * hps.n_frames_per_step,
			[hps.prenet_dim, hps.prenet_dim])

		self.attention_rnn = nn.LSTMCell(
			hps.prenet_dim + hps.encoder_embedding_dim,
			hps.attention_rnn_dim)

		self.attention_layer = Attention(
			hps.attention_rnn_dim, hps.encoder_embedding_dim,
			hps.attention_dim, hps.attention_location_n_filters,
			hps.attention_location_kernel_size)

		self.decoder_rnn = nn.LSTMCell(
			hps.attention_rnn_dim + hps.encoder_embedding_dim,
			hps.decoder_rnn_dim, 1)

		self.linear_projection = LinearNorm(
			hps.decoder_rnn_dim + hps.encoder_embedding_dim,
			hps.num_mels * hps.n_frames_per_step)

		self.gate_layer = LinearNorm(
			hps.decoder_rnn_dim + hps.encoder_embedding_dim, 1,
			bias=True, w_init_gain='sigmoid')

	def get_go_frame(self, memory):
		''' Gets all zeros frames to use as first decoder input
		PARAMS
		------
		memory: decoder outputs

		RETURNS
		-------
		decoder_input: all zeros frames
		'''
		B = memory.size(0)
		decoder_input = Variable(memory.data.new(
			B, self.num_mels * self.n_frames_per_step).zero_())
		# print(decoder_input)
		# print(decoder_input.size())
		return decoder_input

	def initialize_decoder_states(self, memory, mask):
		''' Initializes attention rnn states, decoder rnn states, attention
		weights, attention cumulative weights, attention context, stores memory
		and stores processed memory
		PARAMS
		------
		memory: Encoder outputs
		mask: Mask for padded data if training, expects None for inference
		'''
		B = memory.size(0)
		MAX_TIME = memory.size(1)

		self.attention_hidden = Variable(memory.data.new(
			B, self.attention_rnn_dim).zero_())
		self.attention_cell = Variable(memory.data.new(
			B, self.attention_rnn_dim).zero_())

		self.decoder_hidden = Variable(memory.data.new(
			B, self.decoder_rnn_dim).zero_())
		self.decoder_cell = Variable(memory.data.new(
			B, self.decoder_rnn_dim).zero_())

		self.attention_weights = Variable(memory.data.new(
			B, MAX_TIME).zero_())
		self.attention_weights_cum = Variable(memory.data.new(
			B, MAX_TIME).zero_())
		self.attention_context = Variable(memory.data.new(
			B, self.encoder_embedding_dim).zero_())

		self.memory = memory
		self.processed_memory = self.attention_layer.memory_layer(memory)
		self.mask = mask

	def parse_decoder_inputs(self, decoder_inputs):
		''' Prepares decoder inputs, i.e. mel outputs
		PARAMS
		------
		decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

		RETURNS
		-------
		inputs: processed decoder inputs

		'''
		# (B, num_mels, T_out) -> (B, T_out, num_mels)
		decoder_inputs = decoder_inputs.transpose(1, 2).contiguous()
		decoder_inputs = decoder_inputs.view(
			decoder_inputs.size(0),
			int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
		# (B, T_out, num_mels) -> (T_out, B, num_mels)
		decoder_inputs = decoder_inputs.transpose(0, 1)
		return decoder_inputs

	def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
		''' Prepares decoder outputs for output
		PARAMS
		------
		mel_outputs:
		gate_outputs: gate output energies
		alignments:

		RETURNS
		-------
		mel_outputs:
		gate_outpust: gate output energies
		alignments:
		'''
		# (T_out, B) -> (B, T_out)
		alignments = torch.stack(alignments).transpose(0, 1)
		# (T_out, B) -> (B, T_out)

		gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
		gate_outputs = gate_outputs.contiguous()
		# (T_out, B, num_mels) -> (B, T_out, num_mels)
		mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
		# decouple frames per step
		mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.num_mels)
		# (B, T_out, num_mels) -> (B, num_mels, T_out)
		mel_outputs = mel_outputs.transpose(1, 2)

		return mel_outputs, gate_outputs, alignments

	def decode(self, decoder_input):
		''' Decoder step using stored states, attention and memory
		PARAMS
		------
		decoder_input: previous mel output

		RETURNS
		-------
		mel_output:
		gate_output: gate output energies
		attention_weights:
		'''
		cell_input = torch.cat((decoder_input, self.attention_context), -1)

		self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))

		self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training)

		attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)

		self.attention_context, self.attention_weights = self.attention_layer(self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask)

		self.attention_weights_cum += self.attention_weights

		decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)

		self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
		self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)

		decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)

		decoder_output = self.linear_projection(decoder_hidden_attention_context)

		gate_prediction = self.gate_layer(decoder_hidden_attention_context)

		return decoder_output, gate_prediction, self.attention_weights

	def forward(self, memory, decoder_inputs, memory_lengths):
		''' Decoder forward pass for training
		PARAMS
		------
		memory: Encoder outputs
		decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
		memory_lengths: Encoder output lengths for attention masking.

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
		# print('Encoder outputs', memory.size())
		decoder_input = self.get_go_frame(memory).unsqueeze(0)
		decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
		decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
		decoder_inputs = self.prenet(decoder_inputs)
		# print('decoder_input', decoder_input.size())

		self.initialize_decoder_states(
			memory, mask=~get_mask_from_lengths(memory_lengths))
		mel_outputs, gate_outputs, alignments = [], [], []
		while len(mel_outputs) < decoder_inputs.size(0) - 1:
			decoder_input = decoder_inputs[len(mel_outputs)]
			mel_output, gate_output, attention_weights = self.decode(
				decoder_input)
			# print('mel_output', mel_output.size())
			# print('gate_output', gate_output.size())
			# print('attention_weights', attention_weights.size())
			mel_outputs += [mel_output.squeeze(1)]
			gate_outputs += [gate_output.squeeze((1))]
			alignments += [attention_weights]

		mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
			mel_outputs, gate_outputs, alignments)

		return mel_outputs, gate_outputs, alignments

	def inference(self, memory):
		''' Decoder inference
		PARAMS
		------
		memory: Encoder outputs

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
		decoder_input = self.get_go_frame(memory)

		self.initialize_decoder_states(memory, mask=None)

		mel_outputs, gate_outputs, alignments = [], [], []
		while True:
			decoder_input = self.prenet(decoder_input)
			mel_output, gate_output, alignment = self.decode(decoder_input)
			mel_outputs += [mel_output.squeeze(1)]
			gate_outputs += [gate_output]
			alignments += [alignment]

			if sum(torch.sigmoid(gate_output.data))/len(gate_output.data) > self.gate_threshold:
				print('Terminated by gate.')
				break
			# elif len(mel_outputs) > 1 and is_end_of_frames(mel_output):
			# 	print('Warning: End with low power.')
			# 	break
			elif len(mel_outputs) == self.max_decoder_steps:
				print('Warning: Reached max decoder steps.')
				break

			decoder_input = mel_output

		print(mel_outputs.shape)

		mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
			mel_outputs, gate_outputs, alignments)
		return mel_outputs, gate_outputs, alignments

def is_end_of_frames(output, eps = 0.2):
    return (output.data <= eps).all()

class Tacotron2(nn.Module):
	def __init__(self):
		super(Tacotron2, self).__init__()
		self.num_mels = hps.num_mels
		self.mask_padding = hps.mask_padding
		self.n_frames_per_step = hps.n_frames_per_step
		self.embedding = nn.Embedding(
			hps.n_symbols, hps.symbols_embedding_dim)
		std = sqrt(2.0/(hps.n_symbols+hps.symbols_embedding_dim))
		val = sqrt(3.0)*std  # uniform bounds for std
		self.embedding.weight.data.uniform_(-val, val)
		# self.encoder = Encoder()
		self.encoder = Encoder3D(hps)
		self.decoder = Decoder()
		self.postnet = Postnet()

	def parse_batch(self, batch):
		text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch
		text_padded = to_var(text_padded).long()
		input_lengths = to_var(input_lengths).long()
		max_len = torch.max(input_lengths.data).item()
		mel_padded = to_var(mel_padded).float()
		gate_padded = to_var(gate_padded).float()
		output_lengths = to_var(output_lengths).long()

		return (
			(text_padded, input_lengths, mel_padded, max_len, output_lengths),
			(mel_padded, gate_padded))

	def parse_batch_vid(self, batch):
		vid_padded, input_lengths, mel_padded, gate_padded, target_lengths, split_infos, embed_targets = batch
		vid_padded = to_var(vid_padded).float()
		input_lengths = to_var(input_lengths).float()
		mel_padded = to_var(mel_padded).float()
		gate_padded = to_var(gate_padded).float()
		target_lengths = to_var(target_lengths).float()

		max_len_vid = split_infos[0].data.item()
		max_len_target = split_infos[1].data.item()


		mel_padded = to_var(mel_padded).float()

		return(
			(vid_padded, input_lengths, mel_padded, max_len_vid, target_lengths),
			(mel_padded, gate_padded))

	def parse_output(self, outputs, output_lengths=None):
		if self.mask_padding and output_lengths is not None:
			mask = ~get_mask_from_lengths(output_lengths, True) # (B, T)
			mask = mask.expand(self.num_mels, mask.size(0), mask.size(1)) # (80, B, T)
			mask = mask.permute(1, 0, 2) # (B, 80, T)
			
			outputs[0].data.masked_fill_(mask, 0.0) # (B, 80, T)
			outputs[1].data.masked_fill_(mask, 0.0) # (B, 80, T)
			slice = torch.arange(0, mask.size(2), self.n_frames_per_step)
			outputs[2].data.masked_fill_(mask[:, 0, slice], 1e3)  # gate energies (B, T//n_frames_per_step)

		return outputs

	def forward(self, inputs):
		vid_inputs = inputs["video"].transpose(-3, -4)
		mels = inputs["mel"]
		output_lengths = torch.tensor([mels.shape[-1]])
		vid_lengths = torch.tensor([vid_inputs.shape[-3]])
  
  
		vid_lengths, output_lengths = vid_lengths.data, output_lengths.data

		embedded_inputs = vid_inputs
		# print('vid_inputs',vid_inputs)

		encoder_outputs = self.encoder(embedded_inputs, vid_lengths)
		mel_outputs, gate_outputs, alignments = self.decoder(encoder_outputs, mels, memory_lengths=vid_lengths)

		mel_outputs_postnet = self.postnet(mel_outputs)
		mel_outputs_postnet = mel_outputs + mel_outputs_postnet

		mel_outputs, mel_outputs_postnet, gate_outputs, alignments = self.parse_output(
			[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
			output_lengths)

		return mel_outputs, mel_outputs_postnet, gate_outputs, alignments

	def inference(self, inputs):
		vid_inputs = inputs["video"].transpose(-3, -4)
	
		embedded_inputs = vid_inputs

		encoder_outputs = self.encoder.inference(embedded_inputs)

		mel_outputs, gate_outputs, alignments = self.decoder.inference(
			encoder_outputs)
		mel_outputs_postnet = self.postnet(mel_outputs)

		mel_outputs_postnet = mel_outputs + mel_outputs_postnet
		mel_outputs, mel_outputs_postnet, gate_outputs, alignments = self.parse_output(
			[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

		return mel_outputs, mel_outputs_postnet, gate_outputs, alignments

	def teacher_infer(self, inputs, mels):
		il, _ =  torch.sort(torch.LongTensor([len(x) for x in inputs]),
							dim = 0, descending = True)
		vid_lengths = to_var(il)

		embedded_inputs = self.embedding(inputs).transpose(1, 2)

		encoder_outputs = self.encoder(embedded_inputs, vid_lengths)

		mel_outputs, gate_outputs, alignments = self.decoder(
			encoder_outputs, mels, memory_lengths=vid_lengths)
		
		mel_outputs_postnet = self.postnet(mel_outputs)
		mel_outputs_postnet = mel_outputs + mel_outputs_postnet

		return self.parse_output(
			[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

In [174]:
self = Tacotron2()

In [178]:
vid_inputs = torch.randn(1, 75, 3, 48, 48)
mels = torch.randn(1, 80, 150)

In [179]:
mel_outputs, mel_outputs_postnet, gate_outputs, alignments = self({"video": vid_inputs, "mel": mels})

torch.Size([1, 75, 384])


In [171]:
mel_outputs_postnet.shape

torch.Size([1, 80, 150])