# SRnDeblur_joint 1104 test

## Generator

In [None]:
import argparse
import os
import random
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.backends import cudnn
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from PIL import Image

In [None]:
class Generator(nn.Module):
    def __init__(self, batch_size):
        super(Generator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False # Instance Normalization
        else:
            bn = True # Batch Normalization

        #============================ upscale ============================#
        self.upscale8 = nn.Sequential(
            # [3x32x32] -> [64x32x32]
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x32x32] -> [256x32x32]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x32x32] -> [64x64x64]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x64x64] -> [256x64x64]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x64x64] -> [64x128x128]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x128x128] -> [256x128x128]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x128x128] -> [64x256x256]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),  
            # [64x256x256] -> [3x256x256]
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        )
        #============================ upscale ============================#


        # nn.Conv2d(input channel 수, convolution에 의해 생성된 channel 수, kernel size, stride=default 1, padding=default 0)
        # [3x256x256] -> [64x128x128]
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        # [64x256x256] -> [64x128x128]
#         self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)

        # -> [128x64x64]
        conv2 = [nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]
        if bn == True:
            conv2 += [nn.BatchNorm2d(128)]
        else:
            conv2 += [nn.InstanceNorm2d(128)]
        self.conv2 = nn.Sequential(*conv2)

        # -> [256x32x32]
        conv3 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            conv3 += [nn.BatchNorm2d(256)]
        else:
            conv3 += [nn.InstanceNorm2d(256)]
        self.conv3 = nn.Sequential(*conv3)

        # -> [512x16x16]
        conv4 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(256, 512, 4, 2, 1)]
        if bn == True:
            conv4 += [nn.BatchNorm2d(512)]
        else:
            conv4 += [nn.InstanceNorm2d(512)]
        self.conv4 = nn.Sequential(*conv4)

        # -> [512x8x8]
        conv5 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv5 += [nn.BatchNorm2d(512)]
        else:
            conv5 += [nn.InstanceNorm2d(512)]
        self.conv5 = nn.Sequential(*conv5)

        # -> [512x4x4]
        conv6 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv6 += [nn.BatchNorm2d(512)]
        else:
            conv6 += [nn.InstanceNorm2d(512)]
        self.conv6 = nn.Sequential(*conv6)

        # -> [512x2x2]
        conv7 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv7 += [nn.BatchNorm2d(512)]
        else:
            conv7 += [nn.InstanceNorm2d(512)]
        self.conv7 = nn.Sequential(*conv7)

        # -> [512x1x1]
        conv8 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv8 += [nn.BatchNorm2d(512)]
        else:
            conv8 += [nn.InstanceNorm2d(512)]
        self.conv8 = nn.Sequential(*conv8)

        # -> [512x2x2]
        deconv8 = [nn.ReLU(),
                   nn.ConvTranspose2d(512, 512, 4, 2, 1)]
        if bn == True:
            deconv8 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv8 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv8 = nn.Sequential(*deconv8)

        # [(512+512)x2x2] -> [512x4x4]
        deconv7 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv7 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv7 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv7 = nn.Sequential(*deconv7)

        # [(512+512)x4x4] -> [512x8x8]
        deconv6 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv6 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv6 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv6 = nn.Sequential(*deconv6)

        # [(512+512)x8x8] -> [512x16x16]
        deconv5 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv5 += [nn.BatchNorm2d(512)]
        else:
            deconv5 += [nn.InstanceNorm2d(512)]
        self.deconv5 = nn.Sequential(*deconv5)

        # [(512+512)x16x16] -> [256x32x32]
        deconv4 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 256, 4, 2, 1)]
        if bn == True:
            deconv4 += [nn.BatchNorm2d(256)]
        else:
            deconv4 += [nn.InstanceNorm2d(256)]
        self.deconv4 = nn.Sequential(*deconv4)
        
        # [(512+512)x16x16] -> [256x32x32]
        deconv4_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 1, 256, 4, 2, 1)]
        if bn == True:
            deconv4_0 += [nn.BatchNorm2d(256)]
        else:
            deconv4_0 += [nn.InstanceNorm2d(256)]
        self.deconv4_0 = nn.Sequential(*deconv4_0)        

        # [(256+256)x32x32] -> [128x64x64]
        deconv3 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3 += [nn.BatchNorm2d(128)]
        else:
            deconv3 += [nn.InstanceNorm2d(128)]
        self.deconv3 = nn.Sequential(*deconv3)

        # [(256+256)x32x32] -> [128x64x64]
        deconv3_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3_0 += [nn.BatchNorm2d(128)]
        else:
            deconv3_0 += [nn.InstanceNorm2d(128)]
        self.deconv3_0 = nn.Sequential(*deconv3_0)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2 += [nn.BatchNorm2d(64)]
        else:
            deconv2 += [nn.InstanceNorm2d(64)]
        self.deconv2 = nn.Sequential(*deconv2)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2_0 += [nn.BatchNorm2d(64)]
        else:
            deconv2_0 += [nn.InstanceNorm2d(64)]
        self.deconv2_0 = nn.Sequential(*deconv2_0)

        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1_0 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # NCHW = H*W
        
        noise = torch.empty_like(x).normal_(mean=1.0,std=0.1)
        inN = x + noise
        upx = self.upscale8(inN)
        
        c1 = self.conv1(upx)
        c1_1 = torch.empty_like(c1).normal_(mean=1.0,std=0.1)
        c1_n = c1 + c1_1
        c2 = self.conv2(c1_n)
        c2_1 = torch.empty_like(c2).normal_(mean=1.0,std=0.1)
        c2_n = c2 + c2_1
        c3 = self.conv3(c2_n)
        c3_1 = torch.empty_like(c3).normal_(mean=1.0,std=0.1)
        c3_n = c3 + c3_1
        c4 = self.conv4(c3_n)
        c4_1 = torch.empty_like(c4).normal_(mean=1.0,std=0.1)
        c4_n = c4 + c4_1
        c5 = self.conv5(c4_n)
        c5_1 = torch.empty_like(c5).normal_(mean=1.0,std=0.1)
        c5_n = c5 + c5_1
        c6 = self.conv6(c5_n)
        c6_1 = torch.empty_like(c6).normal_(mean=1.0,std=0.1)
        c6_n = c6 + c6_1
        c7 = self.conv7(c6_n)
        c7_1 = torch.empty_like(c7).normal_(mean=1.0,std=0.1)
        c7_n = c7 + c7_1
        c8 = self.conv8(c7_n)
        c8_1 = torch.empty_like(c8).normal_(mean=1.0,std=0.1)
        c8_n = c8 + c8_1
        
        d3_0 = self.deconv4_0(c4_n)
        d3_0 = torch.cat((c3,d3_0), dim=1)
        d3_1 = torch.empty_like(d3_0).normal_(mean=1.0,std=0.1)
        d3_n = d3_0 + d3_1
        d2_0 = self.deconv3_0(d3_n)
        d2_0 = torch.cat((c2,d2_0), dim=1)
        d2_1 = torch.empty_like(d2_0).normal_(mean=1.0,std=0.1)
        d2_n = d2_0 + d2_1
        d1_0 = self.deconv2_0(d2_n)
        d1_00 = torch.cat((c1,d1_0), dim=1)    
        d1_1 = torch.empty_like(d1_00).normal_(mean=1.0,std=0.1)
        d1_n = d1_00 + d1_1
        outLR = self.deconv1_0(d1_n)
        
        d7 = self.deconv8(c8_n)
        d7 = torch.cat((c7, d7), dim=1)
        d17_1 = torch.empty_like(d7).normal_(mean=1.0,std=0.1)
        d17_n = d7 + d17_1
        d6 = self.deconv7(d17_n)
        d6 = torch.cat((c6, d6), dim=1)
        d16_1 = torch.empty_like(d6).normal_(mean=1.0,std=0.1)
        d16_n = d6 + d16_1
        d5 = self.deconv6(d16_n)
        d5 = torch.cat((c5, d5), dim=1)
        d15_1 = torch.empty_like(d5).normal_(mean=1.0,std=0.1)
        d15_n = d5 + d15_1
        d4 = self.deconv5(d15_n)
        d4 = torch.cat((c4, d4), dim=1)
        d14_1 = torch.empty_like(d4).normal_(mean=1.0,std=0.1)
        d14_n = d4 + d14_1
        d3 = self.deconv4(d14_n)
        d3 = torch.cat((c3, d3), dim=1)
        d13_1 = torch.empty_like(d3).normal_(mean=1.0,std=0.1)
        d13_n = d3 + d13_1
        d2 = self.deconv3(d13_n)
        d2 = torch.cat((c2, d2), dim=1)
        d12_1 = torch.empty_like(d2).normal_(mean=1.0,std=0.1)
        d12_n = d2 + d12_1
        d1 = self.deconv2(d12_n)
        d1 = torch.add(d1,d1_0)
        d1 = torch.cat((c1, d1), dim=1)
        d11_1 = torch.empty_like(d1).normal_(mean=1.0,std=0.1)
        d11_n = d1 + d11_1
        outHR = self.deconv1(d11_n)
