<a href="https://colab.research.google.com/github/arifinnasif/Natural-Hazard-Prediction/blob/master/lightnet_on_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.nn as nn
import torch

## ConvLSTM

In [2]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

    """

    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.

    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """

        Parameters
        ----------
        input_tensor: todo
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful

        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # Since the init is done in forward. Can send image size here
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [3]:
conv_lstm = ConvLSTM(1,[8],[(5,5)],1, batch_first=True) # 1 channel, 8 hidden channel (from geng), 5x5 kernel, 1 layer(?), batch first -> batch er size input tensor e first e dicci,
x = torch.rand((32, 6, 1, 25, 25)) # 32 samples in a batch (batch first), prev 6 hours, 1 channel, 25x25 grid
# convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
_, last_states = conv_lstm(x)
h,c = last_states[0]  # 0 for layer index, 0 for h index
print(last_states[0][1].size())
print("-----")
print(h.size())


torch.Size([32, 8, 25, 25])
-----
torch.Size([32, 8, 25, 25])


In [18]:
x_ = torch.rand((32, 6, 50, 50))
conv_ = nn.Conv2d(in_channels = 6, out_channels = 6*4, kernel_size = 5, groups = 6, stride = 2, padding = 2)
# conv_.weight.data = [None, None, ...].repeat(64, 1, 1, 1)
print(conv_.weight.data.size())
conv_(x_).size()

torch.Size([24, 1, 5, 5])


torch.Size([32, 24, 25, 25])

In [35]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()
    self.prev_hours = 6
    self.conv_2 = nn.Conv2d(in_channels = self.prev_hours,
                          out_channels = self.prev_hours*4, # each input frame gets mapped to 4 layer
                          groups = self.prev_hours,
                          kernel_size = 5,
                          stride = 2,
                          padding = 2)
    self.conv_lstm = ConvLSTM(input_dim = 4,
                               hidden_dim = [8],
                               kernel_size = [(5,5)],
                               num_layers = 1,
                               batch_first=True)

  def forward(self, input_tensor):
    x = self.conv_2(input_tensor.flatten(1,2))
    x = self.conv_lstm(torch.unflatten(x, dim = 1, sizes = (6, 4)))
    return x





In [36]:
x = torch.rand((32, 6, 1, 25, 25))
enc = Encoder()
enc(x)

([tensor([[[[[ 2.5045e-02,  3.5153e-02,  1.8137e-02,  ...,  2.3473e-02,
               1.0929e-02, -5.4349e-04],
             [ 1.9460e-02,  1.6855e-02,  2.3034e-02,  ...,  1.1703e-02,
               9.9390e-03,  1.1204e-02],
             [ 5.2960e-02,  2.3356e-02,  1.9685e-02,  ...,  1.5769e-02,
               2.5054e-02,  3.2990e-03],
             ...,
             [ 2.7088e-02,  3.7276e-03,  3.9338e-02,  ...,  3.3646e-02,
               1.8028e-02,  1.6236e-02],
             [ 1.6262e-02,  2.1168e-02,  6.3282e-03,  ...,  1.8314e-02,
               8.9772e-03,  8.5005e-03],
             [ 2.3123e-02,  3.1544e-02,  1.7580e-02,  ...,  1.8130e-02,
               1.0790e-02,  1.7197e-02]],
  
            [[ 1.6854e-02,  2.7983e-02,  1.9915e-02,  ...,  3.1949e-02,
               1.9896e-02,  1.5842e-02],
             [ 9.4185e-03,  1.8120e-02,  2.5799e-02,  ...,  1.8145e-02,
               2.5882e-02,  2.0002e-02],
             [ 2.2740e-02,  1.1928e-02,  1.3540e-02,  ...,  1.6738e-02,
  