In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader

from super_resolution_data_loader_GAN import *
from pytorch_ssim import *

from SRGAN_model import Generator, Discriminator
from loss import GeneratorLoss

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SRGAN parameters

batch_size = 2
num_epochs = 100
lr = 0.01
threads = 4
upscale_factor = 4

In [3]:
img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
netG = Generator(upscale_factor).to(device)
print('# Generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator().to(device)
print('# Discriminator parameters:', sum(param.numel() for param in netD.parameters()))
generator_criterion = GeneratorLoss().to(device)

# Generator parameters: 734219
# Discriminator parameters: 5215425


In [5]:
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
    
results = {'d_loss': [], 'g_loss': [], 'd_score': [],
           'g_score': [], 'psnr': [], 'ssim': []}

In [6]:
def train(epoch):
    
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 
                       'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    
    for data, target in training_data_loader:
        train_bar = tqdm(training_data_loader)
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target).to(device)
        z = Variable(data).to(device)
        
        fake_img = netG(z)

        netD.zero_grad()
        
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        
        optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()


        optimizerG.step()

        # Loss for current batch before optimization 
        
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

In [7]:
for epoch in range(1, num_epochs + 1):
    train(epoch)

[1/100] Loss_D: 1.0002 Loss_G: 0.1300 D(x): 0.4914 D(G(z)): 0.3750:   0%|          | 0/1550 [00:01<?, ?it/s]
  0%|          | 0/1550 [00:00<?, ?it/s][A
[1/100] Loss_D: 1.0002 Loss_G: 0.1300 D(x): 0.4914 D(G(z)): 0.3750:   0%|          | 0/1550 [00:01<?, ?it/s]
[1/100] Loss_D: 0.9999 Loss_G: 0.1361 D(x): 0.4440 D(G(z)): 0.3550:   0%|          | 0/1550 [00:01<?, ?it/s]
  0%|          | 0/1550 [00:00<?, ?it/s][A
[1/100] Loss_D: 0.9999 Loss_G: 0.1361 D(x): 0.4440 D(G(z)): 0.3550:   0%|          | 0/1550 [00:01<?, ?it/s]
[1/100] Loss_D: 0.9999 Loss_G: 0.1378 D(x): 0.4018 D(G(z)): 0.3337:   0%|          | 0/1550 [00:01<?, ?it/s]
  0%|          | 0/1550 [00:00<?, ?it/s][A
[1/100] Loss_D: 0.9999 Loss_G: 0.1378 D(x): 0.4018 D(G(z)): 0.3337:   0%|          | 0/1550 [00:01<?, ?it/s]
[1/100] Loss_D: 0.9985 Loss_G: 0.1289 D(x): 0.3686 D(G(z)): 0.2925:   0%|          | 0/1550 [00:01<?, ?it/s]
  0%|          | 0/1550 [00:00<?, ?it/s][A
[1/100] Loss_D: 0.9985 Loss_G: 0.1289 D(x): 0.3686 D(G(z)): 0

[1/100] Loss_D: 0.9860 Loss_G: 0.0689 D(x): 0.2107 D(G(z)): 0.1696:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A




  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A




[1/100] Loss_D: 0.9860 Loss_G: 0.0689 D(x): 0.2107 D(G(z)): 0.1696:   0%|          | 0/1550 [00:01<?, ?it/s]




[1/100] Loss_D: 0.9958 Loss_G: 0.0679 D(x): 0.2078 D(G(z)): 0.1787:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A




  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A




[1/100] Loss_D: 0.9958 Loss_G: 0.0679 D(x): 0.2078 D(G(z)): 0.1787:   0%|          | 0/1550 [00:01<?, ?it/s]




[1/100] Loss_D: 0.9945 Loss_G: 0.0668 D(x): 0.2162 D(G(z)): 0.1847:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A




  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A




[1/100] Loss_D: 0.9945 Loss_G: 0.0668 D(x): 0.2162 D(G(z)): 0.1847:   0%|          | 0/1550 [00:01<?, ?it/s]




[1/100] Loss_D: 0.9897 Loss_G: 0.0661 D(x): 0.2199 D(G(z)): 0.1847:   0%|          | 0/1550 [00:01<?, ?it/

  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A








[1/100] Loss_D: 0.9932 Loss_G: 0.0529 D(x): 0.2052 D(G(z)): 0.1607:   0%|          | 0/1550 [00:01<?, ?it/s]








[1/100] Loss_D: 0.9964 Loss_G: 0.0530 D(x): 0.2028 D(G(z)): 0.1604:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A








  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A








[1/100] Loss_D: 0.9964 Loss_G: 0.0530 D(x): 0.2028 D(G(z)): 0.1604:   0%|          | 0/1550 [00:01<?, ?it/s]








[1/100] Loss_D: 0.9965 Loss_G: 0.0525 D(x): 0.1995 D(G(z)): 0.1578:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A








  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A








[1/100] Loss_D: 0.9965 Loss_G: 0.0525 D(x): 0.1995 D(G(z)): 0.1578:   0%|          | 0/1550 [00:01<?, ?it/s]








[1/100] Loss_D: 0.9958 Loss_G: 0.0521 D(x): 0.1970 D(G(z)): 0.1552:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A








[1/100] Loss_D: 0.9973 Loss_G: 0.0472 D(x): 0.1477 D(G(z)): 0.1163:   0%|          | 0/1550 [00:01<?, ?it/s]











[1/100] Loss_D: 0.9973 Loss_G: 0.0468 D(x): 0.1459 D(G(z)): 0.1150:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A











  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A











[1/100] Loss_D: 0.9974 Loss_G: 0.0469 D(x): 0.1442 D(G(z)): 0.1136:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A












  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A












[1/100] Loss_D: 0.9974 Loss_G: 0.0469 D(x): 0.1442 D(G(z)): 0.1136:   0%|          | 0/1550 [00:01<?, ?it/s]












[1/100] Loss_D: 0.9974 Loss_G: 0.0468 D(x): 0.1426 D(G(z)): 0.1124:   0%|          | 0/1550 [00:01<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A












  0%|          | 0/1550 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A












[1/100] Loss_D: 0.99

KeyboardInterrupt: 

In [None]:
netG.eval()
out_path = 'training_results/SRF_' + str(upscale_factor) + '/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

with torch.no_grad():
    val_bar = tqdm(val_loader)
    valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
    val_images = []
    for val_lr, val_hr_restore, val_hr in val_bar:
        batch_size = val_lr.size(0)
        valing_results['batch_sizes'] += batch_size
        lr = val_lr
        hr = val_hr
        if torch.cuda.is_available():
            lr = lr.cuda()
            hr = hr.cuda()
        sr = netG(lr)

        batch_mse = ((sr - hr) ** 2).data.mean()
        valing_results['mse'] += batch_mse * batch_size
        batch_ssim = pytorch_ssim.ssim(sr, hr).item()
        valing_results['ssims'] += batch_ssim * batch_size
        valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
        valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
        val_bar.set_description(
            desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                valing_results['psnr'], valing_results['ssim']))

        val_images.extend(
            [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
             display_transform()(sr.data.cpu().squeeze(0))])
    val_images = torch.stack(val_images)
    val_images = torch.chunk(val_images, val_images.size(0) // 15)
    val_save_bar = tqdm(val_images, desc='[saving training results]')
    index = 1
    for image in val_save_bar:
        image = utils.make_grid(image, nrow=3, padding=5)
        utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
        index += 1

# save model parameters
torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (upscale_factor, epoch))
torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (upscale_factor, epoch))
# save loss\scores\psnr\ssim
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])

if epoch % 10 == 0 and epoch != 0:
    out_path = 'statistics/'
    data_frame = pd.DataFrame(
        data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
              'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
        index=range(1, epoch + 1))
    data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')