In [1]:
from math import log10
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from super_resolution_data_loader_resize import *
from VDSR_model import Net

torch.manual_seed(1)
device = torch.device("cuda")

In [2]:
# VDSR parameters

batch_size = 10
epochs = 500
lr = 0.01
threads = 4
upscale_factor = 4
clip = 0.4

In [3]:
img_path_low = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-56/train'
img_path_ref = '/media/angelo/DATEN/Datasets/Experiment_Masters/300W-3D-low-res-224/train'

train_set = DatasetSuperRes(img_path_low, img_path_ref)
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batch_size, shuffle=True)

In [4]:
print('===> Building model')
model = Net().to(device)
#criterion = nn.MSELoss()
criterion = nn.MSELoss(reduction='sum')

#optimizer = optim.Adam(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
#optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

===> Building model


In [5]:
def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        input_, target = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        loss = criterion(model(input_), target)
        epoch_loss += loss.item()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
    
    scheduler.step() # Decrease learning rate after 10 epochs to 10% of its value
    
    psnr_epoch = 10*log10(1/(epoch_loss / len(training_data_loader)))
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} / PSNR: {:.4f}".format(epoch, epoch_loss / len(training_data_loader), psnr_epoch))

def checkpoint(epoch):
    model_out_path = "trained_models/super_res_model_epoch_{}.pth".format(epoch)
    #torch.save(model, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

In [6]:
for epoch in range(1, epochs + 1):
    train(epoch)
    #test()    
    if epoch % 30 == 0:
        checkpoint(epoch)

===> Epoch[1](1/310): Loss: 1727650.7500
===> Epoch[1](2/310): Loss: 914960.7500
===> Epoch[1](3/310): Loss: 583179.1250
===> Epoch[1](4/310): Loss: 182701.6562
===> Epoch[1](5/310): Loss: 14385.4961
===> Epoch[1](6/310): Loss: 11691.7939
===> Epoch[1](7/310): Loss: 24788.3086
===> Epoch[1](8/310): Loss: 7413.8696
===> Epoch[1](9/310): Loss: 3771.6907
===> Epoch[1](10/310): Loss: 4512.3843
===> Epoch[1](11/310): Loss: 2875.6536
===> Epoch[1](12/310): Loss: 2429.1719
===> Epoch[1](13/310): Loss: 1336.6414
===> Epoch[1](14/310): Loss: 1757.2948
===> Epoch[1](15/310): Loss: 1866.8221
===> Epoch[1](16/310): Loss: 1802.9515
===> Epoch[1](17/310): Loss: 1378.0841
===> Epoch[1](18/310): Loss: 1051.8872
===> Epoch[1](19/310): Loss: 1475.4111
===> Epoch[1](20/310): Loss: 1263.3350
===> Epoch[1](21/310): Loss: 1570.4371
===> Epoch[1](22/310): Loss: 1083.7412
===> Epoch[1](23/310): Loss: 1220.1390
===> Epoch[1](24/310): Loss: 897.1041
===> Epoch[1](25/310): Loss: 701.9407
===> Epoch[1](26/310): L

KeyboardInterrupt: 