#         output = torch.add(outLR,outHR)
#         d1 = torch.cat((c1, d1), dim=1)
#         outHR = self.deconv1(d1)


#         return outLR, outHR
        return upx,outLR, outHR

In [None]:
parser = argparse.ArgumentParser(description='Implementation of Pix2Pix')

# Task
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders train, val, etc)')
parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')

# Options
parser.add_argument('--no_resize_or_crop', action='store_true', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--batchSize', type=int, default=1, help='test Batch size')

# misc
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./test_results')
parser.add_argument('--results_txt', type=str, default='./test_MSE_PSNR_SSIM.txt')

##### Helper Functions for Data Loading & Pre-processing
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

In [None]:
##### Helper Functions for Data Loading & Pre-processingclass ImageFolder(data.Dataset):
class ImageFolder(data.Dataset):
    def __init__(self, opt):
        # os.listdir function gives all lists of directory
        self.root = opt.dataroot
        self.no_resize_or_crop = opt.no_resize_or_crop
        self.no_flip = opt.no_flip
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
        self.transformM = transforms.Compose([transforms.ToTensor()])
        #=====================================================================================#
        self.dir_A = os.path.join(opt.dataroot,'valx8')
        self.Aimg_paths = list(map(lambda x:os.path.join(self.dir_A,x),os.listdir(self.dir_A)))
        #=====================================================================================#
#         self.dir_AB = os.path.join(opt.dataroot, 'train')
#         self.image_paths = list(map(lambda x: os.path.join(self.dir_AB, x), os.listdir(self.dir_AB)))
        
    def __getitem__(self, index):
        #=====================================================================================#
        # A : 32x32 (blur + LR)
        # B : 256x256 (LR)
        # C : 256x256 (GT)
        # D : 256x256 (fmask)
        A_path = self.Aimg_paths[index]
        trn = A_path.find('valx8')
        endn = len(A_path)
        C_path = A_path[:trn]+'val_GT'+A_path[trn+5:endn-4]+'.jpg'
#         B_path = A_path[:trn]+'GT'+A_path[trn+5:endn-4]+'_mask.jpg'
#         print(A_path, C_path)
        A = Image.open(A_path).convert('RGB')
        C = Image.open(C_path).convert('RGB')
#         A = A.resize((256,256),Image.BICUBIC)
        B = A.resize((256,256),Image.BICUBIC)
        C = C.resize((256,256),Image.BICUBIC)
#         B = (C.resize((32,32),Image.BICUBIC)).resize((256,256),Image.BICUBIC)
#         C = C.resize((256,256),Image.BICUBIC)
#         D = D.resize((256,256),Image.BICUBIC)
#         D = D.resize((256,256),Image.BICUBIC)
#         D = torch.zeros(256,256)
#         D = TTF.to_pil_image(D)
        A = self.transform(A)
        B = self.transform(B)
        C = self.transform(C)
#             A = A[:,:32,:32]
        A = A[:,:32,:32]
        B = B[:,:256,:256]
        C = C[:,:256,:256]

        return {'A':A,'B':B, 'C':C,'fname':A_path[trn+6:endn-4]}
        
    def __len__(self):
        return len(self.Aimg_paths)

##### Helper Function for GPU Training
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

##### Helper Function for Math
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

##### Helper Functions for GAN Loss (4D Loss Comparison)
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)
        
    if torch.cuda.is_available():
        labels = labels.cuda()
        
    return criterion(input, labels)
