# Import important libraries

In [1]:
import PIL  #(Pillow)
from PIL import Image
import numpy as np
from numpy import asarray
import pandas as pd
import os
import cv2
import imageio
import dlib
import seaborn as sns
import json
import random
import torchvision
import torchvision.transforms.functional as FT
from torchvision import transforms as transforms
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch import nn
from torch import tensor
from torch.utils import data
from torch.utils.data import Dataset
import torch.backends.cudnn as cudnn
import math
from math import exp, log10, sqrt
import matplotlib.pyplot as plt
import time
import skimage 
from skimage.transform import resize
from skimage.io import imread, imsave
from skimage.color import rgb2gray
from skimage.transform import resize,rescale
from urllib.request import urlopen


# Data Visualization 

In [2]:
from skimage.io import imread, imsave
from skimage.transform import resize

num_in_row = 6
num_in_col = 3
path = '../input/div2k-dataset/DIV2K_train_HR/DIV2K_train_HR'
full_img = None
line_img = None
files = [os.path.join(path, file) for file in np.random.choice(os.listdir(path), size=num_in_row * num_in_col)]
for i, file in enumerate(files):
    img = resize(imread(file), (256, 256))
    if line_img is None:
        line_img = img
        continue
        
    if (i) % (num_in_row) != 0:
        line_img = np.concatenate((line_img, img), axis=1)
    elif full_img is None:
        full_img = line_img.copy()
        line_img = img
    else:
        full_img = np.concatenate((full_img, line_img), axis=0)
        line_img = img
full_img = np.concatenate((full_img, line_img), axis=0)

In [3]:
plt.figure(figsize=(15,15))
plt.imshow(full_img)
plt.axis('off')

imsave('images.png', full_img)

# Constants

In [4]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Some constants
rgb_weights = torch.FloatTensor([65.481, 128.553, 24.966]).to(device)


def convert_image(img, source, target):
    """
    Convert an image from a source format to a target format.
    :param img: image
    :param source: source format, one of 'pil' (PIL image), '[0, 1]' or '[-1, 1]' (pixel value ranges)
    :param target: target format, one of 'pil' (PIL image), '[0, 255]', '[0, 1]', '[-1, 1]' (pixel value ranges),
                   'y-channel' (luminance channel Y in the YCbCr color format, used to calculate PSNR and SSIM)
    :return: converted image
    """
    assert source in {'pil', '[0, 1]', '[-1, 1]'}, "Cannot convert from source format %s!" % source
    assert target in {'pil', '[0, 255]', '[0, 1]', '[-1, 1]',
                      'y-channel'}, "Cannot convert to target format %s!" % target

    # Convert from source to [0, 1]
    if source == 'pil':
        img = FT.to_tensor(img)

    elif source == '[0, 1]':
        pass  # already in [0, 1]

    elif source == '[-1, 1]':
        img = (img + 1.) / 2.

    # Convert from [0, 1] to target
    if target == 'pil':
        img = FT.to_pil_image(img)

    elif target == '[0, 255]':
        img = 255. * img

    elif target == '[0, 1]':
        pass  # already in [0, 1]

    elif target == '[-1, 1]':
        img = 2. * img - 1.

    elif target == 'y-channel':
        img = torch.matmul(255. * img.permute(0, 2, 3, 1)[:, 4:-4, 4:-4, :], rgb_weights) / 255. + 16.

    return img


