In [None]:
import os
import PIL
import torch
import torchvision
import torchsummary
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm

#Setting the device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Custom Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self, image_dir, mask_dir):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.images = os.listdir(image_dir)
      
  def __len__(self):
    return len(self.images)
      
  def __getitem__(self, idx):
    image_path = os.path.join(self.image_dir, self.images[idx])
    mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg','_mask.gif'))
    image = PIL.Image.open(image_path)
    image = torchvision.transforms.Resize((256, 256))(image)
    image = (torchvision.transforms.ToTensor()(image)).unsqueeze(0).to(device)
    mask = PIL.Image.open(mask_path)
    mask = torchvision.transforms.Resize((256, 256))(mask)
    mask = (torchvision.transforms.ToTensor()(mask)).unsqueeze(0).to(device)

    return image, mask

# Partial Convolution

In [None]:
class PartialConv2d(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernerl_size, stride, padding):
    super(PartialConv2d, self).__init__()
    self.input_conv = torch.nn.Conv2d(in_channels, out_channels, kernerl_size, stride, padding)
    self.mask_conv  = torch.nn.Conv2d(in_channels, out_channels, kernerl_size, stride, padding, bias=False)
    torch.nn.init.kaiming_normal_(self.input_conv.weight, a=0, mode="fan_in")
    torch.nn.init.constant_(self.mask_conv.weight, 1.0)
    for param in self.mask_conv.parameters():
      param.requires_grad = False

  def forward(self, X, M):
    output = self.input_conv(X * M)
    output_mask = self.mask_conv(M)
    output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)
    # mask_sum is the sum of the binary mask at every partial convolution location
    mask_is_zero = (output_mask == 0)
    # temporarily sets zero values to one to ease output calculation
    mask_sum = output_mask.masked_fill_(mask_is_zero, 1.0)
    
    # output at each location as follows:
    # output = (W^T dot (X * M) + b - b) /  M_sum + b ; if M_sum > 0
    # output = 0 ; if M_sum == 0
    output = (output - output_bias) / mask_sum + output_bias
    output = output.masked_fill_(mask_is_zero, 0.0)
    
    # mask is updated at each location
    new_mask = torch.ones_like(output)
    new_mask = new_mask.masked_fill_(mask_is_zero, 0.0)

    return output, new_mask

class Conv(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bn, act):
    super().__init__()
    self.layers = torch.nn.ModuleList()
    self.layers.append(PartialConv2d(in_channels, out_channels, kernel_size, stride, padding))
    if bn:
      self.layers.append(torch.nn.BatchNorm2d(out_channels))
    if act=='relu':
      self.layers.append(torch.nn.ReLU())
    elif act=='leaky_relu':
      self.layers.append(torch.nn.LeakyReLU(0.2))

  def forward(self, X, M):
    for layer in self.layers:
      if isinstance(layer, PartialConv2d):
        X, M = layer(X, M)
      else:
        X = layer(X)
    return X, M

# UNet

