In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.data_loader_RGB_resize import *
from SRCNN_model import Net
from utils.pytorch_ssim import *
from utils.loss import *

torch.manual_seed(1)
device = torch.device("cuda")

In [3]:
# SubPixelCNN parameters

batch_size = 32
epochs = 50
lr = 0.01
threads = 4
upscale_factor = 4

In [4]:
#img_path_low = '/media/angelo/DATEN/Datasets/CelebA/LR_112/train/'
#img_path_ref = '/media/angelo/DATEN/Datasets/CelebA/HR/train/'

img_path_low = '../dataset/LR_112/train'
img_path_ref = '../dataset/HR/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [5]:
print('===> Building model')
model = Net().to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

===> Building model


In [6]:
# Initializing face loss

feature_extraction_model = initialize_senet50_2048()
face_loss = FacePerceptionLoss(feature_extraction_model).to(device)

In [7]:
out_path = 'results/'
out_model_path = 'models/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [8]:
def train(epoch):
    epoch_loss = 0
    epoch_total_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        # MSE Loss for PSNR estimation
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        # Face Loss
        total_loss = face_loss(upsampled_img, target)
        epoch_total_loss += total_loss.item()
        #loss.backward()
        total_loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader),
                                                           total_loss.item()))
    
    scheduler.step() # Decrease learning rate after 100 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    #avg_loss_batch = epoch_loss/len(training_data_loader)
    avg_loss_batch = epoch_total_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg.Total Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 2) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Total Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'SRCNN_Loss_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)

def checkpoint(epoch):
    path = out_model_path + "SRCNN_Loss_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [16]:
# Let's inspect training setting the epoch locally

#optimizer.param_groups[0]['lr'] = 0.001
epochs = 10

for epoch in range(1, epochs + 1):
    train(epoch)

===> Epoch[7](1/562): Loss: 1403067.5000
===> Epoch[7](2/562): Loss: 767648.3750
===> Epoch[7](3/562): Loss: 224933.3594
===> Epoch[7](4/562): Loss: 22327.1426
===> Epoch[7](5/562): Loss: 167450.7500
===> Epoch[7](6/562): Loss: 413131.7500


KeyboardInterrupt: 