In [39]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

from layers import GatedConv, ResizeGatedConv
from torchsummary import summary
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

In [41]:
def get_images(batch_size):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])

    cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    data_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=batch_size, shuffle=True)

    return data_loader


def get_mask(image_size, square_size):
    mask = np.zeros(image_size, dtype=np.uint8)
    start = (image_size[0] - square_size) // 2
    end = start + square_size
    mask[start:end, start:end] = 1
    mask = np.asarray(mask, np.float32)
    mask = torch.from_numpy(mask)
    return mask

In [42]:
class FreeFormImageInpaint(nn.Module):
    def __init__(self, in_channels):
        super(FreeFormImageInpaint, self).__init__()

        self.coarse_network = nn.Sequential(
            GatedConv(in_channels, out_channels=32, kernel_size=5, stride=1, padding=2), # batch_size x 32 x 256 x 256
            GatedConv(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), # batch_size x 64 x 128 x 128
            GatedConv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), # batch_size x 64 x 128 x 128
            GatedConv(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=8, dilation=8), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=16, dilation=16), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # batch_size x 128 x 64 x 64
            GatedConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # batch_size x 128 x 64 x 64
            ResizeGatedConv(in_channels=128, out_channels=64, padding=1, scale_factor=2), # batch_size x 64 x 128 x 128
            GatedConv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), # batch_size x 64 x 128 x 128
            ResizeGatedConv(in_channels=64, out_channels=32, padding=1, scale_factor=2), # batch_size x 32 x 256 x 256
            GatedConv(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1), # batch_size x 16 x 256 x 256
            GatedConv(in_channels=16, out_channels=3, kernel_size=3, stride=1, padding=1, feature_act=None) # batch_size x 3 x 256 x 256
        )


    def forward(self, x, masks):
        """
        dim of x: batch_size x 256 x 256 x channels
        dim of mask: batch_size x 256 x 256
        """
        #print("shape of images (B x C x H x W):", x.shape)
        #print("shape of masks (B x H x W):", masks.shape)
        # TODO: normalize images and pair images with corresponding masks as input
        # input will contain masked images
        #x = x.permute(0, 3, 1, 2) # batch_size x channels x 256 x 256
        masks = masks.unsqueeze(1) # batch_size x 1 x 256 x 256
        masked_imgs = x * (1 - masks)
        #print(masks.shape)
        #print(masked_imgs.shape)
        input = torch.cat([masked_imgs, masks], dim=1) # batch_size x (channels + 1) x 256 x 256
        #print("shape of input into coarse network:", input.shape)

        # coarse network
        coarse_out = self.coarse_network(input)
        # clip output so values are between -1 and 1
        coarse_clip = torch.clamp(coarse_out, -1, 1)

        return coarse_clip

    def loss_function(self, x_hat, x, masks, alpha):
        '''
        dim of x_hat & x: batch_size x 3 x 256 x 256
        dim of masks: batch_size x 256 x 256
        '''

        # TODO: convert x/x_hat to just masked and unmasked portion
        #print("shape of x:", x.shape)
        #print("shape of xhat:", x_hat.shape)
        masks = masks.unsqueeze(1) # batch_size x 1 x 256 x 256
        #print("shape of masks:", masks.shape)
        unmasked = x * masks
        unmasked_hat = x_hat * masks
        masked = x * (1 - masks)
        masked_hat = x_hat * (1 - masks)

        mask_bit_ratio = torch.mean(torch.mean(masks, -1), -1) #take the ratio of masked to unmasked bits
        mask_bit_ratio = mask_bit_ratio.unsqueeze(-1)
        mask_bit_ratio = mask_bit_ratio.unsqueeze(-1)
        #print(mask_bit_ratio.shape)
        bit_mask_ratio = torch.mean(torch.mean(1-masks, -1), -1) #take the ratio of unmasked to masked bits
        bit_mask_ratio = bit_mask_ratio.unsqueeze(-1)
        bit_mask_ratio = bit_mask_ratio.unsqueeze(-1)
        masked_loss = alpha * torch.mean(torch.abs(masked - masked_hat) / mask_bit_ratio)
        unmasked_loss = alpha * torch.mean(torch.abs(unmasked - unmasked_hat) / bit_mask_ratio)
        loss = masked_loss + unmasked_loss
        return loss

