In [1]:
from PIL import Image
import os
import numpy as np
import cv2
import torch
from torch.autograd import Variable
import pickle as pk
import logging
import visdom

In [2]:
class Option(object):
    def __init__(self):
        pass

In [3]:
opt = Option()
opt.max_epochs = 100
opt.gpu_ids = [0]
opt.task = 'test'
opt.display_port = 8097
opt.num_images_per_batch = 1
opt.save_dir = '/home/guang/SuperRes/pytorch-srgan/test/'
opt.pretrained_state = '/home/guang/SuperRes/pytorch-srgan/snapshots/perform_test/state_epoch100_iter0.pkl' 
opt.snapshot_subdir = 'test01'  # 'nightingale'
opt.snapshot_prefix_G = 'unet256'
opt.snapshot_prefix_D = 'basic'
opt.snapshot_interval_epochs = 10 # In epochs
opt.snapshot_interval_iters = 5000
opt.display_interval = 1  # In iterations
opt.display_env = opt.snapshot_subdir
opt.num_average_minibatches = None 
opt.learning_rate = 1e-4
opt.lambda_G = 1e-3
opt.no_lsgan = True
opt.dataset = 'test' #'test' #'starcraft1688' #'ilsvrc2012'
opt.num_crops_per_image = 1 # should be the same value as num_images_per_batch
opt.log_file = os.path.join('/home/guang/SuperRes/pytorch-srgan/test/', 
                            opt.snapshot_subdir,
                            'log.txt')
opt.eval_dir = os.path.join('/home/guang/SuperRes/evaluation/SRGAN/')

opt.save_dir = os.path.join(opt.save_dir, opt.snapshot_subdir)
if not os.path.isdir(opt.save_dir):
    os.makedirs(opt.save_dir)
if not os.path.isdir(opt.eval_dir):
    os.makedirs(opt.eval_dir)

In [4]:
logger = logging.getLogger("pytorch-srgan")
logger.setLevel(logging.DEBUG)
logger.propagate = False

log_file = logging.FileHandler(opt.log_file)
log_file.setLevel(logging.DEBUG)

fmt = '%(asctime)s %(levelname)-8s: %(message)s'
fmt = logging.Formatter(fmt)

log_file.setFormatter(fmt)
logger.addHandler(log_file)

logger.info('Pretrained state "{}"'.format(opt.pretrained_state))
logger.info('Snapshot will be saved in "{}"'.format(opt.save_dir))

In [5]:
from models.networks import define_D, GANLoss, define_G

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 1e-2)

In [7]:
#TODO no_lsgan and use sigmoid
opt.use_sigmoid = opt.no_lsgan
model_D = define_D(6, 64, 'basic', use_sigmoid=opt.use_sigmoid)
model_G = define_G(3, 3, 64, which_model_netG='unet_256')

loss_history = {
    'L2_loss': [],
    'Gan_G_loss': [], 
    'D_loss_real': [],
    'D_loss_fake': [],
}

if opt.pretrained_state is None:
    state = {
        'loss_history': {
            'L2_loss': [],
            'Gan_G_loss': [], 
            'D_loss_real': [],
            'D_loss_fake': [],
        },
        'model_G': None,
        'model_D': None,
        'optimizer_G': None,
        'optimizer_D': None,        
        
        'epochs': 0
    }
    model_G.apply(weights_init)
    model_D.apply(weights_init)
    model_G.cuda(device=opt.gpu_ids[0])
    model_D.cuda(device=opt.gpu_ids[0])
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=opt.learning_rate)
    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=opt.learning_rate)

else:
    with open(opt.pretrained_state) as f:
        state = pk.load(f)
    model_G.load_state_dict(torch.load(os.path.join(os.path.dirname(opt.pretrained_state), state['model_G'])))
    model_D.load_state_dict(torch.load(os.path.join(os.path.dirname(opt.pretrained_state), state['model_D'])))
    model_G.cuda(device=opt.gpu_ids[0])
    model_D.cuda(device=opt.gpu_ids[0])
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=opt.learning_rate)
    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=opt.learning_rate)
    optimizer_G.load_state_dict(torch.load(os.path.join(os.path.dirname(opt.pretrained_state), state['optimizer_G'])))
    optimizer_D.load_state_dict(torch.load(os.path.join(os.path.dirname(opt.pretrained_state), state['optimizer_D'])))   
    
loss_history = state['loss_history']
Tensor = torch.cuda.FloatTensor

