In [1]:
import torch
from torch import nn

## GRU-RCN
GRU-RCN is from this paper: https://arxiv.org/pdf/1511.06432

"Model parameters $W, W^{l}_z, W^{l}_r$ and $U, U^{l}_z, U^{l}_r$ are 2d-convolutional kernels. Our model results in hidden recurrent representation that preverves the spatial topoloy, $h_t^{l} = (h_t^{l}(i,j))$ where $(h_t^{l}(i,j)$ is a feature vector defined at the location $(i,j)$. To ensure that the spatial size of the hidden representation remains fixed over time, we use zero-padding in the recurrent convolutions."

"Using convolution, parameters $W, W^{l}_z, W^{l}_r$ have a size of $k_1 \times k_2 \times O_x \times O_h$ where $k_1 \times k_2$ is the convolutional kernel size (usually $3 \times 3$)."

The DPFlow seems to only use ConvGRUCell as the forward and backward GRU

In [2]:
class ConvGRUCell(nn.Module):
    def __init__(self, input_channel, hidden_channel):
        """
        A GRU-RNC cell.
        
        @type  input_channel: integer
        @param input_channel: number of input channels
        @type  hidden_channel: integer
        @param hidden_channel: number of the hidden channels 
        """
        super(ConvGRUCell, self).__init__()
        self.conv_Wz = nn.Conv2d(input_channel, hidden_channel, kernel_size=3, padding=1)
        self.conv_Uz = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, padding=1)
        self.conv_Wr = nn.Conv2d(input_channel, hidden_channel, kernel_size=3, padding=1)
        self.conv_Ur = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, padding=1)
        self.conv_W = nn.Conv2d(input_channel, hidden_channel, kernel_size=3, padding=1)
        self.conv_U = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x: torch.tensor, h_prev: torch.tensor):
        """
        @rtype:   torch.tensor
        @return:  hidden representation of this ConvGRU cell with shape (batch, channel, height, width)
        """
        z = self.sigmoid(self.conv_Wz(x) + self.conv_Uz(h_prev))
        r = self.sigmoid(self.conv_Wr(x) + self.conv_Ur(h_prev))
        h_hat = torch.tanh(self.conv_W(x) + self.conv_U(r*h_prev))

        h = ((1-z) * h_prev) + (z * h_hat)
        return h

In [18]:
# We need to specify the size of the hidden representation
batch_size = 32
hidden_channel = 64
hidden_height, hidden_width = 32, 32

# dummy data
x = torch.randn((batch_size, 3, 32, 32));
h0 = torch.zeros((batch_size, hidden_channel, hidden_height, hidden_width));

In [19]:
convGRUCell = ConvGRUCell(3, hidden_channel)
h = convGRUCell(x, h0)
h.shape

torch.Size([32, 64, 32, 32])