<a href="https://colab.research.google.com/github/Seowon-Ji/Multi-target/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import os
import math
import argparse
import random
import model
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from dataloader import goprodataset
import time
from tqdm import tqdm
from torch import autograd

from utils import VisdomLinePlotter
from skimage.measure import compare_psnr, compare_ssim
from vgg import Vgg19, normalize_vgg

from losses import init_loss
import losses
from options.train_options import TrainOptions

parser = argparse.ArgumentParser(description="Deep Multi-Patch Hierarchical Network")
parser.add_argument("-e", "--epochs", type=int, default=1200)
parser.add_argument("-se", "--start_epoch", type=int, default=0)
parser.add_argument("-b", "--batchsize", type=int, default=8)
parser.add_argument("-s", "--imagesize", type=int, default=256)
parser.add_argument("-l", "--learning_rate", type=float, default=0.0001)
parser.add_argument("-g", "--gpu", type=int, default=0)
parser.add_argument("-save", "--save_file", type=bool, default=False)
parser.add_argument("-visdom", "--visdom", type=bool, default=False)
args = parser.parse_args()

# Hyper Parameters
METHOD = "DMPHN_1_2_4_8_l1_gan"
LEARNING_RATE = args.learning_rate
EPOCHS = args.epochs
GPU = args.gpu
BATCH_SIZE = args.batchsize
IMAGE_SIZE = args.imagesize
SAVE_FILE = args.save_file
VISDOM = args.visdom

opt = TrainOptions().parse()
opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined'
opt.learn_residual = True
opt.resize_or_crop = "crop"
opt.fineSize = 256
opt.gan_type = "gan"
# opt.which_model_netG = "unet_256"

# default = 5000
opt.save_latest_freq = 100

# default = 100
opt.print_freq = 20

opt = TrainOptions().parse()
opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined'
opt.learn_residual = True
opt.resize_or_crop = "crop"
opt.fineSize = 256
opt.gan_type = "gan"

def save_deblur_images(images, iteration, epoch):
    filename_list = './checkpoints/' + METHOD + "/epoch" + str(epoch)
    if not os.path.exists(filename_list):
        os.makedirs(filename_list)
    filename = './checkpoints/' + METHOD + "/epoch" + str(epoch) + "/" + "Iter_" + str(iteration) + "_deblur.png"
    torchvision.utils.save_image(images, filename)


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0.0, 0.5 * math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())


