<a href="https://colab.research.google.com/github/Vaibhavs10/scratchpad/blob/main/test_postnet_conv1d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Taken from ESPNet
"""

import torch


class PostNet(torch.nn.Module):
    """
    From Tacotron2

    Postnet module for Spectrogram prediction network.

    This is a module of Postnet in Spectrogram prediction network,
    which described in `Natural TTS Synthesis by
    Conditioning WaveNet on Mel Spectrogram Predictions`_.
    The Postnet refines the predicted
    Mel-filterbank of the decoder,
    which helps to compensate the detail sturcture of spectrogram.

    .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
       https://arxiv.org/abs/1712.05884
    """

    def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
        """
        Initialize postnet module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            n_layers (int, optional): The number of layers.
            n_filts (int, optional): The number of filter size.
            n_units (int, optional): The number of filter channels.
            use_batch_norm (bool, optional): Whether to use batch normalization..
            dropout_rate (float, optional): Dropout rate..
        """
        super(PostNet, self).__init__()
        self.postnet = torch.nn.ModuleList()
        for layer in range(n_layers - 1):
            ichans = odim if layer == 0 else n_chans
            ochans = odim if layer == n_layers - 1 else n_chans
            if use_batch_norm:
                self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
                                                     torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(),
                                                     torch.nn.Dropout(dropout_rate), )]

            else:
                self.postnet += [
                    torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(),
                                        torch.nn.Dropout(dropout_rate), )]
        ichans = n_chans if n_layers != 1 else odim
        if use_batch_norm:
            self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
                                                 torch.nn.GroupNorm(num_groups=20, num_channels=odim),
                                                 torch.nn.Dropout(dropout_rate), )]

        else:
            self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
                                                 torch.nn.Dropout(dropout_rate), )]

    def forward(self, xs):
        """
        Calculate forward propagation.

        Args:
            xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).

        Returns:
            Tensor: Batch of padded output tensor. (B, odim, Tmax).
        """
        for i in range(len(self.postnet)):
            xs = self.postnet[i](xs)
        return xs

In [2]:
_ = torch.randn([2, 100, 80])

In [4]:
p = PostNet(idim=62, odim=80,)

In [6]:
p(_.transpose(1,2))

tensor([[[ 0.0000e+00, -1.0777e+00, -5.4660e-01,  ...,  0.0000e+00,
          -2.7880e+00, -6.5015e-01],
         [ 2.1876e-01,  0.0000e+00,  2.5389e+00,  ...,  0.0000e+00,
          -9.2436e-01,  0.0000e+00],
         [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  ..., -4.4936e+00,
          -0.0000e+00,  0.0000e+00],
         ...,
         [-0.0000e+00, -1.4846e+00, -7.0412e-01,  ..., -5.1961e-01,
           0.0000e+00, -0.0000e+00],
         [ 4.4329e-03, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
          -0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  2.3490e-01, -0.0000e+00,  ...,  0.0000e+00,
          -1.9897e+00, -1.3159e+00]],

        [[ 6.9515e-01, -0.0000e+00, -0.0000e+00,  ...,  1.8862e-01,
          -1.8237e+00,  0.0000e+00],
         [ 1.7462e-01, -3.2134e-01,  1.8222e+00,  ...,  0.0000e+00,
          -1.4288e+00, -1.2201e-01],
         [-6.6500e-01,  0.0000e+00,  0.0000e+00,  ...,  1.4128e+00,
          -0.0000e+00, -3.7437e-01],
         ...,
         [ 8.4122e-01,  4