In [8]:
# Load the dataset
if opt.dataset == 'ilsvrc2012':
    from datasets.ilsvrc2012 import ILSVRC2012
    ilsvrc = ILSVRC2012('/home/guang/datasets/ILSVRC2012', N=opt.num_crops_per_image, scales=[2,3,4])
    dataset = torch.utils.data.DataLoader(
                ilsvrc,
                batch_size=opt.num_images_per_batch,
                num_workers=int(1),
                shuffle=True)
elif opt.dataset == 'starcraft1688':
    from datasets.starcraft1688 import StarCraft1688Dataset
    starcraft = StarCraft1688Dataset('/home/guang/datasets/starcraft1688', N=opt.num_crops_per_image)
    dataset = torch.utils.data.DataLoader(
                starcraft,
                batch_size=opt.num_images_per_batch,
                num_workers=int(1),
                shuffle=True)
    n_image_pairs = len(starcraft)    
elif opt.dataset == 'test':
    from datasets.testset import testset
    testset = testset('/home/guang/datasets/starcraft_scene2', N=opt.num_images_per_batch)
    dataset = torch.utils.data.DataLoader(
                testset,
                batch_size=opt.num_images_per_batch,
                num_workers=int(1),
                shuffle=True)
    n_image_pairs = len(testset) 
    print('Test images in total: {}').format(n_image_pairs)
    
if opt.num_average_minibatches is None:
    opt.num_average_minibatches = len(dataset)

Test images in total: 244


In [9]:
from torch.nn import functional as F
def center_crop(x, height, width):
    crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2)
    crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2)

    return F.pad(x, [
        crop_w.ceil().int()[0], crop_w.floor().int()[0],
        crop_h.ceil().int()[0], crop_h.floor().int()[0],
    ])

In [10]:
import progressbar as pb
import scipy.misc
from Quality import PSNR, SSIM 
import time

input_tensor = Tensor(1, 3, 512, 512)
target_tensor = Tensor(1, 3, 512, 512)
# label_fake = Variable(Tensor(opt.batch_size).fill_(0))
# label_real = Variable(Tensor(opt.batch_size;t).fill_(1))

