In [25]:
import torch.nn as nn
import torch
import torch.nn.functional as F


class ConvNorm3D(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear', activation=torch.nn.ReLU, residual=False, ):
        super(ConvNorm3D, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.residual = residual
        self.conv3d = torch.nn.Conv3d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)
        self.batched = torch.nn.BatchNorm3d(out_channels)
        self.activation = activation()

        torch.nn.init.xavier_uniform_(
            self.conv3d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.batched.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.activation.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        # pdb.set_trace()
        conv_signal = self.conv3d(signal)

        batched = self.batched(conv_signal)

        if self.residual:
            batched = batched + signal
        activated = self.activation(batched)

        return activated

import torchvision

class Encoder3D(nn.Module):
	"""Encoder module:
        - Three 3-d convolution banks
        - Bidirectional LSTM
    """

	def __init__(self, num_out_feat = 80, encoder_embedding_dim = 384, duration = 3.0, fps = 25, sr = 16000, encoder_n_convolutions = 5, num_init_filters= 24):
		super(Encoder3D, self).__init__()
  
		T = int(duration * fps)

		self.out_channel = num_init_filters
		self.in_channel = 3
		convolutions = []
  
		self.resize = torchvision.transforms.Resize((96, 96))

		for i in range(encoder_n_convolutions):
			if i == 0:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.in_channel, self.out_channel,
								kernel_size=5, stride=(1, 2, 2),
								# padding=int((hparams.encoder_kernel_size - 1) / 2),
								dilation=1, w_init_gain='relu'),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True)
				)
				convolutions.append(conv_layer)
			else:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.in_channel, self.out_channel,
							   kernel_size=3, stride=(1, 2, 2),
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu'),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True),
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=1,
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu', residual=True)
				)
				convolutions.append(conv_layer)

			if i == encoder_n_convolutions - 1:
				conv_layer = nn.Sequential(
					ConvNorm3D(self.out_channel, self.out_channel,
							   kernel_size=3, stride=(1, 3, 3),
							   # padding=int((hparams.encoder_kernel_size - 1) / 2),
							   dilation=1, w_init_gain='relu'), 
     
     )
				convolutions.append(conv_layer)

			self.in_channel = self.out_channel
			self.out_channel *= 2
		self.convolutions = nn.ModuleList(convolutions)

		self.lstm = nn.LSTM(encoder_embedding_dim,
		                    num_out_feat , 1,
		                    batch_first=True, bidirectional=True)
		self.fc = nn.Linear(2*num_out_feat , num_out_feat )
		self.conv_out = nn.Sequential(
      			nn.ConvTranspose1d(T, int(((duration*sr)*22050/sr)/256) , kernel_size=1), 
				nn.Tanh(),
			)

	def forward(self, x, input_lengths = None):
		B, T, C, H, W = x.size() 
		x = x.reshape(B*T, C, H, W)
		x = self.resize(x).reshape(B, T, C, 96, 96)
		for conv in self.convolutions:
			x = F.dropout(conv(x), 0.5, self.training)
		# for i in range(len(self.convolutions)):
		# 	if i==0 or i==1 or i ==2:
		# 		with torch.no_grad():
		# 			x = F.dropout(self.convolutions[i](x), 0.5, self.training)
		# 	else:
		# 		x = F.dropout(self.convolutions[i](x), 0.5, self.training)

		x = x.permute(0, 2, 1, 3, 4).squeeze(4).squeeze(3).contiguous()  # [bs x 90 x encoder_embedding_dim]
		# print(x.size())
		# pytorch tensor are not reversible, hence the conversion
		# input_lengths = input_lengths.cpu().numpy()
		# x = nn.utils.rnn.pack_padded_sequence(
		# 	x, input_lengths, batch_first=True)

		# self.lstm.flatten_parameters()
		outputs, _ = self.lstm(x)
		outputs = self.fc(outputs)
		outputs = self.conv_out (outputs)


		return outputs.transpose(-1, -2).contiguous()


