In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import os 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader

from dataset.data_loader_RGB import *
from utils.pytorch_ssim import *

from facenet_pytorch import InceptionResnetV1
from GAN_model import Generator, Discriminator
from utils.loss import*

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SRGAN parameters

batch_size = 32
num_epochs = 10
lr = 1e-4
threads = 4
upscale_factor = 4

In [3]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

#img_path_low = '/home/jupyter/dataset/LR_56/train/'
#img_path_ref = '/home/jupyter/dataset/HR/train/'

img_path_low = '/media/angelo/DATEN/Datasets/CelebA/LR_56/train/'
img_path_ref = '/media/angelo/DATEN/Datasets/CelebA/HR/train/'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
netG = Generator(upscale_factor).to(device)
netD = Discriminator().to(device)

#netG.weight_init(mean=0.0, std=0.2)
#netD.weight_init(mean=0.0, std=0.2)

feature_extraction_model = InceptionResnetV1(pretrained='vggface2').eval()

generator_criterion = GeneratorLoss().to(device)
face_loss = FaceIdentityLoss(feature_extraction_model).to(device)
MSE_criterion = nn.MSELoss()

In [5]:
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    
results = {'d_loss': [], 'g_loss': [], 'd_score': [],
           'g_score': [], 'psnr': [], 'ssim': []}

In [None]:
# Let's pretrain the Generator
pre_train_epochs = 5
for epoch in range(1, pre_train_epochs+1):
    
    netG.train()
    
    for data, target in training_data_loader:
        data, target = data.to(device), target.to(device)
        
        netG.zero_grad()
        loss = MSE_criterion(netG(data), target)
        loss.backward()
        optimizerG.step()
    
    print('Last loss: {}'.format(loss.item()))
    print('Finished pre-training {}/{}'.format(epoch,pre_train_epochs))

In [6]:
def train(epoch):
    
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 
                       'd_score': 0, 'g_score': 0}
    
    netG.train()
    netD.train()
    iteration = 0
    for data, target in training_data_loader:
        iteration+=1
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target).to(device)
        z = Variable(data).to(device)
        
        fake_img = netG(z)

        netD.zero_grad()
        
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        
        d_loss = 1 - real_out + fake_out
        
        if d_loss.item() > 0.4:
            d_loss.backward(retain_graph=True)
            optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        
        #mse_adversarial_content_loss = generator_criterion(fake_out, fake_img, real_img)
        #face_perceptual_loss = face_loss(fake_img, real_img)
        
        g_loss = generator_criterion(fake_out, fake_img, real_img) + face_loss(fake_img, real_img)
        
        g_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.step()

        # Loss for current batch before optimization 

        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        print("===> Epoch[{}]({}/{}): Discriminator Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), d_loss.item()))

    print('[{}/{}] Loss_D: {:.4f} Loss_G: {:.4f} D(x): {:.4f} D(G(z)): {:.4f}'.format(
            epoch, num_epochs, 
            running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))

    #netG.eval()

    batch_mse = ((fake_img - real_img) ** 2).data.mean()
    batch_ssim = ssim(fake_img, real_img).item()
    batch_psnr = 10 * log10(1 /batch_mse)

    out_path = 'results/'
    out_model_path = 'models/'
    
    if not os.path.exists(out_path):
        os.makedirs(out_path)    
    
    if not os.path.exists(out_model_path):
        os.makedirs(out_model_path)    

    # Save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(batch_psnr)
    results['ssim'].append(batch_ssim)

    if epoch % (num_epochs // 2) == 0:

        # Save model parameters
        torch.save(netG, out_model_path + 'SRGAN_x%d_epoch_%d.pth' % (upscale_factor, epoch))
        #torch.save(netD, 'netD_x%d_epoch_%d.pt' % (upscale_factor, epoch))

        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_path + 'SRGAN_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')

In [7]:
num_epochs = 30

for epoch in range(1, num_epochs + 1):
    train(epoch)



===> Epoch[1](1/4495): Discriminator Loss: 1.0000
===> Epoch[1](2/4495): Discriminator Loss: 0.9854
===> Epoch[1](3/4495): Discriminator Loss: 0.9807
===> Epoch[1](4/4495): Discriminator Loss: 0.9675
===> Epoch[1](5/4495): Discriminator Loss: 0.9328
===> Epoch[1](6/4495): Discriminator Loss: 0.8980
===> Epoch[1](7/4495): Discriminator Loss: 0.8654
===> Epoch[1](8/4495): Discriminator Loss: 0.9001
===> Epoch[1](9/4495): Discriminator Loss: 0.8928
===> Epoch[1](10/4495): Discriminator Loss: 0.8568
===> Epoch[1](11/4495): Discriminator Loss: 0.8094
===> Epoch[1](12/4495): Discriminator Loss: 0.9393
===> Epoch[1](13/4495): Discriminator Loss: 0.9063
===> Epoch[1](14/4495): Discriminator Loss: 0.9257
===> Epoch[1](15/4495): Discriminator Loss: 0.8749
===> Epoch[1](16/4495): Discriminator Loss: 0.8457
===> Epoch[1](17/4495): Discriminator Loss: 0.8218
===> Epoch[1](18/4495): Discriminator Loss: 0.8518
===> Epoch[1](19/4495): Discriminator Loss: 0.8580
===> Epoch[1](20/4495): Discriminator Lo

Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 