In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class _TemporalPooling(nn.Module):
    def __init__(self, in_features=3, out_features=1, kernel_size=2,
                 dropout=0.1):
        super(_TemporalPooling, self).__init__()
        self.kernel_size = kernel_size

        self.pool = nn.AvgPool1d(kernel_size=self.kernel_size,
                                 stride=self.kernel_size, padding=0)
        

        self.conv = nn.Conv1d(in_features, out_features, kernel_size=1,
                             
                              padding=0)
                              
        self.bn = nn.BatchNorm1d(out_features)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # TODO: verify that the inputs are in the right shape
        x = self.pool(x)
        print(x.shape)
        x = self.conv(x)
        
        x = self.bn(F.relu(x))
        
        x = self.drop(F.interpolate(x, scale_factor=self.kernel_size,
                                    mode='linear', align_corners=True))
        return x

L = 32

x = torch.rand((64,64,L))
layer1 = _TemporalPooling(64, 16, int(L/6))
layer2 = _TemporalPooling(64, 16, int(L/4))
layer3 = _TemporalPooling(64, 16, int(L/3))
layer4 = _TemporalPooling(64, 16, int(L/2))

torch.cat([layer1(x), layer3(x)], axis=1).shape


torch.Size([64, 64, 6])
torch.Size([64, 64, 3])


torch.Size([64, 32, 30])