In [26]:
enc = Encoder3D()

In [27]:
# from text import symbols


class hparams:
	################################
	# Data Parameters              #
	################################
	text_cleaners=['english_cleaners']

	################################
	# Audio                        #
	################################
	num_mels = 80
	num_freq = 1025
	sample_rate = 16000
	frame_length_ms = 50
	frame_shift_ms = 12.5
	preemphasis = 0.97
	min_level_db = -100
	ref_level_db = 20
	power = 1.5
	gl_iters = 100

	################################
	# Model Parameters             #
	################################
	# n_symbols = len(symbols)
	symbols_embedding_dim = 512

	# Encoder parameters
	encoder_kernel_size = 5

	# Decoder parameters
	n_frames_per_step = 2
	decoder_rnn_dim = 1024
	prenet_dim = 256
	max_decoder_steps = 120
	gate_threshold = 0.5
	p_attention_dropout = 0.1
	p_decoder_dropout = 0.1

	# Attention parameters
	attention_rnn_dim = 1024
	attention_dim = 128

	# Location Layer parameters
	attention_location_n_filters = 32
	attention_location_kernel_size = 31

	# Mel-post processing network parameters
	postnet_embedding_dim = 512
	postnet_kernel_size = 5
	postnet_n_convolutions = 5

	################################
	# Train                        #
	################################
	is_cuda = True
	pin_mem = True
	n_workers = 8
	lr = 2e-3
	betas = (0.9, 0.999)
	eps = 1e-6
	sch = True
	sch_step = 4000
	max_iter = 1e6
	batch_size = 40
	iters_per_log = 50
	iters_per_sample = 500
	iters_per_ckpt = 1000
	weight_decay = 1e-6
	grad_clip_thresh = 1.0
	mask_padding = True
	p = 10 # mel spec loss penalty
	eg_text = 'Make America great again!'

	############# added
	iscrop = True
	encoder_embedding_dim = 384  # encoder_lstm_units
	encoder_n_convolutions = 5  # enc_conv_num_blocks

	num_init_filters= 24

	prenet_layers= [256, 256]
	decoder_layers= 2
	decoder_lstm_units= 256

	tacotron_teacher_forcing_start_decay= 29000
	tacotron_teacher_forcing_decay_steps= 130000

	T= 90 #90
	overlap= 15
	mel_overlap= 40
	mel_step_size= 240
	img_size = 96
	fps= 30


	use_lws = False
	# Mel spectrogram
	n_fft = 800  # Extra window size is filled with 0 paddings to match this parameter
	hop_size = 200  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
	win_size = 800  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)

	# M-AILABS (and other datasets) trim params (these parameters are usually correct for any
	# data, but definitely must be tuned for specific speakers)
	trim_fft_size = 512
	trim_hop_size = 128
	trim_top_db = 23

	# Mel and Linear spectrograms normalization/scaling and clipping
	signal_normalization = True
	# Whether to normalize mel spectrograms to some predefined range (following below parameters)
	allow_clipping_in_normalization = True # Only relevant if mel_normalization = True
	symmetric_mels = True
	# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
	# faster and cleaner convergence)
	max_abs_value = 4.
	# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
	# be too big to avoid gradient explosion,
	# not too small for fast convergence)
	normalize_for_wavenet = True
	# whether to rescale to [0, 1] for wavenet. (better audio quality)
	clip_for_wavenet = True
	# whether to clip [-max, max] before training/synthesizing with wavenet (better audio quality)

	# Contribution by @begeekmyfriend
	# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
	# levels. Also allows for better G&L phase reconstruction)
	preemphasize = True # whether to apply filter

	fmin = 55
	# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
	# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
	fmax = 7600  # To be increased/reduced depending on data.

	# Griffin Lim
	# Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice.
	griffin_lim_iters = 60
# Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence.
###########################################################################################################################################


In [28]:
hps = hparams

In [29]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal

class ConvNorm3D(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear', activation=torch.nn.ReLU, residual=False):
        super(ConvNorm3D, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.residual = residual
        self.conv3d = torch.nn.Conv3d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)
        self.batched = torch.nn.BatchNorm3d(out_channels)
        self.activation = activation()

        torch.nn.init.xavier_uniform_(
            self.conv3d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.batched.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
        # torch.nn.init.xavier_uniform_(
        #     self.activation.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv3d(signal)

        batched = self.batched(conv_signal)

        if self.residual:
            batched = batched + signal
        activated = self.activation(batched)

        return activated


class LocationLayer(nn.Module):
	def __init__(self, attention_n_filters, attention_kernel_size,
				 attention_dim):
		super(LocationLayer, self).__init__()
		padding = int((attention_kernel_size - 1) / 2)
		self.location_conv = ConvNorm(2, attention_n_filters,
									  kernel_size=attention_kernel_size,
									  padding=padding, bias=False, stride=1,
									  dilation=1)
		self.location_dense = LinearNorm(attention_n_filters, attention_dim,
										 bias=False, w_init_gain='tanh')

	def forward(self, attention_weights_cat):
		processed_attention = self.location_conv(attention_weights_cat)
		processed_attention = processed_attention.transpose(1, 2)
		processed_attention = self.location_dense(processed_attention)
		return processed_attention


class Attention(nn.Module):
	def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
				 attention_location_n_filters, attention_location_kernel_size):
		super(Attention, self).__init__()
		self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
									  bias=False, w_init_gain='tanh')
		self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
									   w_init_gain='tanh')
		self.v = LinearNorm(attention_dim, 1, bias=False)
		self.location_layer = LocationLayer(attention_location_n_filters,
											attention_location_kernel_size,
											attention_dim)
		self.score_mask_value = -float('inf')

	def get_alignment_energies(self, query, processed_memory,
							   attention_weights_cat):
		'''
		PARAMS
		------
		query: decoder output (batch, num_mels * n_frames_per_step)
		processed_memory: processed encoder outputs (B, T_in, attention_dim)
		attention_weights_cat: cumulative and prev. att we;[ights (B, 2, max_time)

		RETURNS
		-------
		alignment (batch, max_time)
		'''

		processed_query = self.query_layer(query.unsqueeze(1))
		processed_attention_weights = self.location_layer(attention_weights_cat)
		energies = self.v(torch.tanh(
			processed_query + processed_attention_weights + processed_memory))

		energies = energies.squeeze(-1)
		return energies

	def forward(self, attention_hidden_state, memory, processed_memory,
				attention_weights_cat, mask):
		'''
		PARAMS
		------
		attention_hidden_state: attention rnn last output
		memory: encoder outputs
		processed_memory: processed encoder outputs
		attention_weights_cat: previous and cummulative attention weights
		mask: binary mask for padded data
		'''
		alignment = self.get_alignment_energies(
			attention_hidden_state, processed_memory, attention_weights_cat)

		if mask is not None:
			alignment.data.masked_fill_(mask, self.score_mask_value)

		attention_weights = F.softmax(alignment, dim=1)
		attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
		attention_context = attention_context.squeeze(1)

		return attention_context, attention_weights

class Prenet(nn.Module):
	def __init__(self, in_dim, sizes):
		super(Prenet, self).__init__()
		in_sizes = [in_dim] + sizes[:-1]
		self.layers = nn.ModuleList(
			[LinearNorm(in_size, out_size, bias=False)
			 for (in_size, out_size) in zip(in_sizes, sizes)])

	def forward(self, x):
		for linear in self.layers:
			x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
		return x




