# SRnDeblur_joint 1001

## Network

![title](images/SRnDeblurN.png)

In [1]:
import torch
import argparse
import os
import random
import math
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TTF
import torchvision.models as models
from torch.backends import cudnn
from torch import optim
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import make_grid
from PIL import Image
from tensorboardX import SummaryWriter

## Generator

![title](images/generator.PNG)

In [2]:
class Generator(nn.Module):
    def __init__(self, batch_size):
        super(Generator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False # Instance Normalization
        else:
            bn = True # Batch Normalization

        #============================ upscale ============================#
        self.upscale8 = nn.Sequential(
            # [3x32x32] -> [64x32x32]
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x32x32] -> [256x32x32]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x32x32] -> [64x64x64]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x64x64] -> [256x64x64]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x64x64] -> [64x128x128]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x128x128] -> [256x128x128]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x128x128] -> [64x256x256]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),  
            # [64x256x256] -> [3x256x256]
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        )
        #============================ upscale ============================#


        # nn.Conv2d(input channel 수, convolution에 의해 생성된 channel 수, kernel size, stride=default 1, padding=default 0)
        # [3x256x256] -> [64x128x128]
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        # [64x256x256] -> [64x128x128]
#         self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)

        # -> [128x64x64]
        conv2 = [nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]
        if bn == True:
            conv2 += [nn.BatchNorm2d(128)]
        else:
            conv2 += [nn.InstanceNorm2d(128)]
        self.conv2 = nn.Sequential(*conv2)

        # -> [256x32x32]
        conv3 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            conv3 += [nn.BatchNorm2d(256)]
        else:
            conv3 += [nn.InstanceNorm2d(256)]
        self.conv3 = nn.Sequential(*conv3)

        # -> [512x16x16]
        conv4 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(256, 512, 4, 2, 1)]
        if bn == True:
            conv4 += [nn.BatchNorm2d(512)]
        else:
            conv4 += [nn.InstanceNorm2d(512)]
        self.conv4 = nn.Sequential(*conv4)

        # -> [512x8x8]
        conv5 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv5 += [nn.BatchNorm2d(512)]
        else:
            conv5 += [nn.InstanceNorm2d(512)]
        self.conv5 = nn.Sequential(*conv5)

        # -> [512x4x4]
        conv6 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv6 += [nn.BatchNorm2d(512)]
        else:
            conv6 += [nn.InstanceNorm2d(512)]
        self.conv6 = nn.Sequential(*conv6)

        # -> [512x2x2]
        conv7 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv7 += [nn.BatchNorm2d(512)]
        else:
            conv7 += [nn.InstanceNorm2d(512)]
        self.conv7 = nn.Sequential(*conv7)

        # -> [512x1x1]
        conv8 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv8 += [nn.BatchNorm2d(512)]
        else:
            conv8 += [nn.InstanceNorm2d(512)]
        self.conv8 = nn.Sequential(*conv8)

        # -> [512x2x2]
        deconv8 = [nn.ReLU(),
                   nn.ConvTranspose2d(512, 512, 4, 2, 1)]
        if bn == True:
            deconv8 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv8 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv8 = nn.Sequential(*deconv8)

        # [(512+512)x2x2] -> [512x4x4]
        deconv7 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv7 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv7 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv7 = nn.Sequential(*deconv7)

        # [(512+512)x4x4] -> [512x8x8]
        deconv6 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv6 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv6 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv6 = nn.Sequential(*deconv6)

        # [(512+512)x8x8] -> [512x16x16]
        deconv5 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv5 += [nn.BatchNorm2d(512)]
        else:
            deconv5 += [nn.InstanceNorm2d(512)]
        self.deconv5 = nn.Sequential(*deconv5)

        # [(512+512)x16x16] -> [256x32x32]
        deconv4 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 256, 4, 2, 1)]
        if bn == True:
            deconv4 += [nn.BatchNorm2d(256)]
        else:
            deconv4 += [nn.InstanceNorm2d(256)]
        self.deconv4 = nn.Sequential(*deconv4)
        
        # [(512+512)x16x16] -> [256x32x32]
        deconv4_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 1, 256, 4, 2, 1)]
        if bn == True:
            deconv4_0 += [nn.BatchNorm2d(256)]
        else:
            deconv4_0 += [nn.InstanceNorm2d(256)]
        self.deconv4_0 = nn.Sequential(*deconv4_0)        

        # [(256+256)x32x32] -> [128x64x64]
        deconv3 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3 += [nn.BatchNorm2d(128)]
        else:
            deconv3 += [nn.InstanceNorm2d(128)]
        self.deconv3 = nn.Sequential(*deconv3)

        # [(256+256)x32x32] -> [128x64x64]
        deconv3_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3_0 += [nn.BatchNorm2d(128)]
        else:
            deconv3_0 += [nn.InstanceNorm2d(128)]
        self.deconv3_0 = nn.Sequential(*deconv3_0)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2 += [nn.BatchNorm2d(64)]
        else:
            deconv2 += [nn.InstanceNorm2d(64)]
        self.deconv2 = nn.Sequential(*deconv2)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2_0 += [nn.BatchNorm2d(64)]
        else:
            deconv2_0 += [nn.InstanceNorm2d(64)]
        self.deconv2_0 = nn.Sequential(*deconv2_0)

        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1_0 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # NCHW = H*W
        
        noise = torch.empty_like(x).normal_(mean=1.0,std=0.1)
        inN = x + noise
        upx = self.upscale8(inN)
        
        c1 = self.conv1(upx)
        c1_1 = torch.empty_like(c1).normal_(mean=1.0,std=0.1)
        c1_n = c1 + c1_1
        c2 = self.conv2(c1_n)
        c2_1 = torch.empty_like(c2).normal_(mean=1.0,std=0.1)
        c2_n = c2 + c2_1
        c3 = self.conv3(c2_n)
        c3_1 = torch.empty_like(c3).normal_(mean=1.0,std=0.1)
        c3_n = c3 + c3_1
        c4 = self.conv4(c3_n)
        c4_1 = torch.empty_like(c4).normal_(mean=1.0,std=0.1)
        c4_n = c4 + c4_1
        c5 = self.conv5(c4_n)
        c5_1 = torch.empty_like(c5).normal_(mean=1.0,std=0.1)
        c5_n = c5 + c5_1
        c6 = self.conv6(c5_n)
        c6_1 = torch.empty_like(c6).normal_(mean=1.0,std=0.1)
        c6_n = c6 + c6_1
        c7 = self.conv7(c6_n)
        c7_1 = torch.empty_like(c7).normal_(mean=1.0,std=0.1)
        c7_n = c7 + c7_1
        c8 = self.conv8(c7_n)
        c8_1 = torch.empty_like(c8).normal_(mean=1.0,std=0.1)
        c8_n = c8 + c8_1
        
        d3_0 = self.deconv4_0(c4_n)
        d3_0 = torch.cat((c3,d3_0), dim=1)
        d3_1 = torch.empty_like(d3_0).normal_(mean=1.0,std=0.1)
        d3_n = d3_0 + d3_1
        d2_0 = self.deconv3_0(d3_n)
        d2_0 = torch.cat((c2,d2_0), dim=1)
        d2_1 = torch.empty_like(d2_0).normal_(mean=1.0,std=0.1)
        d2_n = d2_0 + d2_1
        d1_0 = self.deconv2_0(d2_n)
        d1_00 = torch.cat((c1,d1_0), dim=1)    
        d1_1 = torch.empty_like(d1_00).normal_(mean=1.0,std=0.1)
        d1_n = d1_00 + d1_1
        outLR = self.deconv1_0(d1_n)
        
        d7 = self.deconv8(c8_n)
        d7 = torch.cat((c7, d7), dim=1)
        d17_1 = torch.empty_like(d7).normal_(mean=1.0,std=0.1)
        d17_n = d7 + d17_1
        d6 = self.deconv7(d17_n)
        d6 = torch.cat((c6, d6), dim=1)
        d16_1 = torch.empty_like(d6).normal_(mean=1.0,std=0.1)
        d16_n = d6 + d16_1
        d5 = self.deconv6(d16_n)
        d5 = torch.cat((c5, d5), dim=1)
        d15_1 = torch.empty_like(d5).normal_(mean=1.0,std=0.1)
        d15_n = d5 + d15_1
        d4 = self.deconv5(d15_n)
        d4 = torch.cat((c4, d4), dim=1)
        d14_1 = torch.empty_like(d4).normal_(mean=1.0,std=0.1)
        d14_n = d4 + d14_1
        d3 = self.deconv4(d14_n)
        d3 = torch.cat((c3, d3), dim=1)
        d13_1 = torch.empty_like(d3).normal_(mean=1.0,std=0.1)
        d13_n = d3 + d13_1
        d2 = self.deconv3(d13_n)
        d2 = torch.cat((c2, d2), dim=1)
        d12_1 = torch.empty_like(d2).normal_(mean=1.0,std=0.1)
        d12_n = d2 + d12_1
        d1 = self.deconv2(d12_n)
        d1 = torch.add(d1,d1_0)
        d1 = torch.cat((c1, d1), dim=1)
        d11_1 = torch.empty_like(d1).normal_(mean=1.0,std=0.1)
        d11_n = d1 + d11_1
        outHR = self.deconv1(d11_n)
