In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.data_loader_RGB import *
from utils.pytorch_ssim import *

from models.FSRCNN_model import Net

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# FSRCNN parameters

batch_size = 32
epochs = 50
lr = 0.001
threads = 4
upscale_factor = 4

In [6]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

#img_path_low = '/media/angelo/DATEN/Datasets/CelebA/LR_112/train/'
#img_path_ref = '/media/angelo/DATEN/Datasets/CelebA/HR/train/'

img_path_low = '../LR_56/train/'
img_path_ref = '../HR/train/'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [7]:
print('===> Building model')
model = Net().to(device)
model.weight_init(mean=0.0, std=0.2)

criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

===> Building model


In [8]:
out_path = 'results/'
out_model_path = 'checkpoints/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [9]:
def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        #print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 15 epochs to 20% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    avg_loss_batch = epoch_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 10) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'FSRCNN_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)
    
def checkpoint(epoch):
    path = out_model_path + "FSRCNN_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [10]:
#epochs=10
#optimizer.param_groups[0]['lr'] = 0.001

for epoch in range(1, epochs+1):
    train(epoch)

===> Epoch 1 Complete: Avg. Loss: 0.0143 / PSNR: 18.4558 / SSIM 0.7331
===> Epoch 2 Complete: Avg. Loss: 0.0051 / PSNR: 22.9456 / SSIM 0.7703
===> Epoch 3 Complete: Avg. Loss: 0.0037 / PSNR: 24.3568 / SSIM 0.7717
===> Epoch 4 Complete: Avg. Loss: 0.0034 / PSNR: 24.6761 / SSIM 0.7860
===> Epoch 5 Complete: Avg. Loss: 0.0032 / PSNR: 24.9615 / SSIM 0.7678
Checkpoint saved to checkpoints/FSRCNN_x2_epoch_5.pth
===> Epoch 6 Complete: Avg. Loss: 0.0031 / PSNR: 25.0732 / SSIM 0.7891
===> Epoch 7 Complete: Avg. Loss: 0.0030 / PSNR: 25.1762 / SSIM 0.7911
===> Epoch 8 Complete: Avg. Loss: 0.0029 / PSNR: 25.3187 / SSIM 0.7882
===> Epoch 9 Complete: Avg. Loss: 0.0029 / PSNR: 25.3946 / SSIM 0.7799
===> Epoch 10 Complete: Avg. Loss: 0.0028 / PSNR: 25.4616 / SSIM 0.7845
Checkpoint saved to checkpoints/FSRCNN_x2_epoch_10.pth
===> Epoch 11 Complete: Avg. Loss: 0.0028 / PSNR: 25.4943 / SSIM 0.7585
===> Epoch 12 Complete: Avg. Loss: 0.0027 / PSNR: 25.7472 / SSIM 0.8073
===> Epoch 13 Complete: Avg. Loss: 0