<a href="https://colab.research.google.com/github/TimofeyKulakov/NeuralNets/blob/master/Unet_with_unpooling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn


#Model

In [None]:
def test_module(module, size = (1, 3, 224, 224)):
  test = torch.randn(size)
  print(module.forward(test).shape)

In [None]:
class DownBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DownBlock, self).__init__()

    self.block = nn.Sequential(
             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
             nn.ReLU(),
             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
             nn.ReLU(),
         )
    self.maxpool = nn.MaxPool2d(2, return_indices=True)

  def forward(self, x):

    out_before = self.block(x)
    out_after, mask = self.maxpool(out_before)

    return out_after, out_before, mask

In [None]:
class UpBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(UpBlock, self).__init__()

    self.unpool = nn.MaxUnpool2d(2)

    self.convs = nn.Sequential(
            nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

  def forward(self, x, x_skipped, mask):

    x_up = self.unpool(x, mask, x_skipped.shape)
    x = torch.cat([x_up, x_skipped], dim = 1)
    out = self.convs(x)

    return out

In [None]:
class Unet(nn.Module):
  def __init__(self, n_in_channels, n_base_channels, n_blocks):
    super(Unet, self).__init__()

    self.down = nn.ModuleList(
    [DownBlock(n_in_channels if i == 1 else n_base_channels * 2**(i-2), n_base_channels * (2**(i-1)) if i != (n_blocks) else n_base_channels * (2**(i-2))) 
    for i in range(1, n_blocks + 1)]
    )


    self.up = nn.ModuleList(
    [UpBlock(n_base_channels * (2**i), n_base_channels * (2**(i-1)) if i != 0 else n_base_channels * (2**i)) for i in range(n_blocks - 2, -1, -1)]
    )

    self.final_block = nn.Sequential(
             nn.Conv2d(n_base_channels, 1, kernel_size=3, padding=1),
             nn.Sigmoid()
    )

    self.initialize_weights()

  def forward(self, x):
    out = x
    outputs_bp = []
    masks = []

    for i, block in enumerate(self.down):
            out, before_pooling, mask = block(out)
            outputs_bp.append(before_pooling)
            masks.append(mask)

    out = before_pooling
        
    for i, block in enumerate(self.up):  
            out = block(out, outputs_bp[-i - 2], masks[-i - 2])

    out = self.final_block(out)

    return out

  def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):

                nn.init.kaiming_uniform_(m.weight)
                
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [None]:
unet = Unet(3, 16, 5)

In [None]:
test_module(unet)

torch.Size([1, 1, 224, 224])