class ImageTransforms(object):
    """
    Image transformation pipeline.
    """

    def __init__(self, split, crop_size, scaling_factor, lr_img_type, hr_img_type):
        """
        :param split: one of 'train' or 'test'
        :param crop_size: crop size of HR images
        :param scaling_factor: LR images will be downsampled from the HR images by this factor
        :param lr_img_type: the target format for the LR image; see convert_image() above for available formats
        :param hr_img_type: the target format for the HR image; see convert_image() above for available formats
        """
        self.split = split.lower()
        self.crop_size = crop_size
        self.scaling_factor = scaling_factor
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type

        assert self.split in {'train', 'valid'}

    def __call__(self, img):
        """
        :param img: a PIL source image from which the HR image will be cropped, and then downsampled to create the LR image
        :return: LR and HR images in the specified format
        """

        # Crop
        if self.split == 'train':
            # Take a random fixed-size crop of the image, which will serve as the high-resolution (HR) image
            left = random.randint(1, img.width - self.crop_size)
            top = random.randint(1, img.height - self.crop_size)
            right = left + self.crop_size
            bottom = top + self.crop_size
            hr_img = img.crop((left, top, right, bottom))
        elif self.split == 'valid':
            # Take a random fixed-size crop of the image, which will serve as the high-resolution (HR) image
            left = random.randint(1, img.width - self.crop_size)
            top = random.randint(1, img.height - self.crop_size)
            right = left + self.crop_size
            bottom = top + self.crop_size
            hr_img = img.crop((left, top, right, bottom))
        else:
            # Take the largest possible center-crop of it such that its dimensions are perfectly divisible by the scaling factor
            x_remainder = img.width % self.scaling_factor
            y_remainder = img.height % self.scaling_factor
            left = x_remainder // 2
            top = y_remainder // 2
            right = left + (img.width - x_remainder)
            bottom = top + (img.height - y_remainder)
            hr_img = img.crop((left, top, right, bottom))

        lr_img = hr_img.resize((int(hr_img.width / self.scaling_factor), int(hr_img.height / self.scaling_factor)),
                               Image.BICUBIC)

        assert hr_img.width == lr_img.width * self.scaling_factor and hr_img.height == lr_img.height * self.scaling_factor

        lr_img = convert_image(lr_img, source='pil', target=self.lr_img_type)
        hr_img = convert_image(hr_img, source='pil', target=self.hr_img_type)

        return lr_img, hr_img


class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


def save_checkpoint(state, filename):
    """
    Save model checkpoint.
    :param state: checkpoint contents
    """
    torch.save(state, filename)


def adjust_learning_rate(optimizer, shrink_factor):
    """
    Shrinks learning rate by a specified factor.
    :param optimizer: optimizer whose learning rate must be shrunk.
    :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
    """

    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))# This Python 3 environment comes with many helpful analytics libraries installed

import numpy as np 
import pandas as pd

# Now we have to build the class Dataset

In [5]:

class SRDataset(Dataset):
    """
    A PyTorch Dataset to be used by a PyTorch DataLoader.
    """

    def __init__(self, data_folder, split, crop_size, scaling_factor, lr_img_type, hr_img_type):
        """
        :param data_folder: data with images
        :param split: one of 'train' or 'valid'
        :param crop_size: crop size of target HR images
        :param scaling_factor: the input LR images will be downsampled from the target HR images by this factor; the scaling done in the super-resolution
        :param lr_img_type: the format for the LR image supplied to the model; see convert_image() in utils.py for available formats
        :param hr_img_type: the format for the HR image supplied to the model; see convert_image() in utils.py for available formats
        """

        self.data_folder = data_folder
        self.split = split.lower()
        self.crop_size = int(crop_size)
        self.scaling_factor = int(scaling_factor)
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type


        assert lr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]'}
        assert hr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]'}

        if self.split == 'train':
            assert self.crop_size % self.scaling_factor == 0, "Crop dimensions are not perfectly divisible by scaling factor! This will lead to a mismatch in the dimensions of the original HR patches and their super-resolved (SR) versions!"

        # Read list of image-paths
        if self.split == 'train':
            path = os.path.join(data_folder, 'DIV2K_train_HR','DIV2K_train_HR')
        else:
            path = os.path.join(data_folder, 'DIV2K_valid_HR', 'DIV2K_valid_HR')
            
        self.images = [os.path.join(path, file) for file in os.listdir(path)]
        
        # Select the correct set of transforms
        self.transform = ImageTransforms(split=self.split,
                                         crop_size=self.crop_size,
                                         scaling_factor=self.scaling_factor,
                                         lr_img_type=self.lr_img_type,
                                         hr_img_type=self.hr_img_type)

    def __getitem__(self, i):
        """
        This method is required to be defined for use in the PyTorch DataLoader.
        :param i: index to retrieve
        :return: the 'i'th pair LR and HR images to be fed into the model
        """
        # Read image
        img = Image.open(self.images[i], mode='r')
        img = img.convert('RGB')
        if img.width <= 96 or img.height <= 96:
            print(self.images[i], img.width, img.height)
        lr_img, hr_img = self.transform(img)

        return lr_img, hr_img

    def __len__(self):
        """
        This method is required to be defined for use in the PyTorch DataLoader.
        :return: size of this data (in number of images)
        """
        return len(self.images)

# Convolutional and training model

In [6]:

class ConvolutionalBlock(nn.Module):
    """
    A convolutional block, comprising convolutional, BN, activation layers.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, batch_norm=False, activation=None):
        """
        :param in_channels: number of input channels
        :param out_channels: number of output channe;s
        :param kernel_size: kernel size
        :param stride: stride
        :param batch_norm: include a BN layer?
        :param activation: Type of activation; None if none
        """
        super(ConvolutionalBlock, self).__init__()

        if activation is not None:
            activation = activation.lower()
            assert activation in {'prelu', 'leakyrelu', 'tanh', 'sigmoid'}

        # A container that will hold the layers in this convolutional block
        layers = list()

        # A convolutional layer
        layers.append(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=kernel_size // 2))

        # A batch normalization (BN) layer, if wanted
        if batch_norm is True:
            layers.append(nn.BatchNorm2d(num_features=out_channels))

        # An activation layer, if wanted
        if activation == 'prelu':
            layers.append(nn.PReLU())
        elif activation == 'leakyrelu':
            layers.append(nn.LeakyReLU(0.2))
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        elif activation == 'sigmoid':
            layers.append(nn.Sigmoid())

        # Put together the convolutional block as a sequence of the layers in this container
        self.conv_block = nn.Sequential(*layers)

    def forward(self, input):
        """
        Forward propagation.
        :param input: input images, a tensor of size (N, in_channels, w, h)
        :return: output images, a tensor of size (N, out_channels, w, h)
        """
        output = self.conv_block(input)  # (N, out_channels, w, h)

        return output


class SubPixelConvolutionalBlock(nn.Module):
    """
    A subpixel convolutional block, comprising convolutional, pixel-shuffle, and PReLU activation layers.
    """

    def __init__(self, kernel_size=3, n_channels=64, scaling_factor=2):
        """
        :param kernel_size: kernel size of the convolution
        :param n_channels: number of input and output channels
        :param scaling_factor: factor to scale input images by (along both dimensions)
        """
        super(SubPixelConvolutionalBlock, self).__init__()

        # A convolutional layer that increases the number of channels by scaling factor^2, followed by pixel shuffle and PReLU
        self.conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (scaling_factor ** 2),
                              kernel_size=kernel_size, padding=kernel_size // 2)
        # These additional channels are shuffled to form additional pixels, upscaling each dimension by the scaling factor
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
        self.prelu = nn.PReLU()

    def forward(self, input):
        """
        Forward propagation.
        :param input: input images, a tensor of size (N, n_channels, w, h)
        :return: scaled output images, a tensor of size (N, n_channels, w * scaling factor, h * scaling factor)
        """
        output = self.conv(input)
        output = self.pixel_shuffle(output)
        output = self.prelu(output) 

        return output




# Class SRResNet

In [7]:
class ResidualBlock(nn.Module):
    """
    A residual block, comprising two convolutional blocks with a residual connection across them.
    """

    def __init__(self, kernel_size=3, n_channels=64):
        """
        :param kernel_size: kernel size
        :param n_channels: number of input and output channels (same because the input must be added to the output)
        """
        super(ResidualBlock, self).__init__()

        # The first convolutional block
        self.conv_block1 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=True, activation='PReLu')

        # The second convolutional block
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=True, activation=None)

    def forward(self, input):
        """
        Forward propagation.
        :param input: input images, a tensor of size (N, n_channels, w, h)
        :return: output images, a tensor of size (N, n_channels, w, h)
        """
        residual = input  # (N, n_channels, w, h)
        output = self.conv_block1(input)  # (N, n_channels, w, h)
        output = self.conv_block2(output)  # (N, n_channels, w, h)
        output = output + residual  # (N, n_channels, w, h)

        return output


class SRResNet(nn.Module):
    """
    The SRResNet, as defined in the paper.
    """

    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        """
        :param large_kernel_size: kernel size of the first and last convolutions which transform the inputs and outputs
        :param small_kernel_size: kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
        :param n_channels: number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
        :param n_blocks: number of residual blocks
        :param scaling_factor: factor to scale input images by (along both dimensions) in the subpixel convolutional block
        """
        super(SRResNet, self).__init__()

        # Scaling factor must be 2, 4, or 8
        scaling_factor = int(scaling_factor)
        assert scaling_factor in {2, 4, 8}, "The scaling factor must be 2, 4, or 8!"

        # The first convolutional block
        self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='PReLu')

        # A sequence of n_blocks residual blocks, each containing a skip-connection across the block
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])

        # Another convolutional block
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=small_kernel_size,
                                              batch_norm=True, activation=None)

        # Upscaling is done by sub-pixel convolution, with each such block upscaling by a factor of 2
        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i
              in range(n_subpixel_convolution_blocks)])

        # The last convolutional block
        self.conv_block3 = ConvolutionalBlock(
            in_channels=n_channels, 
            out_channels=3, 
            kernel_size=large_kernel_size,
            batch_norm=False, 
            activation='tanh')

    def forward(self, lr_imgs):
        """
        Forward prop.
        :param lr_imgs: low-resolution input images, a tensor of size (N, 3, w, h)
        :return: super-resolution output images, a tensor of size (N, 3, w * scaling factor, h * scaling factor)
        """
        output = self.conv_block1(lr_imgs)  # (N, 3, w, h)
        residual = output  # (N, n_channels, w, h)
        output = self.residual_blocks(output)  # (N, n_channels, w, h)
        output = self.conv_block2(output)  # (N, n_channels, w, h)
        output = output + residual  # (N, n_channels, w, h)
        output = self.subpixel_convolutional_blocks(output)  # (N, n_channels, w * scaling factor, h * scaling factor)
        sr_imgs = self.conv_block3(output)  # (N, 3, w * scaling factor, h * scaling factor)

        return sr_imgs

# Execution

In [8]:
# Data parameters
data_folder = '../input/div2k-dataset/'  # folder with JSON data files
crop_size = 96  # crop size of target HR images
scaling_factor = 4  # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor

# Model parameters
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 128  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 32  # number of residual blocks

# Learning parameters

if os.path.exists('./checkpoint_srresnet.pth.tar'):
    print('Read from ./checkpoint_srresnet.pth.tar')
    checkpoint = './checkpoint_srresnet.pth.tar' 
elif os.path.exists('../input/weightsresnet/checkpoint_srresnet.pth.tar'):
    print('Read from ../input/weightsresnet/checkpoint_srresnet.pth.tar')
    checkpoint = '../input/weightsresnet/checkpoint_srresnet.pth.tar'     
else:
    checkpoint = None  # path to model checkpoint, None if none
    print('Ther are no any Weights!')
    
batch_size = 32  # batch size

start_epoch = 0  # start at this epoch
epochs = 17  # number of training iterations

workers = 4  # number of workers for loading data in the DataLoader
print_freq = 500  # print training status once every __ batches
lr = 1e-4  # learning rate
grad_clip = None  # clip if gradients are exploding

cudnn.benchmark = True

loss_curve_train = []
loss_curve_valid = []


def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint
    #checkpoint = "./checkpoint_srresnet.pth.tar" 
    

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                         n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                                     lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = nn.L1Loss().to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='[-1, 1]',
                              hr_img_type='[-1, 1]')
    
    valid_dataset = SRDataset(data_folder,
                              split='valid',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='[-1, 1]',
                              hr_img_type='[-1, 1]')
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=workers,
        pin_memory=True)
    
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=workers,
        pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train_losses = train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer)
        
        valid_losses = valid(valid_loader=valid_loader,
              model=model,
              criterion=criterion)
        
        print('Epoch: [{0}] ---- '
              'Train Average Loss: {train_loss.avg:.4f} ---- '
              'Valid Average Loss: {valid_loss.avg:.4f}'
              .format(epoch,
                      train_loss=train_losses,
                      valid_loss=valid_losses))
        
        loss_curve_train.append(train_losses.avg)
        loss_curve_valid.append(valid_losses.avg)

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer
        },
            'checkpoint_srresnet.pth.tar'
        )


def train(train_loader, model, criterion, optimizer):
    """
    One epoch's training.
    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: content loss function (Mean Squared-Error loss)
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    model.train()  # training mode enables batch normalization

    data_time = AverageMeter()
    losses = AverageMeter()

    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        sr_imgs = model(lr_imgs)

        loss = criterion(sr_imgs, hr_imgs)

        optimizer.zero_grad()
        loss.backward()

        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)

        optimizer.step()

        # Keep track of loss
        losses.update(loss.item(), lr_imgs.size(0))

    del lr_imgs, hr_imgs, sr_imgs  # free some memory since their histories may be stored
    return losses
    
    
def valid(valid_loader, model, criterion):
    """
    One epoch's training.
    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: content loss function (Mean Squared-Error loss)
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    model.eval() 

    data_time = AverageMeter()
    losses = AverageMeter()

    for i, (lr_imgs, hr_imgs) in enumerate(valid_loader):

        # Move to default device
        lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96), in [-1, 1]

        # Forward prop.
        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]

        # Loss
        loss = criterion(sr_imgs, hr_imgs)  # scalar

        # Keep track of loss
        losses.update(loss.item(), lr_imgs.size(0))


    del lr_imgs, hr_imgs, sr_imgs  # free some memory since their histories may be stored
    return losses

if __name__ == '__main__':
    main()

In [9]:
import matplotlib.pyplot as plt

plt.plot(loss_curve_train, label='train')
plt.plot(loss_curve_valid, label='valid')
plt.legend()
plt.xlabel('number of training epochs')
plt.ylabel('pixel loss')

# Results

In [10]:
crop_size = 8*96  # crop size of target HR images
scaling_factor = 8
train_dataset = SRDataset('../input/div2k-dataset',
                          split='train',
                          crop_size=crop_size,
                          scaling_factor=scaling_factor,
                          lr_img_type='[-1, 1]',
                          hr_img_type='[-1, 1]')
img_lr, img_hr = train_dataset[25]
plt.figure(figsize=(20,20))
plt.subplot(131)
plt.imshow((img_lr.swapaxes(0,1).swapaxes(1,2) + 1) * 0.5)
plt.title('LR image')
plt.subplot(132)
plt.imshow((img_hr.swapaxes(0,1).swapaxes(1,2) + 1) * 0.5)
plt.title('HR image')

checkpoint = torch.load('./checkpoint_srresnet.pth.tar')
model = checkpoint['model']
imr_pred = model(torch.tensor(img_lr).unsqueeze(0).to(device))

plt.subplot(133)
plt.imshow((imr_pred.detach().cpu().squeeze().numpy().swapaxes(0,1).swapaxes(1,2) + 1) * 0.5)
plt.title('Predicted image')

# Apply the results to some image

In [11]:
img = imread('https://www.popsci.com/uploads/2019/01/07/GMY7X5OKPSBGSTSFEZXDEYW6JU-1024x1024.jpg').astype(float) / 255

In [12]:

img_lr = resize(img,(int(img.shape[0]/4), int(img.shape[1]/4)))
plt.figure(figsize=(15,38))
plt.subplot(311)
plt.imshow(img_lr)
plt.title('Low resolution image')
plt.axis('off')

plt.subplot(312)
plt.imshow(img)
plt.title('Original image')
plt.axis('off')

checkpoint = torch.load('./checkpoint_srresnet.pth.tar')
model = checkpoint['model']
imr_pred = model(torch.tensor(img_lr).swapaxes(1,2).swapaxes(0,1).unsqueeze(0).to(device).float())

plt.subplot(313)
plt.imshow(imr_pred.detach().cpu().squeeze().numpy().swapaxes(0,1).swapaxes(1,2))
plt.title('Generated image')
plt.axis('off')

# plt.savefig('earth.png', dpi=300, bbox_inches='tight')

In [13]:
img_lr = resize(img,(int(img.shape[0]/4), int(img.shape[1]/4)))
plt.figure(figsize=(17,10))
plt.subplot(131)
plt.imshow(img_lr)
plt.title('Low resolution image')
plt.axis('off')

plt.subplot(132)
plt.imshow(img)
plt.title('Original image')
plt.axis('off')

checkpoint = torch.load('./checkpoint_srresnet.pth.tar')
model = checkpoint['model']
imr_pred = model(torch.tensor(img_lr).swapaxes(1,2).swapaxes(0,1).unsqueeze(0).to(device).float())

plt.subplot(133)
plt.imshow(imr_pred.detach().cpu().squeeze().numpy().swapaxes(0,1).swapaxes(1,2))
plt.title('Generated image')
plt.axis('off')