In [27]:
import os
import cv2
import numpy as np
from PIL import Image
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore import context
import mindspore.ops as ops
from src.model.generator import Generator
from src.dataset.create_loader import create_test_dataloader
import time
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

In [28]:
def bicubic(path):
    '''
        use bicubic to downsample img
        @param path : raw set5
    '''
    print("======downsample img")
    # set5 data
    set5 = [os.path.join(path, x) for x in sorted(os.listdir(path))]
    for imgPath in set5:
        img = cv2.imread(imgPath)
        # target size
        traget_size = (int(np.shape(img)[1] / 4), int(np.shape(img)[0] / 4))
        # apply bicubic
        bicubic_img = cv2.resize(img, traget_size, interpolation=cv2.INTER_CUBIC)
        # downsample
        #downsample_img = cv2.resize(bicubic_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_AREA)
        # save
        cv2.imwrite('bicubic/' + imgPath.split('\\')[-1], bicubic_img)
    
    print("img saved in 'bicubic' ")

In [29]:
def infer(path):
    '''
        get hr img
        @param path : downsample set5 
    '''
    print('======get hr')
    # random seed
    set_seed(1)
    # 图模式
    context.set_context(mode=context.GRAPH_MODE, device_id=0, save_graphs=False)
    # dataloader
    test_ds = create_test_dataloader(1, path, inference=True)
    data_size = test_ds.get_dataset_size()
    test_data_loader = test_ds.create_dict_iterator()
    # generator
    generator = Generator(4)
    params = load_checkpoint('ckpt/G_model_1000.ckpt')
    print("======load checkpoint")
    load_param_into_net(generator, params)
    op = ops.ReduceSum(keep_dims=False)
    print("=======starting test=====")
    time_total = 0
    i = 0
    # infer
    for data in test_data_loader:
        time_begin = time.time()
        lr = data['LR']
        output = generator(lr)
        time_total += time.time() - time_begin
        output = op(output, 0).asnumpy()
        output = np.clip(output, -1.0, 1.0)
        output = ((output + 1.0) / 2.0).transpose(1, 2, 0)
        result = Image.fromarray((output * 255.0).astype(np.uint8))
        # save the output image
        result.save(f"set5_hr/{i}.jpg")
        i += 1
    print("Total %d images need %.0fms, per image needs %.0fms." % (data_size, time_total * 1000, \
        (time_total / data_size) * 1000))
    print("Images saved in 'set5_hr'")
    print("Inference End.")


In [30]:
def calculatePsnrAndSsim(lr_path, hr_path):
    '''
        calculate psnr and ssim
        @param lr_path: raw set5
        @param hr_path: hr set5
    '''
    print('calculating avg psnr and avg ssim')
    
    # file path
    lr = [os.path.join(lr_path, x) for x in sorted(os.listdir(lr_path))]
    hr = [os.path.join(hr_path, x) for x in sorted(os.listdir(hr_path))]
    psnr_all, ssim_all = 0, 0

    # calculate 
    for i in range(len(lr)):
        lr_img = cv2.imread(lr[i])
        hr_img = cv2.imread(hr[i])
        # calculate psnr
        psnr = peak_signal_noise_ratio(lr_img, hr_img, data_range=255)
        psnr_all += psnr
        # calculate ssim
        ssim = structural_similarity(lr_img, hr_img, data_range=255, multichannel=True, channel_axis=2)
        ssim_all += ssim
    
    print(f'avg psnr: {psnr_all/len(lr):.2f}')
    print(f'avg ssim: {ssim_all/len(lr):.2f}')


In [31]:
if __name__=='__main__':
    lr_path = 'set5_lr'
    hr_path = 'set5_hr'
    bicubic_path = 'bicubic'

    # downsample img    
    bicubic(lr_path)

    # infer 
    infer(bicubic_path)

    # calculate psnr and ssim
    calculatePsnrAndSsim(lr_path, hr_path)



img saved in 'bicubic' 




Total 5 images need 5280ms, per image needs 1056ms.
Images saved in 'set5_hr'
Inference End.
calculating avg psnr and avg ssim




avg psnr: 22.78
avg ssim: 0.68
