In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.data_loader_RGB_resize import *
from SRCNN_model import Net
from utils.pytorch_ssim import *
from utils.loss import *

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SubPixelCNN parameters

batch_size = 64
epochs = 10
lr = 0.01
threads = 4
upscale_factor = 4

In [3]:
#img_path_low = '../media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '../dataset/LR_56/test'
img_path_ref = '../dataset/HR/test'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
print('===> Building model')
model = Net().to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

===> Building model


In [5]:
# Initializing face loss

feature_extraction_model = initialize_senet50_2048()
face_loss = FacePerceptionLoss(feature_extraction_model).to(device)

In [6]:
out_path = 'results/'
out_model_path = 'models/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [7]:
def train(epoch):
    epoch_loss = 0
    epoch_total_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        # MSE Loss for PSNR estimation
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        # Face Loss
        total_loss = face_loss(upsampled_img, target)
        epoch_total_loss += total_loss.item()
        #loss.backward()
        total_loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader),
                                                           total_loss.item()))
    
    scheduler.step() # Decrease learning rate after 100 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    #avg_loss_batch = epoch_loss/len(training_data_loader)
    avg_loss_batch = epoch_total_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg.Total Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 2) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Total Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'SRCNN_Loss_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)

def checkpoint(epoch):
    path = out_model_path + "SRCNN_Loss_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [8]:
# Let's inspect training setting the epoch locally

epochs = 10
for epoch in range(1, epochs + 1):
    train(epoch)

