<a href="https://colab.research.google.com/github/adubowski/redi-xai/blob/main/inpainting/inpainting_gmcnn_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inpaint Various Datasets using a Trained Inpainted Model
Most code here is taken directly from https://github.com/shepnerd/inpainting_gmcnn/tree/master/pytorch with minor adjustments and refactoring into a Jupyter notebook, including a convenient way of providing arguments for test_options.

The code cell under "Create elliptical masks" is original code, and significant adjustments have been made to the original code from "test.py" onwards.

Otherwise, the cell titles refer to the module at the above Github link that the code was originally taken from.


### Load libraries

In [None]:
from google.colab import drive

## model.basemodel
import os
import torch
import torch.nn as nn

## model.basenet
# import os
# import torch
# import torch.nn as nn

## model.layer
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from util.utils import gauss_kernel
import torchvision.models as models
import numpy as np

## model.loss
# import torch
# import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
# from model.layer import VGG19FeatLayer
from functools import reduce

## model.net
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from model.basemodel import BaseModel
# from model.basenet import BaseNet
# from model.loss import WGANLoss, IDMRFLoss
# from model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
# import numpy as np


## options.test_options
import argparse
# import os
import time

## original code for ellipse masks
import cv2
# import numpy as np
from numpy import random
from numpy.random import randint
# from matplotlib import pyplot as plt
import math

## utils.utils
# import numpy as np
import scipy.stats as st
# import cv2
# import time
# import os
import glob

## Dependencies from test.py
# import numpy as np
# import cv2
# import os
import subprocess
# import glob
# from options.test_options import TestOptions
# from model.net import InpaintingModel_GMCNN
# from util.utils import generate_rect_mask, generate_stroke_mask, getLatest

In [None]:
drive.mount("/content/drive")
dir_path = "/content/drive/MyDrive/redi-detecting-cheating"

### model.basemodel

In [None]:
# a complex model consisted of several nets, and each net will be explicitly defined in other py classes
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel,self).__init__()

    def init(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.save_dir = opt.model_folder
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.model_names = []

    def setInput(self, inputData):
        self.input = inputData

    def forward(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        pass

    def get_current_losses(self):
        pass

    def update_learning_rate(self):
        pass

    def test(self):
        with torch.no_grad():
            self.forward()

    # save models to the disk
    def save_networks(self, which_epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (which_epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.state_dict(), save_path)
                    # net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    # load models from the disk
    def load_networks(self, load_path):
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path)
                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    # print network information
    def print_networks(self, verbose=True):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    # set requies_grad=Fasle to avoid computation
    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


### model.basenet

In [None]:
class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()

    def init(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.save_dir = opt.checkpoint_dir
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')

    def forward(self, *input):
        return super(BaseNet, self).forward(*input)

    def test(self, *input):
        with torch.no_grad():
            self.forward(*input)

    def save_network(self, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(self.cpu().state_dict(), save_path)

    def load_network(self, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        if not os.path.isfile(save_path):
            print('%s not exists yet!' % save_path)
        else:
            try:
                self.load_state_dict(torch.load(save_path))
            except:
                pretrained_dict = torch.load(save_path)
                model_dict = self.state_dict()
                try:
                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                    self.load_state_dict(pretrained_dict)
                    print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
                except:
                    print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label)
                    for k, v in pretrained_dict.items():
                        if v.size() == model_dict[k].size():
                            model_dict[k] = v

                    for k, v in model_dict.items():
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                            print(k.split('.')[0])
                    self.load_state_dict(model_dict)


### model.layer

In [None]:
class Conv2d_BN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d_BN, self).__init__()
        self.model = nn.Sequential([
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias),
            nn.BatchNorm2d(out_channels)
        ])

    def forward(self, *input):
        return self.model(*input)


class upsampling(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, scale=2):
        super(upsampling, self).__init__()
        assert isinstance(scale, int)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                     dilation=dilation, groups=groups, bias=bias)
        self.scale = scale

    def forward(self, x):
        h, w = x.size(2) * self.scale, x.size(3) * self.scale
        xout = self.conv(F.interpolate(input=x, size=(h, w), mode='nearest', align_corners=True))
        return xout


class PureUpsampling(nn.Module):
    def __init__(self, scale=2, mode='bilinear'):
        super(PureUpsampling, self).__init__()
        assert isinstance(scale, int)
        self.scale = scale
        self.mode = mode

    def forward(self, x):
        h, w = x.size(2) * self.scale, x.size(3) * self.scale
        if self.mode == 'nearest':
            xout = F.interpolate(input=x, size=(h, w), mode=self.mode)
        else:
            xout = F.interpolate(input=x, size=(h, w), mode=self.mode, align_corners=True)
        return xout


class GaussianBlurLayer(nn.Module):
    def __init__(self, size, sigma, in_channels=1, stride=1, pad=1):
        super(GaussianBlurLayer, self).__init__()
        self.size = size
        self.sigma = sigma
        self.ch = in_channels
        self.stride = stride
        self.pad = nn.ReflectionPad2d(pad)

    def forward(self, x):
        kernel = gauss_kernel(self.size, self.sigma, self.ch, self.ch)
        kernel_tensor = torch.from_numpy(kernel)
        kernel_tensor = kernel_tensor.cuda()
        x = self.pad(x)
        blurred = F.conv2d(x, kernel_tensor, stride=self.stride)
        return blurred


class ConfidenceDrivenMaskLayer(nn.Module):
    def __init__(self, size=65, sigma=1.0/40, iters=7):
        super(ConfidenceDrivenMaskLayer, self).__init__()
        self.size = size
        self.sigma = sigma
        self.iters = iters
        self.propagationLayer = GaussianBlurLayer(size, sigma, pad=32)

    def forward(self, mask):
        # here mask 1 indicates missing pixels and 0 indicates the valid pixels
        init = 1 - mask
        mask_confidence = None
        for i in range(self.iters):
            mask_confidence = self.propagationLayer(init)
            mask_confidence = mask_confidence * mask
            init = mask_confidence + (1 - mask)
        return mask_confidence


class VGG19(nn.Module):
    def __init__(self, pool='max'):
        super(VGG19, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        out = {}
        out['r11'] = F.relu(self.conv1_1(x))
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        out['r41'] = F.relu(self.conv4_1(out['p3']))
        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])
        return out