def main():


    if VISDOM == True:
        plotter_valid = VisdomLinePlotter(env_name='Plots')

    if os.path.exists(str('./checkpoints/' + METHOD + "/epoch.pkl")):
        checkpoint = torch.load(str('./checkpoints/' + METHOD + "/epoch.pkl"))
        start_epoch = checkpoint['epoch']
        curr_lr = LEARNING_RATE * (0.1 ** (start_epoch // 550))
        print("loaded epoch %d" % start_epoch)
    else:
        start_epoch = 0
        curr_lr = LEARNING_RATE

    print("init data folders")

    netG = model.Generator()
    # netD = model.Discriminator_deblurganv1()
    netD = model.NLayerDiscriminator()
    netG.apply(weight_init).cuda()
    netD.apply(weight_init).cuda()
    criticUpdates = 1

    step_size = 550
    Tensor = torch.cuda.FloatTensor
    # define loss functions
    discLoss, contentLoss = init_loss(opt, Tensor)

    netG_optim = torch.optim.Adam(netG.parameters(), lr=curr_lr)
    netG_scheduler = StepLR(netG_optim, step_size=step_size, gamma=0.1)
    netD_optim = torch.optim.Adam(netD.parameters(), lr=curr_lr)
    netD_scheduler = StepLR(netD_optim, step_size=step_size, gamma=0.1)


    if os.path.exists(str('./checkpoints/' + METHOD + "/netG.pkl")):
        netG.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/netG.pkl")))
        print("load netG success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/netD.pkl")):
        netG.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/netD.pkl")))
        print("load netD success")

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    for epoch in range(args.start_epoch, EPOCHS):
        netG.train()
        netD.train()

        netG_scheduler.step(epoch)
        netD_scheduler.step(epoch)

        print("Training...")

        train_dataset = goprodataset(
            blur_dir='../deblur_data/gopro_reset/train/blur',
            sharp_dir='../deblur_data/gopro_reset/train/sharp',
            crop=True,
            crop_size=IMAGE_SIZE,
            # transform=transforms.ToTensor()
        )

        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        for param_group in netG_optim.param_groups:
            print("netG_current lr: %f"%param_group['lr'])
        for param_group in netD_optim.param_groups:
            print("netD_current lr: %f"%param_group['lr'])

        start = time.time()
        for iteration, images in enumerate(tqdm(train_dataloader)):
            # break
            # mse = nn.MSELoss().cuda(GPU)
            # l1 = nn.L1Loss().cuda(GPU)

            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)


            deblur_image = netG(images_lv1)
            ############################
            # (1) Update D network
            ###########################
            for iter_d in range(criticUpdates):
                netD_optim.zero_grad()
                # loss_D = self.discLoss.get_loss(netD, self.real_A, self.fake_B, self.real_B)
                loss_D =discLoss.get_loss(netD, images_lv1, deblur_image, gt)
                loss_D.backward(retain_graph=True)
                netD_optim.step()

            ############################
            # (2) Update G network
            ###########################
            a = 0.01
            netG_optim.zero_grad()
            # loss_G_GAN = discLoss.get_g_loss(netD, self.real_A, self.fake_B)
            loss_G_GAN = discLoss.get_g_loss(netD, images_lv1, deblur_image)
            # Second, G(A) = B
            # loss_G_Content = contentLoss.get_loss(self.fake_B, self.real_B) * self.opt.lambda_A
            loss_G_Content = contentLoss.get_loss(deblur_image, gt)

            loss_G = a * loss_G_GAN + loss_G_Content

            loss_G.backward()
            netG_optim.step()
            # pixel_loss = mse(deblur_image, gt)


            if (iteration + 1) % 10 == 0:
                stop = time.time()
                print("epoch:", epoch, "iteration:", iteration + 1, "loss:%.4f" % loss_G.item(),"adversarial_loss:%.4f"%loss_G_GAN.item(),"deblur_dis:%.4f"%loss_D,
                      'time:%.4f' % (stop - start))
                start = time.time()

        if (epoch) % 10 == 0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False:
                os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch))

            if not os.path.exists('./checkpoints/' + METHOD +'/parameter' + '/epoch' + str(epoch)):
                os.makedirs('./checkpoints/' + METHOD +'/parameter' + '/epoch' + str(epoch))

            torch.save(netG.state_dict(), str('./checkpoints/' + METHOD +'/parameter' + '/epoch' + str(epoch) + "/netG.pkl"))
            torch.save(netD.state_dict(), str('./checkpoints/' + METHOD +'/parameter' + '/epoch' + str(epoch) + "/netD.pkl"))

            torch.save({'epoch': epoch}, str('./checkpoints/' + METHOD +'/parameter' + '/epoch' + str(epoch) + "/epoch.pkl"))


            print("valid...")
            valid_dataset =  goprodataset(
                blur_dir='../deblur_data/gopro_reset/valid/blur',
                sharp_dir='../deblur_data/gopro_reset/valid/sharp'
                )
            valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
            test_time = 0
            total_psnr = []
            total_ssim = []

            netG.eval()
            for iteration, images in enumerate(tqdm(valid_dataloader)):
                with torch.no_grad():
                    start = time.time()
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
                    deblur_image = netG(images_lv1)
                    stop = time.time()
                    test_time += stop - start
                    # print('RunTime:%.4f' % (stop - start), '  Average Runtime:%.4f' % (test_time / (iteration + 1)))
                    if SAVE_FILE ==True:
                        save_deblur_images(deblur_image.data + 0.5, iteration, epoch)

                    # numpy array
                    deblur_image.data = deblur_image.data + 0.5
                    out_np = deblur_image.cpu().numpy()
                    rgb_images_np = images['sharp_image'].numpy()
                    for i in range(deblur_image.size(0)):
                        index = iteration + i

                        psnr = compare_psnr(im_true=rgb_images_np[i], im_test=out_np[i])
                        ssim = (compare_ssim(X=rgb_images_np[i][0], Y=out_np[i][0]) +
                                compare_ssim(X=rgb_images_np[i][1], Y=out_np[i][1]) +
                                compare_ssim(X=rgb_images_np[i][2], Y=out_np[i][2])) / 3
                        total_psnr.append(psnr)
                        total_ssim.append(ssim)

            print('RunTime:%.4f' % (stop - start), '  Average Runtime:%.4f' % (test_time / (iteration + 1)))

            # average
            total_psnr =np.mean(total_psnr)
            total_ssim =np.mean(total_ssim)

            with open('valid_0006_psnr.txt', 'a') as f:
                f.write("epoch: %d , PSNR: %.4f , SSIM: %.4f\n" % (epoch, total_psnr, total_ssim))
                f.close()

            # Print psnr, ssim
            message = '\t {}: {:.2f}\t {}: {:.4f}'.format('psnr', total_psnr, 'ssim', total_ssim)
            print(message)
            if VISDOM == True:
                plotter_valid.plot('loss', 'psnr', 'validation', epoch, total_psnr)


        torch.save(netG.state_dict(), str('./checkpoints/' + METHOD + "/netG.pkl"))
        torch.save(netD.state_dict(), str('./checkpoints/' + METHOD + "/netD.pkl"))
        torch.save({'epoch': epoch}, str('./checkpoints/' + METHOD + "/epoch.pkl"))





if __name__ == '__main__':
    main()





