In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.data_loader_RGB import *
from utils.pytorch_ssim import *
from utils.loss import*

from models.FSRCNN_coord_model import Net
from facenet_pytorch import InceptionResnetV1

torch.manual_seed(1)
device = torch.device("cuda")

In [3]:
# FSRCNN parameters

batch_size = 24
epochs = 50
lr = 0.001
threads = 4
upscale_factor = 4

In [4]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '/media/angelo/DATEN/Datasets/CelebA/LR_56/train/'
img_path_ref = '/media/angelo/DATEN/Datasets/CelebA/HR/train/'

#img_path_low = '/home/jupyter/dataset/LR_56/train/'
#img_path_ref = '/home/jupyter/dataset/HR/train/'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [6]:
print('===> Building model')
model = Net().to(device)
model.weight_init(mean=0.0, std=0.2)

feature_extraction_model = InceptionResnetV1(pretrained='vggface2').eval()
face_loss = FaceIdentityLoss(feature_extraction_model).to(device)

criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

===> Building model


In [7]:
out_path = 'results/'
out_model_path = 'checkpoints/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [8]:
def train(epoch):
    epoch_loss = 0
    epoch_total_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        # MSE Loss for PSNR estimation
        mse_loss = criterion(upsampled_img, target)
        epoch_loss += mse_loss.item()
        # Face Loss
        total_loss = mse_loss + face_loss(upsampled_img, target)
        epoch_total_loss += total_loss.item()
        total_loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader),
                                                           total_loss.item()))
    
    scheduler.step() # Decrease learning rate after 15 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    avg_loss_batch = epoch_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 10) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'FSRCNN_coord_Loss_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)
    
def checkpoint(epoch):
    path = out_model_path + "FSRCNN_coord_Loss_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [None]:
epochs=10
#optimizer.param_groups[0]['lr'] = 0.001

for epoch in range(1, epochs+1):
    train(epoch)

===> Epoch[1](1/750): Loss: 1.2505
===> Epoch[1](2/750): Loss: 1.2024
===> Epoch[1](3/750): Loss: 1.1412
===> Epoch[1](4/750): Loss: 1.1344
===> Epoch[1](5/750): Loss: 1.1714
===> Epoch[1](6/750): Loss: 1.0784
===> Epoch[1](7/750): Loss: 1.1000
===> Epoch[1](8/750): Loss: 1.0817
===> Epoch[1](9/750): Loss: 1.0050
===> Epoch[1](10/750): Loss: 0.9702
===> Epoch[1](11/750): Loss: 1.0030
===> Epoch[1](12/750): Loss: 0.9445
===> Epoch[1](13/750): Loss: 0.9619
===> Epoch[1](14/750): Loss: 1.0068
===> Epoch[1](15/750): Loss: 1.0023
===> Epoch[1](16/750): Loss: 1.0191
===> Epoch[1](17/750): Loss: 0.9574
===> Epoch[1](18/750): Loss: 0.9687
===> Epoch[1](19/750): Loss: 1.0246
===> Epoch[1](20/750): Loss: 1.0027
===> Epoch[1](21/750): Loss: 0.9743
===> Epoch[1](22/750): Loss: 0.9658
===> Epoch[1](23/750): Loss: 1.0301
===> Epoch[1](24/750): Loss: 0.9685
===> Epoch[1](25/750): Loss: 1.0139
===> Epoch[1](26/750): Loss: 0.9782
===> Epoch[1](27/750): Loss: 0.9366
===> Epoch[1](28/750): Loss: 1.0183
=

===> Epoch[1](226/750): Loss: 0.7260
===> Epoch[1](227/750): Loss: 0.6853
===> Epoch[1](228/750): Loss: 0.7171
===> Epoch[1](229/750): Loss: 0.7000
===> Epoch[1](230/750): Loss: 0.6726
===> Epoch[1](231/750): Loss: 0.6502
===> Epoch[1](232/750): Loss: 0.7261
===> Epoch[1](233/750): Loss: 0.7256
===> Epoch[1](234/750): Loss: 0.7411
===> Epoch[1](235/750): Loss: 0.7706
===> Epoch[1](236/750): Loss: 0.6985
===> Epoch[1](237/750): Loss: 0.6676
===> Epoch[1](238/750): Loss: 0.6773
===> Epoch[1](239/750): Loss: 0.7465
===> Epoch[1](240/750): Loss: 0.6778
===> Epoch[1](241/750): Loss: 0.7024
===> Epoch[1](242/750): Loss: 0.6970
===> Epoch[1](243/750): Loss: 0.6705
===> Epoch[1](244/750): Loss: 0.7046
===> Epoch[1](245/750): Loss: 0.7190
===> Epoch[1](246/750): Loss: 0.6639
===> Epoch[1](247/750): Loss: 0.6673
===> Epoch[1](248/750): Loss: 0.5866
===> Epoch[1](249/750): Loss: 0.7080
===> Epoch[1](250/750): Loss: 0.6708
===> Epoch[1](251/750): Loss: 0.7031
===> Epoch[1](252/750): Loss: 0.6324
=

===> Epoch[1](448/750): Loss: 0.4925
===> Epoch[1](449/750): Loss: 0.4543
===> Epoch[1](450/750): Loss: 0.4912
===> Epoch[1](451/750): Loss: 0.5144
===> Epoch[1](452/750): Loss: 0.5165
===> Epoch[1](453/750): Loss: 0.5560
===> Epoch[1](454/750): Loss: 0.5046
===> Epoch[1](455/750): Loss: 0.4924
===> Epoch[1](456/750): Loss: 0.4368
===> Epoch[1](457/750): Loss: 0.4541
===> Epoch[1](458/750): Loss: 0.4449
===> Epoch[1](459/750): Loss: 0.4549
===> Epoch[1](460/750): Loss: 0.4796
===> Epoch[1](461/750): Loss: 0.5078
===> Epoch[1](462/750): Loss: 0.4859
===> Epoch[1](463/750): Loss: 0.4825
===> Epoch[1](464/750): Loss: 0.4393
===> Epoch[1](465/750): Loss: 0.5729
===> Epoch[1](466/750): Loss: 0.6482
===> Epoch[1](467/750): Loss: 0.5738
===> Epoch[1](468/750): Loss: 0.5952
===> Epoch[1](469/750): Loss: 0.5361
===> Epoch[1](470/750): Loss: 0.4699
===> Epoch[1](471/750): Loss: 0.4812
===> Epoch[1](472/750): Loss: 0.4943
===> Epoch[1](473/750): Loss: 0.4336
===> Epoch[1](474/750): Loss: 0.4811
=

===> Epoch[1](670/750): Loss: 0.3977
===> Epoch[1](671/750): Loss: 0.4319
===> Epoch[1](672/750): Loss: 0.3928
===> Epoch[1](673/750): Loss: 0.3952
===> Epoch[1](674/750): Loss: 0.3714
===> Epoch[1](675/750): Loss: 0.4094
===> Epoch[1](676/750): Loss: 0.3642
===> Epoch[1](677/750): Loss: 0.3636
===> Epoch[1](678/750): Loss: 0.3674
===> Epoch[1](679/750): Loss: 0.4174
===> Epoch[1](680/750): Loss: 0.4075
===> Epoch[1](681/750): Loss: 0.4088
===> Epoch[1](682/750): Loss: 0.4681
===> Epoch[1](683/750): Loss: 0.4282
===> Epoch[1](684/750): Loss: 0.4499
===> Epoch[1](685/750): Loss: 0.5140
===> Epoch[1](686/750): Loss: 0.4696
===> Epoch[1](687/750): Loss: 0.4555
===> Epoch[1](688/750): Loss: 0.4216
===> Epoch[1](689/750): Loss: 0.4158
===> Epoch[1](690/750): Loss: 0.4448
===> Epoch[1](691/750): Loss: 0.4188
===> Epoch[1](692/750): Loss: 0.4330
===> Epoch[1](693/750): Loss: 0.4288
===> Epoch[1](694/750): Loss: 0.3979
===> Epoch[1](695/750): Loss: 0.3999
===> Epoch[1](696/750): Loss: 0.3888
=

===> Epoch[2](141/750): Loss: 0.3070
===> Epoch[2](142/750): Loss: 0.3256
===> Epoch[2](143/750): Loss: 0.3759
===> Epoch[2](144/750): Loss: 0.3830
===> Epoch[2](145/750): Loss: 0.3592
===> Epoch[2](146/750): Loss: 0.3317
===> Epoch[2](147/750): Loss: 0.3099
===> Epoch[2](148/750): Loss: 0.3399
===> Epoch[2](149/750): Loss: 0.4108
===> Epoch[2](150/750): Loss: 0.3290
===> Epoch[2](151/750): Loss: 0.3925
===> Epoch[2](152/750): Loss: 0.3416
===> Epoch[2](153/750): Loss: 0.3921
===> Epoch[2](154/750): Loss: 0.3824
===> Epoch[2](155/750): Loss: 0.4077
===> Epoch[2](156/750): Loss: 0.3834
===> Epoch[2](157/750): Loss: 0.3611
===> Epoch[2](158/750): Loss: 0.3834
===> Epoch[2](159/750): Loss: 0.3464
===> Epoch[2](160/750): Loss: 0.3839
===> Epoch[2](161/750): Loss: 0.4414
===> Epoch[2](162/750): Loss: 0.4018
===> Epoch[2](163/750): Loss: 0.4253
===> Epoch[2](164/750): Loss: 0.3469
===> Epoch[2](165/750): Loss: 0.3503
===> Epoch[2](166/750): Loss: 0.3887
===> Epoch[2](167/750): Loss: 0.4198
=

===> Epoch[2](363/750): Loss: 0.3688
===> Epoch[2](364/750): Loss: 0.3198
===> Epoch[2](365/750): Loss: 0.2976
===> Epoch[2](366/750): Loss: 0.3642
===> Epoch[2](367/750): Loss: 0.3333
===> Epoch[2](368/750): Loss: 0.3674
===> Epoch[2](369/750): Loss: 0.3504
===> Epoch[2](370/750): Loss: 0.3188
===> Epoch[2](371/750): Loss: 0.3012
===> Epoch[2](372/750): Loss: 0.3097
===> Epoch[2](373/750): Loss: 0.3460
===> Epoch[2](374/750): Loss: 0.3763
===> Epoch[2](375/750): Loss: 0.3367
===> Epoch[2](376/750): Loss: 0.3419
===> Epoch[2](377/750): Loss: 0.3559
===> Epoch[2](378/750): Loss: 0.3367
===> Epoch[2](379/750): Loss: 0.3927
===> Epoch[2](380/750): Loss: 0.3858
===> Epoch[2](381/750): Loss: 0.3432
===> Epoch[2](382/750): Loss: 0.3362
===> Epoch[2](383/750): Loss: 0.3473
===> Epoch[2](384/750): Loss: 0.3478
===> Epoch[2](385/750): Loss: 0.3474
===> Epoch[2](386/750): Loss: 0.3659
===> Epoch[2](387/750): Loss: 0.3578
===> Epoch[2](388/750): Loss: 0.3551
===> Epoch[2](389/750): Loss: 0.3042
=

===> Epoch[2](585/750): Loss: 0.3372
===> Epoch[2](586/750): Loss: 0.3317
===> Epoch[2](587/750): Loss: 0.3065
===> Epoch[2](588/750): Loss: 0.3532
===> Epoch[2](589/750): Loss: 0.3089
===> Epoch[2](590/750): Loss: 0.3255
===> Epoch[2](591/750): Loss: 0.3110
===> Epoch[2](592/750): Loss: 0.3388
===> Epoch[2](593/750): Loss: 0.3356
===> Epoch[2](594/750): Loss: 0.3285
===> Epoch[2](595/750): Loss: 0.3126
===> Epoch[2](596/750): Loss: 0.3512
===> Epoch[2](597/750): Loss: 0.4078
===> Epoch[2](598/750): Loss: 0.3626
===> Epoch[2](599/750): Loss: 0.4087
===> Epoch[2](600/750): Loss: 0.3811
===> Epoch[2](601/750): Loss: 0.3566
===> Epoch[2](602/750): Loss: 0.3801
===> Epoch[2](603/750): Loss: 0.3613
===> Epoch[2](604/750): Loss: 0.2975
===> Epoch[2](605/750): Loss: 0.3435
===> Epoch[2](606/750): Loss: 0.2804
===> Epoch[2](607/750): Loss: 0.3238
===> Epoch[2](608/750): Loss: 0.3449
===> Epoch[2](609/750): Loss: 0.2978
===> Epoch[2](610/750): Loss: 0.2431
===> Epoch[2](611/750): Loss: 0.3268
=