class VGG19FeatLayer(nn.Module):
    def __init__(self):
        super(VGG19FeatLayer, self).__init__()
        self.vgg19 = models.vgg19(pretrained=True).features.eval().cuda()
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()

    def forward(self, x):
        out = {}
        x = x - self.mean
        ci = 1
        ri = 0
        for layer in self.vgg19.children():
            if isinstance(layer, nn.Conv2d):
                ri += 1
                name = 'conv{}_{}'.format(ci, ri)
            elif isinstance(layer, nn.ReLU):
                ri += 1
                name = 'relu{}_{}'.format(ci, ri)
                layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                ri = 0
                name = 'pool_{}'.format(ci)
                ci += 1
            elif isinstance(layer, nn.BatchNorm2d):
                name = 'bn_{}'.format(ci)
            else:
                raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
            x = layer(x)
            out[name] = x
        # print([x for x in out])
        return out


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, gain)
            nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def init_net(net, init_type='normal', gpu_ids=[]):
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)
    init_weights(net, init_type)
    return net


def l2normalize(v, eps=1e-12):
    return v / (v.norm()+eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iteration=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iteration = power_iteration
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + '_u')
        v = getattr(self.module, self.name + '_v')
        w = getattr(self.module, self.name + '_bar')

        height = w.data.shape[0]
        for _ in range(self.power_iteration):
            v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + '_u')
            v = getattr(self.module, self.name + '_v')
            w = getattr(self.module, self.name + '_bar')
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = nn.Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name+'_u', u)
        self.module.register_parameter(self.name+'_v', v)
        self.module.register_parameter(self.name+'_bar', w_bar)

    def forward(self, *input):
        self._update_u_v()
        return self.module.forward(*input)


class PartialConv(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1):
        super(PartialConv, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.fnum = 32
        self.padSize = self.ksize // 2
        self.pad = nn.ReflectionPad2d(self.padSize)
        self.eplison = 1e-5
        self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize)

    def forward(self, x, mask):

        mask_ch = mask.size(1)
        sum_kernel_np = np.ones((mask_ch, mask_ch, self.ksize, self.ksize), dtype=np.float32)
        sum_kernel = torch.from_numpy(sum_kernel_np).cuda()

        x = x * mask / (F.conv2d(mask, sum_kernel, stride=1, padding=self.padSize)+self.eplison)
        x = self.pad(x)
        x = self.conv(x)
        mask = F.max_pool2d(mask, self.ksize, stride=self.stride, padding=self.padSize)
        return x, mask


class GatedConv(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, ksize=3, stride=1, act=F.elu):
        super(GatedConv, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.act = act
        self.padSize = self.ksize // 2
        self.pad = nn.ReflectionPad2d(self.padSize)
        self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize)
        self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize,
                               padding=self.padSize)

    def forward(self, x):
        x = self.pad(x)
        x = self.convf(x)
        x = self.act(x)
        m = self.convm(x)
        m = F.sigmoid(m)
        x = x * m
        return x


class GatedDilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=1, dilation=2, act=F.elu):
        super(GatedDilatedConv, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.act = act
        self.padSize = pad
        self.pad = nn.ReflectionPad2d(self.padSize)
        self.convf = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation)
        self.convm = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=ksize, dilation=dilation,
                               padding=self.padSize)

    def forward(self, x):
        x = self.pad(x)
        x = self.convf(x)
        x = self.act(x)
        m = self.convm(x)
        m = F.sigmoid(m)
        x = x * m
        return x


### model.loss

In [None]:
class WGANLoss(nn.Module):
    def __init__(self):
        super(WGANLoss, self).__init__()

    def __call__(self, input, target):
        d_loss = (input - target).mean()
        g_loss = -input.mean()
        return {'g_loss': g_loss, 'd_loss': d_loss}


def gradient_penalty(xin, yout, mask=None):
    gradients = autograd.grad(yout, xin, create_graph=True,
                              grad_outputs=torch.ones(yout.size()).cuda(), retain_graph=True, only_inputs=True)[0]
    if mask is not None:
        gradients = gradients * mask
    gradients = gradients.view(gradients.size(0), -1)
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp


def random_interpolate(gt, pred):
    batch_size = gt.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).cuda()
    # alpha = alpha.expand(gt.size()).cuda()
    interpolated = gt * alpha + pred * (1 - alpha)
    return interpolated


class IDMRFLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer):
        super(IDMRFLoss, self).__init__()
        self.featlayer = featlayer()
        self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0}
        self.feat_content_layers = {'relu4_2': 1.0}
        self.bias = 1.0
        self.nn_stretch_sigma = 0.5
        self.lambda_style = 1.0
        self.lambda_content = 1.0

    def sum_normalize(self, featmaps):
        reduce_sum = torch.sum(featmaps, dim=1, keepdim=True)
        return featmaps / reduce_sum

    def patch_extraction(self, featmaps):
        patch_size = 1
        patch_stride = 1
        patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(3, patch_size, patch_stride)
        self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
        dims = self.patches_OIHW.size()
        self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
        return self.patches_OIHW

    def compute_relative_distances(self, cdist):
        epsilon = 1e-5
        div = torch.min(cdist, dim=1, keepdim=True)[0]
        relative_dist = cdist / (div + epsilon)
        return relative_dist

    def exp_norm_relative_dist(self, relative_dist):
        scaled_dist = relative_dist
        dist_before_norm = torch.exp((self.bias - scaled_dist)/self.nn_stretch_sigma)
        self.cs_NCHW = self.sum_normalize(dist_before_norm)
        return self.cs_NCHW

    def mrf_loss(self, gen, tar):
        meanT = torch.mean(tar, 1, keepdim=True)
        gen_feats, tar_feats = gen - meanT, tar - meanT

        gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True)
        tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True)

        gen_normalized = gen_feats / gen_feats_norm
        tar_normalized = tar_feats / tar_feats_norm

        cosine_dist_l = []
        BatchSize = tar.size(0)

        for i in range(BatchSize):
            tar_feat_i = tar_normalized[i:i+1, :, :, :]
            gen_feat_i = gen_normalized[i:i+1, :, :, :]
            patches_OIHW = self.patch_extraction(tar_feat_i)

            cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW)
            cosine_dist_l.append(cosine_dist_i)
        cosine_dist = torch.cat(cosine_dist_l, dim=0)
        cosine_dist_zero_2_one = - (cosine_dist - 1) / 2
        relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one)
        rela_dist = self.exp_norm_relative_dist(relative_dist)
        dims_div_mrf = rela_dist.size()
        k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0]
        div_mrf = torch.mean(k_max_nc, dim=1)
        div_mrf_sum = -torch.log(div_mrf)
        div_mrf_sum = torch.sum(div_mrf_sum)
        return div_mrf_sum

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)

        style_loss_list = [self.feat_style_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_style_layers]
        self.style_loss = reduce(lambda x, y: x+y, style_loss_list) * self.lambda_style

        content_loss_list = [self.feat_content_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers]
        self.content_loss = reduce(lambda x, y: x+y, content_loss_list) * self.lambda_content

        return self.style_loss + self.content_loss


class StyleLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer, style_layers=None):
        super(StyleLoss, self).__init__()
        self.featlayer = featlayer()
        if style_layers is not None:
            self.feat_style_layers = style_layers
        else:
            self.feat_style_layers = {'relu2_2': 1.0, 'relu3_2': 1.0, 'relu4_2': 1.0}

    def gram_matrix(self, x):
        b, c, h, w = x.size()
        feats = x.view(b * c, h * w)
        g = torch.mm(feats, feats.t())
        return g.div(b * c * h * w)

    def _l1loss(self, gen, tar):
        return torch.abs(gen-tar).mean()

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)
        style_loss_list = [self.feat_style_layers[layer] * self._l1loss(self.gram_matrix(gen_vgg_feats[layer]), self.gram_matrix(tar_vgg_feats[layer])) for
                           layer in self.feat_style_layers]
        style_loss = reduce(lambda x, y: x + y, style_loss_list)
        return style_loss


class ContentLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer, content_layers=None):
        super(ContentLoss, self).__init__()
        self.featlayer = featlayer()
        if content_layers is not None:
            self.feat_content_layers = content_layers
        else:
            self.feat_content_layers = {'relu4_2': 1.0}

    def _l1loss(self, gen, tar):
        return torch.abs(gen-tar).mean()

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)
        content_loss_list = [self.feat_content_layers[layer] * self._l1loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for
                             layer in self.feat_content_layers]
        content_loss = reduce(lambda x, y: x + y, content_loss_list)
        return content_loss


class TVLoss(nn.Module):
    def __init__(self):
        super(TVLoss, self).__init__()

    def forward(self, x):
        h_x, w_x = x.size()[2:]
        h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x-1, :])
        w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x-1])
        loss = torch.sum(h_tv) + torch.sum(w_tv)
        return loss


### options.test_options

