## Required functions
Run at the beginning.

In [None]:
import os, shutil, glob, warnings, os
import argparse
import sys
from datetime import datetime

import matplotlib.pyplot as plt
from torchvision import transforms, utils
import torchvision
from skimage import exposure, color, io, img_as_float, img_as_ubyte
from skimage.util import view_as_windows
from PIL import Image, ImageFilter

import pandas as pd
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict

import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.optim.lr_scheduler as lr_scheduler

import datasets_sisr as data
import srmodels as models

def data_loader(model_config, csv='train'):
    if csv=='train':
        transformed_dataset = data.Pair_Dataset(csv_file=data.compress_csv_path(csv),
                                                   transform=data.Compose([
                                                       data.Recale(model_config["resolution"]),
                                                       data.ToTensor()
                                               ]))
        dataloader = DataLoader(transformed_dataset, batch_size=model_config["batch-size"], shuffle=True, num_workers=model_config["threads"])
    if csv=='valid':
        transformed_dataset = data.Pair_Dataset(csv_file=data.compress_csv_path(csv),
                                                   transform=data.Compose([
                                                       data.Recale(model_config["resolution"]),
                                                       data.ToTensor()
                                               ]))
        dataloader = DataLoader(transformed_dataset, batch_size=model_config["batch-size"], shuffle=False, num_workers=model_config["threads"])
    return dataloader

def train(model_config, epoch, run, generator, dataloader, criterion, optimizer):
    epoch_loss = 0
    generator.train()
    device = next(generator.parameters()).device
    for iteration, batch in enumerate(dataloader):
        img_input = Variable(batch['input'].float().to(device), requires_grad=False)
        img_target = Variable(batch['output'].float().to(device), requires_grad=False)   
        optimizer.zero_grad()       
        img_output = generator(img_input)
        loss = criterion(img_output, img_target)               
        loss.backward()
        optimizer.step()
        epoch_loss = epoch_loss + loss.item()
        
        sys.stdout.write('\r[%d/%d][%d/%d] Generator_L1_Loss: %.4f' 
                             % (epoch, model_config["epochs"], iteration, len(dataloader), loss.item()))
    print("\n ===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(dataloader)))
    
    g_path = os.path.join('weights', run, 'generator.pth')
    os.makedirs(os.path.join('weights', run), exist_ok=True)
    torch.save(generator.state_dict(), g_path)

def test(generator, dataloader, criterion):
    device = next(generator.parameters()).device
    with torch.no_grad():
        epoch_loss = 0
        generator.eval()
        for iteration, batch in enumerate(dataloader):
            img_input = Variable(batch['input'].float().to(device), requires_grad=False)
            img_target = Variable(batch['output'].float().to(device), requires_grad=False)
            img_output = generator(img_input)
            epoch_loss = epoch_loss + criterion(img_output, img_target).item()
        return epoch_loss / len(dataloader)
    
def print_output(trained_model, dataloader_valid):
    device = next(trained_model.parameters()).device
    trained_model.eval()
    os.makedirs('print-sr', exist_ok=True)
    with torch.no_grad():      
        print("===> 8x:")
        for iteration, batch in enumerate(dataloader_valid):     
            input, target = batch['input'].to(device), batch['output'].to(device)
            imgs_input = Variable(input.type(Tensor))
            prediction = trained_model(imgs_input)
            target = target.float()    
            imgs_input = imgs_input[:, :, :, :]
            prediction = prediction[:, :, :, :]
            target = target[:, :, :, :]
            plt.figure(figsize=(20, 6))
            grid = utils.make_grid(imgs_input).cpu()
            utils.save_image(grid, 'print-sr/input.tif')
            input_downsampled = grid.numpy().transpose((1, 2, 0))
            plt.imshow(input_downsampled, interpolation='bicubic')
            plt.axis('off') 
            plt.figure(figsize=(20, 6))
            grid = utils.make_grid(prediction).cpu()
            utils.save_image(grid, 'print-sr/output.tif')
            prediction = np.clip(grid.numpy().transpose((1, 2, 0)), 0, 1)
            plt.imshow(prediction, interpolation='bicubic')
            plt.axis('off')    
            plt.figure(figsize=(20, 6))
            grid = utils.make_grid(target).cpu()
            utils.save_image(grid, 'print-sr/target.tif')
            target = grid.numpy().transpose((1, 2, 0))
            plt.imshow(target, interpolation='bicubic')
            plt.axis('off')                                  
            break

## Training and model configurations
Run after adjustment.

In [None]:
model_config = {
    "image-channel" : 1,
    "batch-size" : 32,
    "epochs" : 100,
    "learning-rate" : 0.0002,
    "resolution" : 512,
    "run-from" : None,
    "cnn-base-channel" : 8,
    "normalization" : "batch",
    "gpu" : True,
    "threads" : 4,
    "test-interval" : 1
}

## Preview examples from the dataset
Wait for images to be printed in a grid.

In [None]:
if model_config["gpu"]:
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
data.generate_compress_csv()
valid_dataset = data_loader(model_config, 'valid')
data.show_patch(valid_dataset, 0)
generator = models.Generator(model_config["image-channel"], base_channel=model_config["cnn-base-channel"], norm=model_config["normalization"])
generator.to(device);
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(generator.parameters(), lr=model_config["learning-rate"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, model_config["epochs"], model_config["learning-rate"]*0.1)
if model_config["run-from"] is not None:
    generator.load_state_dict(torch.load(os.path.join('model-weights', model_config["run-from"], 'model-sisr.pth')))

## Training loop

In [None]:
run = 'sisr' + '-' + datetime.now().strftime("%Y-%m-%d--%H-%M-%S")
train_dataset = data_loader(model_config, 'train')
for epoch in range(model_config["epochs"]):
    train(model_config, epoch, run, generator, train_dataset, criterion, optimizer)   
    scheduler.step()
    if epoch % model_config["test-interval"] == 0:
        test_loss = test(generator, valid_dataset, criterion)
        print('\r>>>> [{}/{}] test_loss: {}'.format(epoch, model_config["epochs"], test_loss))

## Testing loop

In [None]:
print_output(generator, valid_dataset)