In [1]:
####

In [2]:
import torch
from torch import nn

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:

class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)

    def forward(self , x):
        x = self.conv1(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_norm:
            x = self.norm(x)
        return x

In [6]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride , 
                                        padding)
        
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_norm:
            x = self.norm(x)
        return x

In [7]:
config_backbone = [
                   # [out_channels , kernel_size, stride , padding]
                   [64 , (2 , 2) , (2 , 2) , 0] , 
                   [128 , (2 , 2) , (2 , 2) , 0] , 
                   [256 , (2 , 2) , (2 , 2) , 0]
]

In [8]:
class Backbone(nn.Module):
    def __init__(self , 
                 in_channels = 3 , 
                 config = config_backbone):
        super(Backbone , self).__init__()

        self.layers = nn.ModuleList()

        for layer in config:
            out_channels , kernel_size , stride , padding = layer

            self.layers.append(Conv(
                in_channels , 
                out_channels , 
                kernel_size , 
                stride, 
                padding
            ))
            in_channels = out_channels

    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
backbone = Backbone().to(device)
z = backbone(x)
z.shape

In [25]:
config = [
          # [out_channels , kernel_size , stride , padding]
          [256 , (2 , 2) , (2 , 2) , 0] ,
          'S' ,  
          [512 , (2 , 2) , (2 , 2) , 0] , 
          [512 , (3 , 3) , (1 , 1) , 1] , 
          'S' , 
          [512 , (2 , 2) , (2 , 2) , 0] , 
          [512 , (3 , 3) , (1 , 1) , 1] , 
          'S' , 
          [1024 , (2 , 2) , (2 , 2) , 0]  , 
          512 , 
          512 , 
          256
]

In [26]:
class FlowNet(nn.Module):
    def __init__(self , 
                 in_channels_model = 256 , 
                 config = config):
        super(FlowNet , self).__init__()

        self.backbone = Backbone()

        self.layers = nn.ModuleList()
        in_channels = in_channels_model * 2
        last_channels = 0
        for i , layer in enumerate(config):
            if isinstance(layer , list):
                out_channels , kernel_size , stride , padding = layer
                self.layers.append(Conv(in_channels , out_channels , kernel_size , stride , padding))
                in_channels = out_channels
            elif isinstance(layer , str):
                self.layers.append(nn.Identity())

            elif isinstance(layer , int):
                if i == 9:
                    out_channels = layer
                    last_channels = out_channels
                else:
                    in_channels += last_channels
                    out_channels = layer
                self.layers.append(ConvT(in_channels , out_channels))
                in_channels = out_channels
    

    def forward(self , x , y):
        for_concat = []
        x = self.backbone(x)
        y = self.backbone(y)
        x = torch.cat([x , y] , dim=1)
        
        for layer in self.layers:
            x = layer(x)
            if isinstance(layer , nn.Identity):
                for_concat.append(x)
            elif isinstance(layer , ConvT):
                x_prev = for_concat.pop()
                x = torch.cat([x , x_prev] , dim=1)
        print('x' , x.shape)

In [None]:
x = torch.randn(2 , 3 , 512 , 512)
y = torch.randn(2 , 3 , 512 , 512)
flownet = FlowNet().to(device)
z = flownet(x , y)