In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.data_loader_YCbCr import *
from SUB_CNN_model import Net
from utils.pytorch_ssim import *

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SubPixelCNN parameters

batch_size = 64
epochs = 10
lr = 0.01
threads = 4
upscale_factor = 4

In [None]:
#img_path_low = '../media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '../dataset/300W-3D-crap-56/train'
img_path_ref = '../dataset/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
print('===> Building model')
model = Net(upscale_factor=upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

===> Building model


In [6]:
out_path = 'results/'
out_model_path = 'models/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [9]:
def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        #print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 100 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    avg_loss_batch = epoch_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 2) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'SubCNN_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)

def checkpoint(epoch):
    path = out_model_path + "SubCNN_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [10]:
for epoch in range(1, epochs + 1):
    train(epoch)

===> Epoch 1 Complete: Avg. Loss: 0.0064 / PSNR: 21.9451 / SSIM 0.5926
===> Epoch 2 Complete: Avg. Loss: 0.0051 / PSNR: 22.9294 / SSIM 0.6458
===> Epoch 3 Complete: Avg. Loss: 0.0045 / PSNR: 23.4286 / SSIM 0.6882
===> Epoch 4 Complete: Avg. Loss: 0.0042 / PSNR: 23.7899 / SSIM 0.6789
===> Epoch 5 Complete: Avg. Loss: 0.0041 / PSNR: 23.9181 / SSIM 0.6945
Checkpoint saved to models/SubCNN_x4_epoch_5.pth
===> Epoch 6 Complete: Avg. Loss: 0.0040 / PSNR: 23.9775 / SSIM 0.7059
===> Epoch 7 Complete: Avg. Loss: 0.0040 / PSNR: 24.0232 / SSIM 0.7228
===> Epoch 8 Complete: Avg. Loss: 0.0041 / PSNR: 23.9042 / SSIM 0.6912
===> Epoch 9 Complete: Avg. Loss: 0.0038 / PSNR: 24.1668 / SSIM 0.7020
===> Epoch 10 Complete: Avg. Loss: 0.0038 / PSNR: 24.1591 / SSIM 0.7052
Checkpoint saved to models/SubCNN_x4_epoch_10.pth
