In [None]:
!pip3 install pyprind

Collecting pyprind
  Downloading PyPrind-2.11.3-py2.py3-none-any.whl (8.4 kB)
Installing collected packages: pyprind
Successfully installed pyprind-2.11.3


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
PATH = "/content/drive/MyDrive/Projects/Clubs/Analytics/Coord Projects/Model Zoo/Inpainting/datasets"

In [None]:
#!wget https://image-net.org/data/decathlon-1.0-data-imagenet.tar

In [None]:
#!tar -xvf "/content/decathlon-1.0-data-imagenet.tar" -C "/content/drive/MyDrive/Projects/Clubs/Analytics/Coord Projects/Model Zoo/Inpainting/datasets/"

---

In [None]:
from google.colab.patches import cv2_imshow

In [None]:
import torch
import torchvision
import cv2
import numpy as np
import os, glob


class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, PATH, dataset, mode='train', sub_folder=True, img_size=256):
        self.PATH = PATH
        self.dataset = dataset
        self.mode = mode
        self.img_size = 256
        self.images = np.array([])
        self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if mode=='train':
            self.size = 16384*2
            self.transform = True
        elif mode=='val':
            self.size = 1024
            self.transform = True
        elif mode=='test':
            self.size = 32
            self.transform = False

        if sub_folder:
            directories = self.find_sub_folders(os.path.join(self.PATH, self.dataset, self.mode))
            for directory in directories:
                entries = [ os.path.basename(entry) for entry in glob.glob(os.path.join(self.PATH, self.dataset, self.mode, directory, "*.jpg")) ]
                paths = [os.path.join(self.PATH, self.dataset, self.mode, directory, entry) for entry in entries]
                self.images = np.append(self.images, paths)
        else:
            entries = [ os.path.basename(entry) for entry in glob.glob(os.path.join(self.PATH, self.dataset, self.mode, "*.jpg")) ]
            paths = [os.path.join(self.PATH, self.dataset, self.mode, entry) for entry in entries]
            self.images = np.append(self.images, paths)

        np.random.shuffle(self.images)
        self.images = self.images[:self.size]

    def find_sub_folders(self, directory):
        directories = [dir for dir in os.listdir(directory) if os.path.isdir(os.path.join(directory, dir))]
        return directories

    def image_transform(self, image):
        image = np.array(image)/255.
        image = image.transpose((2, 0, 1))
        if self.transform:
            image = self.normalize(torch.from_numpy(image.copy()))
        return image
    
    def image_detransform(self, image):
        image = image.numpy()
        if image.shape[0] == 3:
            image = np.moveaxis(image, 0, -1)
        return image*255

    def __getitem__(self, index):
        image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
        image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
        image = self.image_transform(image)

        return image
    
    def __len__(self):
        return len(self.images)

In [None]:
train_data = CreateDataset(PATH, "imagenet12", mode='train', sub_folder=True)
val_data = CreateDataset(PATH, "imagenet12", mode='val', sub_folder=True)
test_data = CreateDataset(PATH, "imagenet12", mode='test', sub_folder=False)

In [None]:
print(len(train_data))
print(len(val_data))
print(len(test_data))

32768
1024
32


---

In [None]:
def extract_image_patches(images, ksizes, strides, rates, padding='same'):

    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    batch_size, channel, height, width = images.size()

    if padding == 'same':
        images = same_padding(images, ksizes, strides, rates)
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {}.\
                Only "same" or "valid" are supported.'.format(padding))

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches  # [N, C*k*k, L], L is the total number of such blocks

In [None]:
def same_padding(images, ksizes, strides, rates):
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images