class Decoder(nn.Module):
	def __init__(self):
		super(Decoder, self).__init__()
		self.num_mels = hps.num_mels
		self.n_frames_per_step = hps.n_frames_per_step
		self.encoder_embedding_dim = hps.encoder_embedding_dim
		self.attention_rnn_dim = hps.attention_rnn_dim
		self.decoder_rnn_dim = hps.decoder_rnn_dim
		self.prenet_dim = hps.prenet_dim
		self.max_decoder_steps = hps.max_decoder_steps
		self.gate_threshold = hps.gate_threshold
		self.p_attention_dropout = hps.p_attention_dropout
		self.p_decoder_dropout = hps.p_decoder_dropout

		self.prenet = Prenet(
			hps.num_mels * hps.n_frames_per_step,
			[hps.prenet_dim, hps.prenet_dim])

		self.attention_rnn = nn.LSTMCell(
			hps.prenet_dim + hps.encoder_embedding_dim,
			hps.attention_rnn_dim)

		self.attention_layer = Attention(
			hps.attention_rnn_dim, hps.encoder_embedding_dim,
			hps.attention_dim, hps.attention_location_n_filters,
			hps.attention_location_kernel_size)

		self.decoder_rnn = nn.LSTMCell(
			hps.attention_rnn_dim + hps.encoder_embedding_dim,
			hps.decoder_rnn_dim, 1)

		self.linear_projection = LinearNorm(
			hps.decoder_rnn_dim + hps.encoder_embedding_dim,
			hps.num_mels * hps.n_frames_per_step)

		self.gate_layer = LinearNorm(
			hps.decoder_rnn_dim + hps.encoder_embedding_dim, 1,
			bias=True, w_init_gain='sigmoid')

	def get_go_frame(self, memory):
		''' Gets all zeros frames to use as first decoder input
		PARAMS
		------
		memory: decoder outputs

		RETURNS
		-------
		decoder_input: all zeros frames
		'''
		B = memory.size(0)
		decoder_input = Variable(memory.data.new(
			B, self.num_mels * self.n_frames_per_step).zero_())
		print(decoder_input)
		print(decoder_input.size())
		return decoder_input

	def initialize_decoder_states(self, memory, mask):
		''' Initializes attention rnn states, decoder rnn states, attention
		weights, attention cumulative weights, attention context, stores memory
		and stores processed memory
		PARAMS
		------
		memory: Encoder outputs
		mask: Mask for padded data if training, expects None for inference
		'''
		B = memory.size(0)
		MAX_TIME = memory.size(1)

		self.attention_hidden = Variable(memory.data.new(
			B, self.attention_rnn_dim).zero_())
		self.attention_cell = Variable(memory.data.new(
			B, self.attention_rnn_dim).zero_())

		self.decoder_hidden = Variable(memory.data.new(
			B, self.decoder_rnn_dim).zero_())
		self.decoder_cell = Variable(memory.data.new(
			B, self.decoder_rnn_dim).zero_())

		self.attention_weights = Variable(memory.data.new(
			B, MAX_TIME).zero_())
		self.attention_weights_cum = Variable(memory.data.new(
			B, MAX_TIME).zero_())
		self.attention_context = Variable(memory.data.new(
			B, self.encoder_embedding_dim).zero_())

		self.memory = memory
		self.processed_memory = self.attention_layer.memory_layer(memory)
		self.mask = mask

	def parse_decoder_inputs(self, decoder_inputs):
		''' Prepares decoder inputs, i.e. mel outputs
		PARAMS
		------
		decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

		RETURNS
		-------
		inputs: processed decoder inputs

		'''
		# (B, num_mels, T_out) -> (B, T_out, num_mels)
		decoder_inputs = decoder_inputs.transpose(1, 2).contiguous()
		decoder_inputs = decoder_inputs.view(
			decoder_inputs.size(0),
			int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
		# (B, T_out, num_mels) -> (T_out, B, num_mels)
		decoder_inputs = decoder_inputs.transpose(0, 1)
		return decoder_inputs

	def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
		''' Prepares decoder outputs for output
		PARAMS
		------
		mel_outputs:
		gate_outputs: gate output energies
		alignments:

		RETURNS
		-------
		mel_outputs:
		gate_outpust: gate output energies
		alignments:
		'''
		# (T_out, B) -> (B, T_out)
		alignments = torch.stack(alignments).transpose(0, 1)
		# (T_out, B) -> (B, T_out)

		gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
		gate_outputs = gate_outputs.contiguous()
		# (T_out, B, num_mels) -> (B, T_out, num_mels)
		mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
		# decouple frames per step
		mel_outputs = mel_outputs.view(
			mel_outputs.size(0), -1, self.num_mels)
		# (B, T_out, num_mels) -> (B, num_mels, T_out)
		mel_outputs = mel_outputs.transpose(1, 2)

		return mel_outputs, gate_outputs, alignments

	def decode(self, decoder_input):
		''' Decoder step using stored states, attention and memory
		PARAMS
		------
		decoder_input: previous mel output

		RETURNS
		-------
		mel_output:
		gate_output: gate output energies
		attention_weights:
		'''
		cell_input = torch.cat((decoder_input, self.attention_context), -1)

		self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))

		self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training)

		attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)

		self.attention_context, self.attention_weights = self.attention_layer(self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask)

		self.attention_weights_cum += self.attention_weights

		decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)

		self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
		self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)

		decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)

		decoder_output = self.linear_projection(decoder_hidden_attention_context)

		gate_prediction = self.gate_layer(decoder_hidden_attention_context)

		return decoder_output, gate_prediction, self.attention_weights

	def forward(self, memory, decoder_inputs, memory_lengths):
		''' Decoder forward pass for training
		PARAMS
		------
		memory: Encoder outputs
		decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
		memory_lengths: Encoder output lengths for attention masking.

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
		print('Encoder outputs', memory.size())
		decoder_input = self.get_go_frame(memory).unsqueeze(0)
		decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
		decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
		decoder_inputs = self.prenet(decoder_inputs)
		# print('decoder_input', decoder_input.size())

		self.initialize_decoder_states(
			memory, mask=~get_mask_from_lengths(memory_lengths))
		mel_outputs, gate_outputs, alignments = [], [], []
		while len(mel_outputs) < decoder_inputs.size(0) - 1:
			decoder_input = decoder_inputs[len(mel_outputs)]
			mel_output, gate_output, attention_weights = self.decode(
				decoder_input)
			print('mel_output', mel_output.size())
			print('gate_output', gate_output.size())
			print('attention_weights', attention_weights.size())
			mel_outputs += [mel_output.squeeze(1)]
			gate_outputs += [gate_output.squeeze()]
			alignments += [attention_weights]
		mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
			mel_outputs, gate_outputs, alignments)

		return mel_outputs, gate_outputs, alignments

	def inference(self, memory):
		''' Decoder inference
		PARAMS
		------
		memory: Encoder outputs

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
		decoder_input = self.get_go_frame(memory)

		self.initialize_decoder_states(memory, mask=None)

		mel_outputs, gate_outputs, alignments = [], [], []
		while True:
			decoder_input = self.prenet(decoder_input)
			mel_output, gate_output, alignment = self.decode(decoder_input)
			mel_outputs += [mel_output.squeeze(1)]
			gate_outputs += [gate_output]
			alignments += [alignment]

			if sum(torch.sigmoid(gate_output.data))/len(gate_output.data) > self.gate_threshold:
				print('Terminated by gate.')
				break
			# elif len(mel_outputs) > 1 and is_end_of_frames(mel_output):
			# 	print('Warning: End with low power.')
			# 	break
			elif len(mel_outputs) == self.max_decoder_steps:
				print('Warning: Reached max decoder steps.')
				break

			decoder_input = mel_output

		mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
			mel_outputs, gate_outputs, alignments)
		return mel_outputs, gate_outputs, alignments


In [30]:
dec = Decoder()

In [31]:
enc = Encoder3D()

In [34]:
enc_out = enc(torch.randn(1, 3, 75, 160, 160))
enc_out.shape

torch.Size([1, 80, 258])

In [32]:
dec.inference(
    torch.randn(1, 75, 384)
)[0].shape

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([1, 160])
Terminated by gate.


torch.Size([1, 80, 2])