In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import os 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader

from super_resolution_data_loader_GAN import *
from pytorch_ssim import *

from SRGAN_model import Generator, Discriminator
from loss import GeneratorLoss

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SRGAN parameters

batch_size = 32
num_epochs = 10
lr = 0.01
threads = 4
upscale_factor = 4

In [3]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '../dataset/300W-3D-crap-56/train'
img_path_ref = '../dataset/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
netG = Generator(upscale_factor).to(device)
print('# Generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator().to(device)
print('# Discriminator parameters:', sum(param.numel() for param in netD.parameters()))
generator_criterion = GeneratorLoss().to(device)

# Generator parameters: 734219
# Discriminator parameters: 5215425


In [5]:
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
    
results = {'d_loss': [], 'g_loss': [], 'd_score': [],
           'g_score': [], 'psnr': [], 'ssim': []}

In [6]:
def train(epoch):
    
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 
                       'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    
    for data, target in training_data_loader:
        #train_bar = tqdm(training_data_loader)
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target).to(device)
        z = Variable(data).to(device)
        
        fake_img = netG(z)

        netD.zero_grad()
        
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        
        optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()


        optimizerG.step()

        # Loss for current batch before optimization 

        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

    #train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
    #        epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
    #        running_results['g_loss'] / running_results['batch_sizes'],
    #        running_results['d_score'] / running_results['batch_sizes'],
    #        running_results['g_score'] / running_results['batch_sizes']))

    print('[{}/{}] Loss_D: {} Loss_G: {} D(x): {} D(G(z)): {}'.format(
            epoch, num_epochs, 
            running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))

    netG.eval()

    batch_mse = ((fake_img - real_img) ** 2).data.mean()
    batch_ssim = ssim(fake_img, real_img).item()
    batch_psnr = 10 * log10(1 /batch_mse)

    out_path = 'SRGAN_results/'

    if not os.path.exists(out_path):
        os.makedirs(out_path)    

    # Save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(batch_psnr)
    results['ssim'].append(batch_ssim)

    if epoch % 5 == 0:

        # Save model parameters
        torch.save(netG, 'netG_epoch_%d_%d.pth' % (upscale_factor, epoch))
        #torch.save(netD, 'netD_epoch_%d_%d.pth' % (upscale_factor, epoch))

        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_path + 'SRGAN_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')

In [7]:
for epoch in range(1, num_epochs + 1):
    train(epoch)

[1/10] Loss_D: 0.9606085551169611 Loss_G: 0.027038411163995343 D(x): 0.25200491034936523 D(G(z)): 0.1966510102080722
[2/10] Loss_D: 1.0030325412750245 Loss_G: 0.011348660630324194 D(x): 0.19897432315734126 D(G(z)): 0.19661091198844294
[3/10] Loss_D: 1.0007452163388653 Loss_G: 0.009760453354927801 D(x): 0.42396586933443625 D(G(z)): 0.42544204235076905
[4/10] Loss_D: 1.000786513051679 Loss_G: 0.008780144039661654 D(x): 0.7122460011513002 D(G(z)): 0.7123295261782985
[5/10] Loss_D: 1.0027116019495073 Loss_G: 0.00774245516126675 D(x): 0.7360384329672782 D(G(z)): 0.7347663675585101
[6/10] Loss_D: 1.0048047944038145 Loss_G: 0.007435147737062746 D(x): 0.6086329825462834 D(G(z)): 0.6032832504472425
[7/10] Loss_D: 0.9987230831576932 Loss_G: 0.007897250700381494 D(x): 0.12551476273895992 D(G(z)): 0.12414137604180724
[8/10] Loss_D: 1.0030247536013204 Loss_G: 0.007360470990740484 D(x): 0.29405972267350844 D(G(z)): 0.2989126700355161
[9/10] Loss_D: 1.0007373746748893 Loss_G: 0.00712976171484878 D(x)