In [None]:
def reduce_mean(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.mean(x, dim=i, keepdim=keepdim)
    return x


def reduce_std(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.std(x, dim=i, keepdim=keepdim)
    return x


def reduce_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x


In [None]:
def make_color_wheel():
    RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = np.zeros([ncols, 3])
    col = 0
    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
    col += RY
    # YG
    colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
    colorwheel[col:col + YG, 1] = 255
    col += YG
    # GC
    colorwheel[col:col + GC, 1] = 255
    colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
    col += GC
    # CB
    colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
    colorwheel[col:col + CB, 2] = 255
    col += CB
    # BM
    colorwheel[col:col + BM, 2] = 255
    colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
    col += + BM
    # MR
    colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
    colorwheel[col:col + MR, 0] = 255
    return colorwheel

In [None]:
def compute_color(u, v):
    h, w = u.shape
    img = np.zeros([h, w, 3])
    nanIdx = np.isnan(u) | np.isnan(v)
    u[nanIdx] = 0
    v[nanIdx] = 0
    # colorwheel = COLORWHEEL
    colorwheel = make_color_wheel()
    ncols = np.size(colorwheel, 0)
    rad = np.sqrt(u ** 2 + v ** 2)
    a = np.arctan2(-v, -u) / np.pi
    fk = (a + 1) / 2 * (ncols - 1) + 1
    k0 = np.floor(fk).astype(int)
    k1 = k0 + 1
    k1[k1 == ncols + 1] = 1
    f = fk - k0
    for i in range(np.size(colorwheel, 1)):
        tmp = colorwheel[:, i]
        col0 = tmp[k0 - 1] / 255
        col1 = tmp[k1 - 1] / 255
        col = (1 - f) * col0 + f * col1
        idx = rad <= 1
        col[idx] = 1 - rad[idx] * (1 - col[idx])
        notidx = np.logical_not(idx)
        col[notidx] *= 0.75
        img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
    return img

In [None]:
def flow_to_image(flow):
    out = []
    maxu = -999.
    maxv = -999.
    minu = 999.
    minv = 999.
    maxrad = -1
    for i in range(flow.shape[0]):
        u = flow[i, :, :, 0]
        v = flow[i, :, :, 1]
        idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
        u[idxunknow] = 0
        v[idxunknow] = 0
        maxu = max(maxu, np.max(u))
        minu = min(minu, np.min(u))
        maxv = max(maxv, np.max(v))
        minv = min(minv, np.min(v))
        rad = np.sqrt(u ** 2 + v ** 2)
        maxrad = max(maxrad, np.max(rad))
        u = u / (maxrad + np.finfo(float).eps)
        v = v / (maxrad + np.finfo(float).eps)
        img = compute_color(u, v)
        out.append(img)
    return np.float32(np.uint8(out))

In [None]:
def random_bbox(batch_size, image_shape=(256,256,3), mask_shape=(128, 128), margin=(0,0), mask_batch_same=True):
    img_height, img_width, _ = image_shape
    h, w = mask_shape
    margin_height, margin_width = margin
    maxt = img_height - margin_height - h
    maxl = img_width - margin_width - w
    bbox_list = []
    if mask_batch_same:
        t = np.random.randint(margin_height, maxt)
        l = np.random.randint(margin_width, maxl)
        bbox_list.append((t, l, h, w))
        bbox_list = bbox_list * batch_size
    else:
        for i in range(batch_size):
            t = np.random.randint(margin_height, maxt)
            l = np.random.randint(margin_width, maxl)
            bbox_list.append((t, l, h, w))

    return torch.tensor(bbox_list, dtype=torch.int64)

In [None]:
def bbox2mask(bboxes, height, width, max_delta_h, max_delta_w):
    batch_size = bboxes.size(0)
    mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32)
    for i in range(batch_size):
        bbox = bboxes[i]
        delta_h = np.random.randint(max_delta_h // 2 + 1)
        delta_w = np.random.randint(max_delta_w // 2 + 1)
        mask[i, :, bbox[0] + delta_h:bbox[0] + bbox[2] - delta_h, bbox[1] + delta_w:bbox[1] + bbox[3] - delta_w] = 1.
    return mask

In [None]:
def mask_image(x, bboxes, image_shape=(256,256,3), max_delta_shape=(32,32), mask_type='hole'):
    height, width, _ = image_shape
    max_delta_h, max_delta_w = max_delta_shape
    mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)
    if x.is_cuda:
        mask = mask.cuda()

    if mask_type == 'hole':
        result = x * (1. - mask)
    elif mask_type == 'mosaic':
        # TODO: Matching the mosaic patch size and the mask size
        mosaic_unit_size = 12
        downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode='nearest')
        upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode='nearest')
        result = upsampled_image * mask + x * (1. - mask)
    else:
        raise NotImplementedError('Not implemented mask type.')

    return result, mask

In [None]:
def spatial_discounting_mask(spatial_discounting_gamma=0.9, mask_shape=(128, 128), discounted_mask=True, use_cuda=False):
    gamma = spatial_discounting_gamma
    height, width = mask_shape
    shape = [1, 1, height, width]
    if discounted_mask:
        mask_values = np.ones((height, width))
        for i in range(height):
            for j in range(width):
                mask_values[i, j] = max(
                    gamma ** min(i, height - i),
                    gamma ** min(j, width - j))
        mask_values = np.expand_dims(mask_values, 0)
        mask_values = np.expand_dims(mask_values, 0)
    else:
        mask_values = np.ones(shape)
    spatial_discounting_mask_tensor = torch.tensor(mask_values, dtype=torch.float32)
    if use_cuda:
        spatial_discounting_mask_tensor = spatial_discounting_mask_tensor.cuda()
    return spatial_discounting_mask_tensor

In [None]:
def local_patch(x, bbox_list):
    assert len(x.size()) == 4
    patches = []
    for i, bbox in enumerate(bbox_list):
        t, l, h, w = bbox
        patches.append(x[i, :, t:t + h, l:l + w])
    return torch.stack(patches, dim=0)

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm as spectral_norm_fn
from torch.nn.utils import weight_norm as weight_norm_fn
from torchvision import transforms
from torchvision import utils as vutils

In [None]:
class Conv2dBlock(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0,
                 conv_padding=0, dilation=1, weight_norm='none', norm='none',
                 activation='relu', pad_type='zero', transpose=False):
        super(Conv2dBlock, self).__init__()
        self.use_bias = True

        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        elif pad_type == 'none':
            self.pad = None
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        if weight_norm == 'sn':
            self.weight_norm = spectral_norm_fn
        elif weight_norm == 'wn':
            self.weight_norm = weight_norm_fn
        elif weight_norm == 'none':
            self.weight_norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(weight_norm)

        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'elu':
            self.activation = nn.ELU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        if transpose:
            self.conv = nn.ConvTranspose2d(input_dim, output_dim,
                                           kernel_size, stride,
                                           padding=conv_padding,
                                           output_padding=conv_padding,
                                           dilation=dilation,
                                           bias=self.use_bias)
        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride,
                                  padding=conv_padding, dilation=dilation,
                                  bias=self.use_bias)

        if self.weight_norm:
            self.conv = self.weight_norm(self.conv)

    def forward(self, x):
        if self.pad:
            x = self.conv(self.pad(x))
        else:
            x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

In [None]:
def gen_conv(input_dim, output_dim, kernel_size=3, stride=1, padding=0, rate=1,
             activation='elu'):
    return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
                       conv_padding=padding, dilation=rate,
                       activation=activation)


def dis_conv(input_dim, output_dim, kernel_size=5, stride=2, padding=0, rate=1,
             activation='lrelu'):
    return Conv2dBlock(input_dim, output_dim, kernel_size, stride,
                       conv_padding=padding, dilation=rate,
                       activation=activation)

---

