# Convolution_LSTM_pytorch

https://github.com/automan000/Convolutional_LSTM_PyTorch

```
clstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64], kernel_size=5, step=9, effective_step=[2, 4, 8])
lstm_outputs = clstm(cnn_features)
hidden_states = lstm_outputs[0]
```

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [53]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4

        self.padding = int((kernel_size - 1) / 2)

        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
        ch = co * torch.tanh(cc)
        return ch, cc

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wcf = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wco = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
        return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(),
                Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda())


class ConvLSTM(nn.Module):
    # input_channels corresponds to the first input feature map
    # hidden state is a list of succeeding lstm layers.
    def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]):
        super(ConvLSTM, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            setattr(self, name, cell)
            self._all_layers.append(cell)

    def forward(self, input):
        internal_state = []
        outputs = []
        for step in range(self.step):
            x = input
            for i in range(self.num_layers):
                # all cells are initialized in the first step
                name = 'cell{}'.format(i)
                if step == 0:
                    bsize, _, height, width = x.size()
                    (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],
                                                             shape=(height, width))
                    internal_state.append((h, c))

                # do forward
                (h, c) = internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)
                internal_state[i] = (x, new_c)
            # only record effective steps
            if step in self.effective_step:
                outputs.append(x)

        return outputs, (x, new_c)


# if __name__ == '__main__':
#     # gradient check
#     convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5,
#                         effective_step=[4]).cuda()
#     loss_fn = torch.nn.MSELoss()

#     input = Variable(torch.randn(1, 512, 64, 32)).cuda()
#     target = Variable(torch.randn(1, 32, 64, 32)).double().cuda()

#     output = convlstm(input)
#     output = output[0][0].double()
#     res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)
#     print(res)

In [54]:
clstm = ConvLSTM(input_channels=100, hidden_channels=[128, 64, 64], kernel_size=5, step=9, effective_step=[2, 4, 8])
clstm = clstm.cuda()

In [55]:
input = torch.randn(2, 100, 25, 3, requires_grad=True).cuda()
lstm_outputs = clstm(input)

In [56]:
outputs, (x, new_c) = lstm_outputs

In [57]:
len(outputs)  # == len(effective_step)

3

In [58]:
outputs[0].shape # == [N, hidden_channels[-1], input.size(2), input.size(3)]

torch.Size([2, 64, 25, 3])

In [59]:
x.shape # final output after all LSTM Layers, shape: [N, hidden_channels[-1], input.size(2), input.size(3)]

torch.Size([2, 64, 25, 3])

In [60]:
new_c.shape

torch.Size([2, 64, 25, 3])

# Video ConvLSTM_pytorch

https://github.com/ndrplz/ConvLSTM_pytorch

```
model = ConvLSTM(input_dim=channels,
                 hidden_dim=[64, 64, 128],
                 kernel_size=(3, 3),
                 num_layers=3,
                 batch_first=True
                 bias=True,
                 return_all_layers=False)
```

In [20]:
import torch.nn as nn
import torch

