<a href="https://colab.research.google.com/github/Seowon-Ji/Multi-target/blob/master/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import os
import math
import argparse
import random
import model
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from dataloader import goprodataset
import time
from tqdm import tqdm

from skimage.measure import compare_psnr, compare_ssim
from HED import HED

parser = argparse.ArgumentParser(description="Deep Multi-Patch Hierarchical Network")
parser.add_argument("-e", "--epochs", type=int, default=4000)
parser.add_argument("-se", "--start_epoch", type=int, default=0)
parser.add_argument("-b", "--batchsize", type=int, default=8)
parser.add_argument("-s", "--imagesize", type=int, default=256)
parser.add_argument("-l", "--learning_rate", type=float, default=0.0001)
parser.add_argument("-g", "--gpu", type=int, default=0)
args = parser.parse_args()

# Hyper Parameters
METHOD = "DMPHN_1_2_4_8_l1_sobel_005"
LEARNING_RATE = args.learning_rate
EPOCHS = args.epochs
GPU = args.gpu
BATCH_SIZE = args.batchsize
IMAGE_SIZE = args.imagesize


def save_deblur_images(images, iteration, epoch):
    filename_list = './checkpoints/' + METHOD + "/test/epoch" + str(epoch)
    if not os.path.exists(filename_list):
        os.makedirs(filename_list)
    filename = './checkpoints/' + METHOD + "/test/epoch" + str(epoch) + "/" + "Iter_" + str(iteration) + "_deblur.png"
    torchvision.utils.save_image(images, filename)

def save_images(images, name):
    filename = './test_out/'  + name
    torchvision.utils.save_image(images, filename)

def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0.0, 0.5 * math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())