In [None]:
class CoarseGenerator(nn.Module):
    def __init__(self, input_dim, cnum, use_cuda=False):
        super(CoarseGenerator, self).__init__()
        self.use_cuda = use_cuda

        self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
        self.conv2_downsample = gen_conv(cnum, cnum*2, 3, 2, 1)
        self.conv3 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
        self.conv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
        self.conv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)

        self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
        self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
        self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
        self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)

        self.conv11 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.conv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)

        self.conv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
        self.conv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
        self.conv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
        self.conv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
        self.conv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')

    def forward(self, x, mask):
        # For indicating the boundaries of images
        ones = torch.ones(x.size(0), 1, x.size(2), x.size(3))
        if self.use_cuda:
            ones = ones.cuda()
            mask = mask.cuda()
        # 5 x 256 x 256
        x = self.conv1(torch.cat([x, ones, mask], dim=1))
        x = self.conv2_downsample(x)
        # cnum*2 x 128 x 128
        x = self.conv3(x)
        x = self.conv4_downsample(x)
        # cnum*4 x 64 x 64
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7_atrous(x)
        x = self.conv8_atrous(x)
        x = self.conv9_atrous(x)
        x = self.conv10_atrous(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        # cnum*2 x 128 x 128
        x = self.conv13(x)
        x = self.conv14(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        # cnum x 256 x 256
        x = self.conv15(x)
        x = self.conv16(x)
        x = self.conv17(x)
        # 3 x 256 x 256
        x_stage1 = torch.clamp(x, -1., 1.)

        return x_stage1

In [None]:
class FineGenerator(nn.Module):
    def __init__(self, input_dim, cnum, use_cuda=False):
        super(FineGenerator, self).__init__()
        self.use_cuda = use_cuda
        # 3 x 256 x 256
        self.conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
        self.conv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
        # cnum*2 x 128 x 128
        self.conv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
        self.conv4_downsample = gen_conv(cnum*2, cnum*2, 3, 2, 1)
        # cnum*4 x 64 x 64
        self.conv5 = gen_conv(cnum*2, cnum*4, 3, 1, 1)
        self.conv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1)

        self.conv7_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 2, rate=2)
        self.conv8_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 4, rate=4)
        self.conv9_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 8, rate=8)
        self.conv10_atrous = gen_conv(cnum*4, cnum*4, 3, 1, 16, rate=16)

        # attention branch
        # 3 x 256 x 256
        self.pmconv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2)
        self.pmconv2_downsample = gen_conv(cnum, cnum, 3, 2, 1)
        # cnum*2 x 128 x 128
        self.pmconv3 = gen_conv(cnum, cnum*2, 3, 1, 1)
        self.pmconv4_downsample = gen_conv(cnum*2, cnum*4, 3, 2, 1)
        # cnum*4 x 64 x 64
        self.pmconv5 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.pmconv6 = gen_conv(cnum*4, cnum*4, 3, 1, 1, activation='relu')
        self.contextul_attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=self.use_cuda)
        self.pmconv9 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.pmconv10 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.allconv11 = gen_conv(cnum*8, cnum*4, 3, 1, 1)
        self.allconv12 = gen_conv(cnum*4, cnum*4, 3, 1, 1)
        self.allconv13 = gen_conv(cnum*4, cnum*2, 3, 1, 1)
        self.allconv14 = gen_conv(cnum*2, cnum*2, 3, 1, 1)
        self.allconv15 = gen_conv(cnum*2, cnum, 3, 1, 1)
        self.allconv16 = gen_conv(cnum, cnum//2, 3, 1, 1)
        self.allconv17 = gen_conv(cnum//2, input_dim, 3, 1, 1, activation='none')

    def forward(self, xin, x_stage1, mask):
        x1_inpaint = x_stage1 * mask + xin * (1. - mask)
        # For indicating the boundaries of images
        ones = torch.ones(xin.size(0), 1, xin.size(2), xin.size(3))
        if self.use_cuda:
            ones = ones.cuda()
            mask = mask.cuda()

        # conv branch
        xnow = torch.cat([x1_inpaint, ones, mask], dim=1)
        x = self.conv1(xnow)
        x = self.conv2_downsample(x)
        x = self.conv3(x)
        x = self.conv4_downsample(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7_atrous(x)
        x = self.conv8_atrous(x)
        x = self.conv9_atrous(x)
        x = self.conv10_atrous(x)
        x_hallu = x
        # attention branch
        x = self.pmconv1(xnow)
        x = self.pmconv2_downsample(x)
        x = self.pmconv3(x)
        x = self.pmconv4_downsample(x)
        x = self.pmconv5(x)
        x = self.pmconv6(x)
        x, offset_flow = self.contextul_attention(x, x, mask)
        x = self.pmconv9(x)
        x = self.pmconv10(x)
        pm = x
        x = torch.cat([x_hallu, pm], dim=1)
        # merge two branches
        x = self.allconv11(x)
        x = self.allconv12(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.allconv13(x)
        x = self.allconv14(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.allconv15(x)
        x = self.allconv16(x)
        x = self.allconv17(x)
        x_stage2 = torch.clamp(x, -1., 1.)

        return x_stage2, offset_flow

In [None]:
class ContextualAttention(nn.Module):
    def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10,
                 fuse=False, use_cuda=False):
        super(ContextualAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.rate = rate
        self.fuse_k = fuse_k
        self.softmax_scale = softmax_scale
        self.fuse = fuse
        self.use_cuda = use_cuda

    def forward(self, f, b, mask=None):
        # get shapes
        raw_int_fs = list(f.size())   # b*c*h*w
        raw_int_bs = list(b.size())   # b*c*h*w

        # extract patches from background with stride and rate
        kernel = 2 * self.rate
        # raw_w is extracted for reconstruction
        raw_w = extract_image_patches(b, ksizes=[kernel, kernel],
                                      strides=[self.rate*self.stride,
                                               self.rate*self.stride],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
        # raw_shape: [N, C, k, k, L]
        raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
        raw_w = raw_w.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
        int_fs = list(f.size())     # b*c*h*w
        int_bs = list(b.size())
        f_groups = torch.split(f, 1, dim=0)  # split tensors along the batch dimension
        # w shape: [N, C*k*k, L]
        w = extract_image_patches(b, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
        # w shape: [N, C, k, k, L]
        w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
        w = w.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
        w_groups = torch.split(w, 1, dim=0)

        # process mask
        if mask is None:
            mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
            if self.use_cuda:
                mask = mask.cuda()

        else:
            mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest')
        int_ms = list(mask.size())
        # m shape: [N, C*k*k, L]
        m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
        # m shape: [N, C, k, k, L]
        m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
        m = m.permute(0, 4, 1, 2, 3)    # m shape: [N, L, C, k, k]
        m = m[0]    # m shape: [L, C, k, k]
        # mm shape: [L, 1, 1, 1]
        mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True)==0.).to(torch.float32)
        mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale    # to fit the PyTorch tensor image value range
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.use_cuda:
            fuse_weight = fuse_weight.cuda()

        for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
            # conv for compare
            escape_NaN = torch.FloatTensor([1e-4])
            if self.use_cuda:
                escape_NaN = escape_NaN.cuda()
            
            wi = wi[0]  # [L, C, k, k]
            max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
            wi_normed = wi / max_wi
            # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W]
            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi = same_padding(yi, [k, k], [1, 1], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)  # (B=1, C=1, H=32*32, W=32*32)
                yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
                yi = same_padding(yi, [k, k], [1, 1], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)
                yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
                yi = yi.permute(0, 2, 1, 4, 3).contiguous()
            yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax to match
            yi = yi * mm
            yi = F.softmax(yi*scale, dim=1)
            yi = yi * mm  # [1, L, H, W]

            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W

            if int_bs != int_fs:
                # Normalize the offset value to match foreground dimension
                times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1)  # 1*2*H*W

            # deconv for patch pasting
            wi_center = raw_wi[0]
            # yi = F.pad(yi, [0, 1, 0, 1])    # here may need conv_transpose same padding
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(raw_int_fs)

        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view(int_fs[0], 2, *int_fs[2:])

        # case1: visualize optical flow: minus current position
        h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
        w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
        ref_coordinate = torch.cat([h_add, w_add], dim=1)
        if self.use_cuda:
            ref_coordinate = ref_coordinate.cuda()

        offsets = offsets - ref_coordinate
        # flow = pt_flow_to_image(offsets)

        flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
        flow = flow.permute(0, 3, 1, 2)
        if self.use_cuda:
            flow = flow.cuda()
        # case2: visualize which pixels are attended
        # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))

        if self.rate != 1:
            flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest')

        return y, flow

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim=3, cnum=32, use_cuda=False):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.cnum = cnum
        self.use_cuda = use_cuda

        self.coarse_generator = CoarseGenerator(self.input_dim, self.cnum, self.use_cuda)
        self.fine_generator = FineGenerator(self.input_dim, self.cnum, self.use_cuda)

    def forward(self, x, mask):
        x_stage1 = self.coarse_generator(x, mask)
        x_stage2, offset_flow = self.fine_generator(x, x_stage1, mask)
        return x_stage1, x_stage2, offset_flow

---

In [None]:
class DisConvModule(nn.Module):
    def __init__(self, input_dim, cnum):
        super(DisConvModule, self).__init__()

        self.conv1 = dis_conv(input_dim, cnum, 5, 2, 2)
        self.conv2 = dis_conv(cnum, cnum*2, 5, 2, 2)
        self.conv3 = dis_conv(cnum*2, cnum*4, 5, 2, 2)
        self.conv4 = dis_conv(cnum*4, cnum*4, 5, 2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        return x

In [None]:
class LocalDis(nn.Module):
    def __init__(self, input_dim=3, cnum=64, use_cuda=False):
        super(LocalDis, self).__init__()
        self.input_dim = input_dim
        self.cnum = cnum
        self.use_cuda = use_cuda

        self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
        self.linear = nn.Linear(self.cnum*4*8*8, 1)

    def forward(self, x):
        x = self.dis_conv_module(x)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)

        return x

In [None]:
class GlobalDis(nn.Module):
    def __init__(self, input_dim=3, cnum=64, use_cuda=False):
        super(GlobalDis, self).__init__()
        self.use_cuda = use_cuda
        self.input_dim = input_dim
        self.cnum = cnum

        self.dis_conv_module = DisConvModule(self.input_dim, self.cnum)
        self.linear = nn.Linear(self.cnum*4*16*16, 1)

    def forward(self, x):
        x = self.dis_conv_module(x)
        x = x.view(x.size()[0], -1)
        x = self.linear(x)

        return x

---

In [None]:
from torchsummary import summary
import torch.optim as optim
from torch.utils.data import DataLoader
from torch import autograd
import gc

import time
import pyprind

import matplotlib.pyplot as plt

In [None]:
model_train = True  
batch_size = 6
start_epochs = 0
total_epochs = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "/content/drive/MyDrive/Projects/Clubs/Analytics/Coord Projects/Model Zoo/Inpainting/checkpoints"

In [None]:
net_gen = Generator(use_cuda=True)
net_local_dis = LocalDis()
net_global_dis = GlobalDis()

In [None]:
d_params = list(net_local_dis.parameters()) + list(net_global_dis.parameters())

In [None]:
optimizer_g = optim.Adam(net_gen.parameters(), lr=0.0001, betas=(0.5, 0.9))
optimizer_d = optim.Adam(d_params, lr=0.0001, betas=(0.5, 0.9))

In [None]:
net_gen.to(device)
net_local_dis.to(device)
net_global_dis.to(device)
pass

In [None]:
criterionL1 = nn.L1Loss().to(device)

In [None]:
def calc_gradient_penalty(netD, real_data, fake_data, device):
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    alpha = alpha.to(device)

    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates = interpolates.requires_grad_().clone()

    disc_interpolates = netD(interpolates.float())
    grad_outputs = torch.ones(disc_interpolates.size())
    grad_outputs = grad_outputs.to(device)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                grad_outputs=grad_outputs, create_graph=True,
                                retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty

In [None]:
 def dis_forward(netD, ground_truth, x_inpaint):
    #assert ground_truth.size() == x_inpaint.size()
    batch_size = ground_truth.size(0)
    batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
    batch_output = netD(batch_data.float())
    real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)

    return real_pred, fake_pred

In [None]:
#summary(net_gen, [(3,256,256),(1,256,256)])

In [None]:
#summary(net_local_dis, (3,128,128))

In [None]:
#summary(net_global_dis, (3,256,256))

In [None]:
trainloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
valloader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=False)
testloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

In [None]:
!gdown --id 1a3jdE3Mzz_JiVP6CatF4m58zeE99TYKX
!gdown --id 1JXlWTIZBoYZrSQl7KnKIiiEtOiX58A-v

Downloading...
From: https://drive.google.com/uc?id=1a3jdE3Mzz_JiVP6CatF4m58zeE99TYKX
To: /content/dis_00430000.pt
21.7MB [00:00, 190MB/s]
Downloading...
From: https://drive.google.com/uc?id=1JXlWTIZBoYZrSQl7KnKIiiEtOiX58A-v
To: /content/gen_00430000.pt
14.4MB [00:00, 127MB/s]


In [None]:
if model_train:
    if os.path.exists(os.path.join(CHECKPOINT, "net_gen.pth")):
        checkpoints = torch.load(os.path.join(CHECKPOINT, "net_gen.pth"))

        net_gen.load_state_dict(checkpoints['net_gen_state_dict'])
        optimizer_g.load_state_dict(checkpoints['optimizer_g_state_dict'])
        start_epochs = checkpoints['epoch']

    if os.path.exists(os.path.join(CHECKPOINT, "net_dis.pth")):
        checkpoints = torch.load(os.path.join(CHECKPOINT, "net_dis.pth"))

        net_local_dis.load_state_dict(checkpoints['net_local_dis_state_dict'])
        net_global_dis.load_state_dict(checkpoints['net_global_dis_state_dict'])
        optimizer_d.load_state_dict(checkpoints['optimizer_d_state_dict'])
        start_epochs = checkpoints['epoch']
else:
    checkpoint_dis = torch.load("/content/dis_00430000.pt")
    net_gen.load_state_dict(torch.load("/content/gen_00430000.pt"))
    net_local_dis.load_state_dict(checkpoint_dis['localD'])
    net_global_dis.load_state_dict(checkpoint_dis['globalD'])

In [None]:
def epoch_time(epoch_end, epoch_start):
    epoch_length = epoch_end - epoch_start

    minutes = epoch_length//60
    seconds = epoch_length - minutes*60

    return (minutes, seconds)

In [None]:
def train(net_gen, net_local_dis, net_global_dis, iterator, optimizer_g, optimizer_d, criterionL1):

    epoch_loss = {'l1':0, 'ae':0, 'wgan_g':0, 'wgan_d':0, 'wgan_gp':0, 'g':0, 'd':0}
    train_loss = []

    epoch_start = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    net_gen.train()
    net_local_dis.train()
    net_global_dis.train()

    bar = pyprind.ProgBar(len(iterator), bar_char='█')
    for idx, ground_truth in enumerate(iterator, 1):

        gc.collect()
        torch.cuda.empty_cache()

        batch_size = ground_truth.size(0)
        bboxes = random_bbox(batch_size=batch_size)
        x, mask = mask_image(ground_truth, bboxes)

        ground_truth = ground_truth.to(device)
        x = x.to(device)
        mask = mask.to(device)

        losses = {}

        ##################
        ### Prediction ###
        ##################

        x1, x2, offset_flow = net_gen(x.float(), mask)
        local_patch_gt = local_patch(ground_truth, bboxes)
        x1_inpaint = x1 * mask + x * (1. - mask)
        x2_inpaint = x2 * mask + x * (1. - mask)
        local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
        local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)

        gc.collect()
        torch.cuda.empty_cache()

        ##########################
        ### Discriminator Loss ###
        ##########################

        ### Local Discriminator ###
                                                    
        local_patch_real_pred, local_patch_fake_pred = dis_forward(net_local_dis, local_patch_gt, local_patch_x2_inpaint.detach())

        gc.collect()
        torch.cuda.empty_cache()

        ### Global Discriminator ###

        global_real_pred, global_fake_pred = dis_forward(net_global_dis, ground_truth, x2_inpaint.detach())

        gc.collect()
        torch.cuda.empty_cache()

        ### Computing Losses ###

        losses['wgan_d'] = torch.mean(local_patch_fake_pred-local_patch_real_pred) + torch.mean(global_fake_pred-global_real_pred)
        
        local_penalty = calc_gradient_penalty(net_local_dis, local_patch_gt, local_patch_x2_inpaint.detach(), device)
        global_penalty = calc_gradient_penalty(net_global_dis, ground_truth, x2_inpaint.detach(), device)
        losses['wgan_gp'] = local_penalty + global_penalty

        gc.collect()
        torch.cuda.empty_cache()

        ######################
        ### Generator Loss ###
        ######################

        sd_mask = spatial_discounting_mask(use_cuda=True)

        losses['l1'] = 1.2*criterionL1(local_patch_x1_inpaint*sd_mask, local_patch_gt*sd_mask) + criterionL1(local_patch_x2_inpaint*sd_mask, local_patch_gt*sd_mask)

        losses['ae'] = 1.2*criterionL1(x1*(1.-mask), ground_truth*(1.-mask)) + criterionL1(x2*(1.-mask), ground_truth*(1.-mask))

        local_patch_real_pred_gen, local_patch_fake_pred_gen = dis_forward(net_local_dis, local_patch_gt, local_patch_x2_inpaint)
        global_real_pred_gen, global_fake_pred_gen = dis_forward(net_global_dis, ground_truth, x2_inpaint)
        losses['wgan_g'] = - torch.mean(local_patch_fake_pred_gen) - torch.mean(global_fake_pred_gen)

        gc.collect()
        torch.cuda.empty_cache()

        ####################
        ### Forward Pass ###
        ####################

        for k in losses.keys():
            if not losses[k].dim() == 0:
                losses[k] = torch.mean(losses[k])

        #####################
        ### Backward Pass ###
        #####################
        with torch.autograd.set_detect_anomaly(True):
            if idx%5 !=0:
                optimizer_d.zero_grad()
                losses['d'] = losses['wgan_d'] + losses['wgan_gp']*10
                losses['d'].backward()
                optimizer_d.step()
            else:
                optimizer_g.zero_grad()
                losses['g'] = losses['l1']*1.2 + losses['ae']*1.2 + losses['wgan_g']*0.001
                losses['g'].backward()
                optimizer_g.step()

        gc.collect()
        torch.cuda.empty_cache()

        #####################
        ### Visualization ###
        #####################

        for key in losses.keys():
            epoch_loss[key] += losses[key].item()/len(iterator)

        #train_loss.append(losses)
        
        bar.update()
        gc.collect()
        torch.cuda.empty_cache()

    epoch_end = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    epoch_length = epoch_time(epoch_end, epoch_start)

    return epoch_loss, epoch_length