psnr, ssim = [], []
time_points = []
for minibatch_i, data in enumerate(dataset):

    # Stacking crops together as different inputs/targets
    input_data = []
    target_data = []
    for crop_i in range(opt.num_crops_per_image):
        input_data.append(data[crop_i*2])
        target_data.append(data[crop_i*2+1])
    input_data = torch.cat(input_data, dim=0)    
    target_data = torch.cat(target_data, dim=0)
    
    input_tensor.resize_(input_data.size()).copy_(input_data)
    target_tensor.resize_(target_data.size()).copy_(target_data)
    input = Variable(input_tensor)
    target = Variable(target_tensor)
    
    crop_h, crop_w = 360, 640
    start = time.time()
    output = model_G.forward(input)
    point = time.time() - start
    time_points.append(point)
    print('Testing cost time in {:.0f}m {:.0f}s\n'.format(point // 60, point % 60))
    
    #psnr.append(PSNR(output.data.cpu().numpy(), target.data.cpu().numpy()))
    #ssim.append(SSIM(Variable(output.data.cpu()), Variable(target.data.cpu())))
    
    output = center_crop(output, crop_h, crop_w)
    output = output.data.cpu().numpy()
    
    mean = [ 0.5, 0.5, 0.5 ]
    std = [ 0.5, 0.5, 0.5 ]
    for c in range(3):
        pass
        output[:, c, :, :] *= std[c]
        output[:, c, :, :] += mean[c]
            
    for save_i in range(opt.num_images_per_batch):
        im_np = output[save_i,:,:,:]
        scipy.misc.imsave(os.path.join(opt.save_dir, 'test_{}.bmp'.format(minibatch_i*opt.num_crops_per_image+save_i)), im_np)

#print psnr
#print ssim
np.save(os.path.join(opt.eval_dir, "time.npy"), time_points)
print time_points
print len(time_points), np.mean(time_points)

Testing cost time in 0m 3s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in 0m 0s

Testing cost time in

In [69]:
ssim_tmp = []
#print ssim[0].data.cpu().numpy()[0]
for i in range(len(ssim)):
    ssim_tmp.append(ssim[i].data.cpu().numpy()[0])
#print ssim_tmp
np.save(os.path.join(opt.eval_dir, "psnr.npy"), psnr)
np.save(os.path.join(opt.eval_dir, "ssim.npy"), ssim_tmp)

s = np.load(os.path.join(opt.eval_dir, "ssim.npy"))
p = np.load(os.path.join(opt.eval_dir, "psnr.npy"))
print s.size, p.size
print s
print ssim

244 244
[0.47781494 0.43329963 0.3342373  0.3327878  0.70980066 0.3654462
 0.60616595 0.36899504 0.3136771  0.42293277 0.3027359  0.55297124
 0.38609046 0.4021718  0.6002045  0.36817762 0.44297576 0.396913
 0.42642096 0.34386697 0.40278804 0.5285475  0.509858   0.34337878
 0.49067256 0.31302547 0.3813199  0.41302076 0.51902854 0.5138504
 0.5482495  0.5581454  0.36425292 0.29265252 0.5211884  0.39268583
 0.50024426 0.38993314 0.36985856 0.47908875 0.5032304  0.29878843
 0.41463014 0.33173573 0.52971804 0.36280036 0.50544244 0.33991563
 0.37110215 0.37617847 0.44744092 0.5060992  0.4736998  0.33111584
 0.4188926  0.65168834 0.2918025  0.47215784 0.5868749  0.4768651
 0.36913633 0.55280364 0.29167965 0.33295003 0.38133946 0.47418287
 0.43252906 0.39876363 0.40673012 0.5265573  0.42249063 0.31260455
 0.36426198 0.55915046 0.4392343  0.3034012  0.566916   0.3538673
 0.49221054 0.3377267  0.38493884 0.4423349  0.5200705  0.296323
 0.41843688 0.35755607 0.36650828 0.37563044 0.4966811  0.5477

In [62]:
viz.line(Y=psnr_tmp, opts=dict(showlegend=True))

u'window_3626776a8c095e'

In [61]:
np.random.rand(10).size
psnr_tmp = np.array(psnr)
print psnr_tmp.size

244


In [100]:
import scipy.misc

if opt.dataset == 'test':
    from datasets.testset import testset
    testset = testset('/home/guang/datasets/starcraft_scene2', N=opt.num_images_per_batch)
    dataset = torch.utils.data.DataLoader(
                testset,
                batch_size=opt.num_images_per_batch,
                num_workers=int(1),
                shuffle=True)
    n_image_pairs = len(testset)
    print('Test images in total: {}').format(n_image_pairs)

input_tensor = Tensor(1, 3, 1920, 1072)
target_tensor = Tensor(1, 3, 1920, 1072)
# label_fake = Variable(Tensor(opt.batch_size).fill_(0))
# label_real = Variable(Tensor(opt.batch_size).fill_(1))

i_tmp, t_tmp, o_tmp, os_tmp = [], [], [], []

for minibatch_i, data in enumerate(dataset):

    # Stacking crops together as different inputs/targets
    if minibatch_i > 0:
        break
    input_data = []
    target_data = []
    for crop_i in range(opt.num_crops_per_image):
        input_data.append(data[crop_i*2])
        target_data.append(data[crop_i*2+1])
    input_data = torch.cat(input_data, dim=0)    
    target_data = torch.cat(target_data, dim=0)
    
    input_tensor.resize_(input_data.size()).copy_(input_data)
    target_tensor.resize_(target_data.size()).copy_(target_data)

    
    input = Variable(input_tensor)
    target = Variable(target_tensor)
    
    crop_h, crop_w = 360, 640
    output = model_G.forward(input)
    os_tmp = output
    #output = center_crop(output, crop_h, crop_w)
    output = output.data.cpu().numpy()
    
    i_tmp, t_tmp = input.data.cpu().numpy(), target.data.cpu().numpy()
    
    mean = [ 0.5, 0.5, 0.5 ]
    std = [ 0.5, 0.5, 0.5 ]
    for c in range(3):
        pass
        output[:, c, :, :] *= std[c]
        output[:, c, :, :] += mean[c]
    o_tmp = output    
    
    for save_i in range(opt.num_images_per_batch):
        im_np = output[save_i,:,:,:]
        scipy.misc.imsave(os.path.join(opt.save_dir, opt.snapshot_subdir, 'test_{}.bmp'.format(minibatch_i*opt.num_crops_per_image+save_i)), im_np )   
        
           

Test images in total: 244


In [87]:
import numpy
import math

def PSNR(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

In [103]:
p = PSNR(o_tmp, t_tmp)
print p

print(SSIM(os_tmp, target))

ssim_loss = SSIM(window_size = 11)

print(ssim_loss(os_tmp, target))

51.2505619128
Variable containing:
 0.5022
[torch.cuda.FloatTensor of size 1 (GPU 0)]

Variable containing:
 0.5022
[torch.cuda.FloatTensor of size 1 (GPU 0)]



In [102]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def SSIM(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)