In [None]:
class TestOptions:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self):
        self.parser.add_argument('--dataset', type=str, default='skin_test',
                                 help='The dataset of the experiment.')
        self.parser.add_argument('--data_file', type=str, default=os.path.join(dir_path, 'data', 'processed', 'cancer'), help='the file storing testing file paths')
        self.parser.add_argument('--mask_dir', type=str, default=os.path.join(dir_path, 'data', 'masks', 'dilated-masks-224'), help='directory with saved masks, if applicable')
        self.parser.add_argument('--test_dir', type=str, default=os.path.join(dir_path, 'data', 'results_gmcnn'), help='models are saved here')
        self.parser.add_argument('--load_model_dir', type=str, default= os.path.join(dir_path, 'models','inpainting_gmcnn',   \
                                   '20210601-112529_GMCNN_isic_b8_s224x224_gc32_dc64_randmask-ellipse'), help='pretrained models are given here')
        self.parser.add_argument('--seed', type=int, default=1, help='random seed')
        self.parser.add_argument('--gpu_ids', type=str, default='0')

        self.parser.add_argument('--model', type=str, default='gmcnn')
        self.parser.add_argument('--random_mask', type=int, default=1,
                                 help='using random mask')

        self.parser.add_argument('--img_shapes', type=str, default='224,224,3',
                                 help='given shape parameters: h,w,c or h,w')
        self.parser.add_argument('--mask_shapes', type=str, default='40',
                                 help='given mask parameters: h,w  or if mask_type==ellipse then should be number representing ellipse width.')
        self.parser.add_argument('--mask_type', type=str, default='ellipse')
        self.parser.add_argument('--test_num', type=int, default=-1)
        self.parser.add_argument('--mode', type=str, default='save')
        self.parser.add_argument('--phase', type=str, default='test')

        # for generator
        self.parser.add_argument('--g_cnum', type=int, default=32,
                                 help='# of generator filters in first conv layer')
        self.parser.add_argument('--d_cnum', type=int, default=32,
                                 help='# of discriminator filters in first conv layer')

    def parse(self, args=[]):
        if not self.initialized:
            self.initialize()

        if isinstance(args, dict):                    # If args is supplied as a dict, flatten to a list.
          args = [item for pair in args.items() for item in pair]
        elif not isinstance(args, list):              # Otherwise, it should be a list.
          raise('args should be a dict or a list.')
        
        self.opt = self.parser.parse_args(args=args)          # Added args=[]  to make it work in notebook.

        if self.opt.data_file != '':
            self.opt.dataset_path = self.opt.data_file

        if os.path.exists(self.opt.test_dir) is False:
            os.mkdir(self.opt.test_dir)

        assert self.opt.random_mask in [0, 1]
        self.opt.random_mask = True if self.opt.random_mask == 1 else False

        assert self.opt.mask_type in ['rect', 'stroke', 'ellipse', 'saved']     # Added ellipse mask_type option

        str_img_shapes = self.opt.img_shapes.split(',')
        self.opt.img_shapes = [int(x) for x in str_img_shapes]

        if self.opt.mask_type=='ellipse':                                      # If ellipse type then the mask size is just one number.
          self.opt.mask_shapes = int(self.opt.mask_shapes)
        elif self.opt.mask_type == 'saved':
          pass
        else:
          str_mask_shapes = self.opt.mask_shapes.split(',')
          self.opt.mask_shapes = [int(x) for x in str_mask_shapes]

        # model name and date
        self.opt.date_str = 'test_'+time.strftime('%Y%m%d-%H%M%S')
        self.opt.model_folder = self.opt.date_str + '_' + self.opt.dataset + '_' + self.opt.model
        self.opt.model_folder += '_s' + str(self.opt.img_shapes[0]) + 'x' + str(self.opt.img_shapes[1])
        self.opt.model_folder += '_gc' + str(self.opt.g_cnum)
        self.opt.model_folder += '_randmask-' + self.opt.mask_type if self.opt.random_mask else ''
        if self.opt.random_mask:
            self.opt.model_folder += '_seed-' + str(self.opt.seed)
        self.opt.saving_path = os.path.join(self.opt.test_dir, self.opt.model_folder)

        if os.path.exists(self.opt.saving_path) is False and self.opt.mode == 'save':
            os.mkdir(self.opt.saving_path)

        if os.path.exists(os.path.join(self.opt.saving_path, "combined")) is False and self.opt.mode == 'save':
          os.mkdir(os.path.join(self.opt.saving_path, "combined"))

        if os.path.exists(os.path.join(self.opt.saving_path, "inpainted")) is False and self.opt.mode == 'save':
          os.mkdir(os.path.join(self.opt.saving_path, "inpainted"))

        args = vars(self.opt)

        print('------------ Options -------------')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('-------------- End ----------------')

        return self.opt


### Create elliptical masks
Original code

In [None]:
def find_angle(pos1, pos2, ret_type = 'deg'):
    # Find the angle between two pixel points, pos1 and pos2.
    angle_rads = math.atan2(pos2[1] - pos1[1], pos2[0] - pos1[1])
    
    if ret_type == 'rads':
        return angle_rads
    elif ret_type == 'deg':
        return math.degrees(angle_rads)                                     # Convert from radians to degrees.


def sample_centre_pts(n, imsize, xlimits=(50,250), ylimits=(50,250)):
    # Function to generate random sample of points for the centres of the elliptical masks.
    pts = np.empty((n,2))                                                   # Empty array to hold the final points
    
    count=0
    while count < n:
        sample = randint(0, imsize[0], (n,2))[0]                            # Assumes im_size is symmetric

        # Check the point is in the valid region.
        is_valid = (sample[0] < xlimits[0]) | (sample[0] > xlimits[1]) |     \
                (sample[1] < ylimits[0]) | (sample[1] > ylimits[1])
        
        if is_valid:                                                        # Only take the point if it's within the valid region.
            pts[count] = sample
            count += 1

    return pts