##### Helper Function for Math
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def to_numpy(x):
    x = x.cpu()
    x = ((x.detach().numpy()+1)/2)
    x = np.transpose(x,(1,2,0))
    return x

def mse(x, y):
    return np.linalg.norm(x - y)


In [None]:
# Pre-settings
cudnn.benchmark = True
global args
args = parser.parse_args(['--dataroot','./datasets/face_SRnDeblur','--which_direction','AtoB',
                          '--num_epochs','551','--batchSize','16','--no_resize_or_crop',
                          '--model_path','./1031_final_withNevery/models',
                          '--sample_path','./1031_final_withNevery/results_e551',
                         '--results_txt','./1031_final_withNevery/PSNRSSIM_e551.txt'])
# 741 751 761 771 781
print(args)

dataset = ImageFolder(args)
data_loader = data.DataLoader(dataset=dataset,
                              batch_size=args.batchSize,
                              shuffle=True,
                              num_workers=2)

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.sample_path):
    os.makedirs(args.sample_path)

    
g_path = os.path.join(args.model_path, 'generator-%d.pkl' % (args.num_epochs))
print(g_path)

In [None]:
# Load pre-trained model
generator = Generator(args.batchSize)
model_w = torch.load(g_path)
model_w1 = dict()
for k, v in model_w.items():
    nw_name = k[7:]
    model_w1[nw_name] = v
    