#         output = torch.add(outLR,outHR)
#         d1 = torch.cat((c1, d1), dim=1)
#         outHR = self.deconv1(d1)


#         return outLR, outHR
        return upx,outLR, outHR

## discriminator

![title](images/discriminator.PNG)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, inch, batch_size):
        super(Discriminator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False  # Instance Normalization
        else:
            bn = True  # Batch Normalization

#         # [(3+3)x256x256] -> [64x128x128] -> [128x64x64]
#         main = [nn.Conv2d(3*2, 64, 4, 2, 1),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(64, 128, 4, 2, 1)]
        # [(4+4)x256x256] -> [64x128x128] -> [128x64x64]
        main = [nn.Conv2d(inch, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]        
        if bn == True:
            main += [nn.BatchNorm2d(128)]
        else:
            main += [nn.InstanceNorm2d(128)]

        # -> [256x32x32]
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            main += [nn.BatchNorm2d(256)]
        else:
            main += [nn.InstanceNorm2d(256)]

        # -> [512x31x31] (Fully Convolutional)
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(256, 512, 4, 1, 1)]
        if bn == True:
            main += [nn.BatchNorm2d(512)]
        else:
            main += [nn.InstanceNorm2d(512)]

        # -> [1x30x30] (Fully Convolutional, PatchGAN)
        main += [nn.LeakyReLU(0.2, inplace=True),
                  nn.Conv2d(512, 1, 4, 1, 1)]
                  #nn.Sigmoid()]

        self.main = nn.Sequential(*main)

#     def forward(self, x1, x2): # One for Real, One for Fake
#         out = torch.cat((x1, x2), dim=1)
#         return self.main(out)   
    def forward(self, x):
        return self.main(x)

## VGG for loss

In [4]:
# class Vgg19(torch.nn.Module):
#     def __init__(self, requires_grad=False):
#         super(Vgg19, self).__init__()
#         vgg_pretrained_features = models.vgg19(pretrained=True).features
#         self.slice1 = torch.nn.Sequential()
#         self.slice2 = torch.nn.Sequential()
#         self.slice3 = torch.nn.Sequential()
#         self.slice4 = torch.nn.Sequential()
#         self.slice5 = torch.nn.Sequential()
#         for x in range(2):
#             self.slice1.add_module(str(x), vgg_pretrained_features[x])
#         for x in range(2, 7):
#             self.slice2.add_module(str(x), vgg_pretrained_features[x])
#         for x in range(7, 12):
#             self.slice3.add_module(str(x), vgg_pretrained_features[x])
#         for x in range(12, 21):
#             self.slice4.add_module(str(x), vgg_pretrained_features[x])
#         for x in range(21, 30):
#             self.slice5.add_module(str(x), vgg_pretrained_features[x])
#         if not requires_grad:
#             for param in self.parameters():
#                 param.requires_grad = False

#     def forward(self, X):
#         h_relu1 = self.slice1(X)
#         h_relu2 = self.slice2(h_relu1)
#         h_relu3 = self.slice3(h_relu2)
#         h_relu4 = self.slice4(h_relu3)
#         h_relu5 = self.slice5(h_relu4)
#         out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
#         return out
    
# class VGGLoss(nn.Module):
#     def __init__(self):
#         super(VGGLoss, self).__init__()
#         self.vgg = Vgg19().cuda()
#         self.criterion = nn.L1Loss()
#         self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

#     def forward(self, x, y):
#         x_vgg, y_vgg = self.vgg(x), self.vgg(y)
#         loss = 0
#         for i in range(len(x_vgg)):
#             loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
#         return loss

# VGG16-face
class VGG_16(torch.nn.Module):

    def __init__(self):

        super().__init__()
        self.block_size = [2, 2, 3, 3, 3]
        self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.fc6 = nn.Linear(512 * 7 * 7, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, 2622)

    def load_weights(self, path="pretrained/VGG_FACE.t7"):

        model = torchfile.load(path)
        counter = 1
        block = 1
        for i, layer in enumerate(model.modules):
            if layer.weight is not None:
                if block <= 5:
                    self_layer = getattr(self, "conv_%d_%d" % (block, counter))
                    counter += 1
                    if counter > self.block_size[block - 1]:
                        counter = 1
                        block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
                else:
                    self_layer = getattr(self, "fc%d" % (block))
                    block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]

    def forward(self, x):
        """ Pytorch forward
        Args:
            x: input image (224x224)
        Returns: class logits
        """
        x = F.upsample(x,(224,224),mode='bilinear')
        x = F.relu(self.conv_1_1(x))
        x = F.relu(self.conv_1_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2_1(x))
        x = F.relu(self.conv_2_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_3_1(x))
        x = F.relu(self.conv_3_2(x))
        x = F.relu(self.conv_3_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_4_1(x))
        x = F.relu(self.conv_4_2(x))
        x = F.relu(self.conv_4_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_5_1(x))
        x = F.relu(self.conv_5_2(x))
        x = F.relu(self.conv_5_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc6(x))
        x = F.dropout(x, 0.5, self.training)
        x = F.relu(self.fc7(x))
        x = F.dropout(x, 0.5, self.training)
        return self.fc8(x)
    
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG_16().double().cuda()
        
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

