In [None]:
from __future__ import print_function
import argparse
from math import log10

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.backends.cudnn as cudnn
import csv
import torchmetrics.image as imq
import torchmetrics as tm
from custom_losses import *
from custom_metrics import *
from torch.autograd import Variable
from torch.utils.data import DataLoader
from train_utils import *
from dbpn import BasicNet as BDBPN
from data import get_training_set
from dataset import *
import pdb
import socket
import time

#### Model Configuration Settings

In [None]:
# Number of GPUs
gpus=1
# Whether to use GPU or not
gpu_mode=True
# Random seed to use.
seed=123
# Super resolution upscale factor
upscale_factor=2
# Whether to use pretrained model state
pretrained=False
# Location to the pretrained models checkpoints
save_folder='models/Pretrained/'
#Meta data folder
meta_folder='Meta-folder/'
# Name of the pretrained SR model to load')
pretrained_sr='DBPNccp2x-check_epoch_3.pth'
# Type of the model to use ', type=str, default='DBPNLL')
model_type="DBPN"
# Use the below option only if residual learning is desired
residual=False
# parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch')
start_iter=1
# parser.add_argument('--nEpochs', type=int, default=2000, help='number of epochs to train for')
nEpochs=100
# 'Frequency of storing model checkpoints (after how many epoch hops)
snapshots=5
# parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=0.01')
lrate=1e-4
# Descriptive Name of model checkpoint
prefix='CCPs2x'

In [None]:
gpus_list = range(gpus)
hostname = str(socket.gethostname())
cudnn.benchmark = True
# print(opt)

cuda = gpu_mode
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

## Loading the Training and Validation Dataset

In [None]:
print('===> Loading datasets')

#Select the root directory
root='/BioSR/Training/CCPs/'

data_transform = transforms.Compose([transforms.ToTensor()])

# Load the training dataset

train_set = ImagePairTrainDataset(root_dir=root,norm_flag=1)
training_data_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=1, shuffle=True)

# Load the Validation dataset

val_set = ImagePairValidationDataset(root_dir=root,norm_flag=1)
val_data_loader = DataLoader(dataset=val_set, num_workers=1, batch_size=1, shuffle=True)

### Model Architecture

In [None]:
print('===> Building model ', model_type)

if model_type == 'DBPN':
    model = BDBPN(num_channels=1, base_filter=64,  feat = 256, num_stages=7, scale_factor=upscale_factor) 
    
model = torch.nn.DataParallel(model, device_ids=gpus_list)
#criterion = nn.L1Loss()
criterion = FDL()

print('---------- Networks architecture -------------')
#print_network(model)
print('----------------------------------------------')

In [None]:
if cuda:
    model = model.cuda(gpus_list[0])
    criterion = criterion.cuda(gpus_list[0])

optimizer = optim.Adam(model.parameters(), lr=lrate, betas=(0.9, 0.999), eps=1e-8)

if pretrained:
    model_name = os.path.join(save_folder + pretrained_sr)
    if os.path.exists(model_name):
        #model= torch.load(model_name, map_location=lambda storage, loc: storage)
        chk= torch.load(model_name, map_location=lambda storage, loc: storage)
        #model.load_state_dict(torch.load(model_name, map_location=lambda storage, loc: storage))
        model.load_state_dict(chk['model_state_dict'])
        print('Pre-trained SR model is loaded.')
        #optimizer.load_state_dict(torch.load(model_name, map_location=lambda storage, loc: storage))
        optimizer.load_state_dict(chk['optimizer_state_dict'])
        print('Stored optimizer state is loaded.')

def checkpoint(epoch):
    model_out_path = save_folder+model_type+prefix+"_epoch_{}.pth".format(epoch)
    torch.save({'model_state_dict': model.state_dict(),\
    'optimizer_state_dict': optimizer.state_dict()}, model_out_path)
    #torch.save( PATH)
    print("Checkpoint saved to {}".format(model_out_path))

def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        # Use the below line only if residual learning is desired
        #input, target, bicubic = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda(gpus_list[0])
            target = target.cuda(gpus_list[0])
            # Use the below line only if residual learning is desired
            # bicubic = bicubic.cuda(gpus_list[0])

        optimizer.zero_grad()
        t0 = time.time()
        prediction = model(input)

        if residual:
            prediction = prediction + bicubic

        loss = criterion(prediction, target)
        t1 = time.time()
        epoch_loss += loss.data
        loss.backward()
        optimizer.step()

        #print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Timer: {:.4f} sec.".format(epoch, iteration, len(training_data_loader), loss.data, (t1 - t0)))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