def generate_ellipse_mask(imsize, mask_size, seed=None):
    im_centre = (int(imsize[0]/2), int(imsize[1]/2))
    x_bounds =  (int(0.1*imsize[0]), int(imsize[0] - 0.1*imsize[0]))        # Bounds for the valid region of mask centres.
    y_bounds =  (int(0.1*imsize[1]), int(imsize[1] - 0.1*imsize[1]))
    
    if seed is not None:
      random.seed(seed)   # Set seed for repeatability

    n = 1 + random.binomial(1, 0.3)                                         # The number of masks per image either 1 (70% of the time) or 2 (30% of the time) 
    centre_pts = sample_centre_pts(n, imsize, x_bounds, y_bounds)           # Get a random sample for the mask centres.
    
    startAngle = 0.0
    endAngle = 360.0                                                        # Draw full ellipses (although part may fall outside the image)
        
    mask = np.zeros((imsize[0], imsize[1], 1), np.float32)                  # Create blank canvas for the mask.

    for pt in centre_pts:
        size = abs(int(random.normal(mask_size, mask_size/5.0)))            # Randomness introduced in the mask size. 
        ratio = 2*random.random(1) + 1                                      # Ratio between length and width. Sample from Unif(1,3).
        
        centrex = int(pt[0])
        centrey = int(pt[1])
        
        angle = find_angle(im_centre, (centrex, centrey))                   # Get the angle between the centre of the image and the mask centre.
        angle = int(angle + random.normal(0.0, 5.0))                        # Base the angle of rotation on the above angle.
        
        mask = cv2.ellipse(mask, (centrex,centrey), (size, int(size*ratio)), 
                           angle, startAngle, endAngle, 
                           color=1, thickness=-1)                         # Insert a ellipse with the parameters defined above.

    mask = np.minimum(mask, 1.0)                                          # This may be redundant.
    mask = np.transpose(mask, [2, 0, 1])                                  # bring the 'channel' axis to the first axis.
    mask = np.expand_dims(mask, 0)                                        # Add in extra axis at axis=0 - resulting shape (1, 1, imsize[0],imsize[1])

    return mask

# test_mask = generate_ellipse_mask((224,224))
# from matplotlib import pyplot as plt
# plt.imshow(test_mask[0][0], cmap='Greys_r')
# plt.show()

### utils.utils

In [None]:
def gauss_kernel(size=21, sigma=3, inchannels=3, outchannels=3):
    interval = (2 * sigma + 1.0) / size
    x = np.linspace(-sigma-interval/2,sigma+interval/2,size+1)
    ker1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(ker1d, ker1d))
    kernel = kernel_raw / kernel_raw.sum()
    out_filter = np.array(kernel, dtype=np.float32)
    out_filter = out_filter.reshape((1, 1, size, size))
    out_filter = np.tile(out_filter, [outchannels, inchannels, 1, 1])
    return out_filter


def np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, h, w):
    mask = np.zeros((h, w, 1), np.float32)
    numVertex = np.random.randint(maxVertex + 1)
    startY = np.random.randint(h)
    startX = np.random.randint(w)
    brushWidth = 0
    for i in range(numVertex):
        angle = np.random.randint(maxAngle + 1)
        angle = angle / 360.0 * 2 * np.pi
        if i % 2 == 0:
            angle = 2 * np.pi - angle
        length = np.random.randint(maxLength + 1)
        brushWidth = np.random.randint(10, maxBrushWidth + 1) // 2 * 2
        nextY = startY + length * np.cos(angle)
        nextX = startX + length * np.sin(angle)

        nextY = np.maximum(np.minimum(nextY, h - 1), 0).astype(np.int)
        nextX = np.maximum(np.minimum(nextX, w - 1), 0).astype(np.int)

        cv2.line(mask, (startY, startX), (nextY, nextX), 1, brushWidth)
        cv2.circle(mask, (startY, startX), brushWidth // 2, 2)

        startY, startX = nextY, nextX
    cv2.circle(mask, (startY, startX), brushWidth // 2, 2)
    return mask


def generate_rect_mask(im_size, mask_size, margin=8, rand_mask=True):
    mask = np.zeros((im_size[0], im_size[1])).astype(np.float32)
    if rand_mask:
        sz0, sz1 = mask_size[0], mask_size[1]
        of0 = np.random.randint(margin, im_size[0] - sz0 - margin)
        of1 = np.random.randint(margin, im_size[1] - sz1 - margin)
    else:
        sz0, sz1 = mask_size[0], mask_size[1]
        of0 = (im_size[0] - sz0) // 2
        of1 = (im_size[1] - sz1) // 2
    mask[of0:of0+sz0, of1:of1+sz1] = 1
    mask = np.expand_dims(mask, axis=0)
    mask = np.expand_dims(mask, axis=0)
    rect = np.array([[of0, sz0, of1, sz1]], dtype=int)
    return mask, rect


def generate_stroke_mask(im_size, parts=10, maxVertex=20, maxLength=100, maxBrushWidth=24, maxAngle=360):
    mask = np.zeros((im_size[0], im_size[1], 1), dtype=np.float32)
    for i in range(parts):
        mask = mask + np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, im_size[0], im_size[1])
    mask = np.minimum(mask, 1.0)
    mask = np.transpose(mask, [2, 0, 1])
    mask = np.expand_dims(mask, 0)
    return mask


def generate_mask(type, im_size, mask_size):
    if type == 'rect':
        return generate_rect_mask(im_size, mask_size)
    elif type == 'ellipse':
        return generate_ellipse_mask(im_size, mask_size), None
    else:
        return generate_stroke_mask(im_size), None


def getLatest(folder_path):
    files = glob.glob(folder_path)
    file_times = list(map(lambda x: time.ctime(os.path.getctime(x)), files))
    return files[sorted(range(len(file_times)), key=lambda x: file_times[x])[-1]]


### model.net

In [None]:
# generative multi-column convolutional neural net
class GMCNN(BaseNet):
    def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False):
        super(GMCNN, self).__init__()
        self.act = act
        self.using_norm = using_norm
        if using_norm is True:
            self.norm = norm
        else:
            self.norm = None
        ch = cnum

        # network structure
        self.EB1 = []
        self.EB2 = []
        self.EB3 = []
        self.decoding_layers = []

        self.EB1_pad_rec = []
        self.EB2_pad_rec = []
        self.EB3_pad_rec = []

        self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2))
        self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16))

        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))

        self.EB1.append(PureUpsampling(scale=4))

        self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0]

        self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2))
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16))

        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))

        self.EB2.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
        self.EB2.append(PureUpsampling(scale=2))
        self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0]

        self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2))
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16))

        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))

        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1))

        self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1]

        self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1))
        self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1))

        self.decoding_pad_rec = [1, 1]

        self.EB1 = nn.ModuleList(self.EB1)
        self.EB2 = nn.ModuleList(self.EB2)
        self.EB3 = nn.ModuleList(self.EB3)
        self.decoding_layers = nn.ModuleList(self.decoding_layers)

        # padding operations
        padlen = 49
        self.pads = [0] * padlen
        for i in range(padlen):
            self.pads[i] = nn.ReflectionPad2d(i)
        self.pads = nn.ModuleList(self.pads)

    def forward(self, x):
        x1, x2, x3 = x, x, x
        for i, layer in enumerate(self.EB1):
            pad_idx = self.EB1_pad_rec[i]
            x1 = layer(self.pads[pad_idx](x1))
            if self.using_norm:
                x1 = self.norm(x1)
            if pad_idx != 0:
                x1 = self.act(x1)

        for i, layer in enumerate(self.EB2):
            pad_idx = self.EB2_pad_rec[i]
            x2 = layer(self.pads[pad_idx](x2))
            if self.using_norm:
                x2 = self.norm(x2)
            if pad_idx != 0:
                x2 = self.act(x2)

        for i, layer in enumerate(self.EB3):
            pad_idx = self.EB3_pad_rec[i]
            x3 = layer(self.pads[pad_idx](x3))
            if self.using_norm:
                x3 = self.norm(x3)
            if pad_idx != 0:
                x3 = self.act(x3)

        x_d = torch.cat((x1, x2, x3), 1)
        x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d)))
        x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d))
        x_out = torch.clamp(x_d, -1, 1)
        return x_out


# return one dimensional output indicating the probability of realness or fakeness
class Discriminator(BaseNet):
    def __init__(self, in_channels, cnum=32, fc_channels=8*8*32*4, act=F.elu, norm=None, spectral_norm=True):
        super(Discriminator, self).__init__()
        self.act = act
        self.norm = norm
        self.embedding = None
        self.logit = None

        ch = cnum
        self.layers = []
        if spectral_norm:
            self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1)))
        else:
            self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch*2, ch*4, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch*4, ch*4, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Linear(fc_channels, 1))
        self.layers = nn.ModuleList(self.layers)

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
            if self.norm is not None:
                x = self.norm(x)
            x = self.act(x)
        self.embedding = x.view(x.size(0), -1)
        self.logit = self.layers[-1](self.embedding)
        return self.logit


class GlobalLocalDiscriminator(BaseNet):
    def __init__(self, in_channels, cnum=32, g_fc_channels=16*16*32*4, l_fc_channels=8*8*32*4, act=F.elu, norm=None,
                 spectral_norm=True):
        super(GlobalLocalDiscriminator, self).__init__()
        self.act = act
        self.norm = norm

        self.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum,
                                                  act=act, norm=norm, spectral_norm=spectral_norm)
        self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum,
                                                 act=act, norm=norm, spectral_norm=spectral_norm)

    def forward(self, x_g, x_l):
        x_global = self.global_discriminator(x_g)
        x_local = self.local_discriminator(x_l)
        return x_global, x_local


# from util.utils import generate_mask