####################################################################################
#     model = VGG_16().double()
#     model.load_weights()
#     im = cv2.imread("images/ak.png")
#     im = torch.Tensor(im).permute(2, 0, 1).view(1, 3, 224, 224).double()
#     import numpy as np

#     model.eval()
#     im -= torch.Tensor(np.array([129.1863, 104.7624, 93.5940])).double().view(1, 3, 1, 1)
#     preds = F.softmax(model(im), dim=1)
#     values, indices = preds.max(-1)
    

## Helper Function

In [5]:
##### Helper Functions for Data Loading & Pre-processingclass ImageFolder(data.Dataset):
class ImageFolder(data.Dataset):
    def __init__(self, opt):
        # os.listdir function gives all lists of directory
        self.root = opt.dataroot
        self.no_resize_or_crop = opt.no_resize_or_crop
        self.no_flip = opt.no_flip
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
        self.transformM = transforms.Compose([transforms.ToTensor()])
        #=====================================================================================#
        self.dir_A = os.path.join(opt.dataroot,'trainx8')
        self.Aimg_paths = list(map(lambda x:os.path.join(self.dir_A,x),os.listdir(self.dir_A)))
        #=====================================================================================#
#         self.dir_AB = os.path.join(opt.dataroot, 'train')
#         self.image_paths = list(map(lambda x: os.path.join(self.dir_AB, x), os.listdir(self.dir_AB)))
        
    def __getitem__(self, index):
        #=====================================================================================#
        # A : 32x32 (blur + LR)
        # B : 256x256 (LR)
        # C : 256x256 (GT)
        # D : 256x256 (fmask)
        A_path = self.Aimg_paths[index]
        trn = A_path.find('trainx8')
        endn = len(A_path)
        B_path = A_path[:trn]+'wblur'+A_path[trn+7:endn-4]+'.png'
        C_path = A_path[:trn]+'GT'+A_path[trn+7:endn-4]+'.jpg'
        D_path = A_path[:trn]+'fmask'+A_path[trn+7:endn-4]+'.png'
