In [2]:
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-3), x.size(-2), x.size(-1))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-3),  y.size(-2),  y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y

In [28]:
class Encoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.in_planes = [16, 32, 64, 128, 256]

        self.convolutions = nn.Sequential(nn.Conv2d(11, self.in_planes[0], kernel_size=5, stride=1, padding=2),
                                        nn.BatchNorm2d(self.in_planes[0]),
                                        nn.ReLU(),
                                        nn.AvgPool2d(kernel_size=2, stride=2),
                                        nn.Conv2d(self.in_planes[0], self.in_planes[1], kernel_size=5, stride=1, padding=2),
                                        nn.BatchNorm2d(self.in_planes[1]),
                                        nn.ReLU(),
                                        nn.AvgPool2d(kernel_size=2, stride=2),
                                        nn.Conv2d(self.in_planes[1], self.in_planes[2], kernel_size=5, stride=1, padding=2),
                                        nn.BatchNorm2d(self.in_planes[2]),
                                        nn.ReLU(),
                                        nn.AvgPool2d(kernel_size=2, stride=2),
                                        nn.Conv2d(self.in_planes[2], self.in_planes[3], kernel_size=5, stride=1, padding=2),
                                        nn.BatchNorm2d(self.in_planes[3]),
                                        nn.ReLU(),
                                        nn.AvgPool2d(kernel_size=2, stride=2),
                                        nn.Conv2d(self.in_planes[3], self.in_planes[4], kernel_size=5, stride=1, padding=2),
                                        nn.BatchNorm2d(self.in_planes[4]),
                                        nn.ReLU(),
                                        nn.AvgPool2d(kernel_size=2, stride=2)
        
        )
        
        self.fc = nn.Linear(1024, 100)


    def forward(self, x):
        out = self.convolutions(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out

In [31]:
class ConvLSTM(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.Encoder = TimeDistributed(Encoder())
        self.RNN = nn.LSTM(100, 100, 4)
        self.Output = nn.Linear(100, 1)
        
#     seq_len, batch, input_size
    def forward(self, x):
        out = self.Encoder(x)
        out = self.RNN(out)
        out = self.Output(out[0])
        return out[-1]

In [33]:
net = ConvLSTM()

In [34]:
test_input = torch.randn((7, 4, 11, 64, 64))

In [35]:
test_output = net(test_input)

In [36]:
test_output.shape

torch.Size([4, 1])