Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The released pretrained model cannot obtain the PSNR reported in the paper? #48

Closed
DoctorYy opened this issue Oct 14, 2019 · 3 comments
Closed

Comments

@DoctorYy
Copy link

Hi, thanks for your work.
I downloaded the pretrained model in your link and tested the RBPN_4x.pth model on Vid4 test set. However, the results confused me. I got the following performance: calendar=22.2495 (paper
23.99), city=26.1635 (paper 27.73), foliage=24.7383 (paper 26.22), walk=29.2091 (paper 30.70). I wonder why it underperforms the expeted results in the paper. It would be nice of you if you could tell me how to reproduce the reported results using the pretrained models.
Thanks a lot.

BTW, I think there is a small bug in eval.py line73. It should be count=0.

@DoctorYy
Copy link
Author

Hi, I think I figured out this problem~
The output images should be converted into YCbCr space first and evaluated only in Y channel. I added this part of codes and got approximate results similar to the paper.

BTW, the bug in eval.py line73 still remains.

@monika5296
Copy link

Hi, I think I figured out this problem~
The output images should be converted into YCbCr space first and evaluated only in Y channel. I added this part of codes and got approximate results similar to the paper.

BTW, the bug in eval.py line73 still remains.

Can you please highlight what changes you have made ..I am getting all pixels as invalid and psnr as nan..
Thanks in advance.

@DoctorYy
Copy link
Author

Hi, I think I figured out this problem~
The output images should be converted into YCbCr space first and evaluated only in Y channel. I added this part of codes and got approximate results similar to the paper.
BTW, the bug in eval.py line73 still remains.

Can you please highlight what changes you have made ..I am getting all pixels as invalid and psnr as nan..
Thanks in advance.

Here is the eval.py which I have modified.

from __future__ import print_function
import argparse

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from rbpn import Net as RBPN
from data import get_test_set
from functools import reduce
import numpy as np

from scipy.misc import imsave
import scipy.io as sio
import time
import cv2
import math
import pdb

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor")
parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--chop_forward', type=bool, default=False)
parser.add_argument('--threads', type=int, default=1, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--gpus', default=1, type=int, help='number of gpu')
parser.add_argument('--data_dir', type=str, default='./Vid4')
parser.add_argument('--file_list', type=str, default='walk.txt')      # calendar  city  foliage  walk
parser.add_argument('--other_dataset', type=bool, default=True, help="use other dataset than vimeo-90k")
parser.add_argument('--future_frame', type=bool, default=True, help="use future frame")
parser.add_argument('--nFrames', type=int, default=7)
parser.add_argument('--model_type', type=str, default='RBPN')
parser.add_argument('--residual', type=bool, default=False)
parser.add_argument('--output', default='Results/', help='Location to save checkpoint models')
parser.add_argument('--model', default='weights/4x_gpuserver4-1RBPNF7_epoch_29.pth', help='sr pretrained base model')

opt = parser.parse_args()

gpus_list=range(opt.gpus)
print(opt)

cuda = opt.gpu_mode
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Loading datasets')
test_set = get_test_set(opt.data_dir, opt.nFrames, opt.upscale_factor, opt.file_list, opt.other_dataset, opt.future_frame)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

print('===> Building model ', opt.model_type)
if opt.model_type == 'RBPN':
    model = RBPN(num_channels=3, base_filter=256,  feat = 64, num_stages=3, n_resblock=5, nFrames=opt.nFrames, scale_factor=opt.upscale_factor)

if cuda:
    model = torch.nn.DataParallel(model, device_ids=gpus_list)

model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage))
print('Pre-trained SR model is loaded.')

if cuda:
    model = model.cuda(gpus_list[0])

def eval():
    model.eval()
    count=0
    avg_psnr_rgb = 0.0
    avg_psnr_y = 0.0
    for batch in testing_data_loader:
        input, target, neigbor, flow, bicubic = batch[0], batch[1], batch[2], batch[3], batch[4]
        
        with torch.no_grad():
            input = Variable(input).cuda(gpus_list[0])
            bicubic = Variable(bicubic).cuda(gpus_list[0])
            neigbor = [Variable(j).cuda(gpus_list[0]) for j in neigbor]
            flow = [Variable(j).cuda(gpus_list[0]).float() for j in flow]

        t0 = time.time()
        if opt.chop_forward:
            with torch.no_grad():
                prediction = chop_forward(input, neigbor, flow, model, opt.upscale_factor)
        else:
            with torch.no_grad():
                prediction = model(input, neigbor, flow) 
        
        if opt.residual:
            prediction = prediction + bicubic
            
        t1 = time.time()
        print("===> Processing: %d || Timer: %.4f sec." % (count, (t1 - t0)))
        save_img(prediction.cpu().data, str(count), True)
        #save_img(target, str(count), False)
        
        prediction=prediction.cpu()
        prediction = prediction.data[0].numpy().astype(np.float32)
        prediction = prediction*255.
        
        target = target.squeeze().numpy().astype(np.float32)
        target = target*255.
        
        # [3, H, W]  --->   [H, W, 3]
        prediction = np.transpose(prediction, (1, 2 ,0)).clip(0, 255)#.astype(np.uint8)
        target = np.transpose(target, (1, 2 ,0)).clip(0, 255)#.astype(np.uint8)
        
        '''
        print(target.shape)
        print(prediction.shape)
        import matplotlib.pyplot as plt
        plt.subplot(1,2,1)
        plt.imshow(prediction)
        plt.subplot(1,2,2)
        plt.imshow(target)
        plt.show()
        '''
          
        psnr_rgb = PSNR(prediction.copy(),target.copy(), shave_border=4, mode='rgb')
        print('PSNR (RGB):', psnr_rgb)
        psnr_y = PSNR(prediction.copy(), target.copy(), shave_border=4, mode='y')
        print('PSNR (Y):', psnr_y)
        avg_psnr_rgb += psnr_rgb
        avg_psnr_y += psnr_y
        count+=1
    print('###############################')
    print('Testing set: %s'%opt.file_list)
    print("AVG PSNR (RGB) = %6.4f"%(avg_psnr_rgb/count))
    print("AVG PSNR (Y) = %6.4f"%(avg_psnr_y/count))
    print('###############################')

