In [10]:
from librosa.filters import mel as librosa_mel_fn
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F

import torch
from librosa.filters import mel as librosa_mel_fn
from audio_processing import dynamic_range_compression
from audio_processing import dynamic_range_decompression
from stft import STFT

ModuleNotFoundError: No module named 'audio_processing'

In [3]:
def _get_mask_from_lengths(lengths):
    r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
    is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.

    Args:
        lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).

    Returns:
        mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
    """
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
    mask = (ids < lengths.unsqueeze(1)).byte()
    mask = torch.le(mask, 0)
    return mask

In [9]:
tensor = torch.LongTensor([3, 4, 4, 4])
_get_mask_from_lengths(tensor)

tensor([[False, False, False,  True],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])

tensor([[False,  True,  True],
        [False, False, False]])

In [7]:
mel_basis = librosa_mel_fn(
            sr=256, n_fft=256, n_mels=32, fmin=0, fmax=256)
mel_basis

  mel_basis = librosa_mel_fn(


array([[-0.        ,  0.01661682,  0.03323364, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [None]:
def SpectrogramCompressing(spectrogram, windowfunc):
    return

In [None]:



class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal


class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]

        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output

class LocationLayer(nn.Module):
    def __init__(self, attention_n_filters, attention_kernel_size,
                 attention_dim):
        super(LocationLayer, self).__init__()
        padding = int((attention_kernel_size - 1) / 2)
        self.location_conv = ConvNorm(2, attention_n_filters,
                                      kernel_size=attention_kernel_size,
                                      padding=padding, bias=False, stride=1,
                                      dilation=1)
        self.location_dense = LinearNorm(attention_n_filters, attention_dim,
                                         bias=False, w_init_gain='tanh')

    def forward(self, attention_weights_cat):
        processed_attention = self.location_conv(attention_weights_cat)
        processed_attention = processed_attention.transpose(1, 2)
        processed_attention = self.location_dense(processed_attention)
        return processed_attention


class Attention(nn.Module):
    def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
                 attention_location_n_filters, attention_location_kernel_size):
        super(Attention, self).__init__()
        self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                      bias=False, w_init_gain='tanh')
        self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
                                       w_init_gain='tanh')
        self.v = LinearNorm(attention_dim, 1, bias=False)
        self.location_layer = LocationLayer(attention_location_n_filters,
                                            attention_location_kernel_size,
                                            attention_dim)
        self.score_mask_value = -float("inf")

    def get_alignment_energies(self, query, processed_memory,
                               attention_weights_cat):
        """
        PARAMS
        ------
        query: decoder output (batch, n_mel_channels * n_frames_per_step)
        processed_memory: processed encoder outputs (B, T_in, attention_dim)
        attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

        RETURNS
        -------
        alignment (batch, max_time)
        """

        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_weights_cat)
        energies = self.v(torch.tanh(
            processed_query + processed_attention_weights + processed_memory))

        energies = energies.squeeze(-1)
        return energies

    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask):
        """
        PARAMS
        ------
        attention_hidden_state: attention rnn last output
        memory: encoder outputs
        processed_memory: processed encoder outputs
        attention_weights_cat: previous and cummulative attention weights
        mask: binary mask for padded data
        """
        alignment = self.get_alignment_energies(
            attention_hidden_state, processed_memory, attention_weights_cat)

        if mask is not None:
            alignment.data.masked_fill_(mask, self.score_mask_value)

        attention_weights = F.softmax(alignment, dim=1)
        attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        attention_context = attention_context.squeeze(1)

        return attention_context, attention_weights


class Prenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super(Prenet, self).__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [LinearNorm(in_size, out_size, bias=False)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.2, training=True) #was 0.5
        return x


class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self, hparams):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(hparams.postnet_embedding_dim))
        )

        for i in range(1, hparams.postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(hparams.postnet_embedding_dim,
                             hparams.postnet_embedding_dim,
                             kernel_size=hparams.postnet_kernel_size, stride=1,
                             padding=int((hparams.postnet_kernel_size - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(hparams.postnet_embedding_dim))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(hparams.n_mel_channels))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.2, self.training) # was 0.5
        x = F.dropout(self.convolutions[-1](x), 0.2, self.training) # was 0.5

        return x


class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """
    def __init__(self, hparams):
        super(Encoder, self).__init__()

        convolutions = []
        for _ in range(hparams.encoder_n_convolutions):
            conv_layer = nn.Sequential(
                ConvNorm(hparams.encoder_embedding_dim,
                         hparams.encoder_embedding_dim,
                         kernel_size=hparams.encoder_kernel_size, stride=1,
                         padding=int((hparams.encoder_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(hparams.encoder_embedding_dim))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
                            int(hparams.encoder_embedding_dim / 2), 1,
                            batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        for conv in self.convolutions:
            x = F.dropout(F.relu(conv(x)), 0.2, self.training) # was 0.5

        x = x.transpose(1, 2)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths, batch_first=True)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True)

        return outputs

    def inference(self, x):
        for conv in self.convolutions:
            x = F.relu(conv(x))#F.dropout(F.relu(conv(x)), 0.2, self.training) # was 0.5

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs


class Decoder(nn.Module):
    def __init__(self, hparams):
        super(Decoder, self).__init__()
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        self.encoder_embedding_dim = hparams.encoder_embedding_dim
        self.attention_rnn_dim = hparams.attention_rnn_dim
        self.decoder_rnn_dim = hparams.decoder_rnn_dim
        self.prenet_dim = hparams.prenet_dim
        self.max_decoder_steps = hparams.max_decoder_steps
        self.gate_threshold = hparams.gate_threshold
        self.p_attention_dropout = hparams.p_attention_dropout
        self.p_decoder_dropout = hparams.p_decoder_dropout

        self.prenet = Prenet(
            hparams.n_mel_channels * hparams.n_frames_per_step,
            [hparams.prenet_dim, hparams.prenet_dim])

        self.attention_rnn = nn.LSTMCell(
            hparams.prenet_dim + hparams.encoder_embedding_dim,
            hparams.attention_rnn_dim)

        self.attention_layer = Attention(
            hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
            hparams.attention_dim, hparams.attention_location_n_filters,
            hparams.attention_location_kernel_size)

        self.decoder_rnn = nn.LSTMCell(
            hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
            hparams.decoder_rnn_dim, 1)

        self.linear_projection = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
            hparams.n_mel_channels * hparams.n_frames_per_step)

        self.gate_layer = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
            bias=True, w_init_gain='sigmoid')

    def get_go_frame(self, memory):
        """ Gets all zeros frames to use as first decoder input
        PARAMS
        ------
        memory: decoder outputs

        RETURNS
        -------
        decoder_input: all zeros frames
        """
        B = memory.size(0)
        decoder_input = Variable(memory.data.new(
            B, self.n_mel_channels * self.n_frames_per_step).zero_())
        return decoder_input

    def initialize_decoder_states(self, memory, mask):
        """ Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory
        PARAMS
        ------
        memory: Encoder outputs
        mask: Mask for padded data if training, expects None for inference
        """
        B = memory.size(0)
        MAX_TIME = memory.size(1)

        self.attention_hidden = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())
        self.attention_cell = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())

        self.decoder_hidden = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())
        self.decoder_cell = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())

        self.attention_weights = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_weights_cum = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_context = Variable(memory.data.new(
            B, self.encoder_embedding_dim).zero_())

        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

    def parse_decoder_inputs(self, decoder_inputs):
        """ Prepares decoder inputs, i.e. mel outputs
        PARAMS
        ------
        decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

        RETURNS
        -------
        inputs: processed decoder inputs

        """
        # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(1, 2)
        decoder_inputs = decoder_inputs.view(
            decoder_inputs.size(0),
            int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
        # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        return decoder_inputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
        """ Prepares decoder outputs for output
        PARAMS
        ------
        mel_outputs:
        gate_outputs: gate output energies
        alignments:

        RETURNS
        -------
        mel_outputs:
        gate_outpust: gate output energies
        alignments:
        """
        # (T_out, B) -> (B, T_out)
        alignments = torch.stack(alignments).transpose(0, 1)
        # (T_out, B) -> (B, T_out)
        gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        gate_outputs = gate_outputs.contiguous()
        # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        # decouple frames per step
        mel_outputs = mel_outputs.view(
            mel_outputs.size(0), -1, self.n_mel_channels)
        # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
        mel_outputs = mel_outputs.transpose(1, 2)

        return mel_outputs, gate_outputs, alignments

    def decode(self, decoder_input):
        """ Decoder step using stored states, attention and memory
        PARAMS
        ------
        decoder_input: previous mel output

        RETURNS
        -------
        mel_output:
        gate_output: gate output energies
        attention_weights:
        """
        cell_input = torch.cat((decoder_input, self.attention_context), -1)
        self.attention_hidden, self.attention_cell = self.attention_rnn(
            cell_input, (self.attention_hidden, self.attention_cell))
        self.attention_hidden = F.dropout(
            self.attention_hidden, self.p_attention_dropout, self.training)

        attention_weights_cat = torch.cat(
            (self.attention_weights.unsqueeze(1),
             self.attention_weights_cum.unsqueeze(1)), dim=1)
        self.attention_context, self.attention_weights = self.attention_layer(
            self.attention_hidden, self.memory, self.processed_memory,
            attention_weights_cat, self.mask)

        self.attention_weights_cum += self.attention_weights
        decoder_input = torch.cat(
            (self.attention_hidden, self.attention_context), -1)
        self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
            decoder_input, (self.decoder_hidden, self.decoder_cell))
        self.decoder_hidden = F.dropout(
            self.decoder_hidden, self.p_decoder_dropout, self.training)

        decoder_hidden_attention_context = torch.cat(
            (self.decoder_hidden, self.attention_context), dim=1)
        decoder_output = self.linear_projection(
            decoder_hidden_attention_context)

        gate_prediction = self.gate_layer(decoder_hidden_attention_context)
        return decoder_output, gate_prediction, self.attention_weights

    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze(1)]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments

    def inference(self, memory):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        self.initialize_decoder_states(memory, mask=None)

        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_input = self.prenet(decoder_input)
            mel_output, gate_output, alignment = self.decode(decoder_input)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if torch.sigmoid(gate_output.data) > self.gate_threshold:
                break
            elif len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments


