<a href="https://colab.research.google.com/github/ShkarupaDC/image_inpainting/blob/master/Image_Inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from collections import namedtuple

In [13]:
class PConv2d(nn.Conv2d):


  def __init__(self, *args, **kwargs):
    
    if "return_mask" in kwargs:
      
      self.return_mask = kwargs["return_mask"]
      kwargs.pop("return_mask")
    
    super().__init__(*args, **kwargs)
    self.window = self.in_channels * self.out_channels *\
          self.kernel_size[0] * self.kernel_size[0]
    
    self.mask_kernel = torch.ones(self.out_channels,
          self.in_channels, *self.kernel_size)


  def forward(self, x, mask=None):

    with torch.no_grad():
      if mask is not None:

        updated_mask = F.conv2d(mask, self.mask_kernel, None,
            self.stride, self.padding, self.dilation)
        
        mask_ratio = self.window / (updated_mask + 1e-6)
        
        updated_mask = torch.clamp(updated_mask, 0, 1)
        mask_ratio = torch.mul(updated_mask, mask_ratio)
    
    x = super().forward(x if mask is None else torch.mul(x, mask))

    if self.bias is not None:
      
      bias_view = self.bias.view(1, self.out_channels, 1, 1)
      x = torch.mul(x, mask_ratio) + bias_view

    else: x = torch.mul(x, mask_ratio)

    if self.return_mask is True: return (x, updated_mask)
    
    else: return x

In [14]:
class EncoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size,
      stride=1, padding=0, bn=True, activation=nn.ReLU()):
    super().__init__()

    self.pconv = PConv2d(in_channels, out_channels,
        kernel_size, stride, padding, return_mask=True)
    
    if bn is True:
      self.bn = nn.BatchNorm2d(out_channels)

    self.activation = activation

  
  def forward(self, x, mask):
    
    x, mask = self.pconv(x, mask)

    if hasattr(self, "bn"):
      x = self.bn(x)

    x = self.activation(x)
    
    return x, mask

In [18]:
class DecoderLayer(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size,
    stride=1, padding=0, bn=True, activation=nn.LeakyReLU(0.2)):

    self.pconv = PConv2d(in_channels, out_channels,
      kernel_size, stride, padding, return_mask=True)
    
    if bn is True:
      self.bn = nn.BatchNorm2d(out_channels * 2)

    self.activation = activation


  def forward(self, x, mask, encoded_x, encoded_mask):

    x = F.upsample_nearest(x, (2, 2))
    mask = F.upsample_nearest(mask, (2, 2))

    x = torch.cat([x, encoded_x], dim=-1)
    mask = torch.cat([mask, encoded_mask], dim=-1)

    if hasattr(self, "bn"): x = self.bn(x)
    x = self.activation(x)

    return x, mask


In [32]:
UNetBlock = namedtuple("UNetBlock", ["in_channels", "out_channels",
    "kernel_size", "stride", "padding", "dilation", "bn"])

EncoderBlock = [
  UNetBlock(3, 64, 7, 2, 3, 1, False), UNetBlock(64, 128, 5, 2, 2, 1, True),
  UNetBlock(128, 256, 3, 2, 1, 1, True), UNetBlock(256, 512, 3, 2, 1, 1, True),
  UNetBlock(512, 512, 3, 2, 1, 1, True), UNetBlock(512, 512, 3, 2, 1, 1, True),
  UNetBlock(512, 512, 3, 2, 1, 1, True), UNetBlock(512, 512, 3, 2, 1, 1, True)
]

DecoderBlock = [
  UNetBlock(1024, 512, 3, 2, 1, 1, True), UNetBlock(1024, 512, 3, 2, 1, 1, True),
  UNetBlock(1024, 512, 3, 2, 1, 1, True), UNetBlock(1024, 512, 3, 2, 1, 1, True),
  UNetBlock(768, 256, 3, 2, 1, 1, True), UNetBlock(384, 128, 3, 2, 1, 1, True),
  UNetBlock(192, 64, 3, 2, 1, 1, True), UNetBlock(67, 3, 3, 2, 1, 1, False)
]

In [20]:
class UNetModule(nn.Module):

  def __init__(self, e_blocks, d_blocks):
    super().__init__()

    self.e_layers, self.d_layers = [], []
    blocks = zip(e_blocks, d_blocks)

    for idx, (e_block, d_block) in enumerate(blocks):
      
      self.e_layers.append(EncoderLayer(*e_block))
      self.d_layers.append(DecoderLayer(*d_block))

    self.e_layers = nn.ModuleList(self.e_layers)
    self.d_layers = nn.ModuleList(self.d_layers)

  def forward(self, x, mask):
    
    encoded = [(x, mask)]

    for e_layer in self.e_layers:
      x, mask = e_layer(x, mask)
      encoded.append((x, mask))

    for idx, d_layer in enumerate(self.d_layers):
      x, mask = d_layer(x, mask, *encoded[-(idx + 1)])

    return x, mask