#         B_path = A_path[:trn]+'GT'+A_path[trn+5:endn-4]+'_mask.jpg'
        
        A = Image.open(A_path).convert('RGB')
        B = Image.open(B_path).convert('RGB')
        C = Image.open(C_path).convert('RGB')
        D = Image.open(D_path)
        E = A.resize((256,256),Image.BICUBIC)
#         A = A.resize((256,256),Image.BICUBIC)
#         B = (C.resize((32,32),Image.BICUBIC)).resize((256,256),Image.BICUBIC)
#         C = C.resize((256,256),Image.BICUBIC)
#         D = D.resize((256,256),Image.BICUBIC)
#         D = D.resize((256,256),Image.BICUBIC)
#         D = torch.zeros(256,256)
#         D = TTF.to_pil_image(D)

        A = self.transform(A)
        B = self.transform(B)
        C = self.transform(C)
        D = self.transformM(D)
        E = self.transform(E)
        A = A[:,:32,:32]
#             A = A[:,:256,:256]
        B = B[:,:256,:256]
        C = C[:,:256,:256]
        D = D[:,:256,:256]
        E = E[:,:256,:256]
        
#         if (not(self.no_flip)) and random.random() < 0.5:
#             idx = [i for i in range(A.size(2)-1,-1,-1)]
#             idx = torch.LongTensor(idx)
#             A = A.index_select(2,idx)
#             B = B.index_select(2,idx)
#             C = C.index_select(2,idx)
#             D = D.index_select(2,idx)