===> Epoch[3](55/750): Loss: 0.3463
===> Epoch[3](56/750): Loss: 0.3177
===> Epoch[3](57/750): Loss: 0.3673
===> Epoch[3](58/750): Loss: 0.2830
===> Epoch[3](59/750): Loss: 0.2869
===> Epoch[3](60/750): Loss: 0.3424
===> Epoch[3](61/750): Loss: 0.3422
===> Epoch[3](62/750): Loss: 0.3162
===> Epoch[3](63/750): Loss: 0.3293
===> Epoch[3](64/750): Loss: 0.4018
===> Epoch[3](65/750): Loss: 0.3685
===> Epoch[3](66/750): Loss: 0.3287
===> Epoch[3](67/750): Loss: 0.2991
===> Epoch[3](68/750): Loss: 0.3775
===> Epoch[3](69/750): Loss: 0.3356
===> Epoch[3](70/750): Loss: 0.3376
===> Epoch[3](71/750): Loss: 0.2878
===> Epoch[3](72/750): Loss: 0.3201
===> Epoch[3](73/750): Loss: 0.3690
===> Epoch[3](74/750): Loss: 0.2658
===> Epoch[3](75/750): Loss: 0.3080
===> Epoch[3](76/750): Loss: 0.3367
===> Epoch[3](77/750): Loss: 0.3307
===> Epoch[3](78/750): Loss: 0.2929
===> Epoch[3](79/750): Loss: 0.3325
===> Epoch[3](80/750): Loss: 0.3318
===> Epoch[3](81/750): Loss: 0.2823
===> Epoch[3](82/750): Loss:

===> Epoch[3](278/750): Loss: 0.2899
===> Epoch[3](279/750): Loss: 0.3448
===> Epoch[3](280/750): Loss: 0.2842
===> Epoch[3](281/750): Loss: 0.3217
===> Epoch[3](282/750): Loss: 0.3617
===> Epoch[3](283/750): Loss: 0.2878
===> Epoch[3](284/750): Loss: 0.3229
===> Epoch[3](285/750): Loss: 0.2902
===> Epoch[3](286/750): Loss: 0.3576
===> Epoch[3](287/750): Loss: 0.3314
===> Epoch[3](288/750): Loss: 0.2779
===> Epoch[3](289/750): Loss: 0.3574
===> Epoch[3](290/750): Loss: 0.3134
===> Epoch[3](291/750): Loss: 0.3528
===> Epoch[3](292/750): Loss: 0.3744
===> Epoch[3](293/750): Loss: 0.3124
===> Epoch[3](294/750): Loss: 0.3319
===> Epoch[3](295/750): Loss: 0.3070
===> Epoch[3](296/750): Loss: 0.3767
===> Epoch[3](297/750): Loss: 0.3227
===> Epoch[3](298/750): Loss: 0.3546
===> Epoch[3](299/750): Loss: 0.3756
===> Epoch[3](300/750): Loss: 0.3603
===> Epoch[3](301/750): Loss: 0.2773
===> Epoch[3](302/750): Loss: 0.3655
===> Epoch[3](303/750): Loss: 0.3692
===> Epoch[3](304/750): Loss: 0.3257
=