In [None]:
class UNet(torch.nn.Module):
  def __init__(self, in_channels=3, encoder_kernels = [7,5,5,3,3,3,3,3], encoder_channels = [64,128,256,512,512,512,512,512], \
                                    decoder_kernels = [3,3,3,3,3,3,3,3], decoder_channels = [512,512,512,512,256,128,64,3] ):
    super(UNet, self).__init__()
    self.in_channels = 3
    self.downs = torch.nn.ModuleList()
    self.ups = torch.nn.ModuleList()
    for num, out_channels, kernel_size in zip(range(len(encoder_kernels)), encoder_channels, encoder_kernels):
      if num!=0:bn=True
      else:bn=False
      self.downs.append(Conv(in_channels, out_channels, kernel_size, 2, kernel_size//2, bn, 'relu'))
      in_channels = out_channels
    encoder_channels = encoder_channels[:-1][::-1]
    encoder_channels.append(3)
    for num, out_channels, kernel_size in zip(range(len(decoder_kernels)), decoder_channels, decoder_kernels):
      if num==len(decoder_kernels)-1:bn=False
      else:bn=True
      in_channels=in_channels+encoder_channels[num]
      self.ups.append(Conv(in_channels, out_channels, kernel_size, 1, kernel_size//2, bn, 'leaky_relu'))
      in_channels=out_channels

  def forward(self, X, M):
    self.X = X
    self.M = M
    skip_conctns = []
    for idx, down in enumerate(self.downs):
      X, M = down(X,M)
      skip_conctns.append([X,M])
    skip_conctns = skip_conctns[:-1][::-1]
    skip_conctns.append([self.X,self.M])
    for idx, up in enumerate(self.ups):
      X = F.interpolate(X,scale_factor=2)
      M = F.interpolate(M,scale_factor=2)
      skip_conctn = skip_conctns[idx]
      X = torch.cat((skip_conctn[0],X), dim=1)
      M = torch.cat((skip_conctn[1],M), dim=1)
      X, M = up(X,M)

    return X, M

# Loss

In [None]:
class Loss:
  def __init__(self, I_out, I_gt, I_in, M, lambdas=[1, 6, 0.05, 120, 0.1], layer_nums=[4, 9, 16]):
    self.I_out = I_out
    self.I_gt = I_gt
    self.I_in = I_in
    self.I_comp = self.I_in*self.M + self.I_out*(1-self.M)
    self.M  = M
    self.model = torchvision.models.vgg16(pretrained=True).eval().features[:17].to(device)
    self.lambdas = lambdas
    self.layer_nums = layer_nums
    self.l1loss = torch.nn.L1Loss()

  def generate_features(self, x):
    features = []
    for num, layer in enumerate(self.model):
      x = layer(x)
      if num in self.layer_nums:
        features.append(x)
    return features

  def gram_matrix(self, feature_matrix):
    B, C, H, W = feature_matrix.size()
    feature_matrix = feature_matrix.view(B, C, H * W)
    feature_matrix_t = feature_matrix.transpose(1, 2)
  
    # batch matrix multiplication * normalization factor K_n
    # (B, C, H * W) x (B, H * W, C) ==> (B, C, C)
    # size = (B, C, C)
    return torch.bmm(feature_matrix, feature_matrix_t) / (C*H*W)

  def total_loss(self):
    I_comp_features = self.generate_featutes(self.I_comp)
    I_out_features  = self.generate_featutes(self.I_out)
    I_gt_features   = self.generate_featutes(self.I_gt)

    I_comp_gram = self.gram_matrix(self.I_comp)
    I_out_gram  = self.gram_matrix(self.I_out)
    I_gt_gram   = self.gram_matrix(self.I_gt)

    L_valid      = self.l1loss(self.M*self.I_out, self.M*self.I_gt)
    L_hole       = self.l1loss((1-self.M)*self.I_out, (1-self.M)*self.I_gt)
    L_perceptual = self.l1loss(I_out_features, I_gt_features)+self.l1loss(I_comp_features, I_gt_features)
    L_style      = self.l1loss(I_out_gram, I_gt_gram)+self.l1loss(I_comp_gram, I_gt_gram)
    L_tv         = self.l1loss(self.I_comp[:, :, :, :-1], self.I_comp[:, :, :, 1:]) + self.l1loss(self.I_comp[:, :, :-1, :], self.I_comp[:, :, 1:, :])

    return self.lambdas[0]*L_valid + self.lambdas[1]*L_hole + self.lambdas[2]*L_perceptual + self.lambdas[3]*L_style + self.lambdas[4]*L_tv

In [None]:
def save_chkpt(model, optim, filename='drive/MyDrive/UNet.pth.tar'):
  chkpt = {'model':model.state_dict(),'optim':optim.state_dict()}
  torch.save(chkpt, filename)

def load_chkpt(model, optim, filename='drive/MyDrive/UNet.pth.tar'):
  chkpt = torch.load(filename)
  model.load_state_dict(chkpt['model'])
  optim.load_state_dict(chkpt['optim'])

In [None]:
model = UNet()

In [None]:
epoch_losses = [] 
def train(x_train, y_train, model, optim, loss_fn, epochs):
  train_set = Dataset(x_train, y_train)
  train_loader = torch.utils.data.DataLoader(train_set, 16, True)
  for epoch in range(epochs):
    batch_losses = []
    if epoch%1==0 and epoch!=0:
      save_chkpt(model, optim)
    loop = tqdm(train_loader,  position=0, leave=True)
    for x, y in loop:
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = loss_fn(y_hat, y)
      batch_losses.append(loss.item())
      optim.zero_grad()
      loss.backward()
      optim.step()
      loop.set_postfix(loss=loss)

    epoch_losses.append(sum(batch_losses)/len(batch_losses))
  return

In [None]:
torchsummary.summary(model,[(3,256,256),(3,256,256)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,472
            Conv2d-2         [-1, 64, 128, 128]           9,408
     PartialConv2d-3  [[-1, 64, 128, 128], [-1, 64, 128, 128]]               0
              ReLU-4         [-1, 64, 128, 128]               0
              Conv-5  [[-1, 64, 128, 128], [-1, 64, 128, 128]]               0
            Conv2d-6          [-1, 128, 64, 64]         204,928
            Conv2d-7          [-1, 128, 64, 64]         204,800
     PartialConv2d-8  [[-1, 128, 64, 64], [-1, 128, 64, 64]]               0
       BatchNorm2d-9          [-1, 128, 64, 64]             256
             ReLU-10          [-1, 128, 64, 64]               0
             Conv-11  [[-1, 128, 64, 64], [-1, 128, 64, 64]]               0
           Conv2d-12          [-1, 256, 32, 32]         819,456
           Conv2d-13          [-1, 256, 32, 32]