In [1]:
# get root
import sys
sys.path.append("../")

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import torch
import skimage.io as io

import model.edsr as edsr
import torch.nn.functional as F

from skimage.io import imread, imshow
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.filters import gaussian
from skimage import data, color

# options for EDSR
class Opt:
    scale = 4
    num_blocks = 32
    num_channels = 256
    res_scale = 0.1

def im2tensor(im):
    np_t = np.ascontiguousarray(im.transpose((2, 0, 1)))
    tensor = torch.from_numpy(np_t).float()
    return tensor

def tensor2im(tensor):
    tensor = tensor.detach().squeeze(0)
    im = tensor.clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy()
    return im

def downsample_img(img, factor=2):
    downsampled = resize(img, (img.shape[0] // factor, img.shape[1] // factor), order=1, mode='reflect', 
                             clip=True, preserve_range=True, anti_aliasing=True)
    upsampled = resize(downsampled, img.shape, order=0, mode='reflect', 
                             clip=True, preserve_range=True, anti_aliasing=False)
    return upsampled / 255

opt = Opt()
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net_base = edsr.Net(opt).to(dev)
net_moa = edsr.Net(opt).to(dev)

image_scaling = 8
image_name = "Galah-cockatoo.jpg"# "Canon_003_HR.png"

print("Setup Complete!")
 

  return torch._C._cuda_getDeviceCount() > 0
Setup Complete!


## DIV2K pre-trained model

In [2]:
path_image_HR = image_name
#path_image_LR = "Canon_003_LR4.png"
path_base = "/data/pt_models/DIV2K_EDSR_X4_base.pt"
path_moa = "/data/pt_models/DIV2K_EDSR_X4_moa.pt"

state_base = torch.load(path_base, map_location=lambda storage, loc: storage)
state_moa = torch.load(path_moa, map_location=lambda storage, loc: storage)
net_base.load_state_dict(state_base)
net_moa.load_state_dict(state_moa)

HR = io.imread(path_image_HR)
HR_tensor = im2tensor(HR).unsqueeze(0).to(dev)

# LR = io.imread(path_image_LR)
LR = downsample_img(HR, factor=image_scaling)
LR_tensor = (im2tensor(LR)*255).unsqueeze(0).to(dev)

# apply CutBlur
LR_tensor[..., 200:600, 200:600] = HR_tensor[..., 200:600, 200:600]

with torch.no_grad():
    SR_base = tensor2im(net_base(LR_tensor))
    SR_moa = tensor2im(net_moa(LR_tensor))

FileNotFoundError: [Errno 2] No such file or directory: '/data/pt_models/DIV2K_EDSR_X4_base.pt'

In [None]:
LR_plot = tensor2im(LR_tensor)[:,:]/255
HR_plot = HR[:,:] / 255
SR_base_plot = SR_base[:,:] / 255
SR_moa_plot = SR_moa[:,:] / 255

diff_SR_base = (HR_plot-SR_base_plot).mean(2) * 10
diff_SR_moa = (HR_plot-SR_moa_plot).mean(2) * 10

f, axarr = plt.subplots(3, 2, figsize=(18, 24))
axarr[0, 0].imshow(LR_plot)
axarr[0, 0].set_title("Input (Cutblurred LR)", fontsize=18)
axarr[0, 0].axis("off")
 
axarr[0, 1].axis("off")
 
axarr[1, 0].imshow(SR_base_plot)
axarr[1, 0].set_title("EDSR w/o MoA", fontsize=18)
axarr[1, 0].axis("off")
 
axarr[1, 1].imshow(diff_SR_base, vmin=0, vmax=1, cmap="viridis")
axarr[1, 1].set_title("EDSR w/o MoA (Δ)", fontsize=18)
axarr[1, 1].axis("off")

axarr[2, 0].imshow(SR_moa_plot)
axarr[2, 0].set_title("EDSR w/ MoA", fontsize=18)
axarr[2, 0].axis("off")
 
axarr[2, 1].imshow(diff_SR_moa, vmin=0, vmax=1, cmap="viridis")
axarr[2, 1].set_title("EDSR w/ MoA (Δ)", fontsize=18)
axarr[2, 1].axis("off")

plt.show()

## RealSR pre-trained model

In [None]:
path_image_HR = image_name
#path_image_LR = "Canon_003_LR4.png"
path_base = "/data/pt_models/RealSR_EDSR_X4_base.pt"
path_moa = "/data/pt_models/RealSR_EDSR_X4_moa.pt"

state_base = torch.load(path_base, map_location=lambda storage, loc: storage)
state_moa = torch.load(path_moa, map_location=lambda storage, loc: storage)
net_base.load_state_dict(state_base)
net_moa.load_state_dict(state_moa)

HR = io.imread(path_image_HR)
HR_tensor = im2tensor(HR).unsqueeze(0).to(dev)

# LR = io.imread(path_image_LR)
LR = downsample_img(HR, factor=image_scaling)
LR_tensor = (im2tensor(LR)*255).unsqueeze(0).to(dev)

# apply CutBlur
LR_tensor[..., 200:600, 200:600] = HR_tensor[..., 200:600, 200:600]

with torch.no_grad():
    SR_base = tensor2im(net_base(LR_tensor))
    SR_moa = tensor2im(net_moa(LR_tensor))

In [None]:
LR_plot = tensor2im(LR_tensor)[:,:]/255
HR_plot = HR[:,:] / 255
SR_base_plot = SR_base[:,:] / 255
SR_moa_plot = SR_moa[:,:] / 255

diff_SR_base = (HR_plot-SR_base_plot).mean(2) * 10
diff_SR_moa = (HR_plot-SR_moa_plot).mean(2) * 10

f, axarr = plt.subplots(3, 2, figsize=(18, 24))
axarr[0, 0].imshow(LR_plot)
axarr[0, 0].set_title("Input (Cutblurred LR)", fontsize=18)
axarr[0, 0].axis("off")
 
axarr[0, 1].axis("off")
 
axarr[1, 0].imshow(SR_base_plot)
axarr[1, 0].set_title("EDSR w/o MoA", fontsize=18)
axarr[1, 0].axis("off")
 
axarr[1, 1].imshow(diff_SR_base, vmin=0, vmax=1, cmap="viridis")
axarr[1, 1].set_title("EDSR w/o MoA (Δ)", fontsize=18)
axarr[1, 1].axis("off")

axarr[2, 0].imshow(SR_moa_plot)
axarr[2, 0].set_title("EDSR w/ MoA", fontsize=18)
axarr[2, 0].axis("off")
 
axarr[2, 1].imshow(diff_SR_moa, vmin=0, vmax=1, cmap="viridis")
axarr[2, 1].set_title("EDSR w/ MoA (Δ)", fontsize=18)
axarr[2, 1].axis("off")

plt.show()