generator.load_state_dict(model_w1)
# generator.load_state_dict(torch.load(g_path))
generator.eval()

if torch.cuda.is_available():
    generator = generator.cuda()

total_step = len(data_loader) # For Print Log

from PIL import Image
from skimage import data, img_as_float
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

f = open(args.results_txt,'w')
mse_in_all, mse_out_all, psnr_in_all, psnr_out_all, ssim_in_all, ssim_out_all = 0,0,0,0,0,0

for i, sample in enumerate(data_loader):

    input_A = sample['A']
    input_A_Bi = sample['B']
    GTHR = sample['C']
    testfileN = sample['fname']

    in_blurLR = to_variable(input_A)
    in_bili = to_variable(input_A_Bi)
    upx, fakeLR, fakeHR = generator(in_blurLR)
    v_GTHR = to_variable(GTHR)
    
    # print the log info
    print('Validation[%d/%d]' % (i + 1, total_step))
    # save the sampled images

    in_Ar = upx[:,0:3,:,:]
    in_Ar_bi = in_bili[:,0:3,:,:]
    fake_Br = fakeHR[:,0:3,:,:]
    real_Br = v_GTHR[:,0:3,:,:]
    
    for k in range(16):
        in_Ar_ = img_as_float(to_numpy(in_Ar[k,:,:,:]))
        fake_Br_ = img_as_float(to_numpy(fake_Br[k,:,:,:]))
        real_Br_ = img_as_float(to_numpy(real_Br[k,:,:,:]))

        mse_in = mse(real_Br_,in_Ar_)
        mse_out = mse(real_Br_,fake_Br_)
        psnr_in = psnr(real_Br_,in_Ar_,data_range=in_Ar_.max()-in_Ar_.min())
        psnr_out = psnr(real_Br_,fake_Br_,data_range=fake_Br_.max()-fake_Br_.min())
        ssim_in = ssim(real_Br_,in_Ar_,data_range=in_Ar_.max()-in_Ar_.min(),multichannel=True)
        ssim_out = ssim(real_Br_,fake_Br_,data_range=fake_Br_.max()-fake_Br_.min(),multichannel=True)

        mse_in_all += mse_in
        mse_out_all += mse_out
        psnr_in_all += psnr_in
        psnr_out_all += psnr_out
        ssim_in_all += ssim_in
        ssim_out_all += ssim_out

        f.write('%s.png \n' % testfileN[k])
        f.write('mse_in : %f, psnr_in : %f, ssim_in : %f \n' % (mse_in, psnr_in, ssim_in))
        f.write('mse_out : %f, psnr_out : %f, ssim_out : %f \n' % (mse_out, psnr_out, ssim_out))
        
        if not os.path.exists(args.sample_path+'/Generated_1'):
            os.makedirs(args.sample_path+'/Generated_1')     
        if not os.path.exists(args.sample_path+'/LRGEGT'):
            os.makedirs(args.sample_path+'/LRGEGT') 
            
