In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from super_resolution_data_loader_resize import *
from pytorch_ssim import *

from VDSR_model import Net

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# VDSR parameters

batch_size = 10
epochs = 2
lr = 0.01
threads = 4
upscale_factor = 4
clip = 0.4

In [3]:
#img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
#img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

img_path_low = '../dataset/300W-3D-crap-56/train'
img_path_ref = '../dataset/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
print('===> Building model')
model = Net().to(device)
#criterion = nn.MSELoss()
criterion = nn.MSELoss(reduction='sum')

#optimizer = optim.Adam(model.parameters(), lr=lr)
#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

===> Building model


In [5]:
out_path = 'results/'
out_model_path = 'models/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [6]:
def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 10 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    avg_loss_batch = epoch_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 2) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'VDSR_x' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)
    
def checkpoint(epoch):
    path = out_model_path + "VDSR_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [7]:
for epoch in range(1, epochs + 1):
    train(epoch)

===> Epoch[1](1/310): Loss: 1503400.3750
===> Epoch[1](2/310): Loss: 45730035138560.0000
===> Epoch[1](3/310): Loss: 1955.9928
===> Epoch[1](4/310): Loss: 562456.3125
===> Epoch[1](5/310): Loss: 2427.4365
===> Epoch[1](6/310): Loss: 2766.4075
===> Epoch[1](7/310): Loss: 20244.9473
===> Epoch[1](8/310): Loss: 6325.9766
===> Epoch[1](9/310): Loss: 2129.4104
===> Epoch[1](10/310): Loss: 2834.1235
===> Epoch[1](11/310): Loss: 7925.2656
===> Epoch[1](12/310): Loss: 1919.3485
===> Epoch[1](13/310): Loss: 2124.1406
===> Epoch[1](14/310): Loss: 2540.4839
===> Epoch[1](15/310): Loss: 6363.8730
===> Epoch[1](16/310): Loss: 2224.4751
===> Epoch[1](17/310): Loss: 1601.7109
===> Epoch[1](18/310): Loss: 1954.0804
===> Epoch[1](19/310): Loss: 1850.3380
===> Epoch[1](20/310): Loss: 2476.3950
===> Epoch[1](21/310): Loss: 2075.2229
===> Epoch[1](22/310): Loss: 2004.1355
===> Epoch[1](23/310): Loss: 2634.8691
===> Epoch[1](24/310): Loss: 2056.5549
===> Epoch[1](25/310): Loss: 2217.0256
===> Epoch[1](26/3

===> Epoch[1](208/310): Loss: 17346.5020
===> Epoch[1](209/310): Loss: 11300.1982
===> Epoch[1](210/310): Loss: 2182.2686
===> Epoch[1](211/310): Loss: 1870.3076
===> Epoch[1](212/310): Loss: 2098.2563
===> Epoch[1](213/310): Loss: 2506.9155
===> Epoch[1](214/310): Loss: 2093.1335
===> Epoch[1](215/310): Loss: 2104.9294
===> Epoch[1](216/310): Loss: 2372.5283
===> Epoch[1](217/310): Loss: 2052.2358
===> Epoch[1](218/310): Loss: 2190.7812
===> Epoch[1](219/310): Loss: 2225.3677
===> Epoch[1](220/310): Loss: 1996.3370
===> Epoch[1](221/310): Loss: 1979.7551
===> Epoch[1](222/310): Loss: 2547.6755
===> Epoch[1](223/310): Loss: 2485.7202
===> Epoch[1](224/310): Loss: 2225.4087
===> Epoch[1](225/310): Loss: 2498.8235
===> Epoch[1](226/310): Loss: 2301.8977
===> Epoch[1](227/310): Loss: 2222.9768
===> Epoch[1](228/310): Loss: 3494.8232
===> Epoch[1](229/310): Loss: 2702.3657
===> Epoch[1](230/310): Loss: 1814.1924
===> Epoch[1](231/310): Loss: 1383.8340
===> Epoch[1](232/310): Loss: 2022.189

===> Epoch[2](103/310): Loss: 2161.6367
===> Epoch[2](104/310): Loss: 2316.4028
===> Epoch[2](105/310): Loss: 1919.2620
===> Epoch[2](106/310): Loss: 2825.9282
===> Epoch[2](107/310): Loss: 1995.8745
===> Epoch[2](108/310): Loss: 2238.9619
===> Epoch[2](109/310): Loss: 2661.0227
===> Epoch[2](110/310): Loss: 1505.9053
===> Epoch[2](111/310): Loss: 1530.5667
===> Epoch[2](112/310): Loss: 2892.8923
===> Epoch[2](113/310): Loss: 1876.2600
===> Epoch[2](114/310): Loss: 1825.7676
===> Epoch[2](115/310): Loss: 2289.2229
===> Epoch[2](116/310): Loss: 2448.9658
===> Epoch[2](117/310): Loss: 2195.3438
===> Epoch[2](118/310): Loss: 2002.2190
===> Epoch[2](119/310): Loss: 2617.7659
===> Epoch[2](120/310): Loss: 2449.3630
===> Epoch[2](121/310): Loss: 1383.0450
===> Epoch[2](122/310): Loss: 1679.2716
===> Epoch[2](123/310): Loss: 2861.7751
===> Epoch[2](124/310): Loss: 2015.1135
===> Epoch[2](125/310): Loss: 2085.5728
===> Epoch[2](126/310): Loss: 2248.8862
===> Epoch[2](127/310): Loss: 2287.3110


===> Epoch[2](308/310): Loss: 2346.3743
===> Epoch[2](309/310): Loss: 1676.0118
===> Epoch[2](310/310): Loss: 2486.2146
===> Epoch 2 Complete: Avg. Loss: 10405.6834 / PSNR: -40.1727 / SSIM 0.6511
Checkpoint saved to models/VDSR_x4_epoch_2.pth
