In [2]:
import torch
from torch import nn

In [3]:
class BidirEncoder(nn.Module):
    """
    The hidden_channels is always [64, 96, 128] in the DPFlow paper, and pyramid_level = 3
    """
    def __init__(self):
        super().__init__()

        ### Conv Stem: takes the raw image as input and outputs X0 and H0 ###
        self.conv_stem = ResStem([hidden_channels[0], hidden_channels[1], 2 * hidden_channels[2]])
        self.lower_stem = ResStem([hidden_channels[0], hidden_channels[1], hidden_channels[2]])

        ### Forward GRU ###
        self.forward_gru = ConvGRUCell(hidden_channels[-1], hidden_channels[-1])
        # for passing Hf and Xf to the lower scale level. H_out = H_in/2
        self.down_gru = nn.Conv2d(hidden_channels[-1], hidden_channels[-1], kernel_size=3, stride=2, padding=1, bias=True) 

        ### Backward GRU ###
        self.backward_gru = ConvGRUCell(hidden_channels[-1], hidden_channels[-1])
        # for passing Hb and Xb to the higher scale level. H_out = 2*H_in
        self.up_gru = nn.ConvTranspose2d(hidden_channels[-1], hidden_channels[-1], kernel_size=4, stride=2, padding=1, bias=True) 

        ### Forward CGU ###
        #TODO: forward cgu blocks
        
        ### Backward CGU ###
        #TODO: backward cgu block
        


    def forward(self, x: torch.tensor, y: torch.tensor, pyr_levels: int):
        """
        Takes two frames of image x and y as input. x and y will go through the same process separately.
        @param x: a raw image
        @param y: a raw image

        @return: Two feature pyramids x_pyramid[::-1], y_pyramid[::-1] as the embeddings to pass to the decode
        """

        x0, hx0 = self._get_init_stat(x)
        y0, hy0 = self._get_init_stat(y)
        
        x_pyramid, y_pyramid = self._encode(x0, hx0, y0, hy0, pyr_levels)
        
        return x_pyramid[::-1], y_pyramid[::-1]

    
    def _encode(self, x0, hx0, y0. hy0, pyr_levels):
        #TODO: Implement the dual-pyramid encoder block
        """
        Go through the forward process (high scale to low scale) and then the backward process (low scale to high scale)
        Returns the feature pyramid of the input x
        """
        
        x_pyramid = [] # store concatenation of xf, xb, xi for each scale
        y_pyramid = [] # store concatenation of yf, yb, yi for each scale
        
        input_x = x # for concatenation in the final feature pyramid
        input_y = y # for concatenation in the final feature pyramid
        
        ####### Forward Start #######
        x_forwards = []
        y_forwards = []

        x_f, hx_f = x0, hx0
        y_f, hy_f = y0, hy0
        
        for i in range(pyr_levels):
            hx_f = self.forward_gru(x_f, hx_f)
            hy_f = self.forward_gru(y_f, hy_f)

            # TODO: xf, yf = self.forward_cgu(hx, hy) # this is used as Xf for the next level, and in the final concatenation
            x_f = x_f.contiguous() # make the data contiguous to speed up the computation?
            y_f = y_f.contiguous()

            if (i < pyr_levels - 1): # don't do the down_gru for the lowest level
                hx_f = torch.tanh(self.down_gru(hx_f))
                hy_f = torch.tanh(self.down_gru(hy_f))

            x_forwards.append(x_f)
            y_forwards.append(y_f)

        ####### Forward End #######

        ####### Backward Start #######
        hx_b = torch.zeros_like(x_forwards[-1])
        hy_b = torch.zeros_like(y_forwards[-1])

        for i in range(len(pyr_levels) - 1, -1, -1):
            x_f = x_forwards[i]
            y_f = y_forwards[i]

            hx_b = self.backward_gru(x_f, hx_b)
            hy_b = self.backward_gru(y_f, hy_b)

            x_b = self.backward_cgu(hx_b)
            y_b = self.backward_cgu(hy_b)

        

        
        ####### Backward End #######
            
        
        pass


    def _get_init_stat(self, x):
        """
        Pass the input image x to the conv_stem, and return x0, h0
        """
        x = self.conv_stem(x)
        x, hx = torch.split(x, [x.shape[1] // 2, x.shape[1] // 2], 1)
        hx = torch.tanh(hx)
        return x, hx
        