In [None]:
def evaluate(net_gen, iterator,criterionL1):

    epoch_loss = {'l1':0, 'ae':0, 'wgan_g':0, 'g':0}
    eval_loss = []

    epoch_start = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    net_gen.eval()
    net_local_dis.eval()
    net_global_dis.eval()

    with torch.no_grad():
        bar = pyprind.ProgBar(len(iterator), bar_char='█')
        for idx, ground_truth in enumerate(iterator, 1):

            gc.collect()
            torch.cuda.empty_cache()

            batch_size = ground_truth.size(0)
            bboxes = random_bbox(batch_size=batch_size)
            x, mask = mask_image(ground_truth, bboxes)

            ground_truth = ground_truth.to(device)
            x = x.to(device)
            mask = mask.to(device)

            losses = {}

            ##################
            ### Prediction ###
            ##################

            x1, x2, offset_flow = net_gen(x.float(), mask)
            local_patch_gt = local_patch(ground_truth, bboxes)
            x1_inpaint = x1 * mask + x * (1. - mask)
            x2_inpaint = x2 * mask + x * (1. - mask)
            local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
            local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)

            gc.collect()
            torch.cuda.empty_cache()

            ######################
            ### Generator Loss ###
            ######################

            sd_mask = spatial_discounting_mask(use_cuda=True)

            losses['l1'] = 1.2*criterionL1(local_patch_x1_inpaint*sd_mask, local_patch_gt*sd_mask) + criterionL1(local_patch_x2_inpaint*sd_mask, local_patch_gt*sd_mask)

            losses['ae'] = 1.2*criterionL1(x1*(1.-mask), ground_truth*(1.-mask)) + criterionL1(x2*(1.-mask), ground_truth*(1.-mask))

            local_patch_real_pred, local_patch_fake_pred = dis_forward(net_local_dis, local_patch_gt, local_patch_x2_inpaint)
            global_real_pred, global_fake_pred = dis_forward(net_global_dis, ground_truth, x2_inpaint)
            losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - torch.mean(global_fake_pred)

            gc.collect()
            torch.cuda.empty_cache()

            ####################
            ### Forward Pass ###
            ####################

            losses['g'] = losses['l1']*1.2 + losses['ae']*1.2 + losses['wgan_g']*0.001
            for k in losses.keys():
                if not losses[k].dim() == 0:
                    losses[k] = torch.mean(losses[k])

            #####################
            ### Visualization ###
            #####################

            for key in losses.keys():
                epoch_loss[key] += losses[key].item()/len(iterator)

            #eval_loss.append(losses)
            
            bar.update()
            gc.collect()
            torch.cuda.empty_cache()

    epoch_end = time.time()
    gc.collect()
    torch.cuda.empty_cache()

    epoch_length = epoch_time(epoch_end, epoch_start)

    return epoch_loss, epoch_length

