<a href="https://colab.research.google.com/github/ShkarupaDC/image_inpainting/blob/master/Image_Inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from collections import namedtuple

In [13]:
class PConv2d(nn.Conv2d):


  def __init__(self, *args, **kwargs):
    
    if "return_mask" in kwargs:
      
      self.return_mask = kwargs["return_mask"]
      kwargs.pop("return_mask")
    
    super().__init__(*args, **kwargs)
    self.window = self.in_channels * self.out_channels *\
          self.kernel_size[0] * self.kernel_size[0]
    
    self.mask_kernel = torch.ones(self.out_channels,
          self.in_channels, *self.kernel_size)


  def forward(self, x, mask=None):

    with torch.no_grad():
      if mask is not None:

        updated_mask = F.conv2d(mask, self.mask_kernel, None,
            self.stride, self.padding, self.dilation)
        
        mask_ratio = self.window / (updated_mask + 1e-6)
        
        updated_mask = torch.clamp(updated_mask, 0, 1)
        mask_ratio = torch.mul(updated_mask, mask_ratio)
    
    x = super().forward(x if mask is None else torch.mul(x, mask))

    if self.bias is not None:
      
      bias_view = self.bias.view(1, self.out_channels, 1, 1)
      x = torch.mul(x, mask_ratio) + bias_view

    else: x = torch.mul(x, mask_ratio)

    if self.return_mask is True: return (x, updated_mask)
    
    else: return x

In [14]:
class EncoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size,
      stride=1, padding=0, bn=True, activation=nn.ReLU()):
    super().__init__()

    self.pconv = PConv2d(in_channels, out_channels,
        kernel_size, stride, padding, return_mask=True)
    
    if bn is True:
      self.bn = nn.BatchNorm2d(out_channels)

    self.activation = activation

  
  def forward(self, x, mask):
    
    x, mask = self.pconv(x, mask)

    if hasattr(self, "bn"):
      x = self.bn(x)

    x = self.activation(x)
    
    return x, mask

In [139]:
class DecoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size,
    stride=1, padding=0, bn=True, activation=nn.LeakyReLU(0.2)):
    super().__init__()

    self.pconv = PConv2d(in_channels + out_channels, out_channels,
      kernel_size, stride, padding, return_mask=True)
    
    if bn is True:
      self.bn = nn.BatchNorm2d(out_channels)

    self.activation = activation


  def forward(self, x, mask, encoded_x, encoded_mask):

    x = F.interpolate(x, scale_factor=2, mode="nearest")
    mask = F.interpolate(mask, scale_factor=2, mode="nearest")

    x = torch.cat([x, encoded_x], dim=1)
    mask = torch.cat([mask, encoded_mask], dim=1)

    x, mask = self.pconv(x, mask)

    if hasattr(self, "bn"): x = self.bn(x)
    x = self.activation(x)

    return x, mask


In [140]:
UNetBlock = namedtuple("UNetBlock", ["in_channels", "out_channels",
    "kernel_size", "stride", "padding", "bn"])

EncoderBlock = [
  UNetBlock(3, 64, 7, 2, 3, False), UNetBlock(64, 128, 5, 2, 2, True),
  UNetBlock(128, 256, 3, 2, 1, True), UNetBlock(256, 512, 3, 2, 1, True),
  UNetBlock(512, 512, 3, 2, 1, True), UNetBlock(512, 512, 3, 2, 1, True),
  UNetBlock(512, 512, 3, 2, 1, True), UNetBlock(512, 512, 3, 2, 1, True)
]

DecoderBlock = [
  UNetBlock(512, 512, 3, 1, 1, True), UNetBlock(512, 512, 3, 1, 1, True),
  UNetBlock(512, 512, 3, 1, 1, True), UNetBlock(512, 512, 3, 1, 1, True),
  UNetBlock(512, 256, 3, 1, 1, True), UNetBlock(256, 128, 3, 1, 1, True),
  UNetBlock(128, 64, 3, 1, 1, True), UNetBlock(64, 3, 3, 1, 1, False)
]

In [141]:
class UNetModule(nn.Module):

  def __init__(self, e_blocks, d_blocks):
    super().__init__()

    self.e_layers, self.d_layers = [], []
    blocks = zip(e_blocks, d_blocks)

    for idx, (e_block, d_block) in enumerate(blocks):
      
      self.e_layers.append(EncoderLayer(*e_block))
      self.d_layers.append(DecoderLayer(*d_block))

    self.e_layers = nn.ModuleList(self.e_layers)
    self.d_layers = nn.ModuleList(self.d_layers)

  def forward(self, x, mask):
    
    encoded = [(x, mask)]

    for e_layer in self.e_layers:
      x, mask = e_layer(x, mask)
      encoded.append((x, mask))

    for idx, d_layer in enumerate(self.d_layers):
      x, mask = d_layer(x, mask, *encoded[-(idx + 2)])

    return x, mask


In [142]:
unet = UNetModule(EncoderBlock, DecoderBlock)

In [143]:
unet(torch.Tensor(1, 3, 512, 512), torch.ones((1, 3, 512, 512)))

(tensor([[[[ 8.6368e-01,  2.1252e+00,  1.3758e+00,  ...,  1.7904e+00,
             2.1906e+00,  2.5858e+00],
           [ 1.0734e-02, -3.9869e-03, -1.2115e-01,  ..., -7.0984e-02,
            -6.9399e-02,  2.7649e-02],
           [-9.7896e-02, -3.8180e-02, -1.5243e-01,  ..., -4.9820e-02,
            -4.3078e-02,  2.3399e-01],
           ...,
           [-6.1077e+00, -9.7851e+00, -1.7425e+01,  ..., -9.2996e+00,
            -8.5264e+00, -9.2138e+00],
           [-4.9272e+00, -1.4002e+01, -1.6254e+01,  ..., -1.0073e+01,
            -6.9511e+00, -1.6341e+01],
           [-1.0363e+01, -2.3281e+01, -2.2901e+01,  ..., -1.3245e+01,
            -1.6378e+01, -1.2668e+01]],
 
          [[-4.9273e-01, -2.3594e-01, -2.0951e-01,  ..., -2.3786e-01,
            -1.8807e-01,  9.2741e-01],
           [-5.8254e-01, -1.7978e-01, -1.7195e-01,  ..., -1.5349e-01,
            -1.2689e-01,  2.7418e-01],
           [-4.6452e-01, -2.2543e-01, -1.7828e-01,  ..., -1.0264e-01,
            -7.2852e-02,  5.1800e-01],