def save_img(img, img_name, pred_flag):
    save_img = img.squeeze().clamp(0, 1).numpy().transpose(1,2,0)

    # save img
    save_dir=os.path.join(opt.output, opt.data_dir, os.path.splitext(opt.file_list)[0]+'_'+str(opt.upscale_factor)+'x')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    if pred_flag:
        save_fn = save_dir +'/'+ img_name+'_'+opt.model_type+'F'+str(opt.nFrames)+'.png'
    else:
        save_fn = save_dir +'/'+ img_name+'.png'
    cv2.imwrite(save_fn, cv2.cvtColor(save_img*255, cv2.COLOR_BGR2RGB),  [cv2.IMWRITE_PNG_COMPRESSION, 0])

def bgr2ycbcr(img, only_y=True):
    """bgr version of rgb2ycbcr
    only_y: only return Y channel
    Input:
        uint8, [0, 255]
        float, [0, 1]
    """
    in_img_type = img.dtype
    img.astype(np.float32)
    if in_img_type != np.uint8:
        img *= 255.
    # convert
    if only_y:
        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
    else:
        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
    if in_img_type == np.uint8:
        rlt = rlt.round()
    else:
        rlt /= 255.
    return rlt.astype(in_img_type)


def PSNR(pred, gt, shave_border=0, mode='rgb'):
    if mode == 'y':
        pred = bgr2ycbcr(pred)
        gt = bgr2ycbcr(gt)
    pred = pred.astype(np.float)
    gt = gt.astype(np.float)
    height, width = pred.shape[:2]
    
    if mode == 'rgb':
        pred = pred[shave_border:height - shave_border, shave_border:width - shave_border, :]
        gt = gt[shave_border:height - shave_border, shave_border:width - shave_border, :]
    elif mode == 'y':
        pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
        gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
    else:
        print('Not recognized color mode!')
        
    imdff = pred - gt
    rmse = math.sqrt(np.mean(imdff ** 2))
    if rmse == 0:
        return 100
    return 20 * math.log10(255.0 / rmse)

def compute_PSNR(imgs1, imgs2):
    """Compute PSNR between two image array and return the psnr summation"""
        
    img1 = np.array(imgs1,dtype='float')
    img2 = np.array(imgs2,dtype='float')
    #img1 = imgs1.astype(np.float)
    #img2 = imgs2.astype(np.float)
    mse = np.mean( (img1 - img2) ** 2 )
    if mse == 0:
      return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
    
def chop_forward(x, neigbor, flow, model, scale, shave=8, min_size=2000, nGPUs=opt.gpus):
    b, c, h, w = x.size()
    h_half, w_half = h // 2, w // 2
    h_size, w_size = h_half + shave, w_half + shave
    inputlist = [
        [x[:, :, 0:h_size, 0:w_size], [j[:, :, 0:h_size, 0:w_size] for j in neigbor], [j[:, :, 0:h_size, 0:w_size] for j in flow]],
        [x[:, :, 0:h_size, (w - w_size):w], [j[:, :, 0:h_size, (w - w_size):w] for j in neigbor], [j[:, :, 0:h_size, (w - w_size):w] for j in flow]],
        [x[:, :, (h - h_size):h, 0:w_size], [j[:, :, (h - h_size):h, 0:w_size] for j in neigbor], [j[:, :, (h - h_size):h, 0:w_size] for j in flow]],
        [x[:, :, (h - h_size):h, (w - w_size):w], [j[:, :, (h - h_size):h, (w - w_size):w] for j in neigbor], [j[:, :, (h - h_size):h, (w - w_size):w] for j in flow]]]

    if w_size * h_size < min_size:
        outputlist = []
        for i in range(0, 4, nGPUs):
            with torch.no_grad():
                input_batch = inputlist[i]#torch.cat(inputlist[i:(i + nGPUs)], dim=0)
                output_batch = model(input_batch[0], input_batch[1], input_batch[2])
            outputlist.extend(output_batch.chunk(nGPUs, dim=0))
    else:
        outputlist = [
            chop_forward(patch[0], patch[1], patch[2], model, scale, shave, min_size, nGPUs) \
            for patch in inputlist]

    h, w = scale * h, scale * w
    h_half, w_half = scale * h_half, scale * w_half
    h_size, w_size = scale * h_size, scale * w_size
    shave *= scale

    with torch.no_grad():
        output = Variable(x.data.new(b, c, h, w))
    output[:, :, 0:h_half, 0:w_half] \
        = outputlist[0][:, :, 0:h_half, 0:w_half]
    output[:, :, 0:h_half, w_half:w] \
        = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
    output[:, :, h_half:h, 0:w_half] \
        = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
    output[:, :, h_half:h, w_half:w] \
        = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

    return output

##Eval Start!!!!
eval()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants