In [None]:
import torch
import torch.backends.cudnn as cudnn
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
import numpy as np
import PIL.Image as pil_image
import os
import copy
from scipy.io import savemat
from tqdm import tqdm
import cv2
import time

from dataset import TrainDataset, EvalDataset
from datasets import SRDataset

from models import VDSR, VDSR_mod
from imresize import imresize
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr, calc_ssim, AverageMeter, preprocess_kernels, GradLoss

In [None]:
def loss_fn(out, gt):
    
    l1_loss = nn.L1Loss()(out, gt)
    
    Grad_Loss = GradLoss()
    Grad_Loss = Grad_Loss.to(device)
    grad_loss = Grad_Loss(out, gt)
    
    #sobel = Sobel()(hr_bicubic.detach())
    #loss_map = 1 - torch.clamp(sobel, 0, 1)
    #loss_interp = nn.L1Loss()(out * loss_map, hr_bicubic * loss_map)
    
    a, b = 1.0, 0.1
    pixel_loss = a*l1_loss + b*grad_loss
    
    return pixel_loss


In [None]:
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = VDSR_mod()
        state_dict = model.state_dict()

        #pre-trained weights loaded in model

        for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
            if n in state_dict.keys():
                state_dict[n].copy_(p)
                if ('block' not in n):
                    state_dict[n].requires_grad = False
        
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    miles = [20, 50, 80]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=miles, gamma=0.5, last_epoch=-1, verbose=False)

    # Custom dataloaders
    train_dataset = TrainDataset(train_file, scale, crop_size, kernel_file)
    train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train(train_loader=train_loader,
              model=model,
              optimizer=optimizer,
              epoch=epoch)

        torch.save(model.state_dict(), os.path.join(outputs_dir, 'best.pth'))
        #scheduler.step()
        
    return model
        
def train(train_loader, model, optimizer, epoch):
    """
    One epoch's training.
    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: content loss function
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    model.train()  # training mode enables batch normalization

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()

    # Batches
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to default device
        lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24)
        hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96)

        # Forward prop.
        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]

        # Loss
        loss = loss_fn(sr_imgs, hr_imgs)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)

        # Update model
        optimizer.step()

        # Keep track of loss
        losses.update(loss.item(), lr_imgs.size(0))

        # Keep track of batch time
        batch_time.update(time.time() - start)

        # Reset start time
        start = time.time()

        # Print status
        #if i % print_freq == 0:
        #    print('Epoch: [{0}][{1}/{2}]----'
        #          'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
        #          'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
        #          'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader),
        #                                                           batch_time=batch_time,
        #                                                            data_time=data_time, loss=losses))
            
    del lr_imgs, hr_imgs, sr_imgs  # free some memory since their histories may be stored
    
def evaluate_single(dataset_dir, image, model, output_dir, scale, Kernel, device):
    hr_file = dataset_dir + '/{}'.format(image)
    # = 'data/Set5/LR_noncubic/X2/{}x2.png'.format(image)
    output_file = output_dir + '/{}_x2.png'.format(image[:-4])
    
    hr = pil_image.open(hr_file).convert('RGB')
    hr = np.array(hr).astype(np.float32)
    lr = imresize(hr, scale=1./scale, kernel=Kernel)
    lr = imresize(lr, scale=scale, output_shape=hr.shape, kernel='cubic')    

    ycbcr_lr = convert_rgb_to_ycbcr(lr)
    
    y_lr = ycbcr_lr[..., 0]
    y_lr /= 255.
    y_lr = torch.from_numpy(y_lr).to(device)
    y_lr = y_lr.unsqueeze(0).unsqueeze(0)
    
    model.eval()
    with torch.no_grad():
        preds = model(y_lr.float()).clamp(0.0, 1.0)
    
    ycbcr_hr = convert_rgb_to_ycbcr(hr)
    y_hr = ycbcr_hr[..., 0]
    y_hr /= 255.
    y_hr = torch.from_numpy(y_hr).to(device)
    y_hr = y_hr.unsqueeze(0).unsqueeze(0)
    
    psnr = calc_psnr(y_hr, preds)
    psnr = psnr.cpu().numpy()
    
    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    y_hr = y_hr.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    ssim = calc_ssim(y_hr, preds, scale=2)
    
    #output = np.array([preds, ycbcr_lr[..., 1], ycbcr_lr[..., 2]]).transpose([1, 2, 0])
    #output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    #output = pil_image.fromarray(output)
    #output.save(output_file)
    #print(psnr, ssim)
    
    return (psnr, ssim)


In [None]:
dataset = 'Urban100'
dataset_dir = '../../data/{}/HR'.format(dataset)
weights_file = '../MAML-simple/aniso_5000.pth'
outputs_dir = 'output/Set5_mod'
scale = 2

kernel_file = '../../data/Set5/LR_noncubic/non_cubic.mat'
#kernel_file = 'cubic'

# Learning parameters
checkpoint = None  # path to model checkpoint, None if none
batch_size = 1
crop_size = 128
start_epoch = 0  # start at this epoch
epochs = 50
num_workers = 8  # number of workers for loading data in the DataLoader
print_freq = 50  # print training status once every __ batches
lr = 1e-4
grad_clip = None  # clip if gradients are exploding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cudnn.benchmark = True

#image = 'woman'
images = os.listdir(dataset_dir)
Kernel = preprocess_kernels(kernel_file, sf=scale)

lst_psnr, lst_ssim = [], []

if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)
    
for i in range(len(images)):
    train_file = dataset_dir + '/{}'.format(images[i])
    model = main()
    
    psnr, ssim = evaluate_single(dataset_dir, images[i], model, outputs_dir, scale, Kernel, device)
    
    lst_psnr.append(psnr)
    lst_ssim.append(ssim)
    print(i)
    
print('Done')
print(torch.mean(torch.FloatTensor(lst_psnr)), torch.mean(torch.FloatTensor(lst_ssim)))
    

In [None]:
#Testing VDSR/any model - complete set of images (not individuals) without FTuning on test images

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

model = VDSR().to(device)

state_dict = model.state_dict()
dataset = 'BSD100'
method = 'og'

dataset_dir = '../../data/{}/HR'.format(dataset)
output_dir = 'output/{}_{}'.format(dataset, method)

#weights_file = 'output/Set14_{}/best.pth'.format(method)
#weights_file = 'output/{}_{}/best.pth'.format(dataset, method)
weights_file = 'pretrained_models/vdsr_x2.pth'

#kernel_file = '../../data/Set5/LR_noncubic/iso_1.mat'
kernel_file = 'cubic'
scale = 2

#images = os.listdir(dataset_dir)
images = sorted(os.listdir(dataset_dir))
Kernel = preprocess_kernels(kernel_file, sf=scale)


for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

lst_psnr, lst_ssim = [], []

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

for i in range(len(images)):
    print(i)
    
    psnr, ssim = evaluate_single(dataset_dir, images[i], model, output_dir, scale, Kernel, device)
    
    lst_psnr.append(psnr)
    lst_ssim.append(ssim)
    
print('Done')
print(np.mean(lst_psnr), np.mean(lst_ssim))

In [None]:
# entire net meta-trained, only adapters FTuned on images (beta=1e-3, alpha=1e-6, 1e-4), trained on DIV2K, test
# scenario is real-world (Lr:Lr-son)

# Set5 (cubic)        - VDSR (x2)                 - 37.5549/0.9595
# Set5 (noncubic)     - VDSR (x2)                 - 28.7537/0.8339


# Set14 (cubic)       - VDSR (x2)                 - 32.9226/0.9133
# Set14 (noncubic)    - VDSR (x2)                 - 26.4692/0.7440


# Urban100 (cubic)    - VDSR (x2)                 - 30.3503/0.9155
# Urban100 (noncubic) - VDSR (x2)                 - 23.2416/0.6723

# Set5 (cubic)        - VDSR (x4)                 - 31.2365/0.8832
# Set5 (noncubic)     - VDSR (x4)                 - 29.4586/0.8479


# Set14 (cubic)       - VDSR (x4)                 - 27.7680/0.7694
# Set14 (noncubic)    - VDSR (x4)                 - 26.2408/0.7229


# Urban100 (cubic)    - VDSR (x4)                 - 24.3303/0.7189
# Urban100 (noncubic) - VDSR (x4)                 - 23.1669/0.6638

