In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import os 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader

from dataset.data_loader_RGB import *
from utils.pytorch_ssim import *

from facenet_pytorch import InceptionResnetV1
from models.GAN_coord_model import Generator, Discriminator
from utils.loss import*

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SRGAN parameters

batch_size = 32
num_epochs = 30
lr = 1e-4
threads = 4
upscale_factor = 4

In [3]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '/home/jupyter/dataset/LR_56/train/'
img_path_ref = '/home/jupyter/dataset/HR/train/'

#img_path_low = '/media/angelo/DATEN/Datasets/CelebA/LR_56/train/'
#img_path_ref = '/media/angelo/DATEN/Datasets/CelebA/HR/train/'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
netG = Generator(upscale_factor).to(device)
netD = Discriminator().to(device)

#netG = torch.load('models/SRGAN_coord_FaceLoss_x4_epoch_1.pth', map_location=lambda storage, loc: storage).to(device)
#netD = torch.load('models/netD_FaceLoss_x4_epoch_12.pth', map_location=lambda storage, loc: storage).to(device)

#netG.weight_init(mean=0.0, std=0.2)
#netD.weight_init(mean=0.0, std=0.2)

feature_extraction_model = InceptionResnetV1(pretrained='vggface2').eval()

generator_criterion = GeneratorLoss().to(device)
MSE_criterion = nn.MSELoss().to(device)

In [5]:
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    
results = {'d_loss': [], 'g_loss': [], 'd_score': [],
           'g_score': [], 'psnr': [], 'ssim': []}

In [6]:
# Let's pretrain the Generator
pre_train_epochs = 5
for epoch in range(1, pre_train_epochs+1):
    netG.train()
    
    for data, target in training_data_loader:
        data, target = data.to(device), target.to(device)
        netG.zero_grad()
        loss = MSE_criterion(netG(data), target)
        loss.backward()
        optimizerG.step()
    
    print('Last loss: {}'.format(loss.item()))
    print('Finished pre-training {}/{}'.format(epoch,pre_train_epochs))

Last loss: 0.002991111483424902
Finished pre-training 1/5
Last loss: 0.0026030167937278748
Finished pre-training 2/5
Last loss: 0.0024772584438323975
Finished pre-training 3/5
Last loss: 0.00225018966011703
Finished pre-training 4/5
Last loss: 0.0026883746031671762
Finished pre-training 5/5


In [7]:
out_path = 'results/'
out_model_path = 'checkpoints/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)  

In [10]:
def train(epoch):
    
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 
                       'd_score': 0, 'g_score': 0}
    
    netG.train()
    netD.train()
    iteration = 0
    for data, target in training_data_loader:
        iteration+=1
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target).to(device)
        z = Variable(data).to(device)
        
        fake_img = netG(z)

        netD.zero_grad()
        
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        
        d_loss = 1 - real_out + fake_out
        
        if d_loss.item() > 0.5:
            d_loss.backward(retain_graph=True)
            optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss + FaceID Loss
        ###########################
        netG.zero_grad()
        
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.step()

        # Loss for current batch before optimization 

        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        print("===> Epoch[{}]({}/{}): D(x) {:.4f} Loss_D: {:.4f}".format(epoch, iteration, 
                                                                         len(training_data_loader), 
                                                                         real_out.item(),
                                                                         d_loss.item()))

    batch_mse = ((fake_img - real_img) ** 2).data.mean()
    batch_ssim = ssim(fake_img, real_img).item()
    batch_psnr = 10 * log10(1 /batch_mse)  
    
    print('[{}/{}] Loss_D: {:.4f} Loss_G: {:.4f} D(x): {:.4f} D(G(z)): {:.4f} PSNR {:.4f}'.format(
        epoch, num_epochs, 
        running_results['d_loss'] / running_results['batch_sizes'],
        running_results['g_loss'] / running_results['batch_sizes'],
        running_results['d_score'] / running_results['batch_sizes'],
        running_results['g_score'] / running_results['batch_sizes'], batch_psnr))
          
    # Save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(batch_psnr)
    results['ssim'].append(batch_ssim)

    if epoch % 1 == 0:

        # Save model parameters
        torch.save(netG, out_model_path + 'SRGAN_coord_x%d_epoch_%d.pth' % (upscale_factor, epoch))
        torch.save(netD, out_model_path + 'netD_coord_x%d_epoch_%d.pth' % (upscale_factor, epoch))

        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_path + 'SRGAN_coord_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')

In [None]:
#num_epochs = 30

for epoch in range(1, num_epochs+1):
    train(epoch)

===> Epoch[1](1/562): D(x) 0.5021 Loss_D: 0.9999
===> Epoch[1](2/562): D(x) 0.4975 Loss_D: 0.9998
===> Epoch[1](3/562): D(x) 0.4967 Loss_D: 0.9996
===> Epoch[1](4/562): D(x) 0.5004 Loss_D: 0.9989
===> Epoch[1](5/562): D(x) 0.5004 Loss_D: 0.9994
===> Epoch[1](6/562): D(x) 0.4987 Loss_D: 0.9984
===> Epoch[1](7/562): D(x) 0.4974 Loss_D: 0.9983
===> Epoch[1](8/562): D(x) 0.4966 Loss_D: 0.9986
===> Epoch[1](9/562): D(x) 0.4942 Loss_D: 0.9971
===> Epoch[1](10/562): D(x) 0.4971 Loss_D: 0.9970
===> Epoch[1](11/562): D(x) 0.5087 Loss_D: 0.9952
===> Epoch[1](12/562): D(x) 0.5094 Loss_D: 0.9953
===> Epoch[1](13/562): D(x) 0.5090 Loss_D: 0.9952
===> Epoch[1](14/562): D(x) 0.5214 Loss_D: 0.9917
===> Epoch[1](15/562): D(x) 0.5108 Loss_D: 0.9930
===> Epoch[1](16/562): D(x) 0.5247 Loss_D: 0.9909
===> Epoch[1](17/562): D(x) 0.5435 Loss_D: 0.9837
===> Epoch[1](18/562): D(x) 0.5416 Loss_D: 0.9821
===> Epoch[1](19/562): D(x) 0.5563 Loss_D: 0.9823
===> Epoch[1](20/562): D(x) 0.4860 Loss_D: 0.9819
===> Epoc