class InpaintingModel_GMCNN(BaseModel):
    def __init__(self, in_channels, act=F.elu, norm=None, opt=None):
        super(InpaintingModel_GMCNN, self).__init__()
        self.opt = opt
        self.init(opt)

        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()

        self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda()
        # self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cpu()

        init_weights(self.netGM)
        self.model_names = ['GM']
        if self.opt.phase == 'test':
            return

        self.netD = None

        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
        self.optimizer_D = None

        self.wganloss = None
        self.recloss = nn.L1Loss()
        self.aeloss = nn.L1Loss()
        self.mrfloss = None
        self.lambda_adv = opt.lambda_adv
        self.lambda_rec = opt.lambda_rec
        self.lambda_ae = opt.lambda_ae
        self.lambda_gp = opt.lambda_gp
        self.lambda_mrf = opt.lambda_mrf
        self.G_loss = None
        self.G_loss_reconstruction = None
        self.G_loss_mrf = None
        self.G_loss_adv, self.G_loss_adv_local = None, None
        self.G_loss_ae = None
        self.D_loss, self.D_loss_local = None, None
        self.GAN_loss = None

        self.gt, self.gt_local = None, None
        self.mask, self.mask_01 = None, None
        self.rect = None
        self.im_in, self.gin = None, None

        self.completed, self.completed_local = None, None
        self.completed_logit, self.completed_local_logit = None, None
        self.gt_logit, self.gt_local_logit = None, None

        self.pred = None

        if self.opt.pretrain_network is False:
            if self.opt.mask_type == 'rect':
                self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
                                                     g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,
                                                     l_fc_channels=opt.mask_shapes[0]//16*opt.mask_shapes[1]//16*opt.d_cnum*4,
                                                     spectral_norm=self.opt.spectral_norm).cuda()
            else:
                self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
                                                     spectral_norm=self.opt.spectral_norm,
                                                     g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,
                                                     l_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4).cuda()
            init_weights(self.netD)
            self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr,
                                                betas=(0.5, 0.9))
            self.wganloss = WGANLoss()
            self.mrfloss = IDMRFLoss()

    def initVariables(self):
        self.gt = self.input['gt']
        mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes)
        self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1])
        self.mask = self.confidence_mask_layer(self.mask_01)
        if self.opt.mask_type == 'rect':
            self.rect = [rect[0, 0], rect[0, 1], rect[0, 2], rect[0, 3]]
            self.gt_local = self.gt[:, :, self.rect[0]:self.rect[0] + self.rect[1],
                            self.rect[2]:self.rect[2] + self.rect[3]]
        else:
            self.gt_local = self.gt
        self.im_in = self.gt * (1 - self.mask_01)
        self.gin = torch.cat((self.im_in, self.mask_01), 1)

    def forward_G(self):
        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
        if self.opt.pretrain_network is False:
            # discriminator
            self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
            self.G_loss_mrf = self.mrfloss((self.completed_local+1)/2.0, (self.gt_local.detach()+1)/2.0)
            self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf

            self.G_loss_adv = -self.completed_logit.mean()
            self.G_loss_adv_local = -self.completed_local_logit.mean()
            self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)

    def forward_D(self):
        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(), self.completed_local.detach())
        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
        # hinge loss
        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(1.0 + self.completed_local_logit).mean()
        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
        self.D_loss = self.D_loss + self.D_loss_local

    def backward_G(self):
        self.G_loss.backward()

    def backward_D(self):
        self.D_loss.backward(retain_graph=True)

    def optimize_parameters(self):
        self.initVariables()

        self.pred = self.netGM(self.gin)
        self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
        if self.opt.mask_type == 'rect':
            self.completed_local = self.completed[:, :, self.rect[0]:self.rect[0] + self.rect[1],
                                   self.rect[2]:self.rect[2] + self.rect[3]]
        else:
            self.completed_local = self.completed

        if self.opt.pretrain_network is False:
            for i in range(self.opt.D_max_iters):
                self.optimizer_D.zero_grad()
                self.optimizer_G.zero_grad()
                self.forward_D()
                self.backward_D()
                self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.forward_G()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_losses(self):
        l = {'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(),
             'G_loss_ae': self.G_loss_ae.item()}
        if self.opt.pretrain_network is False:
            l.update({'G_loss_adv': self.G_loss_adv.item(),
                      'G_loss_adv_local': self.G_loss_adv_local.item(),
                      'D_loss': self.D_loss.item(),
                      'G_loss_mrf': self.G_loss_mrf.item()})
        return l

    def get_current_visuals(self):
        return {'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(),
                'completed': self.completed.cpu().detach().numpy()}

    def get_current_visuals_tensor(self):
        return {'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(),
                'completed': self.completed.cpu().detach()}

    def evaluate(self, im_in, mask):
        im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1
        mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()
        im_in = im_in * (1-mask)
        xin = torch.cat((im_in, mask), 1)
        ret = self.netGM(xin) * mask + im_in * (1-mask)
        ret = (ret.cpu().detach().numpy() + 1) * 127.5
        return ret.astype(np.uint8)


### test.py
Based on code from inpainting_gmcnn but adjusted significantly.

##### Set up inpainting parameters - first inpaint the coloured patches using saved masks.

In [None]:
args_patches = {'--dataset': 'inpaint_coloured_patches',  
                '--test_num': '-1', 
                '--data_file': '{}'.format(os.path.join(dir_path, 'models', 'test_files.txt')),
                '--mask_type': 'saved',
                '--random_mask': '0',
                '--load_model_dir': '{}'.format(os.path.join(dir_path, 'models', 'inpainting_gmcnn', \
                                                             '20210607-165607_GMCNN_expanded_isic_no_patch_fifth_run_b8_s224x224_gc32_dc64_randmask-ellipse'))
}

args_no_patches = {'--dataset': 'inpaint_no_patches',  
                  '--test_num': '-1', 
                  '--data_file': '{}'.format(os.path.join(dir_path, 'models', 'test_files.txt')),
                  '--mask_type': 'ellipse',
                  '--random_mask': '1',
                  '--load_model_dir': '{}'.format(os.path.join(dir_path, 'models', 'inpainting_gmcnn', \
                                                             '20210607-165607_GMCNN_expanded_isic_no_patch_fifth_run_b8_s224x224_gc32_dc64_randmask-ellipse'))                   
}

# Arguments for inpainting the patches in the training set.
args_train_patches = {'--dataset': 'inpaint_train_patches',  
                  '--test_num': '-1', 
                  '--data_file': '{}'.format(os.path.join(dir_path, 'models', 'train_files.txt')),
                  '--mask_type': 'saved',
                  '--random_mask': '0',
                  '--load_model_dir': '{}'.format(os.path.join(dir_path, 'models', 'inpainting_gmcnn', \
                                                             '20210607-165607_GMCNN_expanded_isic_no_patch_fifth_run_b8_s224x224_gc32_dc64_randmask-ellipse'))                   
}

# Arguments for inpainting the patches in the training set.
args_malignant = {'--dataset': 'inpaint_malignant',  
                  '--test_num': '-1', 
                  '--data_file': '{}'.format(os.path.join(dir_path, 'data', 'malignant-patches', 'manually-adjusted')),
                  '--mask_type': 'saved',
                  '--mask_dir': '{}'.format(os.path.join(dir_path, 'data', 'masks', 'malignant-patches')),
                  '--random_mask': '0',
                  '--load_model_dir': '{}'.format(os.path.join(dir_path, 'models', 'inpainting_gmcnn', \
                                                             '20210607-165607_GMCNN_expanded_isic_no_patch_fifth_run_b8_s224x224_gc32_dc64_randmask-ellipse'))                   
}

config_patches    = TestOptions().parse(args=args_patches)
config_no_patches = TestOptions().parse(args=args_no_patches)
config_train_patches = TestOptions().parse(args=args_train_patches)
config_malignant = TestOptions().parse(args=args_malignant)