In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [44]:
n_epochs = 20
alpha = 0.02
learning_rate = 0.001
batch_size = 256

In [45]:
image_loader = get_images(batch_size)
image_size = (32, 32)
square_size = 10
binary_mask = get_mask(image_size, square_size)
# reshape binary mask to add batch_size dimension
binary_mask = binary_mask.unsqueeze(0)
binary_mask = binary_mask.expand(batch_size, -1, -1)
binary_mask = binary_mask.to(device)

print(binary_mask.shape)

Files already downloaded and verified
torch.Size([256, 32, 32])


In [46]:
from torchsummary import summary
inpaint = FreeFormImageInpaint(in_channels=4).to(device)
summary(inpaint, [(3, 32, 32), (32, 32)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]           3,232
            Conv2d-2           [-1, 32, 32, 32]           3,232
           Sigmoid-3           [-1, 32, 32, 32]               0
         LeakyReLU-4           [-1, 32, 32, 32]               0
         LeakyReLU-5           [-1, 32, 32, 32]               0
         LeakyReLU-6           [-1, 32, 32, 32]               0
         LeakyReLU-7           [-1, 32, 32, 32]               0
         LeakyReLU-8           [-1, 32, 32, 32]               0
         LeakyReLU-9           [-1, 32, 32, 32]               0
        LeakyReLU-10           [-1, 32, 32, 32]               0
        LeakyReLU-11           [-1, 32, 32, 32]               0
        LeakyReLU-12           [-1, 32, 32, 32]               0
        LeakyReLU-13           [-1, 32, 32, 32]               0
        LeakyReLU-14           [-1, 32,

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [47]:
inpaint = FreeFormImageInpaint(in_channels=4).to(device)
optimizer = torch.optim.Adam(inpaint.parameters(), lr=learning_rate)

In [48]:
len(image_loader)

196

In [49]:
for epoch in range(n_epochs):
    train_loss = 0
    loss = 0
    for batch_idx, (data, _) in enumerate(image_loader):
        #print(batch_idx)
        if batch_idx == 195:
            break
        data = data.to(device)
        binary_mask = binary_mask.to(device)
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # model forward
        x_hat = inpaint(data, binary_mask)
        #print("shape of output:", x_hat.shape)
        # compute the loss
        loss = inpaint.loss_function(x_hat, data, binary_mask, alpha)
        # model backward
        loss.backward()
        # update the model paramters
        optimizer.step()
        # update running training loss
        train_loss += loss
    train_loss = train_loss/len(image_loader)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

Epoch: 0 	Training Loss: 0.024470
Epoch: 1 	Training Loss: 0.014601
Epoch: 2 	Training Loss: 0.012335
Epoch: 3 	Training Loss: 0.010665
Epoch: 4 	Training Loss: 0.009768
Epoch: 5 	Training Loss: 0.009265
Epoch: 6 	Training Loss: 0.008589
Epoch: 7 	Training Loss: 0.008274
Epoch: 8 	Training Loss: 0.007982
Epoch: 9 	Training Loss: 0.007741
Epoch: 10 	Training Loss: 0.007501
Epoch: 11 	Training Loss: 0.007263
Epoch: 12 	Training Loss: 0.007117
Epoch: 13 	Training Loss: 0.007022
Epoch: 14 	Training Loss: 0.006948
Epoch: 15 	Training Loss: 0.006930
Epoch: 16 	Training Loss: 0.006889
Epoch: 17 	Training Loss: 0.006776
Epoch: 18 	Training Loss: 0.006793
Epoch: 19 	Training Loss: 0.006731