#         A = (torch.cat((A,D),dim=0))[0:4,:,:]
#         B = (torch.cat((B,C),dim=0))[0:4,:,:]
#         print(A.shape, B.shape, C.shape, D.shape)
#         print('A', A.size())
#         print('B', B.size())
#         print('C', C.size())
#         print('D', D.size())
        return {'A':A, 'B':B, 'C':C, 'D':D, 'E':E}
        
    def __len__(self):
        return len(self.Aimg_paths)

##### Helper Function for GPU Training
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

##### Helper Function for Math
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

##### Helper Functions for GAN Loss (4D Loss Comparison)
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)
        
    if torch.cuda.is_available():
        labels = labels.cuda()
        
    return criterion(input, labels)


## training argument

In [6]:
# GPU using
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
parser = argparse.ArgumentParser(description='Implementation of SRnDeblur')

# Task
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders train, val, etc)')
parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')

# Pre-processing
parser.add_argument('--no_resize_or_crop', action='store_true', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')

# Hyper-parameters
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)  # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999)  # momentum2 in Adam
parser.add_argument('--lambda_A', type=float, default=10.0)

# misc
parser.add_argument('--model_path', type=str, default='./SRnDeblur_Neverywhere/models')  # Model Tmp Save
parser.add_argument('--sample_path', type=str, default='./SRnDeblur_Neverywhere/results')  # Results
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=100)
parser.add_argument('--num_workers', type=int, default=2)

_StoreAction(option_strings=['--num_workers'], dest='num_workers', nargs=None, const=None, default=2, type=<class 'int'>, choices=None, help=None, metavar=None)

## Main

In [7]:
# Pre-settings
cudnn.benchmark = True
global args
args = parser.parse_args(['--dataroot','./datasets/face_SRnDeblur','--which_direction','AtoB',
                          '--num_epochs','1001','--batchSize','64','--no_resize_or_crop',
                          '--log_step','100'])
print(args)

dataset = ImageFolder(args)

data_loader = data.DataLoader(dataset=dataset, batch_size=args.batchSize, shuffle=True, num_workers=args.num_workers)

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.sample_path):
    os.makedirs(args.sample_path)
    

Namespace(batchSize=64, beta1=0.5, beta2=0.999, dataroot='./datasets/face_SRnDeblur', lambda_A=10.0, log_step=100, lr=0.0002, model_path='./SRnDeblur_Neverywhere/models', no_flip=False, no_resize_or_crop=True, num_epochs=1001, num_workers=2, sample_path='./SRnDeblur_Neverywhere/results', sample_step=100, which_direction='AtoB')


In [None]:
# tensorboardX summary
summary = SummaryWriter('SRnDeblur_Neverywhere/log')
# Networks
# generator = nn.DataParallel(Generator(args.batchSize))
# discriminator = nn.DataParallel(Discriminator(6,args.batchSize))
# discriminatorM = nn.DataParallel(Discriminator(6,args.batchSize))
generator = nn.DataParallel(Generator(args.batchSize))
discriminatorHR = nn.DataParallel(Discriminator(6,args.batchSize))
discriminatorHRM = nn.DataParallel(Discriminator(3,args.batchSize))

# dummmy = torch.zeros(1,3,32,32)
# summary.add_graph(generator,dummmy)
# summary.add_graph(discriminator,(dummmy,dummmy))

# Losses -> vanilaGAN의 loss, MSE는 LSGAN의 loss
# criterionGAN = nn.BCELoss() 
criterionGAN = nn.MSELoss()
criterionL1 = nn.L1Loss()
vgg_loss = VGGLoss()
# Optimizers
d_parameters = list(discriminatorHR.parameters()) + list(discriminatorHRM.parameters())
g_optimizer = optim.Adam(generator.parameters(), args.lr, [args.beta1, args.beta2])
d_optimizer = optim.Adam(d_parameters, args.lr, [args. beta1, args.beta2])

if torch.cuda.is_available():
    generator = generator.cuda()
    discriminatorHR = discriminatorHR.cuda()
    discriminatorHRM = discriminatorHRM.cuda()
    
### train generator and discriminator
# for printing the log
total_step = len(data_loader)

for epoch  in range(args.num_epochs):
    for i, sample in enumerate(data_loader):
        # A : 32x32 (blur + LR)
        # B : 256x256 (LR)
        # C : 256x256 (GT)
        # D : 256x256 (fmask)
        input_A = sample['A']
        GTLR = sample['B']
        GTHR = sample['C']
        fmaskGT = sample['D'].to("cuda")
        inputbi = sample['E']
        
        #======================== preparing ========================#
        in_blurLR = to_variable(input_A)
        upx, fakeLR, fakeHR = generator(in_blurLR)
        v_GTLR = to_variable(GTLR)
        v_GTHR = to_variable(GTHR)
        v_inputbi = to_variable(inputbi)
        fakeHR_md = fakeHR * fmaskGT
        GTHR_md = v_GTHR * fmaskGT
        
        #========================= train D =========================#
        # zero_grad : 역전파 실행 전 변화도 0으로 만듦
        discriminatorHR.zero_grad()
        discriminatorHRM.zero_grad()
        
        pred_fakeHR = discriminatorHR(torch.cat((upx.detach(), fakeHR.detach()),dim=1))
        loss_D_fake = GAN_Loss(pred_fakeHR, False, criterionGAN)
        pred_realHR = discriminatorHR(torch.cat((upx.detach(), v_GTHR),dim=1))
        loss_D_real = GAN_Loss(pred_realHR, True, criterionGAN)
        
        pred_fakeM = discriminatorHRM(fakeHR_md.detach())
        loss_D_fakeM = GAN_Loss(pred_fakeM, False, criterionGAN)
        pred_realM = discriminatorHRM(GTHR_md)
        loss_D_realM = GAN_Loss(pred_realM, True, criterionGAN)
        
        # combined loss
        loss_D = (loss_D_fake + loss_D_real + loss_D_fakeM + loss_D_realM) * 0.25 #0.25 & 0.5
        loss_D.backward()
        d_optimizer.step()
        #========================= train G =========================#
        generator.zero_grad()
        
        pred_fake = discriminatorHR(torch.cat((upx, fakeHR), dim=1))
        pred_fakeM = discriminatorHRM(fakeHR_md)
        loss_G_GAN = GAN_Loss(pred_fake, True, criterionGAN) + GAN_Loss(pred_fakeM, True, criterionGAN)
        
        loss_G_L1_LR = criterionL1(fakeLR, v_GTLR)
        loss_G_L1_HR = criterionL1(fakeHR, v_GTHR)
        loss_G_L1_HRM = criterionL1(fakeHR_md, GTHR_md)
        
        loss_G_vgg_HR = vgg_loss(fakeHR, v_GTHR)
        
        loss_G = loss_G_GAN + (loss_G_L1_LR + loss_G_L1_HR + loss_G_L1_HRM) * args.lambda_A + loss_G_vgg_HR * 10.0
        
        loss_G.backward()
        g_optimizer.step()
        
        # print the log information
        if (i+1) % args.log_step == 0:
            print('Epoch [%d/%d], BatchStep[%d/%d]' % (epoch + 1, args.num_epochs, i + 1, total_step))
            print('D_loss: %.4f, G_loss: %.4f' % (loss_D.data, loss_G.data))
            print('D_RealHR_loss: %.4f, D_FakeHR_loss: %.4f, D_RealHRM_loss: %.4f, D_FakeHRM_loss: %.4f' 
                  % (loss_D_real.data, loss_D_fake.data, loss_D_realM.data, loss_D_fakeM.data))          
            print('GAN_loss: %.4f, G_L1_LR_loss: %.4f, G_L1_HR_loss: %.4f, G_L1_HRM_loss: %.4f, G_vgg_HR_loss: %.4f '
                  % (loss_G_GAN.data, loss_G_L1_LR.data, loss_G_L1_HR.data,loss_G_L1_HRM.data, loss_G_vgg_HR.data))
            
        # save the sampled images
        if (i+1)%args.sample_step == 0:
            in_Ar = upx[0:4,:,:,:]
            fake_Br = fakeHR[0:4,:,:,:]
            real_Br = v_GTHR[0:4,:,:,:]
            bilinear_in = v_inputbi[0:4,:,:,:]
            
            resHR = torch.cat((torch.cat((in_Ar, fake_Br),dim=3), real_Br), dim=3)
            resLR = torch.cat((torch.cat((bilinear_in, fake_Br),dim=3), real_Br),dim=3)
            torchvision.utils.save_image(denorm(resHR.data), os.path.join(args.sample_path, 'HRwF-%d-%d.png' % (epoch + 1, i + 1)))
            torchvision.utils.save_image(denorm(resLR.data), os.path.join(args.sample_path, 'HRwB-%d-%d.png' % (epoch + 1, i + 1)))

