In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
class Conv2dTF(nn.Conv2d):
    """Conv2d with padding behavior from Tensorflow

    adapted from
    https://github.com/mlperf/inference/blob/16a5661eea8f0545e04c86029362e22113c2ec09/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py#L40
    as referenced in this issue:
    https://github.com/pytorch/pytorch/issues/3867#issuecomment-507025011

    used to maintain behavior of original implementation of TweetyNet that used Tensorflow 1.0 low-level API
    """
    def __init__(self, *args, **kwargs):
        super(Conv2dTF, self).__init__(*args, **kwargs)
        self.padding = kwargs.get("padding", "SAME")

    def _compute_padding(self, input, dim):
        input_size = input.size(dim + 2)
        filter_size = self.weight.size(dim + 2)
        effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
        out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
        total_padding = max(
            0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
        )
        additional_padding = int(total_padding % 2 != 0)

        return additional_padding, total_padding

    def forward(self, input):
        if self.padding == "VALID":
            return F.conv2d(
                input,
                self.weight,
                self.bias,
                self.stride,
                padding=0,
                dilation=self.dilation,
                groups=self.groups,
            )
        rows_odd, padding_rows = self._compute_padding(input, dim=0)
        cols_odd, padding_cols = self._compute_padding(input, dim=1)
        if rows_odd or cols_odd:
            input = F.pad(input, [0, cols_odd, 0, rows_odd])

        return F.conv2d(
            input,
            self.weight,
            self.bias,
            self.stride,
            padding=(padding_rows // 2, padding_cols // 2),
            dilation=self.dilation,
            groups=self.groups,
        )

In [3]:
input_shape=(1, 513, 88)
conv1_filters=32
conv1_kernel_size=(5, 5)
conv2_filters=64
conv2_kernel_size=(5, 5)
pool1_size=(8, 1)
pool1_stride=(8, 1)
pool2_size=(8, 1)
pool2_stride=(8, 1)


cnn = nn.Sequential(
            Conv2dTF(in_channels=input_shape[0],
                     out_channels=conv1_filters,
                     kernel_size=conv1_kernel_size,
                     padding='same'
                     ),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=pool1_size,
                         stride=pool1_stride),
            Conv2dTF(in_channels=conv1_filters,
                      out_channels=conv2_filters,
                      kernel_size=conv2_kernel_size,
                     padding = 'same'
                     ),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=pool2_size,
                         stride=pool2_stride),
        )

In [4]:
batch_shape = tuple((1,) + input_shape)
tmp = torch.rand(batch_shape)
tmp_out = cnn(tmp)

In [5]:
tmp_out.shape

torch.Size([1, 64, 8, 88])

In [6]:
n_features = tmp_out.shape[1] * tmp_out.shape[2]

In [9]:
x = tmp_out.view(1, n_features, -1).permute(0, 2, 1)

In [10]:
x.shape

torch.Size([1, 88, 512])

In [8]:
rnn = nn.LSTM(
    input_size=n_features,
    hidden_size=n_features,
    num_layers=1,
    dropout=0,
    bidirectional=True)

In [12]:
rnn_out, (hidden, cell_state) = rnn(x)

In [13]:
rnn_out.shape

torch.Size([1, 88, 1024])

In [16]:
num_classes = 10

In [17]:
fc = nn.Linear(2 * n_features, num_classes)

In [19]:
x = fc(rnn_out)

In [20]:
x.shape

torch.Size([1, 88, 10])