class Tacotron2(nn.Module):
    def __init__(self, hparams):
        super(Tacotron2, self).__init__()
        self.mask_padding = hparams.mask_padding
        self.fp16_run = hparams.fp16_run
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        self.embedding = nn.Embedding(
            hparams.n_symbols, hparams.symbols_embedding_dim)
        std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
        val = sqrt(3.0) * std  # uniform bounds for std
        self.embedding.weight.data.uniform_(-val, val)
        self.encoder = Encoder(hparams)
        self.decoder = Decoder(hparams)
        self.postnet = Postnet(hparams)

    def parse_batch(self, batch):
        text_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths = batch
        text_padded = to_gpu(text_padded).long()
        input_lengths = to_gpu(input_lengths).long()
        max_len = torch.max(input_lengths.data).item()
        mel_padded = to_gpu(mel_padded).float()
        gate_padded = to_gpu(gate_padded).float()
        output_lengths = to_gpu(output_lengths).long()

        return (
            (text_padded, input_lengths, mel_padded, max_len, output_lengths),
            (mel_padded, gate_padded))

    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs

    def forward(self, inputs):
        text_inputs, text_lengths, mels, max_len, output_lengths = inputs
        text_lengths, output_lengths = text_lengths.data, output_lengths.data

        embedded_inputs = self.embedding(text_inputs).transpose(1, 2)

        encoder_outputs = self.encoder(embedded_inputs, text_lengths)

        mel_outputs, gate_outputs, alignments = self.decoder(
            encoder_outputs, mels, memory_lengths=text_lengths)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        return self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
            output_lengths)

    def inference(self, inputs):
        embedded_inputs = self.embedding(inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(embedded_inputs)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

        return outputs

In [None]:
#!L
N_SYMBOLS = 148
SYMBOLS_EMBEDDING_DIM = 512
N_MEL_CHANNELS = 80
N_FRAMES_PER_STEP = 1
ENCODER_EMBEDDING_DIM = 512
ATTENTION_DIM = 128
PRENET_DIM = 256
ATTENTION_RNN_DIM = 1024
ATTENTION_LOCATION_N_FILTERS = 32
ATTENTION_LOCATION_KERNEL_SIZE = 31
DECODER_RNN_DIM = 1024
P_ATTENTION_DROPOUT = 0.1
P_DECODER_DROPOUT = 0.1
GATE_THRESHOLD = 0.5
MAX_DECODER_STEPS = 1000

In [14]:
class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal

class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """
    def __init__(self, hidden_dim=512):
        super(Encoder, self).__init__()

        convolutions = []
        for _ in range(3): #3 is param
            conv_layer = nn.Sequential(
                ConvNorm(hidden_dim,
                         hidden_dim,
                         kernel_size=5, stride=1, #kernel size is param
                         padding=int((5 - 1) / 2), 
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(hidden_dim))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hidden_dim,
                            int(hidden_dim / 2), 1,
                            batch_first=True, bidirectional=True)
        
    def forward(self,  x):
        for conv in self.convolutions:
            x = F.dropout(F.relu(conv(x)), 0.2, self.training) # was 0.5

        x = x.transpose(1, 2)
        #print("x", x.size())

        # pytorch tensor are not reversible, hence the conversion
        
        outputs, _ = self.lstm(x)
        return outputs
        
class Prenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super(Prenet, self).__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [LinearNorm(in_size, out_size, bias=False)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), 0.2, self.training) #was 0.5
        return x
    
class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self, hidden_dim=512):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(80, hidden_dim,
                         kernel_size=5, stride=1,
                         padding=int((5 - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(hidden_dim))
        )

        for i in range(1, 5 - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(hidden_dim,
                             hidden_dim,
                             kernel_size=5, stride=1,
                             padding=int((5 - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(hidden_dim))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(512, 80,
                         kernel_size=5, stride=1,
                         padding=int((5 - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(80))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.2, self.training) # was 0.5
        x = F.dropout(self.convolutions[-1](x), 0.2, self.training) # was 0.5

        return x

class Tacotron2(nn.Module):
    def __init__(self, feature_dim=27, win_len=4000):
        super(Tacotron2, self).__init__()
        self.win_len = win_len
        
        self.linear1 = nn.Conv1d(feature_dim, 512, 1) #512 is param
        self.encoder = Encoder()
        self.prenet = Prenet(80, [256, 256]) #params
        self.attention_rnn = nn.LSTMCell(256 + 512, 1024) #prenet dim + encoder dim, hidden dim
        
        self.attention_v = nn.Linear(128, 1, bias=False) #param
        self.attention_query = nn.Linear(1024, 128) #param
        self.attention_memory = nn.Linear(512, 128) #param
        self.attention_location_fc = nn.Linear(32, 128) # param
        self.attention_location_conv = nn.Conv1d(2, 32, kernel_size=31, padding=int((31 - 1) / 2), bias=False) #params
        
        self.decoder_rnn = nn.LSTMCell(1024 + 512, 1024) #param
        self.decoder_output_projection = nn.Linear(1024 + 512, 80)#param
        
        self.postnet = Postnet()#param
        
        
        self.classifier1 = nn.Linear(80, 2)#param
        self.classifier2 = nn.Linear(80, 2)#param
        

   

    def forward(self, inputs):
        
        encoder_inputs = F.relu(self.linear1(inputs))
        print("encoder_inputs", encoder_inputs.size())
        encoder_outputs = self.encoder(encoder_inputs)

        B, N, *_ = encoder_outputs.size()
        print("encoder_outputs", encoder_outputs.size())

        # prepare decoder for inference
        # zero decoder states
        attention_hidden = Variable(encoder_outputs.data.new(B, 1024).zero_()) #param
        attention_cell = Variable(encoder_outputs.data.new(B, 1024).zero_())#param

        decoder_hidden = Variable(encoder_outputs.data.new(B, 1024).zero_()) #param
        decoder_cell = Variable(encoder_outputs.data.new(B, 1024).zero_()) #param

        attention_weights = Variable(encoder_outputs.data.new(B, N).zero_())
        attention_weights_sum = Variable(encoder_outputs.data.new(B, N).zero_())
        attention_context = Variable(encoder_outputs.data.new(B, 512).zero_()) #param

        # initialize memory
        decoder_memory = encoder_outputs
        decoder_processed_memory = self.attention_memory(decoder_memory)

        # prepare all-zero input frame
        decoder_input = Variable(encoder_outputs.data.new(
            B, 80  #param
        ).zero_())

        # start inference
        mel_outputs, alignments = [], []
        print("decoder_input", decoder_input.shape)
        for _ in range(self.win_len):
            # prenet   

            prenet_output = self.prenet(decoder_input)
            # attention

            attention_rnn_input = torch.cat((prenet_output, attention_context), dim=-1)
            print("prenet_output", prenet_output.shape)
            print("attention_context", attention_context.shape)
            print("attention_rnn_input", attention_rnn_input.shape)

            attention_hidden, attention_cell = self.attention_rnn(attention_rnn_input, (attention_hidden, attention_cell))
            attention_hidden = F.dropout(attention_hidden, p = 0.2, training = False)

            calculate_attention_query = self.attention_query(attention_hidden.unsqueeze(1))

            attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_sum.unsqueeze(1)), dim = 1)
            attention_loc = self.attention_location_fc(self.attention_location_conv(attention_weights_cat).transpose(1, 2))

            print("calculate_attention_query", calculate_attention_query.shape)
            print("decoder_processed_memory", decoder_processed_memory.shape)
            print("attention_loc", attention_loc.shape)

            #raise Exception("exit")

            energy = self.attention_v(torch.tanh(calculate_attention_query + attention_loc + decoder_processed_memory))
            energy = energy.squeeze(-1)

            attention_weights = F.softmax(energy, dim=1)
            print("attention_weights", attention_weights.shape)
            
            attention_context = torch.bmm(attention_weights.unsqueeze(1), decoder_memory)
            print("attention_context", attention_context.shape)
            
            attention_context = attention_context.squeeze(1)
            print("attention_context2", attention_context.shape)
            

            attention_weights_sum += attention_weights

            # decoder rnn

            decoder_rnn_input = torch.cat((attention_hidden, attention_context), dim=-1)
            decoder_hidden, decoder_cell = self.decoder_rnn(decoder_rnn_input, (decoder_hidden, decoder_cell))
            decoder_hidden = F.dropout(decoder_hidden, p = 0.2, training = False)
            #..........

            #..........
            decoder_data = torch.cat((decoder_hidden, attention_context), dim=1)
            decoder_output = self.decoder_output_projection(decoder_data)

            #stop_gate_prediction = decoder_stop_gate_projection(decoder_data)

            mel_outputs += [decoder_output.squeeze(1)]
            alignments += [attention_weights]

            decoder_input = decoder_output

        alignments = torch.stack(alignments).transpose(0, 1)
        #gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        #gate_outputs = gate_outputs.contiguous()
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, 80)
        classifier_mel = self.classifier1(mel_outputs)
        
        mel_outputs = mel_outputs.transpose(1, 2)

        mel_outputs_postnet = self.postnet(mel_outputs)
        classifier_mel_postnet = self.classifier1(mel_outputs_postnet.transpose(1, 2))
        

        #synthesized_mels = mel_outputs + mel_outputs_postnet
        return classifier_mel_postnet, classifier_mel, alignments

    

Можно вынести только prenet из forward(((

In [15]:
test_input = torch.FloatTensor(np.ones((2, 27, 4000)))
taco = Tacotron2()
post, preds, align = taco(test_input)

encoder_inputs torch.Size([2, 512, 4000])
encoder_outputs torch.Size([2, 4000, 512])
decoder_input torch.Size([2, 80])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calcul

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Siz

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
cal

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_conte

attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1, 128])
decoder_processed_memory torch.Size([2, 4000, 128])
attention_loc torch.Size([2, 4000, 128])
attention_weights torch.Size([2, 4000])
attention_context torch.Size([2, 1, 512])
attention_context2 torch.Size([2, 512])
prenet_output torch.Size([2, 256])
attention_context torch.Size([2, 512])
attention_rnn_input torch.Size([2, 768])
calculate_attention_query torch.Size([2, 1,

In [72]:
preds.shape

torch.Size([2, 4000, 2])

In [73]:
align.shape

torch.Size([2, 4000, 4000])