#         torchvision.utils.save_image(denorm(fake_Br[k,:,:,:].data), os.path.join(args.sample_path+'/Generated_1', '%s.png' % testfileN[k]))
        torchvision.utils.save_image(denorm(in_Ar_bi[k,:,:,:].data), os.path.join(args.sample_path+'/LR256', '%s.png' % testfileN[k]))

#         res = torch.cat((torch.cat((in_Ar_bi[k,:,:,:], fake_Br[k,:,:,:]), dim=2), real_Br[k,:,:,:]), dim=2) 
#         torchvision.utils.save_image(denorm(res.data), os.path.join(args.sample_path+'/LRGEGT', '%s.png' % testfileN[k]))

f.write('Average of MSE PSNR SSIM \n')
f.write('mse_in : %f, psnr_in : %f, ssim_in : %f \n' % (mse_in_all/1200, psnr_in_all/1200, ssim_in_all/1200))
f.write('mse_out : %f, psnr_out : %f, ssim_out : %f \n' % (mse_out_all/1200, psnr_out_all/1200, ssim_out_all/1200))
print('Average of MSE PSNR SSIM \n')
print('mse_in : %f, psnr_in : %f, ssim_in : %f \n' % (mse_in_all/1200, psnr_in_all/1200, ssim_in_all/1200))
print('mse_out : %f, psnr_out : %f, ssim_out : %f \n' % (mse_out_all/1200, psnr_out_all/1200, ssim_out_all/1200))
f.close()

In [1]:
import argparse
import os
import random
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.backends import cudnn
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from PIL import Image