def main():
    hed = HED().cuda(GPU)
    print("init data folders")

    encoder_lv1 = model.Encoder()
    encoder_lv2 = model.Encoder()
    encoder_lv3 = model.Encoder()
    encoder_lv4 = model.Encoder()

    decoder_lv1 = model.Decoder()
    decoder_lv2 = model.Decoder()
    decoder_lv3 = model.Decoder()
    decoder_lv4 = model.Decoder()

    encoder_lv1.apply(weight_init).cuda(GPU)
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)
    encoder_lv4.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)
    decoder_lv4.apply(weight_init).cuda(GPU)

    step_size = 550

    encoder_lv1_optim = torch.optim.Adam(encoder_lv1.parameters(), lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim, step_size=step_size, gamma=0.1)
    encoder_lv2_optim = torch.optim.Adam(encoder_lv2.parameters(), lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim, step_size=step_size, gamma=0.1)
    encoder_lv3_optim = torch.optim.Adam(encoder_lv3.parameters(), lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim, step_size=step_size, gamma=0.1)
    encoder_lv4_optim = torch.optim.Adam(encoder_lv4.parameters(), lr=LEARNING_RATE)
    encoder_lv4_scheduler = StepLR(encoder_lv4_optim, step_size=step_size, gamma=0.1)

    decoder_lv1_optim = torch.optim.Adam(decoder_lv1.parameters(), lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim, step_size=step_size, gamma=0.1)
    decoder_lv2_optim = torch.optim.Adam(decoder_lv2.parameters(), lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim, step_size=step_size, gamma=0.1)
    decoder_lv3_optim = torch.optim.Adam(decoder_lv3.parameters(), lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim, step_size=step_size, gamma=0.1)
    decoder_lv4_optim = torch.optim.Adam(decoder_lv4.parameters(), lr=LEARNING_RATE)
    decoder_lv4_scheduler = StepLR(decoder_lv4_optim, step_size=step_size, gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        checkpoint = torch.load(str('./checkpoints/' + METHOD + "/epoch.pkl"))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/epoch.pkl")):
        checkpoint = torch.load(str('./checkpoints/' + METHOD + "/epoch.pkl"))
        args.start_epoch = checkpoint['epoch']
        init_epoch = 0
        if args.start_epoch >= step_size:
            init_epoch = 550
        if args.start_epoch >= 2 *step_size:
            init_epoch = step_size * 2

        encoder_lv1_scheduler.step(init_epoch)
        encoder_lv2_scheduler.step(init_epoch)
        encoder_lv3_scheduler.step(init_epoch)
        encoder_lv4_scheduler.step(init_epoch)

        decoder_lv1_scheduler.step(init_epoch)
        decoder_lv2_scheduler.step(init_epoch)
        decoder_lv3_scheduler.step(init_epoch)
        decoder_lv4_scheduler.step(init_epoch)

        print("load epoch %d success"%args.start_epoch)

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    epoch = checkpoint['epoch']
    encoder_lv1_scheduler.step(epoch)
    encoder_lv2_scheduler.step(epoch)
    encoder_lv3_scheduler.step(epoch)
    encoder_lv4_scheduler.step(epoch)

    decoder_lv1_scheduler.step(epoch)
    decoder_lv2_scheduler.step(epoch)
    decoder_lv3_scheduler.step(epoch)
    decoder_lv4_scheduler.step(epoch)

    if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False:
        os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch))

    if not os.path.exists('./checkpoints/' + METHOD + '/parameter' + '/epoch' + str(epoch)):
        os.makedirs('./checkpoints/' + METHOD + '/parameter' + '/epoch' + str(epoch))

    print("test...")
    test_dataset = goprodataset(
        blur_dir='../deblur_data/gopro_reset/test/blur',
        sharp_dir='../deblur_data/gopro_reset/test/sharp'
    )
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    test_time = 0
    total_psnr = []
    total_ssim = []
    images_name = os.listdir('../deblur_data/gopro_reset/test/blur')
    for iteration, images in enumerate(tqdm(test_dataloader)):
        with torch.no_grad():
            start = time.time()
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            H = images_lv1.size(2)
            W = images_lv1.size(3)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
            images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
            images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
            images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
            images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
            images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
            images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
            images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
            images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)

            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)

            feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)

            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)

            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + torch.cat(
                (feature_lv3_top, feature_lv3_bot), 2)
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)
            stop = time.time()
            test_time += stop - start
            # print('RunTime:%.4f' % (stop - start), '  Average Runtime:%.4f' % (test_time / (iteration + 1)))
            # save_deblur_images(deblur_image.data + 0.5, iteration, epoch)
            save_images(deblur_image.data + 0.5, images_name[iteration])
            # numpy array
            deblur_image.data = deblur_image.data + 0.5
            out_np = deblur_image.cpu().numpy()
            rgb_images_np = images['sharp_image'].numpy()
            for i in range(deblur_image.size(0)):
                index = iteration + i

                psnr = compare_psnr(im_true=rgb_images_np[i], im_test=out_np[i])
                ssim = (compare_ssim(X=rgb_images_np[i][0], Y=out_np[i][0]) +
                        compare_ssim(X=rgb_images_np[i][1], Y=out_np[i][1]) +
                        compare_ssim(X=rgb_images_np[i][2], Y=out_np[i][2])) / 3
                total_psnr.append(psnr)
                total_ssim.append(ssim)

    print('RunTime:%.4f' % (stop - start), '  Average Runtime:%.4f' % (test_time / (iteration + 1)))

    # average
    total_psnr = np.mean(total_psnr)
    total_ssim = np.mean(total_ssim)

    with open('test_sobel_005_psnr.txt', 'a') as f:
        f.write("epoch: %d , PSNR: %.4f , SSIM: %.4f \n" % (epoch, total_psnr, total_ssim))
        f.close()

    # Print psnr, ssim
    message = '\t {}: {:.2f}\t {}: {:.4f}'.format('psnr', total_psnr, 'ssim', total_ssim)
    print(message)
    print("DONE")




if __name__ == '__main__':
    main()





