In [2]:
import os
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
convlstm_encoder_params = [
    [
        OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
    ],

    [
        ConvLSTM(input_channel=8, num_filter=64, b_h_w=(batch_size, 96, 96),
                 kernel_size=3, stride=1, padding=1),
        ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32),
                 kernel_size=3, stride=1, padding=1),
        ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16),
                 kernel_size=3, stride=1, padding=1),
    ]
]
for index, (params, rnn) in enumerate(zip(subnets, rnns), 1):

In [None]:

class ConvLSTM(nn.Module):
  """
  Class representing ConvLSTM layer based on Xingjian Shi: "Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting"

  Adapted from https://github.com/Hzzone/Precipitation-Nowcasting/blob/master/nowcasting/models/convLSTM.py

  """
  def __init__(self, input_channel, num_filter, b_h_w, kernel_size, stride=1, padding=1):
    """
    Initiate the ConvLSTM layer

    '''
    Parameters
    ----------
    input_channel: int 
      the size of input
    num_filter: int
      number of filter used in the layer represent the time frame
    kernel_size: int
      kernel size
    stride: int
      stride
    padding: int
      padding
    """
    super().__init__()
    self._conv = nn.Conv2d(in_channels=input_channel + num_filter,
                           out_channels=num_filter*4, #multiply with 4 since for i=>input, h=>hidden state, c=>memory, f=>forgotten layer
                           kernel_size=kernel_size,
                           stride=stride,
                           padding=padding)
    self._batch_size, self._state_height, self._state_width = b_h_w
    # Iniate the weight equals to 0.
    self.Wci = nn.Parameter(torch.zeros(1, num_filter, self._state_height, self._state_width))
    self.Wcf = nn.Parameter(torch.zeros(1, num_filter, self._state_height, self._state_width))
    self.Wco = nn.Parameter(torch.zeros(1, num_filter, self._state_height, self._state_width))

    
    self._input_channel = input_channel
    self._num_filter = num_filter

    # inputs: S*B*C*H*W
    # where S represent
  def forward(self, inputs=None, states=None, seq_len=cfg.HKO.BENCHMARK.IN_LEN):

        if states is None:
            c = torch.zeros((inputs.size(1), self._num_filter, self._state_height,
                                  self._state_width), dtype=torch.float).to(cfg.GLOBAL.DEVICE)
            h = torch.zeros((inputs.size(1), self._num_filter, self._state_height,
                             self._state_width), dtype=torch.float).to(cfg.GLOBAL.DEVICE)
        else:
            h, c = states

        outputs = []
        for index in range(seq_len):
            # initial inputs
            if inputs is None:
                x = torch.zeros((h.size(0), self._input_channel, self._state_height,
                                      self._state_width), dtype=torch.float).to(cfg.GLOBAL.DEVICE)
            else:
                x = inputs[index, ...]
            cat_x = torch.cat([x, h], dim=1)
            conv_x = self._conv(cat_x)

            i, f, tmp_c, o = torch.chunk(conv_x, 4, dim=1)

            i = torch.sigmoid(i+self.Wci*c)
            f = torch.sigmoid(f+self.Wcf*c)
            c = f*c + i*torch.tanh(tmp_c)
            o = torch.sigmoid(o+self.Wco*c)
            h = o*torch.tanh(c)
            outputs.append(h)
        return torch.stack(outputs), (h, c)