In [2]:
class Generator(nn.Module):
    def __init__(self, batch_size):
        super(Generator, self).__init__()

        bn = None
        if batch_size == 1:
            bn = False # Instance Normalization
        else:
            bn = True # Batch Normalization

        #============================ upscale ============================#
        self.upscale8 = nn.Sequential(
            # [3x32x32] -> [64x32x32]
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x32x32] -> [256x32x32]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x32x32] -> [64x64x64]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x64x64] -> [256x64x64]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x64x64] -> [64x128x128]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            # [64x128x128] -> [256x128x128]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            # [256x128x128] -> [64x256x256]
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),  
            # [64x256x256] -> [3x256x256]
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        )
        #============================ upscale ============================#


        # nn.Conv2d(input channel 수, convolution에 의해 생성된 channel 수, kernel size, stride=default 1, padding=default 0)
        # [3x256x256] -> [64x128x128]
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        # [64x256x256] -> [64x128x128]
#         self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)

        # -> [128x64x64]
        conv2 = [nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1)]
        if bn == True:
            conv2 += [nn.BatchNorm2d(128)]
        else:
            conv2 += [nn.InstanceNorm2d(128)]
        self.conv2 = nn.Sequential(*conv2)

        # -> [256x32x32]
        conv3 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(128, 256, 4, 2, 1)]
        if bn == True:
            conv3 += [nn.BatchNorm2d(256)]
        else:
            conv3 += [nn.InstanceNorm2d(256)]
        self.conv3 = nn.Sequential(*conv3)

        # -> [512x16x16]
        conv4 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(256, 512, 4, 2, 1)]
        if bn == True:
            conv4 += [nn.BatchNorm2d(512)]
        else:
            conv4 += [nn.InstanceNorm2d(512)]
        self.conv4 = nn.Sequential(*conv4)

        # -> [512x8x8]
        conv5 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv5 += [nn.BatchNorm2d(512)]
        else:
            conv5 += [nn.InstanceNorm2d(512)]
        self.conv5 = nn.Sequential(*conv5)

        # -> [512x4x4]
        conv6 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv6 += [nn.BatchNorm2d(512)]
        else:
            conv6 += [nn.InstanceNorm2d(512)]
        self.conv6 = nn.Sequential(*conv6)

        # -> [512x2x2]
        conv7 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv7 += [nn.BatchNorm2d(512)]
        else:
            conv7 += [nn.InstanceNorm2d(512)]
        self.conv7 = nn.Sequential(*conv7)

        # -> [512x1x1]
        conv8 = [nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 512, 4, 2, 1)]
        if bn == True:
            conv8 += [nn.BatchNorm2d(512)]
        else:
            conv8 += [nn.InstanceNorm2d(512)]
        self.conv8 = nn.Sequential(*conv8)

        # -> [512x2x2]
        deconv8 = [nn.ReLU(),
                   nn.ConvTranspose2d(512, 512, 4, 2, 1)]
        if bn == True:
            deconv8 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv8 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv8 = nn.Sequential(*deconv8)

        # [(512+512)x2x2] -> [512x4x4]
        deconv7 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv7 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv7 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv7 = nn.Sequential(*deconv7)

        # [(512+512)x4x4] -> [512x8x8]
        deconv6 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv6 += [nn.BatchNorm2d(512), nn.Dropout(0.5)]
        else:
            deconv6 += [nn.InstanceNorm2d(512), nn.Dropout(0.5)]
        self.deconv6 = nn.Sequential(*deconv6)

        # [(512+512)x8x8] -> [512x16x16]
        deconv5 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 512, 4, 2, 1)]
        if bn == True:
            deconv5 += [nn.BatchNorm2d(512)]
        else:
            deconv5 += [nn.InstanceNorm2d(512)]
        self.deconv5 = nn.Sequential(*deconv5)

        # [(512+512)x16x16] -> [256x32x32]
        deconv4 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 2, 256, 4, 2, 1)]
        if bn == True:
            deconv4 += [nn.BatchNorm2d(256)]
        else:
            deconv4 += [nn.InstanceNorm2d(256)]
        self.deconv4 = nn.Sequential(*deconv4)
        
        # [(512+512)x16x16] -> [256x32x32]
        deconv4_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(512 * 1, 256, 4, 2, 1)]
        if bn == True:
            deconv4_0 += [nn.BatchNorm2d(256)]
        else:
            deconv4_0 += [nn.InstanceNorm2d(256)]
        self.deconv4_0 = nn.Sequential(*deconv4_0)        

        # [(256+256)x32x32] -> [128x64x64]
        deconv3 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3 += [nn.BatchNorm2d(128)]
        else:
            deconv3 += [nn.InstanceNorm2d(128)]
        self.deconv3 = nn.Sequential(*deconv3)

        # [(256+256)x32x32] -> [128x64x64]
        deconv3_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1)]
        if bn == True:
            deconv3_0 += [nn.BatchNorm2d(128)]
        else:
            deconv3_0 += [nn.InstanceNorm2d(128)]
        self.deconv3_0 = nn.Sequential(*deconv3_0)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2 += [nn.BatchNorm2d(64)]
        else:
            deconv2 += [nn.InstanceNorm2d(64)]
        self.deconv2 = nn.Sequential(*deconv2)
        
        # [(128+128)x64x64] -> [64x128x128]
        deconv2_0 = [nn.ReLU(),
                   nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1)]
        if bn == True:
            deconv2_0 += [nn.BatchNorm2d(64)]
        else:
            deconv2_0 += [nn.InstanceNorm2d(64)]
        self.deconv2_0 = nn.Sequential(*deconv2_0)

        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
        # [(64+64)x128x128] -> [3x256x256]
        self.deconv1_0 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # NCHW = H*W
        
        noise = torch.empty_like(x).normal_(mean=1.0,std=0.1)
        inN = x + noise
        upx = self.upscale8(inN)
        
        c1 = self.conv1(upx)
        c1_1 = torch.empty_like(c1).normal_(mean=1.0,std=0.1)
        c1_n = c1 + c1_1
        c2 = self.conv2(c1_n)
        c2_1 = torch.empty_like(c2).normal_(mean=1.0,std=0.1)
        c2_n = c2 + c2_1
        c3 = self.conv3(c2_n)
        c3_1 = torch.empty_like(c3).normal_(mean=1.0,std=0.1)
        c3_n = c3 + c3_1
        c4 = self.conv4(c3_n)
        c4_1 = torch.empty_like(c4).normal_(mean=1.0,std=0.1)
        c4_n = c4 + c4_1
        c5 = self.conv5(c4_n)
        c5_1 = torch.empty_like(c5).normal_(mean=1.0,std=0.1)
        c5_n = c5 + c5_1
        c6 = self.conv6(c5_n)
        c6_1 = torch.empty_like(c6).normal_(mean=1.0,std=0.1)
        c6_n = c6 + c6_1
        c7 = self.conv7(c6_n)
        c7_1 = torch.empty_like(c7).normal_(mean=1.0,std=0.1)
        c7_n = c7 + c7_1
        c8 = self.conv8(c7_n)
        c8_1 = torch.empty_like(c8).normal_(mean=1.0,std=0.1)
        c8_n = c8 + c8_1
        
        d3_0 = self.deconv4_0(c4_n)
        d3_0 = torch.cat((c3,d3_0), dim=1)
        d3_1 = torch.empty_like(d3_0).normal_(mean=1.0,std=0.1)
        d3_n = d3_0 + d3_1
        d2_0 = self.deconv3_0(d3_n)
        d2_0 = torch.cat((c2,d2_0), dim=1)
        d2_1 = torch.empty_like(d2_0).normal_(mean=1.0,std=0.1)
        d2_n = d2_0 + d2_1
        d1_0 = self.deconv2_0(d2_n)
        d1_00 = torch.cat((c1,d1_0), dim=1)    
        d1_1 = torch.empty_like(d1_00).normal_(mean=1.0,std=0.1)
        d1_n = d1_00 + d1_1
        outLR = self.deconv1_0(d1_n)
        
        d7 = self.deconv8(c8_n)
        d7 = torch.cat((c7, d7), dim=1)
        d17_1 = torch.empty_like(d7).normal_(mean=1.0,std=0.1)
        d17_n = d7 + d17_1
        d6 = self.deconv7(d17_n)
        d6 = torch.cat((c6, d6), dim=1)
        d16_1 = torch.empty_like(d6).normal_(mean=1.0,std=0.1)
        d16_n = d6 + d16_1
        d5 = self.deconv6(d16_n)
        d5 = torch.cat((c5, d5), dim=1)
        d15_1 = torch.empty_like(d5).normal_(mean=1.0,std=0.1)
        d15_n = d5 + d15_1
        d4 = self.deconv5(d15_n)
        d4 = torch.cat((c4, d4), dim=1)
        d14_1 = torch.empty_like(d4).normal_(mean=1.0,std=0.1)
        d14_n = d4 + d14_1
        d3 = self.deconv4(d14_n)
        d3 = torch.cat((c3, d3), dim=1)
        d13_1 = torch.empty_like(d3).normal_(mean=1.0,std=0.1)
        d13_n = d3 + d13_1
        d2 = self.deconv3(d13_n)
        d2 = torch.cat((c2, d2), dim=1)
        d12_1 = torch.empty_like(d2).normal_(mean=1.0,std=0.1)
        d12_n = d2 + d12_1
        d1 = self.deconv2(d12_n)
        d1 = torch.add(d1,d1_0)
        d1 = torch.cat((c1, d1), dim=1)
        d11_1 = torch.empty_like(d1).normal_(mean=1.0,std=0.1)
        d11_n = d1 + d11_1
        outHR = self.deconv1(d11_n)