#             resMasked_0 = torch.cat((torch.cat((torch.cat((fake_Br,real_Br),dim=2),fmasked_FB[0:4,:,:,:]),dim=2),fmasked_GT[0:4,:,:,:]),dim=2)
#             resMasked = torch.cat((torch.cat((resMasked_0,parmasked_FB[0:4,:,:,:]),dim=2),parmasked_FB[0:4,:,:,:]),dim=2)
#             torchvision.utils.save_image(denorm(resMasked.data), os.path.join(args.sample_path, 'Generated_RM-%d-%d.png' % (epoch + 1, i + 1)))
            
    # save summary
    summary.add_scalar('loss/loss_D_real', loss_D_real.item(), epoch)
    summary.add_scalar('loss/loss_D_fake', loss_D_fake.item(), epoch)
    summary.add_scalar('loss/loss_D_realM', loss_D_realM.item(), epoch)
    summary.add_scalar('loss/loss_D_fakeM', loss_D_fakeM.item(), epoch)
    summary.add_scalar('loss/loss_D', loss_D.item(), epoch)
    
    summary.add_scalar('loss/loss_G_GAN', loss_G_GAN.item(), epoch)
    summary.add_scalar('loss/loss_G_L1_LR', loss_G_L1_LR.item(), epoch)
    summary.add_scalar('loss/loss_G_L1_HR', loss_G_L1_HR.item(), epoch)
    summary.add_scalar('loss/loss_G_L1_HRM', loss_G_L1_HRM.item(), epoch)
    summary.add_scalar('loss/loss_G_vgg_HR', loss_G_vgg_HR.item(), epoch)
    summary.add_scalar('loss/loss_G', loss_G.item(), epoch)
    
    fakeOutHR = make_grid(fake_Br, normalize=True, scale_each=True)
    realGTHR = make_grid(real_Br, normalize=True, scale_each=True)

    
    summary.add_image('0_GT_HR', realGTHR, epoch)
    summary.add_image('1_generated HR', fakeOutHR, epoch)

    
    # save the model parameters
    if epoch % 10 == 0:
        g_path = os.path.join(args.model_path, 'generator-%d.pkl' % (epoch + 1))
        torch.save(generator.state_dict(), g_path)

Epoch [1/1001], BatchStep[100/450]
D_loss: 0.0261, G_loss: 20.3309
D_RealHR_loss: 0.0178, D_FakeHR_loss: 0.0439, D_RealHRM_loss: 0.0231, D_FakeHRM_loss: 0.0196
GAN_loss: 1.9430, G_L1_LR_loss: 0.2119, G_L1_HR_loss: 0.4025, G_L1_HRM_loss: 0.1277, G_vgg_HR_loss: 1.0967 
Epoch [1/1001], BatchStep[200/450]
D_loss: 0.0158, G_loss: 18.2189
D_RealHR_loss: 0.0190, D_FakeHR_loss: 0.0290, D_RealHRM_loss: 0.0056, D_FakeHRM_loss: 0.0095
GAN_loss: 1.9812, G_L1_LR_loss: 0.1954, G_L1_HR_loss: 0.3284, G_L1_HRM_loss: 0.0952, G_vgg_HR_loss: 1.0047 
Epoch [1/1001], BatchStep[300/450]
D_loss: 0.0105, G_loss: 16.8386
D_RealHR_loss: 0.0199, D_FakeHR_loss: 0.0101, D_RealHRM_loss: 0.0048, D_FakeHRM_loss: 0.0071
GAN_loss: 2.0075, G_L1_LR_loss: 0.1653, G_L1_HR_loss: 0.2966, G_L1_HRM_loss: 0.0894, G_vgg_HR_loss: 0.9317 
Epoch [1/1001], BatchStep[400/450]
D_loss: 0.0047, G_loss: 15.8431
D_RealHR_loss: 0.0040, D_FakeHR_loss: 0.0051, D_RealHRM_loss: 0.0062, D_FakeHRM_loss: 0.0035
GAN_loss: 1.9600, G_L1_LR_loss: 0.16