In [21]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

    """
    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.
    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        Parameters
        ----------
        input_tensor: todo
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful
        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # Since the init is done in forward. Can send image size here
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [45]:
channels = 100
model = ConvLSTM(input_dim=channels,
                 hidden_dim=[64, 64, 128],
                 kernel_size=(3, 3),
                 num_layers=3,
                 batch_first=True,
                 bias=True,
                 return_all_layers=False)
model = model.cuda()

In [46]:
input = torch.randn(2, 1, 100, 25, 4, requires_grad=True).cuda()  # B, T, C, H, W
layer_output_list, last_state_list = model(input)

In [47]:
layer_output_list[0].shape  # [B, T, hidden_dim[-1], H, W]

torch.Size([2, 1, 128, 25, 4])

In [48]:
len(last_state_list[0])  # == 2, return of ConvLSTMCell: h_next, c_next

2

In [49]:
last_state_list[0][0].shape  # last state of h_next

torch.Size([2, 128, 25, 4])

In [50]:
last_state_list[0][1].shape # last state of c_next

torch.Size([2, 128, 25, 4])

# Deep Audio Net

In [5]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F


def init_layer(layer):
    """Initialize a Linear or Convolutional layer.
    Ref: He, Kaiming, et al. "Delving deep into rectifiers: Surpassing
    human-level performance on imagenet classification." Proceedings of the
    IEEE international conference on computer vision. 2015.

    Input
        layer: torch.Tensor - The current layer of the neural network
    """

    if layer.weight.ndimension() == 4:
        (n_out, n_in, height, width) = layer.weight.size()
        n = n_in * height * width
    elif layer.weight.ndimension() == 3:
        (n_out, n_in, height) = layer.weight.size()
        n = n_in * height
    elif layer.weight.ndimension() == 2:
        (n_out, n) = layer.weight.size()

    std = math.sqrt(2. / n)
    scale = std * math.sqrt(3.)
    layer.weight.data.uniform_(-scale, scale)

    if layer.bias is not None:
        layer.bias.data.fill_(0.)


def init_lstm(layer):
    """
    Initialises the hidden layers in the LSTM - H0 and C0.

    Input
        layer: torch.Tensor - The LSTM layer
    """
    n_i1, n_i2 = layer.weight_ih_l0.size()
    n_i = n_i1 * n_i2

    std = math.sqrt(2. / n_i)
    scale = std * math.sqrt(3.)
    layer.weight_ih_l0.data.uniform_(-scale, scale)

    if layer.bias_ih_l0 is not None:
        layer.bias_ih_l0.data.fill_(0.)

    n_h1, n_h2 = layer.weight_hh_l0.size()
    n_h = n_h1 * n_h2

    std = math.sqrt(2. / n_h)
    scale = std * math.sqrt(3.)
    layer.weight_hh_l0.data.uniform_(-scale, scale)

    if layer.bias_hh_l0 is not None:
        layer.bias_hh_l0.data.fill_(0.)


def init_att_layer(layer):
    """
    Initilise the weights and bias of the attention layer to 1 and 0
    respectively. This is because the first iteration through the attention
    mechanism should weight each time step equally.

    Input
        layer: torch.Tensor - The current layer of the neural network
    """
    layer.weight.data.fill_(1.)

    if layer.bias is not None:
        layer.bias.data.fill_(0.)


def init_bn(bn):
    """
    Initialize a Batchnorm layer.

    Input
        bn: torch.Tensor - The batch normalisation layer
    """

    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock1d(nn.Module):
    """
    Creates an instance of a 1D convolutional layer. This includes the
    convolutional filter but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, kernel, stride, pad,
                 normalisation, dil=1):
        super(ConvBlock1d, self).__init__()
        self.norm = normalisation
        self.conv1 = nn.Conv1d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel,
                               stride=stride,
                               padding=pad,
                               dilation=dil)
        if self.norm == 'bn':
            self.bn1 = nn.BatchNorm1d(out_channels)
        elif self.norm == 'wn':
            self.conv1 = nn.utils.weight_norm(self.conv1, name='weight')
        else:
            self.conv1 = self.conv1
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        init_layer(self.conv1)
        init_bn(self.bn1)

    def forward(self, input):
        """
        Passes the input through the convolutional filter

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.norm == 'bn':
            x = self.relu(self.bn1(self.conv1(x)))
        else:
            x = self.relu(self.conv1(x))

        return x


class ConvBlock2d(nn.Module):
    """
    Creates an instance of a 2D convolutional layer. This includes the
    convolutional filter but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, kernel, stride, pad,
                 normalisation, att=None):
        super(ConvBlock2d, self).__init__()
        self.norm = normalisation
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel,
                               stride=stride,
                               padding=pad)
        if self.norm == 'bn':
            self.bn1 = nn.BatchNorm2d(out_channels)
        elif self.norm == 'wn':
            self.conv1 = nn.utils.weight_norm(self.conv1, name='weight')
        else:
            self.conv1 = self.conv1
        self.att = att
        if not self.att:
            self.act = nn.ReLU()
        else:
            self.norm = None
            if self.att == 'softmax':
                self.act = nn.Softmax(dim=-1)
            elif self.att == 'global':
                self.act = None
            else:
                self.act = nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        if self.att:
            init_att_layer(self.conv1)
        else:
            init_layer(self.conv1)
        init_bn(self.bn1)

    def forward(self, input):
        """
        Passes the input through the convolutional filter

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.att:
            x = self.conv1(x)
            if self.act():
                x = self.act(x)
        else:
            if self.norm == 'bn':
                x = self.act(self.bn1(self.conv1(x)))
            else:
                x = self.act(self.conv1(x))

        return x


class FullyConnected(nn.Module):
    """
    Creates an instance of a fully-connected layer. This includes the
    hidden layers but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, activation, normalisation,
                 att=None):
        super(FullyConnected, self).__init__()
        self.att = att
        self.norm = normalisation
        self.fc = nn.Linear(in_features=in_channels,
                            out_features=out_channels)
        if activation == 'sigmoid':
            self.act = nn.Sigmoid()
            self.norm = None
        elif activation == 'softmax':
            self.act = nn.Softmax(dim=-1)
            self.norm = None
        elif activation == 'global':
            self.act = None
            self.norm = None
        else:
            self.act = nn.ReLU()
            if self.norm == 'bn':
                self.bnf = nn.BatchNorm1d(out_channels)
            elif self.norm == 'wn':
                self.wnf = nn.utils.weight_norm(self.fc, name='weight')

        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        if self.att:
            init_att_layer(self.fc)
        else:
            init_layer(self.fc)
        if self.norm == 'bn':
            init_bn(self.bnf)

    def forward(self, input):
        """
        Passes the input through the fully-connected layer

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.norm is not None:
            if self.norm == 'bn':
                x = self.act(self.bnf(self.fc(x)))
            else:
                x = self.act(self.wnf(x))
        else:
            if self.att:
                if self.act:
                    x = self.act(self.fc(x))
                else:
                    x = self.fc(x)
            else:
                if self.act:
                    x = self.act(self.fc(x))
                else:
                    x = self.fc(x)        

        return x


# def lstm_with_attention(net_params):
#     if 'LSTM_1' in net_params:
#         arguments = net_params['LSTM_1']
#     else:
#         arguments = net_params['GRU_1']
#     if 'ATTENTION_1' in net_params and 'ATTENTION_Global' not in net_params:
#         if arguments[-1]:
#             return 'forward'
#         else:
#             return 'whole'
#     if 'ATTENTION_1' in net_params and 'ATTENTION_Global' in net_params:
#         if arguments[-1]:
#             return 'forward'
#         else:
#             return 'whole'
#     if 'ATTENTION_1' not in net_params and 'ATTENTION_Global' in net_params:
#         if arguments[-1]:
#             return 'forward_only'
#         else:
#             return 'forward_only'


def reshape_x(x):
    """
    Reshapes the input 'x' if there is a dimension of length 1

    Input:
        x: torch.Tensor - The input

    Output:
        x: torch.Tensor - Reshaped
    """
    dims = x.dim()
    if x.shape[1] == 1 and x.shape[2] == 1 and x.shape[3] == 1:
        x = torch.reshape(x, (x.shape[0], 1))
    elif dims == 4:
        first, second, third, fourth = x.shape
        if second == 1:
            x = torch.reshape(x, (first, third, fourth))
        elif third == 1:
            x = torch.reshape(x, (first, second, fourth))
        else:
            x = torch.reshape(x, (first, second, third))
    elif dims == 3:
        first, second, third = x.shape
        if second == 1:
            x = torch.reshape(x, (first, third))
        elif third == 1:
            x = torch.reshape(x, (first, second))

    return x


class ConvLSTM_2D(nn.Module):
    def __init__(self, input_dim, output_dim, conv2D_hidden, conv1D_hidden, 
                 lstm_hidden, num_layers, activation, norm, dropout):
        super(ConvLSTM_2D, self).__init__()
        self.conv2D = ConvBlock2d(in_channels=input_dim,
                                  out_channels=conv2D_hidden,
                                  kernel=(68, 5),
                                  stride=(1, 1),
                                  pad=(33, 2),
                                  normalisation='bn')
        self.conv1D = ConvBlock1d(in_channels=22*conv2D_hidden,     # (((num_kp + 2*P) - K)//S + 1) // max_pool * conv2D_hidden
                                  out_channels=conv1D_hidden,       # 512
                                  kernel=11,
                                  stride=1,
                                  pad=5,
                                  normalisation='bn')
        self.pool2D = nn.MaxPool2d(kernel_size=3,
                                   stride=3,
                                   padding=0)
        self.pool1D = nn.MaxPool1d(kernel_size=5,
                                   stride=5,
                                   padding=0)
        self.drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(input_size=conv1D_hidden,
                            hidden_size=lstm_hidden,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.attention_layer = nn.Sequential(nn.Linear(lstm_hidden, lstm_hidden),
                                             nn.ReLU(inplace=True))
        self.fc = FullyConnected(in_channels=lstm_hidden,
                                 out_channels=output_dim,
                                 activation=activation,
                                 normalisation=norm)
        
    def attention_net_with_w(self, lstm_out, lstm_hidden):
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        lstm_hidden = lstm_hidden.unsqueeze(1)
        atten_w = self.attention_layer(lstm_hidden)
        m = nn.Tanh()(h)
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        softmax_w = F.softmax(atten_context, dim=-1)
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result
    
    def forward(self, x):
        N, C, V, T = x.shape
        x = self.conv2D(x)  # output x: (N, conv2D_hidden, 67, 1800)
        x = self.pool2D(x)  # output x: (N, conv2D_hidden, 67//3, 1800//3)  2, 16, 22, 600
        N, C, V, T = x.shape
        x = x.view(N, C * V, T)
        x = self.conv1D(x)  # output x: (N, conv1D_hidden, 600)
        x = self.pool1D(x)  # output x: (N, conv1D_hidden, 600//5)
        x = self.drop(x)
        x = x.permute(0, 2, 1).contiguous()
        x, (final_hidden_state, final_cell_state) = self.lstm(x)               # output x: (N, 120, lstm_hidden)
        final_hidden_state = final_hidden_state.permute(1, 0, 2).contiguous()  # output final_hidden_state: (N, num_layer*num_direction, lstm_hidden)
        x = self.attention_net_with_w(x, final_hidden_state)                   # output x: (N, lstm_hidden)
        x = self.fc(x)
        
        return x
    

class ConvLSTM_1D(nn.Module):
    def __init__(self, input_dim, output_dim, conv_hidden,
                 lstm_hidden, num_layers, activation, norm, dropout):
        super(ConvLSTM_1D, self).__init__()
        self.conv2D = ConvBlock2d(in_channels=input_dim,
                                  out_channels=conv_hidden,
                                  kernel=(68, 11),
                                  stride=(1, 1),
                                  pad=(0, 5),
                                  normalisation='bn')
        self.pool1D = nn.MaxPool1d(kernel_size=5,
                                   stride=5,
                                   padding=0)
        self.drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(input_size=conv_hidden,
                            hidden_size=lstm_hidden,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.attention_layer = nn.Sequential(nn.Linear(lstm_hidden, lstm_hidden),
                                             nn.ReLU(inplace=True))
        self.fc = FullyConnected(in_channels=lstm_hidden,
                                 out_channels=output_dim,
                                 activation=activation,
                                 normalisation=norm)
        
    def attention_net_with_w(self, lstm_out, lstm_hidden):
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        lstm_hidden = lstm_hidden.unsqueeze(1)
        atten_w = self.attention_layer(lstm_hidden)
        m = nn.Tanh()(h)
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        softmax_w = F.softmax(atten_context, dim=-1)
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result
    
    def forward(self, x):
        N, C, V, T = x.shape
        x = self.conv2D(x)  # output x: (N, conv_hidden, 1, 1800)
        x = self.pool1D(x.squeeze())  # output x: (N, conv_hidden, 1800//5)
        x = self.drop(x)
        x = x.permute(0, 2, 1).contiguous()
        x, (final_hidden_state, final_cell_state) = self.lstm(x)               # output x: (N, 120, lstm_hidden)
        final_hidden_state = final_hidden_state.permute(1, 0, 2).contiguous()  # output final_hidden_state: (N, num_layer*num_direction, lstm_hidden)
        x = self.attention_net_with_w(x, final_hidden_state)                   # output x: (N, lstm_hidden)
        x = self.fc(x)
        
        return x


class CustomMel(nn.Module):
    def __init__(self, input_dim, output_dim, conv_hidden, lstm_hidden, num_layers, activation, norm, dropout):
        super(CustomMel, self).__init__()
        self.conv = ConvBlock1d(in_channels=input_dim,      # 80
                                out_channels=conv_hidden,   # 128
                                kernel=3,
                                stride=1,
                                pad=1,
                                normalisation='bn')         # ['bn', 'wn', else]
        self.pool = nn.MaxPool1d(kernel_size=3,
                                 stride=3,
                                 padding=0)
        self.drop = nn.Dropout(dropout)                     # 0.2
        self.lstm = nn.LSTM(input_size=conv_hidden,         # 128
                            hidden_size=lstm_hidden,        # 128
                            num_layers=num_layers,          # 2
                            batch_first=True,
                            bidirectional=False)
        self.fc = FullyConnected(in_channels=lstm_hidden,   # 128
                                 out_channels=output_dim,   # 2
                                 activation=activation,     # ['sigmoid', 'softmax', 'global', else]
                                 normalisation=norm)        # ['bn', 'wn']: nn.BatchNorm1d, nn.utils.weight_norm                

    def forward(self, net_input):
        x = net_input
        batch, freq, width = x.shape
        x = self.conv(x)
        x = self.pool(x)
        x = self.drop(x)
        x = torch.transpose(x, 1, 2)
        x, _ = self.lstm(x)                                 # output shape: (batch, width//stride(pool), lstm_hidden) 5x600x128
        x = self.fc(x[:, -1, :].reshape(batch, -1))         # output shape: (batch, output_dim)

        return x


class CustomRaw(nn.Module):
    def __init__(self, input_dim, output_dim, conv_hidden, lstm_hidden, num_layers, activation, dropout):
        super(CustomRaw).__init__()
        # x = [(in + (2*pad) - (kernel-1) - 1) / stride] + 1
        self.conv1 = ConvBlock1d(in_channels=input_dim,     # 1
                                 out_channels=conv_hidden,  # 128
                                 kernel=1024,               # 512
                                 stride=512,                # 512
                                 pad=0,
                                 dil=1,
                                 normalisation='bn')

        self.conv2 = ConvBlock1d(in_channels=conv_hidden,   # 128
                                 out_channels=conv_hidden,  # 128
                                 kernel=3,
                                 stride=1,
                                 pad=1,
                                 normalisation='bn')

        self.pool1 = nn.MaxPool1d(kernel_size=3,
                                  stride=3,
                                  padding=0)

        self.drop = nn.Dropout(dropout)

        self.lstm = nn.LSTM(input_size=conv_hidden,         # 128
                            hidden_size=lstm_hidden,        # 128
                            num_layers=num_layers,          # 2
                            batch_first=True,
                            bidirectional=False)

        self.fc = FullyConnected(in_channels=lstm_hidden,   # 128
                                 out_channels=output_dim,   # 2
                                 activation=activation,     # ['sigmoid', 'softmax', 'global', else]
                                 normalisation=None)

    def forward(self, net_input):
        x = net_input
        batch, freq, width = x.shape
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.drop(x)
        x = torch.transpose(x, 1, 2)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :].reshape(batch, -1))

        return x

## test conv1D

In [82]:
input_dim = 80
conv_hidden = 128
kernel = 3
stride = 1
pad = 1
normalisation = 'bn'

input_1D = torch.randn(2, 80, 300)   # (batch_size, freq(height), width)

In [83]:
conv_1D = ConvBlock1d(in_channels=input_dim,      # 80
                      out_channels=conv_hidden,   # 128
                      kernel=kernel,
                      stride=stride,
                      pad=pad,
                      normalisation=normalisation)         # ['bn', 'wn', else]
conv_1D = conv_1D.cuda()

In [84]:
result_1D = conv_1D(input_1D.cuda())

In [87]:
result_1D.shape  # (2, 80, 300) -> (2, conv_hidden, 300)

torch.Size([2, 128, 300])

In [88]:
pool_1D = nn.MaxPool1d(kernel_size=3, stride=3, padding=0).cuda()

In [89]:
pool_result_1D = pool_1D(result_1D)

In [90]:
pool_result_1D.shape

torch.Size([2, 128, 100])

In [92]:
torch.transpose(pool_result_1D, 1, 2).shape

torch.Size([2, 100, 128])

## test conv2D

### conv2D + pool1D

In [229]:
input_dim = 4
conv_hidden = 16
kernel = (68, 11)
stride= 1
pad = (0, 5)
normalisation = 'bn'

input_2D = torch.randn(2, 4, 68, 1800)   # (batch_size, channel, heigh(num_KPs), width(time series))

In [230]:
conv_2D = ConvBlock2d(in_channels=input_dim,      
                      out_channels=conv_hidden,   
                      kernel=kernel,
                      stride=stride,
                      pad=pad,
                      normalisation=normalisation)         # ['bn', 'wn', else]
conv_2D = conv_2D.cuda()

In [231]:
result_2D = conv_2D(input_2D.cuda())

In [232]:
result_2D.shape

torch.Size([2, 16, 1, 1800])

In [233]:
result_2D.squeeze().shape

torch.Size([2, 16, 1800])

In [235]:
pool_1D = nn.MaxPool1d(kernel_size=3, stride=3, padding=0).cuda()
pool_result_1D = pool_1D(result_2D.squeeze())
pool_result_1D.shape

torch.Size([2, 16, 600])

### conv2D + pool2D

In [115]:
input_dim = 3
conv_hidden = 16
kernel = (68, 5)
stride= (1, 1)
pad = (33, 2)
normalisation = 'bn'

input_2D = torch.randn(2, 3, 68, 1800)   # (batch_size, channel, heigh(num_KPs), width(time series))

In [116]:
conv_2D = ConvBlock2d(in_channels=input_dim,      
                      out_channels=conv_hidden,   
                      kernel=kernel,
                      stride=stride,
                      pad=pad,
                      normalisation=normalisation)         # ['bn', 'wn', else]
conv_2D = conv_2D.cuda()

In [117]:
result_2D = conv_2D(input_2D.cuda())
result_2D.shape

torch.Size([2, 16, 67, 1800])

In [122]:
pool_2D = nn.MaxPool2d(kernel_size=3, stride=3, padding=0).cuda()
pool_result_2D = pool_2D(result_2D)
pool_result_2D.shape

torch.Size([2, 16, 22, 600])

### test ConvLSTM 2D, whole class together

In [225]:
class ConvLSTM_2D(nn.Module):
    def __init__(self, input_dim, output_dim, conv2D_hidden, conv1D_hidden, 
                 lstm_hidden, num_layers, activation, norm, dropout):
        super(ConvLSTM_2D, self).__init__()
        self.conv2D = ConvBlock2d(in_channels=input_dim,
                                  out_channels=conv2D_hidden,
                                  kernel=(68, 5),
                                  stride=(1, 1),
                                  pad=(33, 2),
                                  normalisation='bn')
        self.conv1D = ConvBlock1d(in_channels=22*conv2D_hidden,     # (((num_kp + 2*P) - K)//S + 1) // max_pool * conv2D_hidden
                                  out_channels=conv1D_hidden,       # 512
                                  kernel=11,
                                  stride=1,
                                  pad=5,
                                  normalisation='bn')
        self.pool2D = nn.MaxPool2d(kernel_size=3,
                                   stride=3,
                                   padding=0)
        self.pool1D = nn.MaxPool1d(kernel_size=5,
                                   stride=5,
                                   padding=0)
        self.drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(input_size=conv1D_hidden,
                            hidden_size=lstm_hidden,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.attention_layer = nn.Sequential(nn.Linear(lstm_hidden, lstm_hidden),
                                             nn.ReLU(inplace=True))
        self.fc = FullyConnected(in_channels=lstm_hidden,
                                 out_channels=output_dim,
                                 activation=activation,
                                 normalisation=norm)
        
    def attention_net_with_w(self, lstm_out, lstm_hidden):
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        lstm_hidden = lstm_hidden.unsqueeze(1)
        atten_w = self.attention_layer(lstm_hidden)
        m = nn.Tanh()(h)
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        softmax_w = F.softmax(atten_context, dim=-1)
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result
    
    def forward(self, x):
        N, C, V, T = x.shape
        x = self.conv2D(x)  # output x: (N, conv2D_hidden, 67, 1800)
        x = self.pool2D(x)  # output x: (N, conv2D_hidden, 67//3, 1800//3)  2, 16, 22, 600
        N, C, V, T = x.shape
        x = x.view(N, C * V, T)
        x = self.conv1D(x)  # output x: (N, conv1D_hidden, 600)
        x = self.pool1D(x)  # output x: (N, conv1D_hidden, 600//5)
        x = self.drop(x)
        x = x.permute(0, 2, 1).contiguous()
        x, (final_hidden_state, final_cell_state) = self.lstm(x)               # output x: (N, 120, lstm_hidden)
        final_hidden_state = final_hidden_state.permute(1, 0, 2).contiguous()  # output final_hidden_state: (N, num_layer*num_direction, lstm_hidden)
        x = self.attention_net_with_w(x, final_hidden_state)                   # output x: (N, lstm_hidden)
        x = self.fc(x)
        
        return x
    

In [12]:
model = ConvLSTM_2D(input_dim = 3, 
                    output_dim = 512, 
                    conv2D_hidden = 16, 
                    conv1D_hidden = 512, 
                    lstm_hidden = 512,
                    num_layers = 4,
                    activation = 'relu', 
                    norm = 'bn', 
                    dropout = 0.5)

model = model.cuda()

In [13]:
input = torch.randn(2, 3, 68, 1800).cuda()   # (batch_size, channel, heigh(num_KPs), width(time series))

In [14]:
result = model(input)
result.shape

torch.Size([2, 512])

### test ConvLSTM 1D, whole class together

In [243]:
class ConvLSTM_1D(nn.Module):
    def __init__(self, input_dim, output_dim, conv_hidden,
                 lstm_hidden, num_layers, activation, norm, dropout):
        super(ConvLSTM_1D, self).__init__()
        self.conv2D = ConvBlock2d(in_channels=input_dim,
                                  out_channels=conv_hidden,
                                  kernel=(68, 11),
                                  stride=(1, 1),
                                  pad=(0, 5),
                                  normalisation='bn')
        self.pool1D = nn.MaxPool1d(kernel_size=5,
                                   stride=5,
                                   padding=0)
        self.drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(input_size=conv_hidden,
                            hidden_size=lstm_hidden,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.attention_layer = nn.Sequential(nn.Linear(lstm_hidden, lstm_hidden),
                                             nn.ReLU(inplace=True))
        self.fc = FullyConnected(in_channels=lstm_hidden,
                                 out_channels=output_dim,
                                 activation=activation,
                                 normalisation=norm)
        
    def attention_net_with_w(self, lstm_out, lstm_hidden):
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        lstm_hidden = lstm_hidden.unsqueeze(1)
        atten_w = self.attention_layer(lstm_hidden)
        m = nn.Tanh()(h)
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        softmax_w = F.softmax(atten_context, dim=-1)
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result
    
    def forward(self, x):
        N, C, V, T = x.shape
        x = self.conv2D(x)  # output x: (N, conv_hidden, 1, 1800)
        x = self.pool1D(x.squeeze())  # output x: (N, conv_hidden, 1800//5)
        x = self.drop(x)
        x = x.permute(0, 2, 1).contiguous()
        x, (final_hidden_state, final_cell_state) = self.lstm(x)               # output x: (N, 120, lstm_hidden)
        final_hidden_state = final_hidden_state.permute(1, 0, 2).contiguous()  # output final_hidden_state: (N, num_layer*num_direction, lstm_hidden)
        x = self.attention_net_with_w(x, final_hidden_state)                   # output x: (N, lstm_hidden)
        x = self.fc(x)
        
        return x

In [9]:
model = ConvLSTM_1D(input_dim = 3, 
                    output_dim = 512, 
                    conv_hidden = 16,
                    lstm_hidden = 512,
                    num_layers = 4,
                    activation = 'relu', 
                    norm = 'bn', 
                    dropout = 0.5)

model = model.cuda()

In [10]:
input = torch.randn(2, 3, 68, 1800).cuda()   # (batch_size, channel, heigh(num_KPs), width(time series))

In [11]:
result = model(input)
result.shape

torch.Size([2, 512])

In [11]:
import math
import torch
import torch.nn as nn


def init_layer(layer):
    """Initialize a Linear or Convolutional layer.
    Ref: He, Kaiming, et al. "Delving deep into rectifiers: Surpassing
    human-level performance on imagenet classification." Proceedings of the
    IEEE international conference on computer vision. 2015.

    Input
        layer: torch.Tensor - The current layer of the neural network
    """

    if layer.weight.ndimension() == 4:
        (n_out, n_in, height, width) = layer.weight.size()
        n = n_in * height * width
    elif layer.weight.ndimension() == 3:
        (n_out, n_in, height) = layer.weight.size()
        n = n_in * height
    elif layer.weight.ndimension() == 2:
        (n_out, n) = layer.weight.size()

    std = math.sqrt(2. / n)
    scale = std * math.sqrt(3.)
    layer.weight.data.uniform_(-scale, scale)

    if layer.bias is not None:
        layer.bias.data.fill_(0.)


def init_lstm(layer):
    """
    Initialises the hidden layers in the LSTM - H0 and C0.

    Input
        layer: torch.Tensor - The LSTM layer
    """
    n_i1, n_i2 = layer.weight_ih_l0.size()
    n_i = n_i1 * n_i2

    std = math.sqrt(2. / n_i)
    scale = std * math.sqrt(3.)
    layer.weight_ih_l0.data.uniform_(-scale, scale)

    if layer.bias_ih_l0 is not None:
        layer.bias_ih_l0.data.fill_(0.)

    n_h1, n_h2 = layer.weight_hh_l0.size()
    n_h = n_h1 * n_h2

    std = math.sqrt(2. / n_h)
    scale = std * math.sqrt(3.)
    layer.weight_hh_l0.data.uniform_(-scale, scale)

    if layer.bias_hh_l0 is not None:
        layer.bias_hh_l0.data.fill_(0.)


def init_att_layer(layer):
    """
    Initilise the weights and bias of the attention layer to 1 and 0
    respectively. This is because the first iteration through the attention
    mechanism should weight each time step equally.

    Input
        layer: torch.Tensor - The current layer of the neural network
    """
    layer.weight.data.fill_(1.)

    if layer.bias is not None:
        layer.bias.data.fill_(0.)


def init_bn(bn):
    """
    Initialize a Batchnorm layer.

    Input
        bn: torch.Tensor - The batch normalisation layer
    """

    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock1d(nn.Module):
    """
    Creates an instance of a 1D convolutional layer. This includes the
    convolutional filter but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, kernel, stride, pad,
                 normalisation, dil=1):
        super(ConvBlock1d, self).__init__()
        self.norm = normalisation
        self.conv1 = nn.Conv1d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel,
                               stride=stride,
                               padding=pad,
                               dilation=dil)
        if self.norm == 'bn':
            self.bn1 = nn.BatchNorm1d(out_channels)
        elif self.norm == 'wn':
            self.conv1 = nn.utils.weight_norm(self.conv1, name='weight')
        else:
            self.conv1 = self.conv1
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        init_layer(self.conv1)
        init_bn(self.bn1)

    def forward(self, input):
        """
        Passes the input through the convolutional filter

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.norm == 'bn':
            x = self.relu(self.bn1(self.conv1(x)))
        else:
            x = self.relu(self.conv1(x))

        return x


class ConvBlock2d(nn.Module):
    """
    Creates an instance of a 2D convolutional layer. This includes the
    convolutional filter but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, kernel, stride, pad,
                 normalisation, att=None):
        super(ConvBlock2d, self).__init__()
        self.norm = normalisation
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel,
                               stride=stride,
                               padding=pad)
        if self.norm == 'bn':
            self.bn1 = nn.BatchNorm2d(out_channels)
        elif self.norm == 'wn':
            self.conv1 = nn.utils.weight_norm(self.conv1, name='weight')
        else:
            self.conv1 = self.conv1
        self.att = att
        if not self.att:
            self.act = nn.ReLU()
        else:
            self.norm = None
            if self.att == 'softmax':
                self.act = nn.Softmax(dim=-1)
            elif self.att == 'global':
                self.act = None
            else:
                self.act = nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        if self.att:
            init_att_layer(self.conv1)
        else:
            init_layer(self.conv1)
        init_bn(self.bn1)

    def forward(self, input):
        """
        Passes the input through the convolutional filter

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.att:
            x = self.conv1(x)
            if self.act():
                x = self.act(x)
        else:
            if self.norm == 'bn':
                x = self.act(self.bn1(self.conv1(x)))
            else:
                x = self.act(self.conv1(x))

        return x


class FullyConnected(nn.Module):
    """
    Creates an instance of a fully-connected layer. This includes the
    hidden layers but also the type of normalisation "batch" or
    "weight", the activation function, and initialises the weights.
    """
    def __init__(self, in_channels, out_channels, activation, normalisation,
                 att=None):
        super(FullyConnected, self).__init__()
        self.att = att
        self.norm = normalisation
        self.fc = nn.Linear(in_features=in_channels,
                            out_features=out_channels)
        if activation == 'sigmoid':
            self.act = nn.Sigmoid()
            self.norm = None
        elif activation == 'softmax':
            self.act = nn.Softmax(dim=-1)
            self.norm = None
        elif activation == 'global':
            self.act = None
            self.norm = None
        else:
            self.act = nn.ReLU()
            if self.norm == 'bn':
                self.bnf = nn.BatchNorm1d(out_channels)
            elif self.norm == 'wn':
                self.wnf = nn.utils.weight_norm(self.fc, name='weight')

        self.init_weights()

    def init_weights(self):
        """
        Initialises the weights of the current layer
        """
        if self.att:
            init_att_layer(self.fc)
        else:
            init_layer(self.fc)
        if self.norm == 'bn':
            init_bn(self.bnf)

    def forward(self, input):
        """
        Passes the input through the fully-connected layer

        Input
            input: torch.Tensor - The current input at this stage of the network
        """
        x = input
        if self.norm is not None:
            if self.norm == 'bn':
                x = self.act(self.bnf(self.fc(x)))
            else:
                x = self.act(self.wnf(x))
        else:
            if self.att:
                if self.act:
                    x = self.act(self.fc(x))
                else:
                    x = self.fc(x)
            else:
                if self.act:
                    x = self.act(self.fc(x))
                else:
                    x = self.fc(x)        

        return x


class ConvLSTM_PR(nn.Module):
    def __init__(self, input_dim, output_dim, conv_hidden, lstm_hidden, num_layers, activation, norm, dropout):
        super(ConvLSTM_PR, self).__init__()
        self.conv = ConvBlock2d(in_channels=input_dim,      # 80
                                out_channels=conv_hidden,   # 128
                                kernel=(2, 3),
                                stride=(1, 1),
                                pad=(0, 1),
                                normalisation='bn')         # ['bn', 'wn', else]
        self.pool = nn.MaxPool1d(kernel_size=3,
                                 stride=3,
                                 padding=0)
        self.drop = nn.Dropout(dropout)                     # 0.2
        self.lstm = nn.LSTM(input_size=conv_hidden,         # 128
                            hidden_size=lstm_hidden,        # 128
                            num_layers=num_layers,          # 2
                            batch_first=True,
                            bidirectional=True)
        self.fc = FullyConnected(in_channels=lstm_hidden*2,   # 128
                                 out_channels=output_dim,   # 2
                                 activation=activation,     # ['sigmoid', 'softmax', 'global', else]
                                 normalisation=norm)        # ['bn', 'wn']: nn.BatchNorm1d, nn.utils.weight_norm

    def forward(self, net_input):
        x = net_input
        batch, C, F, T = x.shape
        x = self.conv(x)
        x = self.pool(x.squeeze())
        x = self.drop(x)
        x = x.permute(0, 2, 1).contiguous()
        x, _ = self.lstm(x)                                 # output shape: (batch, width//stride(pool), lstm_hidden*2) 5x600x128
        x = self.fc(x[:, -1, :].reshape(batch, -1))         # output shape: (batch, output_dim)

        return x

In [5]:
input = torch.randn(2, 300, 2, 3)   # (batch_size, channel, heigh(num_KPs), width(time series))
input = input.permute(0, 3, 2, 1).contiguous()

In [12]:
model = ConvLSTM_PR(input_dim = 3, 
                    output_dim = 128, 
                    conv_hidden = 128,
                    lstm_hidden = 128,
                    num_layers = 4,
                    activation = 'relu', 
                    norm = 'bn', 
                    dropout = 0.5)

model = model.cuda()

In [13]:
result = model(input.cuda())
result.shape

torch.Size([2, 128])