#         output = torch.add(outLR,outHR)
#         d1 = torch.cat((c1, d1), dim=1)
#         outHR = self.deconv1(d1)


#         return outLR, outHR
        return upx,outLR, outHR

In [3]:
parser = argparse.ArgumentParser(description='Implementation of Pix2Pix')

# Task
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders train, val, etc)')
parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')

# Options
parser.add_argument('--no_resize_or_crop', action='store_true', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--batchSize', type=int, default=1, help='test Batch size')

# misc
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./test_results')
parser.add_argument('--results_txt', type=str, default='./test_MSE_PSNR_SSIM.txt')

##### Helper Functions for Data Loading & Pre-processing
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

In [4]:
##### Helper Functions for Data Loading & Pre-processingclass ImageFolder(data.Dataset):
class ImageFolder(data.Dataset):
    def __init__(self, opt):
        # os.listdir function gives all lists of directory
        self.root = opt.dataroot
        self.no_resize_or_crop = opt.no_resize_or_crop
        self.no_flip = opt.no_flip
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
        self.transformM = transforms.Compose([transforms.ToTensor()])
        #=====================================================================================#
        self.dir_A = os.path.join(opt.dataroot,'val_wmine')
        self.Aimg_paths = list(map(lambda x:os.path.join(self.dir_A,x),os.listdir(self.dir_A)))
        #=====================================================================================#
#         self.dir_AB = os.path.join(opt.dataroot, 'train')
#         self.image_paths = list(map(lambda x: os.path.join(self.dir_AB, x), os.listdir(self.dir_AB)))
        
    def __getitem__(self, index):
        #=====================================================================================#
        # A : 32x32 (blur + LR)
        # B : 256x256 (LR)
        # C : 256x256 (GT)
        # D : 256x256 (fmask)
        A_path = self.Aimg_paths[index]
        trn = A_path.find('val_wmine')
        endn = len(A_path)
        A = Image.open(A_path).convert('RGB')
#         A = A.resize((16,16),Image.BICUBIC)
        A = A.resize((32,32),Image.BICUBIC)
        B = A.resize((256,256),Image.BICUBIC)
        A = self.transform(A)
        B = self.transform(B)
        A = A[:,:32,:32]
        B = B[:,:256,:256]

        return {'A':A,'B':B, 'fname':A_path[trn+10:endn-4]}
        
    def __len__(self):
        return len(self.Aimg_paths)

##### Helper Function for GPU Training
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

##### Helper Function for Math
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

##### Helper Functions for GAN Loss (4D Loss Comparison)
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)
        
    if torch.cuda.is_available():
        labels = labels.cuda()
        
    return criterion(input, labels)