In [None]:
train_loss = []
val_loss = []

if model_train:
    for epoch in range(start_epochs+1, total_epochs+start_epochs+1):
        print("Starting Epoch[{0}/{1}]".format(epoch, total_epochs+start_epochs))
        
        epoch_start = time.time()

        train_epoch_loss, _ = train(net_gen, net_local_dis, net_global_dis, trainloader, optimizer_g, optimizer_d, criterionL1)
        train_loss.append(train_epoch_loss)
        print(" | Train Loss: Generator: {0}  |  Disctiminator: {1}".format(train_epoch_loss['g'], train_epoch_loss['d']))

        val_epoch_loss, _ = evaluate(net_gen, valloader, criterionL1)
        val_loss.append(val_epoch_loss)
        print(" | Validation Loss: Generator: {0}".format(val_epoch_loss['g']))

        torch.save({
                'epoch': epoch,
                'net_gen_state_dict': net_gen.state_dict(),
                'optimizer_g_state_dict': optimizer_g.state_dict(),
                }, os.path.join(CHECKPOINT, "net_gen.pth"))
        torch.save({
                'epoch': epoch,
                'net_local_dis_state_dict': net_local_dis.state_dict(),
                'net_global_dis_state_dict': net_global_dis.state_dict(),
                'optimizer_d_state_dict': optimizer_d.state_dict(),
                }, os.path.join(CHECKPOINT, "net_dis.pth"))
        

        epoch_end = time.time()

        minutes, seconds = epoch_time(epoch_end, epoch_start)

        print("Finished Epoch[{0}/{1}]".format(epoch, total_epochs+start_epochs))