Epoch [8/1001], BatchStep[400/450]
D_loss: 0.0135, G_loss: 14.8094
D_RealHR_loss: 0.0008, D_FakeHR_loss: 0.0010, D_RealHRM_loss: 0.0243, D_FakeHRM_loss: 0.0280
GAN_loss: 1.9914, G_L1_LR_loss: 0.1275, G_L1_HR_loss: 0.2273, G_L1_HRM_loss: 0.0771, G_vgg_HR_loss: 0.8499 
Epoch [9/1001], BatchStep[100/450]
D_loss: 0.0050, G_loss: 14.1838
D_RealHR_loss: 0.0019, D_FakeHR_loss: 0.0016, D_RealHRM_loss: 0.0051, D_FakeHRM_loss: 0.0112
GAN_loss: 2.0431, G_L1_LR_loss: 0.1398, G_L1_HR_loss: 0.2178, G_L1_HRM_loss: 0.0700, G_vgg_HR_loss: 0.7865 
Epoch [9/1001], BatchStep[200/450]
D_loss: 0.0320, G_loss: 15.9926
D_RealHR_loss: 0.0042, D_FakeHR_loss: 0.0014, D_RealHRM_loss: 0.0849, D_FakeHRM_loss: 0.0374
GAN_loss: 2.0510, G_L1_LR_loss: 0.1465, G_L1_HR_loss: 0.2733, G_L1_HRM_loss: 0.0967, G_vgg_HR_loss: 0.8776 
Epoch [9/1001], BatchStep[300/450]
D_loss: 0.0061, G_loss: 14.3355
D_RealHR_loss: 0.0024, D_FakeHR_loss: 0.0030, D_RealHRM_loss: 0.0082, D_FakeHRM_loss: 0.0108
GAN_loss: 2.0035, G_L1_LR_loss: 0.12

Epoch [16/1001], BatchStep[300/450]
D_loss: 0.0028, G_loss: 12.7996
D_RealHR_loss: 0.0049, D_FakeHR_loss: 0.0030, D_RealHRM_loss: 0.0021, D_FakeHRM_loss: 0.0014
GAN_loss: 2.0574, G_L1_LR_loss: 0.1065, G_L1_HR_loss: 0.1685, G_L1_HRM_loss: 0.0524, G_vgg_HR_loss: 0.7468 
Epoch [16/1001], BatchStep[400/450]
D_loss: 0.0022, G_loss: 12.4205
D_RealHR_loss: 0.0042, D_FakeHR_loss: 0.0021, D_RealHRM_loss: 0.0016, D_FakeHRM_loss: 0.0009
GAN_loss: 1.9652, G_L1_LR_loss: 0.1049, G_L1_HR_loss: 0.1654, G_L1_HRM_loss: 0.0517, G_vgg_HR_loss: 0.7235 
Epoch [17/1001], BatchStep[100/450]
D_loss: 0.0026, G_loss: 12.4867
D_RealHR_loss: 0.0014, D_FakeHR_loss: 0.0010, D_RealHRM_loss: 0.0040, D_FakeHRM_loss: 0.0041
GAN_loss: 2.0205, G_L1_LR_loss: 0.1084, G_L1_HR_loss: 0.1667, G_L1_HRM_loss: 0.0513, G_vgg_HR_loss: 0.7203 
Epoch [17/1001], BatchStep[200/450]
D_loss: 0.0046, G_loss: 13.5733
D_RealHR_loss: 0.0010, D_FakeHR_loss: 0.0013, D_RealHRM_loss: 0.0087, D_FakeHRM_loss: 0.0074
GAN_loss: 2.0682, G_L1_LR_loss: 

Epoch [24/1001], BatchStep[200/450]
D_loss: 0.1693, G_loss: 11.7508
D_RealHR_loss: 0.0291, D_FakeHR_loss: 0.6438, D_RealHRM_loss: 0.0025, D_FakeHRM_loss: 0.0018
GAN_loss: 1.7314, G_L1_LR_loss: 0.1104, G_L1_HR_loss: 0.1419, G_L1_HRM_loss: 0.0476, G_vgg_HR_loss: 0.7020 
Epoch [24/1001], BatchStep[300/450]
D_loss: 0.0711, G_loss: 11.6905
D_RealHR_loss: 0.0395, D_FakeHR_loss: 0.2146, D_RealHRM_loss: 0.0160, D_FakeHRM_loss: 0.0142
GAN_loss: 1.4283, G_L1_LR_loss: 0.1140, G_L1_HR_loss: 0.1458, G_L1_HRM_loss: 0.0491, G_vgg_HR_loss: 0.7172 
Epoch [24/1001], BatchStep[400/450]
D_loss: 0.1527, G_loss: 11.3977
D_RealHR_loss: 0.4836, D_FakeHR_loss: 0.1257, D_RealHRM_loss: 0.0010, D_FakeHRM_loss: 0.0005
GAN_loss: 1.3293, G_L1_LR_loss: 0.1148, G_L1_HR_loss: 0.1430, G_L1_HRM_loss: 0.0482, G_vgg_HR_loss: 0.7009 
Epoch [25/1001], BatchStep[100/450]
D_loss: 0.0849, G_loss: 11.8633
D_RealHR_loss: 0.0186, D_FakeHR_loss: 0.3098, D_RealHRM_loss: 0.0060, D_FakeHRM_loss: 0.0053
GAN_loss: 1.6594, G_L1_LR_loss: 