In [None]:
def train(epoch):
    epoch_train_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        # Use the below line only if residual learning is desired
        #input, target, bicubic = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
        input, target = Variable(batch[0]), Variable(batch[1])
        if cuda:
            input = input.cuda(gpus_list[0])
            target = target.cuda(gpus_list[0])
            # Use the below line only if residual learning is desired
            # bicubic = bicubic.cuda(gpus_list[0])

        optimizer.zero_grad()
        t0 = time.time()
        prediction = model(input)

        if residual:
            prediction = prediction + bicubic

        loss = criterion(prediction, target)
        t1 = time.time()
        epoch_train_loss += loss.data
        loss.backward()
        optimizer.step()

        #print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Timer: {:.4f} sec.".format(epoch, iteration, len(training_data_loader), loss.data, (t1 - t0)))

    #####################################################Validation Loss##################################################################
    model.eval()
    epoch_val_loss=[]
    epoch_mse=[]
    epoch_psnr=[]
    epoch_msssim=[]
    epoch_nrmse=[]
    with torch.no_grad():
        for iteration, batch in enumerate(val_data_loader,1):
            input,target=Variable(batch[0]), Variable(batch[1])
            if cuda:
                input=input.cuda(gpus_list[0])
                target=target.cuda(gpus_list[0])
            
            prediction = model(input)
            loss=criterion(prediction,target)
            epoch_val_loss.append(loss.data)
            epoch_mse.append(tm.MeanSquaredError().cuda(gpus_list[0])(percentile_normalize(prediction),target))
            #epoch_psnr.append(imq.PeakSignalNoiseRatio().cuda(gpus_list[0])(prediction,target))
            epoch_psnr.append(calcPSNR(percentile_normalize(prediction),percentile_normalize(target)))
            epoch_msssim.append(imq.StructuralSimilarityIndexMeasure().cuda(gpus_list[0])(percentile_normalize(prediction),target))
            epoch_nrmse.append(NormalizedRootMeanSquaredError(normalization='l2').cuda(gpus_list[0])(prediction,target))
            
    print("===> Epoch {} Complete: Avg. Training Loss: {:.4f}, Avg. Validation Loss: {:.4f}, Avg. MSE: {:.4f}, Avg. PSNR: {:.4f}\
          , Avg. SSIM: {:.4f}, Avg. NRMSE: {:.4f}".\
          format(epoch, epoch_train_loss / len(training_data_loader),sum(epoch_val_loss)/len(val_data_loader), \
                 sum(epoch_mse)/len(val_data_loader), sum(epoch_psnr)/len(val_data_loader),\
                    sum(epoch_msssim)/len(val_data_loader),sum(epoch_nrmse)/len(val_data_loader)))
    return epoch_train_loss / len(training_data_loader),sum(epoch_val_loss)/len(val_data_loader), sum(epoch_mse)/len(val_data_loader),\
                 sum(epoch_psnr)/len(val_data_loader),sum(epoch_msssim)/len(val_data_loader),sum(epoch_nrmse)/len(val_data_loader)

In [None]:
train_losses=[]
val_losses=[]
mses=[]
psnrs=[]
msssims=[]
nrmses=[]

epoch_train_loss=0
epoch_val_loss=0
epoch_mse=0
epoch_psnr=0
epoch_msssim=0
epoch_nrmse=0
for epoch in range(start_iter, nEpochs + 1):
    print("Starting Epoch {}".format(epoch))
    epoch_train_loss, epoch_val_loss, epoch_mse, epoch_psnr, epoch_msssim, epoch_nrmse=train(epoch)
    print("Epoch {} Completed".format(epoch))
    train_losses.append(epoch_train_loss)
    val_losses.append(epoch_val_loss)
    mses.append(epoch_mse)
    psnrs.append(epoch_psnr)
    msssims.append(epoch_msssim)
    nrmses.append(epoch_nrmse)
    # learning rate is decayed by a factor of 10 every half of total epochs
    if (epoch+1) % (nEpochs/2) == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10.0
        print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
            
    if (epoch+1) % (snapshots) == 0:
        checkpoint(epoch)


In [None]:
#checkpoint(epoch)
with open(meta_folder+model_type+prefix+'losses_metrics.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Epoch_Number', 'Training_Loss', 'Validation_Loss', 'MSE', 'PSNR', 'MSSSIM', 'NRMSE'])
    for epoch, (t_loss, v_loss,mse,psnr,msssim,nrmse) in \
        enumerate(zip(train_losses, val_losses,mses, psnrs, msssims, nrmses), 1):
        writer.writerow([epoch, t_loss, v_loss, mse,psnr,msssim,nrmse])