##### Helper Function for Math
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def to_numpy(x):
    x = x.cpu()
    x = ((x.detach().numpy()+1)/2)
    x = np.transpose(x,(1,2,0))
    return x

def mse(x, y):
    return np.linalg.norm(x - y)


In [5]:
# Pre-settings
cudnn.benchmark = True
global args
args = parser.parse_args(['--dataroot','./datasets/face_SRnDeblur','--which_direction','AtoB',
                          '--num_epochs','551','--batchSize','16','--no_resize_or_crop',
                          '--model_path','./1031_final_withNevery/models',
                          '--sample_path','./1031_final_withNevery/valmine551',
                         '--results_txt','./1031_final_withNevery/PSNRSSIM_e551.txt'])

# 741 751 761 771 781
print(args)

dataset = ImageFolder(args)
data_loader = data.DataLoader(dataset=dataset,
                              batch_size=args.batchSize,
                              shuffle=True,
                              num_workers=2)

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.sample_path):
    os.makedirs(args.sample_path)

g_path = os.path.join(args.model_path, 'generator-%d.pkl' % (args.num_epochs))
print(g_path)

Namespace(batchSize=16, dataroot='./datasets/face_SRnDeblur', model_path='./1031_final_withNevery/models', no_flip=False, no_resize_or_crop=True, num_epochs=551, results_txt='./1031_final_withNevery/PSNRSSIM_e551.txt', sample_path='./1031_final_withNevery/valmine551', which_direction='AtoB')
./1031_final_withNevery/models/generator-551.pkl


In [9]:
# Load pre-trained model
generator = Generator(args.batchSize)
model_w = torch.load(g_path)
model_w1 = dict()
for k, v in model_w.items():
    nw_name = k[7:]
    model_w1[nw_name] = v
    
generator.load_state_dict(model_w1)
# generator.load_state_dict(torch.load(g_path))
generator.eval()

if torch.cuda.is_available():
    generator = generator.cuda()

total_step = len(data_loader) # For Print Log

for i, sample in enumerate(data_loader):

    input_A = sample['A']
    input_A_Bi = sample['B']
    testfileN = sample['fname']

    in_blurLR = to_variable(input_A)
    in_bili = to_variable(input_A_Bi)
    upx, fakeLR, fakeHR = generator(in_blurLR)
    
    # print the log info
    print('Validation[%d/%d]' % (i + 1, total_step))
    # save the sampled images

    in_Ar = upx[:,0:3,:,:]
    in_Ar_bi = in_bili[:,0:3,:,:]
    fake_Br = fakeHR[:,0:3,:,:]
    
    for k in range(16):

        if not os.path.exists(args.sample_path+'/Generated'):
            os.makedirs(args.sample_path+'/Generated')  
            
        torchvision.utils.save_image(denorm(in_bili[k,:,:,:].data), os.path.join(args.sample_path+'/Generated/input', '%s_input.png' % testfileN[k]))
        torchvision.utils.save_image(denorm(fake_Br[k,:,:,:].data), os.path.join(args.sample_path+'/Generated', '%s.png' % testfileN[k]))


Validation[1/2]
Validation[2/2]


In [None]:
print(args.sample_path+'/Generated')