===> Epoch[1](1/32): Loss: 23140.2031
===> Epoch[1](2/32): Loss: 286992.8438
===> Epoch[1](3/32): Loss: 14641.0146
===> Epoch[1](4/32): Loss: 2581.7200
===> Epoch[1](5/32): Loss: 19474.5430
===> Epoch[1](6/32): Loss: 4241.9165
===> Epoch[1](7/32): Loss: 1031.5955
===> Epoch[1](8/32): Loss: 4084.9775
===> Epoch[1](9/32): Loss: 6789.3882
===> Epoch[1](10/32): Loss: 6652.1992
===> Epoch[1](11/32): Loss: 5006.8901
===> Epoch[1](12/32): Loss: 3356.2263
===> Epoch[1](13/32): Loss: 2277.9170
===> Epoch[1](14/32): Loss: 2991.7910
===> Epoch[1](15/32): Loss: 2601.0361
===> Epoch[1](16/32): Loss: 1346.8875
===> Epoch[1](17/32): Loss: 893.4781
===> Epoch[1](18/32): Loss: 1820.7660
===> Epoch[1](19/32): Loss: 3346.6521
===> Epoch[1](20/32): Loss: 4269.7983
===> Epoch[1](21/32): Loss: 3746.7119
===> Epoch[1](22/32): Loss: 2039.6263
===> Epoch[1](23/32): Loss: 1141.6293
===> Epoch[1](24/32): Loss: 1855.1998
===> Epoch[1](25/32): Loss: 2305.9873
===> Epoch[1](26/32): Loss: 1237.7905
===> Epoch[1](27/

===> Epoch[7](15/32): Loss: 359.0991
===> Epoch[7](16/32): Loss: 369.5098
===> Epoch[7](17/32): Loss: 359.5314
===> Epoch[7](18/32): Loss: 378.0017
===> Epoch[7](19/32): Loss: 409.6355
===> Epoch[7](20/32): Loss: 364.7599
===> Epoch[7](21/32): Loss: 340.8898
===> Epoch[7](22/32): Loss: 356.2024
===> Epoch[7](23/32): Loss: 396.3860
===> Epoch[7](24/32): Loss: 395.8489
===> Epoch[7](25/32): Loss: 383.8347
===> Epoch[7](26/32): Loss: 372.8183
===> Epoch[7](27/32): Loss: 389.4428
===> Epoch[7](28/32): Loss: 375.5167
===> Epoch[7](29/32): Loss: 361.6542
===> Epoch[7](30/32): Loss: 405.0774
===> Epoch[7](31/32): Loss: 342.7524
===> Epoch[7](32/32): Loss: 425.9633
===> Epoch 7 Complete: Avg.Total Loss: 388.5189 / PSNR: 22.2413 / SSIM 0.4379
===> Epoch[8](1/32): Loss: 318.6097
===> Epoch[8](2/32): Loss: 356.1222
===> Epoch[8](3/32): Loss: 356.9418
===> Epoch[8](4/32): Loss: 368.5544
===> Epoch[8](5/32): Loss: 356.2720
===> Epoch[8](6/32): Loss: 349.7952
===> Epoch[8](7/32): Loss: 376.1325
===>

In [9]:
for epoch in range(11, 21):
    train(epoch)

===> Epoch[11](1/32): Loss: 286.3771
===> Epoch[11](2/32): Loss: 312.0671
===> Epoch[11](3/32): Loss: 284.5405
===> Epoch[11](4/32): Loss: 282.8788
===> Epoch[11](5/32): Loss: 270.2509
===> Epoch[11](6/32): Loss: 271.4466
===> Epoch[11](7/32): Loss: 302.0374
===> Epoch[11](8/32): Loss: 286.5351
===> Epoch[11](9/32): Loss: 283.7441
===> Epoch[11](10/32): Loss: 272.6861
===> Epoch[11](11/32): Loss: 291.1337
===> Epoch[11](12/32): Loss: 273.2654
===> Epoch[11](13/32): Loss: 281.8201
===> Epoch[11](14/32): Loss: 295.7831
===> Epoch[11](15/32): Loss: 287.1376
===> Epoch[11](16/32): Loss: 274.7680
===> Epoch[11](17/32): Loss: 301.1264
===> Epoch[11](18/32): Loss: 287.8368
===> Epoch[11](19/32): Loss: 310.7317
===> Epoch[11](20/32): Loss: 288.5804
===> Epoch[11](21/32): Loss: 270.4119
===> Epoch[11](22/32): Loss: 294.8414
===> Epoch[11](23/32): Loss: 280.6215
===> Epoch[11](24/32): Loss: 279.7087
===> Epoch[11](25/32): Loss: 282.5111
===> Epoch[11](26/32): Loss: 289.1137
===> Epoch[11](27/32)

===> Epoch[17](10/32): Loss: 272.2487
===> Epoch[17](11/32): Loss: 301.0264
===> Epoch[17](12/32): Loss: 294.0424
===> Epoch[17](13/32): Loss: 266.9948
===> Epoch[17](14/32): Loss: 297.7023
===> Epoch[17](15/32): Loss: 281.5451
===> Epoch[17](16/32): Loss: 277.5119
===> Epoch[17](17/32): Loss: 292.3018
===> Epoch[17](18/32): Loss: 289.3124
===> Epoch[17](19/32): Loss: 286.1448
===> Epoch[17](20/32): Loss: 269.7347
===> Epoch[17](21/32): Loss: 280.9667
===> Epoch[17](22/32): Loss: 258.5630
===> Epoch[17](23/32): Loss: 268.2055
===> Epoch[17](24/32): Loss: 294.2051
===> Epoch[17](25/32): Loss: 276.0740
===> Epoch[17](26/32): Loss: 270.1336
===> Epoch[17](27/32): Loss: 277.5269
===> Epoch[17](28/32): Loss: 291.2243
===> Epoch[17](29/32): Loss: 251.6869
===> Epoch[17](30/32): Loss: 278.7131
===> Epoch[17](31/32): Loss: 269.7232
===> Epoch[17](32/32): Loss: 307.6551
===> Epoch 17 Complete: Avg.Total Loss: 280.3613 / PSNR: 23.6595 / SSIM 0.4259
===> Epoch[18](1/32): Loss: 281.8600
===> Epoch