In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from super_resolution_data_loader_resize import *
from pytorch_ssim import *

from SRCNN_coord_conv import Net

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# SRCNN CoordConv parameters

batch_size = 32
epochs = 10
lr = 0.01
threads = 4
upscale_factor = 4

In [3]:
img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

#img_path_low = '../dataset/300W-3D-crap-56/train'
#img_path_ref = '../dataset/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
print('===> Building model')
model = Net().to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

===> Building model


In [5]:
out_path = 'results/'
out_model_path = 'models/'

if not os.path.exists(out_path):
    os.makedirs(out_path)    

if not os.path.exists(out_model_path):
    os.makedirs(out_model_path)   
    
results = {'avg_loss': [], 'psnr': [], 'ssim': []}

In [6]:
def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        upsampled_img = model(input_)
        loss = criterion(upsampled_img, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 100 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    ssim_epoch = ssim(upsampled_img, target).item()
    avg_loss_batch = epoch_loss/len(training_data_loader)
    
    results['psnr'].append(psnr_epoch)
    results['ssim'].append(ssim_epoch)
    results['avg_loss'].append(avg_loss_batch)
    
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f} / SSIM {:.4f}".format(epoch, 
                                                                                          avg_loss_batch, 
                                                                                          psnr_epoch,
                                                                                          ssim_epoch))
    if epoch % (epochs // 2) == 0:
    
        data_frame = pd.DataFrame(
                data={'Avg. Loss': results['avg_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))

        data_frame.to_csv(out_path + 'SRCNN_coordx' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
        
        checkpoint(epoch)
    
def checkpoint(epoch):
    path = out_model_path + "SRCNN_coord_x{}_epoch_{}.pth".format(upscale_factor, epoch)
    torch.save(model, path)
    print("Checkpoint saved to {}".format(path))

In [7]:
for epoch in range(1, epochs + 1):
    train(epoch)

===> Epoch[1](1/97): Loss: 0.0996
===> Epoch[1](2/97): Loss: 12.3834
===> Epoch[1](3/97): Loss: 0.4383
===> Epoch[1](4/97): Loss: 0.1461
===> Epoch[1](5/97): Loss: 0.1600
===> Epoch[1](6/97): Loss: 0.1464
===> Epoch[1](7/97): Loss: 0.0857
===> Epoch[1](8/97): Loss: 0.0660
===> Epoch[1](9/97): Loss: 0.0578
===> Epoch[1](10/97): Loss: 0.0711
===> Epoch[1](11/97): Loss: 0.0487
===> Epoch[1](12/97): Loss: 0.0333
===> Epoch[1](13/97): Loss: 0.0375
===> Epoch[1](14/97): Loss: 0.0462
===> Epoch[1](15/97): Loss: 0.0462
===> Epoch[1](16/97): Loss: 0.0367
===> Epoch[1](17/97): Loss: 0.0208
===> Epoch[1](18/97): Loss: 0.0269
===> Epoch[1](19/97): Loss: 0.0283
===> Epoch[1](20/97): Loss: 0.0332
===> Epoch[1](21/97): Loss: 0.0272
===> Epoch[1](22/97): Loss: 0.0242
===> Epoch[1](23/97): Loss: 0.0191
===> Epoch[1](24/97): Loss: 0.0193
===> Epoch[1](25/97): Loss: 0.0199
===> Epoch[1](26/97): Loss: 0.0173
===> Epoch[1](27/97): Loss: 0.0175
===> Epoch[1](28/97): Loss: 0.0140
===> Epoch[1](29/97): Loss: 

===> Epoch[3](38/97): Loss: 0.0034
===> Epoch[3](39/97): Loss: 0.0035
===> Epoch[3](40/97): Loss: 0.0031
===> Epoch[3](41/97): Loss: 0.0037
===> Epoch[3](42/97): Loss: 0.0033
===> Epoch[3](43/97): Loss: 0.0039
===> Epoch[3](44/97): Loss: 0.0033
===> Epoch[3](45/97): Loss: 0.0029
===> Epoch[3](46/97): Loss: 0.0038
===> Epoch[3](47/97): Loss: 0.0037
===> Epoch[3](48/97): Loss: 0.0035
===> Epoch[3](49/97): Loss: 0.0032
===> Epoch[3](50/97): Loss: 0.0031
===> Epoch[3](51/97): Loss: 0.0030
===> Epoch[3](52/97): Loss: 0.0032
===> Epoch[3](53/97): Loss: 0.0037
===> Epoch[3](54/97): Loss: 0.0035
===> Epoch[3](55/97): Loss: 0.0032
===> Epoch[3](56/97): Loss: 0.0031
===> Epoch[3](57/97): Loss: 0.0034
===> Epoch[3](58/97): Loss: 0.0027
===> Epoch[3](59/97): Loss: 0.0031
===> Epoch[3](60/97): Loss: 0.0034
===> Epoch[3](61/97): Loss: 0.0029
===> Epoch[3](62/97): Loss: 0.0033
===> Epoch[3](63/97): Loss: 0.0035
===> Epoch[3](64/97): Loss: 0.0045
===> Epoch[3](65/97): Loss: 0.0050
===> Epoch[3](66/97)

===> Epoch[5](75/97): Loss: 0.0022
===> Epoch[5](76/97): Loss: 0.0026
===> Epoch[5](77/97): Loss: 0.0023
===> Epoch[5](78/97): Loss: 0.0022
===> Epoch[5](79/97): Loss: 0.0024
===> Epoch[5](80/97): Loss: 0.0022
===> Epoch[5](81/97): Loss: 0.0025
===> Epoch[5](82/97): Loss: 0.0025
===> Epoch[5](83/97): Loss: 0.0025
===> Epoch[5](84/97): Loss: 0.0028
===> Epoch[5](85/97): Loss: 0.0022
===> Epoch[5](86/97): Loss: 0.0026
===> Epoch[5](87/97): Loss: 0.0025
===> Epoch[5](88/97): Loss: 0.0022
===> Epoch[5](89/97): Loss: 0.0025
===> Epoch[5](90/97): Loss: 0.0027
===> Epoch[5](91/97): Loss: 0.0028
===> Epoch[5](92/97): Loss: 0.0022
===> Epoch[5](93/97): Loss: 0.0024
===> Epoch[5](94/97): Loss: 0.0019
===> Epoch[5](95/97): Loss: 0.0022
===> Epoch[5](96/97): Loss: 0.0021
===> Epoch[5](97/97): Loss: 0.0024
===> Epoch 5 Complete: Avg. Loss: 0.0024 / PSNR: 26.1546 / SSIM 0.7505
Checkpoint saved to models/SRCNN_coord_x4_epoch_5.pth
===> Epoch[6](1/97): Loss: 0.0022
===> Epoch[6](2/97): Loss: 0.0024
==

===> Epoch[8](12/97): Loss: 0.0027
===> Epoch[8](13/97): Loss: 0.0023
===> Epoch[8](14/97): Loss: 0.0025
===> Epoch[8](15/97): Loss: 0.0023
===> Epoch[8](16/97): Loss: 0.0022
===> Epoch[8](17/97): Loss: 0.0025
===> Epoch[8](18/97): Loss: 0.0021
===> Epoch[8](19/97): Loss: 0.0020
===> Epoch[8](20/97): Loss: 0.0024
===> Epoch[8](21/97): Loss: 0.0028
===> Epoch[8](22/97): Loss: 0.0026
===> Epoch[8](23/97): Loss: 0.0023
===> Epoch[8](24/97): Loss: 0.0019
===> Epoch[8](25/97): Loss: 0.0026
===> Epoch[8](26/97): Loss: 0.0022
===> Epoch[8](27/97): Loss: 0.0021
===> Epoch[8](28/97): Loss: 0.0023
===> Epoch[8](29/97): Loss: 0.0025
===> Epoch[8](30/97): Loss: 0.0025
===> Epoch[8](31/97): Loss: 0.0022
===> Epoch[8](32/97): Loss: 0.0019
===> Epoch[8](33/97): Loss: 0.0026
===> Epoch[8](34/97): Loss: 0.0024
===> Epoch[8](35/97): Loss: 0.0023
===> Epoch[8](36/97): Loss: 0.0023
===> Epoch[8](37/97): Loss: 0.0022
===> Epoch[8](38/97): Loss: 0.0021
===> Epoch[8](39/97): Loss: 0.0028
===> Epoch[8](40/97)

===> Epoch[10](48/97): Loss: 0.0028
===> Epoch[10](49/97): Loss: 0.0024
===> Epoch[10](50/97): Loss: 0.0025
===> Epoch[10](51/97): Loss: 0.0022
===> Epoch[10](52/97): Loss: 0.0021
===> Epoch[10](53/97): Loss: 0.0025
===> Epoch[10](54/97): Loss: 0.0027
===> Epoch[10](55/97): Loss: 0.0023
===> Epoch[10](56/97): Loss: 0.0021
===> Epoch[10](57/97): Loss: 0.0024
===> Epoch[10](58/97): Loss: 0.0024
===> Epoch[10](59/97): Loss: 0.0025
===> Epoch[10](60/97): Loss: 0.0023
===> Epoch[10](61/97): Loss: 0.0024
===> Epoch[10](62/97): Loss: 0.0023
===> Epoch[10](63/97): Loss: 0.0026
===> Epoch[10](64/97): Loss: 0.0023
===> Epoch[10](65/97): Loss: 0.0022
===> Epoch[10](66/97): Loss: 0.0027
===> Epoch[10](67/97): Loss: 0.0023
===> Epoch[10](68/97): Loss: 0.0024
===> Epoch[10](69/97): Loss: 0.0020
===> Epoch[10](70/97): Loss: 0.0022
===> Epoch[10](71/97): Loss: 0.0023
===> Epoch[10](72/97): Loss: 0.0023
===> Epoch[10](73/97): Loss: 0.0022
===> Epoch[10](74/97): Loss: 0.0025
===> Epoch[10](75/97): Loss: