In [1]:
####

In [2]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) ,
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_pool:
            self.max_pool = nn.MaxPool2d(kernel_size=2 , stride=2)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_pool:
            x = self.max_pool(x)
        if self.use_activation:
            x = self.activation(x)
        return x        

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
conv = Conv(3 , 32 , use_pool=True).to(device)
z = conv(x)
z.shape

In [6]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels, 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels ,
                                         out_channels , 
                                        kernel_size , 
                                        stride , 
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
conv = ConvT(3 , 32).to(device)
z = conv(x)
z.shape

In [8]:
class FCN_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 use_pool = True):
        super(FCN_Block , self).__init__()

        self.conv1 = Conv(in_channels , out_channels)
        self.conv2 = Conv(out_channels , out_channels , use_pool=use_pool)

    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [12]:
config = [
          # [out_channels , use_pool]
          [64 , True] ,
          [128 , True] , 
          [256 , False] , 
          [256 , True] , 
          'S' , 
          [512 , False] , 
          [512 , True] , 
          'S' , 
          [512 , False] , 
          [512 , True] , 
          [4096 , True] , 
          (512 , 'C') , 
          (256 , 'C') , 
          128 , 
          64 , 
          3
]

In [17]:
class Save(nn.Module):
    def __init__(self):
        super(Save , self).__init__()
        self.s = 2
    def forward(self , x):
        return x

In [21]:
class FCN(nn.Module):
    def __init__(self , 
                 in_channels = 3 , 
                 out_channels = 3 , 
                 config = config):
        super(FCN , self).__init__()

        self.layers = nn.ModuleList()

        for layer in config:
            if isinstance(layer , list):
                out_channels , use_pool = layer
                self.layers.append(FCN_Block(in_channels , out_channels , use_pool=use_pool))
                in_channels = out_channels
            elif isinstance(layer , tuple):
                out_channels , _ = layer
                self.layers.append(ConvT(in_channels , out_channels))
                in_channels = out_channels
            elif isinstance(layer , str):
                self.layers.append(Save())
            elif isinstance(layer , int):
                out_channels = layer
                self.layers.append(ConvT(in_channels , out_channels))
                in_channels = out_channels

    def forward(self , x):
        saved = []
        for layer , i in enumerate(self.layers):
            if isinstance(layer , FCN_Block):
                x = layer(x)
            elif isinstance(layer , Save):
                x = layer(x)
                saved.append(x)
            elif isinstance(layer , ConvT):
                x = layer(x)
                if i == 12 or i == 13:
                    x += saved.pop()
        return x

In [24]:
def test():
    x = torch.randn(2 , 3 , 224 , 224).to(device)
    fcn = FCN().to(device)
    z = fcn(x)
    print(z.shape)

In [None]:
test()