In [None]:
import torch.nn as nn
import torch
from collections import OrderedDict

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.conv = nn.Sequential(OrderedDict([
#             ('dp0', nn.Dropout(0.35)),
            ('conv0', nn.Conv2d(
                in_channels=self.input_dim + self.hidden_dim,
                out_channels=4 * self.hidden_dim,
                kernel_size=self.kernel_size,
                padding=self.padding,
                bias=self.bias)),
#             ('dp1', nn.Dropout(0.35)),
#             ('conv1', nn.Conv2d(
#                 in_channels=2 * self.hidden_dim,
#                 out_channels=4 * self.hidden_dim,
#                 kernel_size=self.kernel_size,
#                 padding=self.padding,
#                 bias=self.bias)),
#             ('dp1', nn.Dropout(0.75)),
#             ('inorm1', nn.InstanceNorm2d(128)),
#             ('conv2', nn.Conv2d(
#                 in_channels=8 * self.hidden_dim,
#                 out_channels=4 * self.hidden_dim,
#                 kernel_size=self.kernel_size,
#                 padding=self.padding,
#                 bias=self.bias)),
#             ('dp2', nn.Dropout(0.25)),
#             ('inorm2', nn.InstanceNorm2d(128)),
#             ('conv3', nn.Conv2d(
#                 in_channels=6 * self.hidden_dim,
#                 out_channels=4 * self.hidden_dim,
#                 kernel_size=self.kernel_size,
#                 padding=self.padding,
#                 bias=self.bias)),
        ]))
     
    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = nn.functional.avg_pool2d(self.conv(combined),kernel_size=(3,3), padding=(1,1), stride=(1,1))
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        
        del combined_conv, cur_state, input_tensor, combined, h_cur
        
        i = torch.sigmoid(cc_i)
        del cc_i
        f = torch.sigmoid(cc_f)
        del cc_f
        o = torch.sigmoid(cc_o)
        del cc_o
        g = torch.tanh(cc_g)
        del cc_g
        
        c_next = f * c_cur + i * g
        del f, c_cur, i, g
        h_next = o * torch.tanh(c_next)
        del o

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.conv0.weight.device),    
                    torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.conv0.weight.device))
    
class EncoderDecoderConvLSTM(nn.Module):
    def __init__(self, nf, in_chan, generator=False):
        super(EncoderDecoderConvLSTM, self).__init__()

        """ ARCHITECTURE 
        # Encoder (ConvLSTM)
        # Encoder Vector (final hidden state of encoder)
        # Decoder (ConvLSTM) - takes Encoder Vector as input
        # Decoder (3D CNN) - produces regression predictions for our model
        """
        self.generator = generator
        self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan,
                                               hidden_dim=2*nf,
                                               kernel_size=(3, 3),
                                               bias=True,
                                               )

        self.encoder_2_convlstm = ConvLSTMCell(input_dim=2*nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True,
                                              )
#         self.encoder_3_convlstm = ConvLSTMCell(input_dim=3*nf,
#                                                hidden_dim=nf,
#                                                kernel_size=(3, 3),
#                                                bias=True,
#                                               )
#         self.encoder_4_convlstm = ConvLSTMCell(input_dim=2*nf,
#                                                hidden_dim=nf,
#                                                kernel_size=(3, 3),
#                                                bias=True,
#                                               )

        self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf,  
                                               hidden_dim=2*nf,
                                               kernel_size=(3, 3),
                                               bias=True,
                                              )

        self.decoder_2_convlstm = ConvLSTMCell(input_dim=2*nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True,
                                              )
#         self.decoder_3_convlstm = ConvLSTMCell(input_dim=2*nf,
#                                                hidden_dim=nf,
#                                                kernel_size=(3, 3),
#                                                bias=True,
#                                               )
#         self.decoder_4_convlstm = ConvLSTMCell(input_dim=2*nf,
#                                                hidden_dim=nf,
#                                                kernel_size=(3, 3),
#                                                bias=True,
#                                               )

        self.decoder_CNN = nn.Sequential(
                nn.Dropout(0.95),
                nn.BatchNorm3d(nf),
#                 ConvLSTMCell(input_dim=nf,hidden_dim=nf, kernel_size=(3, 3),bias=True),
#                 nn.Conv2d(in_channels=nf,
#                           out_channels=1,
#                           kernel_size=( 4, 4),
#                           padding=(1, 1),
#                           stride=(2,2),
#                           bias=True)
            
            )
        
            
    def autoencoder(self, x, seq_len, future_step, 
                    h_t, c_t, 
                    h_t2, c_t2, 
#                     h_t3, c_t3, 
#                     h_t4, c_t4, 
                    h_t5, c_t5, 
                    h_t6, c_t6,
#                     h_t7, c_t7,
#                     h_t8, c_t8
                   ):

        
        outputs = []

        # encoder
        for t in range(seq_len):
            h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],
                                               cur_state=[h_t, c_t])  # we could concat to provide skip conn here
            h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,
                                                 cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here
#             h_t3, c_t3 = self.encoder_3_convlstm(input_tensor=h_t2,
#                                                  cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here
#             h_t4, c_t4 = self.encoder_4_convlstm(input_tensor=h_t3,
#                                                  cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here
            
        del h_t, c_t

        # encoder_vector
        encoder_vector = h_t2
        del h_t2, c_t2

        # decoder
        for t in range(future_step):
            h_t5, c_t5 = self.decoder_1_convlstm(input_tensor=encoder_vector,
                                                 cur_state=[h_t5, c_t5])  # we could concat to provide skip conn here

            h_t6, c_t6 = self.decoder_2_convlstm(input_tensor=h_t5,
                                                 cur_state=[h_t6, c_t6])  # we could concat to provide skip conn here
#             h_t7, c_t7 = self.decoder_3_convlstm(input_tensor=h_t6,
#                                                  cur_state=[h_t7, c_t7])  # we could concat to provide skip conn here

#             h_t8, c_t8 = self.decoder_4_convlstm(input_tensor=h_t7,
#                                                  cur_state=[h_t8, c_t8])  # we could concat to provide skip conn here
            
            encoder_vector = h_t6
            outputs = outputs+[h_t6]  # predictions

        del  h_t5, c_t5, h_t6, c_t6
#         del h_t7, c_t7
        
        outputs = torch.stack(outputs, 1)
        outputs = outputs.permute(0, 2, 1, 3, 4)
        outputs = self.decoder_CNN(outputs)
        if self.generator:
            outputs = torch.nn.Tanh()(outputs)


        return outputs
    
    def forward(self, x, future_seq=chan, hidden_state=None):

        """
        Parameters
        ----------
        input_tensor:
            5-D Tensor of shape (b, t, c, h, w)        #   batch, time, channel, height, width
        """
        b, seq_len, c, h, w = x.size()
        
        # initialize hidden states
        h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
#         h_t3, c_t3 = self.encoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
#         h_t4, c_t4 = self.encoder_4_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t5, c_t5 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t6, c_t6 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
#         h_t7, c_t7 = self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
#         h_t8, c_t8 = self.decoder_4_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        # autoencoder forward
        outputs = self.autoencoder(x, seq_len, future_seq, 
                                   h_t, c_t,
                                   h_t2, c_t2, 
#                                    h_t3, c_t3, 
#                                    h_t4, c_t4,
                                   h_t5, c_t5, 
                                   h_t6, c_t6, 
#                                    h_t7, c_t7, 
#                                    h_t8, c_t8
                                  )
        del x, seq_len, future_seq, h_t, c_t, h_t2, c_t2
        del h_t5, c_t5, h_t6, c_t6
        
        return outputs