===> Epoch[3](500/750): Loss: 0.3047
===> Epoch[3](501/750): Loss: 0.3196
===> Epoch[3](502/750): Loss: 0.3015
===> Epoch[3](503/750): Loss: 0.3063
===> Epoch[3](504/750): Loss: 0.3447
===> Epoch[3](505/750): Loss: 0.2993
===> Epoch[3](506/750): Loss: 0.2910
===> Epoch[3](507/750): Loss: 0.2667
===> Epoch[3](508/750): Loss: 0.3364
===> Epoch[3](509/750): Loss: 0.2927
===> Epoch[3](510/750): Loss: 0.3249
===> Epoch[3](511/750): Loss: 0.3441
===> Epoch[3](512/750): Loss: 0.3598
===> Epoch[3](513/750): Loss: 0.3067
===> Epoch[3](514/750): Loss: 0.3161
===> Epoch[3](515/750): Loss: 0.3193
===> Epoch[3](516/750): Loss: 0.3234
===> Epoch[3](517/750): Loss: 0.2664
===> Epoch[3](518/750): Loss: 0.2786
===> Epoch[3](519/750): Loss: 0.3017
===> Epoch[3](520/750): Loss: 0.3062
===> Epoch[3](521/750): Loss: 0.3754
===> Epoch[3](522/750): Loss: 0.3043
===> Epoch[3](523/750): Loss: 0.3384
===> Epoch[3](524/750): Loss: 0.3731
===> Epoch[3](525/750): Loss: 0.2856
===> Epoch[3](526/750): Loss: 0.3250
=

===> Epoch[3](722/750): Loss: 0.2762
===> Epoch[3](723/750): Loss: 0.3285
===> Epoch[3](724/750): Loss: 0.3048
===> Epoch[3](725/750): Loss: 0.3092
===> Epoch[3](726/750): Loss: 0.2885
===> Epoch[3](727/750): Loss: 0.3099
===> Epoch[3](728/750): Loss: 0.4110
===> Epoch[3](729/750): Loss: 0.2779
===> Epoch[3](730/750): Loss: 0.3041
===> Epoch[3](731/750): Loss: 0.2869
===> Epoch[3](732/750): Loss: 0.3545
===> Epoch[3](733/750): Loss: 0.3513
===> Epoch[3](734/750): Loss: 0.2970
===> Epoch[3](735/750): Loss: 0.3399
===> Epoch[3](736/750): Loss: 0.2974
===> Epoch[3](737/750): Loss: 0.2883
===> Epoch[3](738/750): Loss: 0.3646
===> Epoch[3](739/750): Loss: 0.3515
===> Epoch[3](740/750): Loss: 0.2842
===> Epoch[3](741/750): Loss: 0.2960
===> Epoch[3](742/750): Loss: 0.3351
===> Epoch[3](743/750): Loss: 0.3256
===> Epoch[3](744/750): Loss: 0.3402
===> Epoch[3](745/750): Loss: 0.3189
===> Epoch[3](746/750): Loss: 0.2962
===> Epoch[3](747/750): Loss: 0.3320
===> Epoch[3](748/750): Loss: 0.3570
=

===> Epoch[4](193/750): Loss: 0.2981
===> Epoch[4](194/750): Loss: 0.2916
===> Epoch[4](195/750): Loss: 0.3056
===> Epoch[4](196/750): Loss: 0.2684
===> Epoch[4](197/750): Loss: 0.3072
===> Epoch[4](198/750): Loss: 0.2816
===> Epoch[4](199/750): Loss: 0.3350
===> Epoch[4](200/750): Loss: 0.3344
===> Epoch[4](201/750): Loss: 0.2725
===> Epoch[4](202/750): Loss: 0.2573
===> Epoch[4](203/750): Loss: 0.3248
===> Epoch[4](204/750): Loss: 0.2998
===> Epoch[4](205/750): Loss: 0.2427
===> Epoch[4](206/750): Loss: 0.3234
===> Epoch[4](207/750): Loss: 0.3552
===> Epoch[4](208/750): Loss: 0.2923
===> Epoch[4](209/750): Loss: 0.3171
===> Epoch[4](210/750): Loss: 0.3532
===> Epoch[4](211/750): Loss: 0.2779
===> Epoch[4](212/750): Loss: 0.3173
===> Epoch[4](213/750): Loss: 0.2719
===> Epoch[4](214/750): Loss: 0.3195
===> Epoch[4](215/750): Loss: 0.3363
===> Epoch[4](216/750): Loss: 0.3220
===> Epoch[4](217/750): Loss: 0.3019
===> Epoch[4](218/750): Loss: 0.2906
===> Epoch[4](219/750): Loss: 0.2852
=