##### Set up the trained inpainting model.

In [None]:
# Get the available GPUs.
os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmax([int(x.split()[2]) for x in subprocess.Popen(
        "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()]
        ))

# Set up a model and load the trained weights.
print('configuring model..')
ourModel = InpaintingModel_GMCNN(in_channels=4, opt=config_patches)
ourModel.print_networks()
if config_patches.load_model_dir != '':
    print('Loading trained model from {}'.format(config_patches.load_model_dir))
    ourModel.load_networks(getLatest(os.path.join(config_patches.load_model_dir, '*.pth')))
    print('Loading done.')

##### Extract the relevant file paths.

In [None]:
if os.path.isfile(config_patches.dataset_path):
    pathfile = open(config_patches.dataset_path, 'rt').read().splitlines()
elif os.path.isdir(config_patches.dataset_path):
    pathfile = glob.glob(os.path.join(config_patches.dataset_path, '*.jpg'))   # Changed from png.
else:
    print('Invalid testing data file/folder path.')
    exit(1)

mask_files = os.listdir(config_patches.mask_dir)        # Get list of all the mask image names.

# Separate images without patches and with patches (i.e. with corresponding mask)
patch_ind = [os.path.basename(file) in mask_files for file in pathfile]
path_ims_patches    = [file for i, file in enumerate(pathfile) if patch_ind[i]]
path_ims_no_patches = [file for i, file in enumerate(pathfile) if not patch_ind[i]]

# Extract the paths for the training images with patches.
pathfile_train = open(config_train_patches.dataset_path, 'rt').read().splitlines()

path_train_patches = [file for file in pathfile_train if os.path.basename(file) in mask_files]

# Extract the paths for the relevant malignant images.
pathfile_malignant = glob.glob(os.path.join(config_malignant.dataset_path, '*.jpg'))

##### Function for looping through the images, inpainting & saving the results.

In [None]:
def inpaint_ims(config, pathfile):
  print("-" * 30, "\n Inpainting on {}".format(config.dataset))

  if config.random_mask:
      np.random.seed(config.seed)

  total_number = len(pathfile)
  test_num = total_number if config.test_num == -1 else min(total_number, config.test_num)
  print('The total number of testing images is {}, and we take {} for test.'.format(total_number, test_num))

  for i in range(test_num):  
      filename = os.path.basename(pathfile[i])                                # Extract the filename from the full path.

      if config.mask_type == 'saved':                                         # Use a saved mask for this project, rather than randomly generating one.
        mask_img = cv2.imread(os.path.join(config.mask_dir, filename), 0)     # Read the mask in grayscale.
        mask = (mask_img > 100)  # Threshold the mask for intensities from 100-255.
        # Add two extra axes at the start of the array to match the expected shape for the model: (1, 1, 224, 224)
        mask = np.expand_dims(mask, axis=(0,1)) 

      else:
        mask, _ = generate_mask(config.mask_type, config.img_shapes, config.mask_shapes)

      image = cv2.imread(pathfile[i])
      if image is None:                                           # Added because some of the images in our directory may be empty.
        continue
      image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      h, w = image.shape[:2]

      if h >= config.img_shapes[0] and w >= config.img_shapes[1]:
          h_start = (h-config.img_shapes[0]) // 2
          w_start = (w-config.img_shapes[1]) // 2
          image = image[h_start: h_start+config.img_shapes[0], w_start: w_start+config.img_shapes[1], :]
      else:
          t = min(h, w)
          image = image[(h-t)//2:(h-t)//2+t, (w-t)//2:(w-t)//2+t, :]
          image = cv2.resize(image, (config.img_shapes[1], config.img_shapes[0]))

      image = np.transpose(image, [2, 0, 1])
      image = np.expand_dims(image, axis=0)
      image_vis = image * (1-mask) + 255 * mask
      image_vis = np.transpose(image_vis[0][::-1,:,:], [1, 2, 0])
      # cv2.imwrite(os.path.join(config.saving_path, 'input_{}'.format(filename)), image_vis.astype(np.uint8))

      h, w = image.shape[2:]
      grid = 4
      image = image[:, :, :h // grid * grid, :w // grid * grid]
      mask = mask[:, :, :h // grid * grid, :w // grid * grid]
      result = ourModel.evaluate(image, mask)
      result = np.transpose(result[0][::-1,:,:], [1, 2, 0])
      cv2.imwrite(os.path.join(config.saving_path, "inpainted", filename), result)   # The extension '.jpg' is already included in the filename variable.

      image = np.transpose(image[0][::-1,:,:], [1, 2, 0])
      if (image.shape == result.shape) & (image.shape == image_vis.shape):               
        im_combined = np.concatenate((image, image_vis, result), axis=1)                # Combine the original, masked & output images and write to file.
        cv2.imwrite(os.path.join(config.saving_path, "combined", filename), im_combined)
      else: 
        print('Mismatched shapes, images not combined. \n\toriginal: {}, input: {}, result: {}'.format(image.shape, image_vis.shape, result.shape))

      print(' > {} / {}'.format(i+1, test_num))
  print('done.')

##### Run the inpainting for both sets.

In [None]:
inpaint_ims(config_patches, path_ims_patches)
inpaint_ims(config_no_patches, path_ims_no_patches)

# inpaint the patches in the training set.
inpaint_ims(config_train_patches, path_train_patches) 

# inpaint the relevant patch sections for the malignant experiment.
inpaint_ims(config_malignant, pathfile_malignant) 