In [2]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model_super_resolution import Net
from super_resolution_data_loader import *

torch.manual_seed(1)
device = torch.device("cuda")

In [8]:
# SubPixelCNN parameters

batch_size = 64
epochs = 500
lr = 0.01
threads = 4
upscale_factor = 4

In [12]:
img_path_low = 'E:\\Datasets\\3DFaces\\300W-3D-ALL\\images-low-res'
img_path_ref = 'E:\\Datasets\\3DFaces\\300W-3D-ALL\\images-128'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [13]:
print('===> Building model')
model = Net(upscale_factor=upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

===> Building model


In [14]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        loss = criterion(model(input_), target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        #print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 100 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f}".format(epoch, epoch_loss / len(training_data_loader), psnr_epoch))

'''
def test():
    avg_psnr = 0
    with torch.no_grad():
        for batch in testing_data_loader:
            input, target = batch[0].to(device), batch[1].to(device)

            prediction = model(input)
            mse = criterion(prediction, target)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))
'''

def checkpoint(epoch):
    model_out_path = "trained_models/super_res_model_epoch_{}.pth".format(epoch)
    torch.save(model, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

In [15]:
for epoch in range(1, epochs + 1):
    train(epoch)
    #test()    
    if epoch % 30 == 0:
        checkpoint(epoch)

===> Epoch 1 Complete: Avg. Loss: 0.6513 / PSNR: 1.8624
===> Epoch 2 Complete: Avg. Loss: 0.0056 / PSNR: 22.4925
===> Epoch 3 Complete: Avg. Loss: 0.0038 / PSNR: 24.1902
===> Epoch 4 Complete: Avg. Loss: 0.0030 / PSNR: 25.2540
===> Epoch 5 Complete: Avg. Loss: 0.0023 / PSNR: 26.4751
===> Epoch 6 Complete: Avg. Loss: 0.0020 / PSNR: 27.0674
===> Epoch 7 Complete: Avg. Loss: 0.0017 / PSNR: 27.6386
===> Epoch 8 Complete: Avg. Loss: 0.0016 / PSNR: 27.8619
===> Epoch 9 Complete: Avg. Loss: 0.0015 / PSNR: 28.1679
===> Epoch 10 Complete: Avg. Loss: 0.0015 / PSNR: 28.3521
===> Epoch 11 Complete: Avg. Loss: 0.0014 / PSNR: 28.5427
===> Epoch 12 Complete: Avg. Loss: 0.0013 / PSNR: 28.7497
===> Epoch 13 Complete: Avg. Loss: 0.0014 / PSNR: 28.6003
===> Epoch 14 Complete: Avg. Loss: 0.0011 / PSNR: 29.4441
===> Epoch 15 Complete: Avg. Loss: 0.0033 / PSNR: 24.7815
===> Epoch 16 Complete: Avg. Loss: 0.0017 / PSNR: 27.6110
===> Epoch 17 Complete: Avg. Loss: 0.0012 / PSNR: 29.2190
===> Epoch 18 Complete: 