Starting Epoch[27/90]


  "The default behavior for interpolate/upsample with float scale_factor changed "
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
0% [██████████████████████████████] 100% | ETA: 00:00:00

 | Train Loss: Generator: -0.07538441772653824  |  Disctiminator: -244.9920135553503



Total time elapsed: 02:39:17
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:04:48


 | Validation Loss: Generator: -0.07023869422199909
Finished Epoch[27/90]
Starting Epoch[28/90]


0% [██████████████████████████████] 100% | ETA: 00:00:00

 | Train Loss: Generator: -0.036059386844847476  |  Disctiminator: -240.7046226423151



Total time elapsed: 01:10:02
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:48


 | Validation Loss: Generator: -0.19356585787272865
Finished Epoch[28/90]
Starting Epoch[29/90]


0% [██████████████████████████████] 100% | ETA: 00:00:00

 | Train Loss: Generator: -0.0636371394320836  |  Disctiminator: -242.54500389005872



Total time elapsed: 01:04:36
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:47


 | Validation Loss: Generator: -0.05802334587382475
Finished Epoch[29/90]
Starting Epoch[30/90]


KeyboardInterrupt: ignored

---

In [None]:
preds = []

for idx, ground_truth in enumerate(testloader, 1):
    batch_size = ground_truth.size(0)
    bboxes = random_bbox(batch_size=batch_size)
    x, mask = mask_image(ground_truth, bboxes)

    ground_truth = ground_truth.to(device)
    x = x.to(device)
    mask = mask.to(device)

    x1, x2, offset_flow = net_gen(x.float(), mask)
    x2_inpaint = x2 * mask + x * (1. - mask)

    for index in range(batch_size):
        ground = ground_truth[index].detach().cpu().numpy()
        masked = x[index].detach().cpu().numpy()
        image = x2_inpaint[index].detach().cpu().numpy()
        
        ground = np.moveaxis(ground, 0, -1)*255
        masked = np.moveaxis(masked, 0, -1)*255
        image = np.moveaxis(image, 0, -1)*255

        preds.append({'ground':ground, 'masked':masked, 'image':image})

In [None]:
for index in range(len(preds)):
    render = preds[index]
    cv2_imshow(np.concatenate((render['ground'], render['masked'], render['image']), axis=1))