In [1]:
import os
import time
import os.path as osp

import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor
import cv2
import glob
import easydict
import numpy as np

## Settings

In [2]:
TEST_IMAGE_FOLDER = 'lr/*'

In [3]:
args = {
    'cpu' : False, # True if you don't have a CUDA based GPU
    'scale' : 2, 
    'inp_images' : TEST_IMAGE_FOLDER
}

## Code

### ESRGAN

In [4]:
import architecture as arch

In [7]:
esrgan = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
esrgan.load_state_dict(torch.load('RRDB_PSNR_x4.pth'), strict=True)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [8]:
esrgan.eval()

RRDB_Net(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity + 
    |Sequential(
    |  (0): RRDB(
    |    (RDB1): ResidualDenseBlock_5C(
    |      (conv1): Sequential(
    |        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    |        (1): LeakyReLU(negative_slope=0.2, inplace)
    |      )
    |      (conv2): Sequential(
    |        (0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    |        (1): LeakyReLU(negative_slope=0.2, inplace)
    |      )
    |      (conv3): Sequential(
    |        (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    |        (1): LeakyReLU(negative_slope=0.2, inplace)
    |      )
    |      (conv4): Sequential(
    |        (0): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    |        (1): LeakyReLU(negative_slope=0.2, inplace)
    |      )
    |      (conv5): Sequential(
    |        (0): C

#### Super resolve images

In [9]:
with torch.no_grad():
    esrgan = esrgan.to('cpu' if args['cpu'] else 'cuda')

    idx = 0
    for path in glob.glob(args['inp_images']):
        idx += 1
        base = os.path.splitext(os.path.basename(path))[0]
        print(idx, base)
        # read image
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = img * 1.0 / 255
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img_LR = img.unsqueeze(0)
        img_LR = img_LR.to('cpu' if args['cpu'] else 'cuda')
        
        start_time = time.time()
        output = esrgan(img_LR)
        end_time = time.time()
        print('Time taken :', end_time - start_time, '\n')
        
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round()
        cv2.imwrite('hr/{:s}_ESRGAN.png'.format(base), output)


1 0853x4
Time taken : 2.379230499267578 

2 119082
Time taken : 1.0742292404174805 

3 42012
Time taken : 1.0762825012207031 



### ProGANSR

In [10]:
import multiproc
from progressive_loader import Dataset
import copy
import skimage.io as io

from collections import OrderedDict

In [11]:
from layers import (_DenseBlock, CompressionBlock, Conv2d, DenseResidualBlock,
                     init_weights, PixelShuffleUpsampler, ResidualBlock)
from enum import Enum
from math import log2
from logger import info, error
from utils import (get_filenames, IMG_EXTENSIONS, print_evaluation,
                         tensor2im)

In [12]:
class DataLoader(multiproc.MyDataLoader):
    def __init__(self, dataset, batch_size, scale=None):
        self.dataset = dataset
        self.phase = 'test'

        super(DataLoader, self).__init__(
            self.dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=16,
            random_vars=None,
            drop_last=True,
            sampler=None
        )

In [13]:
class ProSR(nn.Module):
    """docstring for PyramidDenseNet"""

    def __init__(self, residual_denseblock, num_init_features, bn_size,
                 growth_rate, ps_woReLU, level_config, level_compression,
                 res_factor, max_num_feature, max_scale, **kwargs):
        super(ProSR, self).__init__()
        self.max_scale = max_scale
        self.n_pyramids = int(log2(self.max_scale))

        # used in curriculum learning, initially set to the last scale
        self.current_scale_idx = self.n_pyramids - 1

        self.residual_denseblock = residual_denseblock
        self.DenseBlock = _DenseBlock
        self.Upsampler = PixelShuffleUpsampler
        self.upsample_args = {'woReLU': ps_woReLU}

        denseblock_params = {
            'num_layers': None,
            'num_input_features': num_init_features,
            'bn_size': bn_size,
            'growth_rate': growth_rate,
        }

        num_features = denseblock_params['num_input_features']

        # Initiate network

        # each scale has its own init_conv
        for s in range(1, self.n_pyramids + 1):
            self.add_module('init_conv_%d' % s, Conv2d(3, num_init_features,
                                                       3))

        # Each denseblock forms a pyramid
        for i in range(self.n_pyramids):
            block_config = level_config[i]
            pyramid_residual = OrderedDict()

            # starting from the second pyramid, compress the input features
            if i != 0:
                out_planes = num_init_features if level_compression <= 0 else int(
                    level_compression * num_features)
                pyramid_residual['compression_%d' % i] = CompressionBlock(
                    in_planes=num_features, out_planes=out_planes)
                num_features = out_planes

            # serial connect blocks
            for b, num_layers in enumerate(block_config):
                denseblock_params['num_layers'] = num_layers
                denseblock_params['num_input_features'] = num_features
                # residual dense block used in ProSRL and ProSRGAN
                if self.residual_denseblock:
                    pyramid_residual['residual_denseblock_%d' %
                                     (b + 1)] = DenseResidualBlock(
                                         res_factor=res_factor,
                                         **denseblock_params)
                else:
                    block, num_features = self.create_denseblock(
                        denseblock_params,
                        with_compression=(b != len(block_config) - 1),
                        compression_rate=kwargs['block_compression'])
                    pyramid_residual['denseblock_%d' % (b + 1)] = block

            # conv before upsampling
            block, num_features = self.create_finalconv(
                num_features, max_num_feature)
            pyramid_residual['final_conv'] = block
            self.add_module('pyramid_residual_%d' % (i + 1),
                            nn.Sequential(pyramid_residual))

            # upsample the residual by 2 before reconstruction and next level
            self.add_module(
                'pyramid_residual_%d_residual_upsampler' % (i + 1),
                self.Upsampler(2, num_features, **self.upsample_args))

            # reconstruction convolutions
            reconst_branch = OrderedDict()
            out_channels = num_features
            reconst_branch['final_conv'] = Conv2d(out_channels, 3, 3)
            self.add_module('reconst_%d' % (i + 1),
                            nn.Sequential(reconst_branch))

        init_weights(self)

    def get_init_conv(self, idx):
        """choose which init_conv based on curr_scale_idx (1-based)"""
        return getattr(self, 'init_conv_%d' % idx)

    def forward(self, x, upscale_factor=None, blend=1.0):
        if upscale_factor is None:
            upscale_factor = self.max_scale
        else:
            valid_upscale_factors = [
                2**(i + 1) for i in range(self.n_pyramids)
            ]
            if upscale_factor not in valid_upscale_factors:
                error("Invalid upscaling factor {}: choose one of: {}".format(
                    upscale_factor, valid_upscale_factors))
                raise SystemExit(1)

        feats = self.get_init_conv(log2(upscale_factor))(x)
        for s in range(1, int(log2(upscale_factor)) + 1):
            if self.residual_denseblock:
                feats = getattr(self, 'pyramid_residual_%d' % s)(feats) + feats
            else:
                feats = getattr(self, 'pyramid_residual_%d' % s)(feats)
            feats = getattr(
                self, 'pyramid_residual_%d_residual_upsampler' % s)(feats)

            # reconst residual image if reached desired scale /
            # use intermediate as base_img / use blend and s is one step lower than desired scale
            if 2**s == upscale_factor or (blend != 1.0 and 2**
                                          (s + 1) == upscale_factor):
                tmp = getattr(self, 'reconst_%d' % s)(feats)
                # if using blend, upsample the second last feature via bilinear upsampling
                if (blend != 1.0 and s == self.current_scale_idx):
                    base_img = nn.functional.upsample(
                        tmp,
                        scale_factor=2,
                        mode='bilinear',
                        align_corners=True)
                if 2**s == upscale_factor:
                    if (blend != 1.0) and s == self.current_scale_idx + 1:
                        tmp = tmp * blend + (1 - blend) * base_img
                    output = tmp

        return output

    def create_denseblock(self,
                          denseblock_params,
                          with_compression=True,
                          compression_rate=0.5):
        block = OrderedDict()
        block['dense'] = self.DenseBlock(**denseblock_params)
        num_features = denseblock_params['num_input_features']
        num_features += denseblock_params['num_layers'] * denseblock_params['growth_rate']

        if with_compression:
            out_planes = num_features if compression_rate <= 0 else int(
                compression_rate * num_features)
            block['comp'] = CompressionBlock(
                in_planes=num_features, out_planes=out_planes)
            num_features = out_planes

        return nn.Sequential(block), num_features

    def create_finalconv(self, in_channels, max_channels=None):
        block = OrderedDict()
        if in_channels > max_channels:
            block['final_comp'] = CompressionBlock(in_channels, max_channels)
            block['final_conv'] = Conv2d(max_channels, max_channels, (3, 3))
            out_channels = max_channels
        else:
            block['final_conv'] = Conv2d(in_channels, in_channels, (3, 3))
            out_channels = in_channels
        return nn.Sequential(block), out_channels

    def class_name(self):
        return 'ProSR'


In [14]:
checkpoint = torch.load('proSR_x2.pth')
progansr = ProSR(**checkpoint['params']['G'])
progansr.load_state_dict(checkpoint['state_dict'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [15]:
progansr = progansr.to('cpu' if args['cpu'] else 'cuda')
progansr.eval()

ProSR(
  (init_conv_1): Conv2d(
    (conv): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(3, 160, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (init_conv_2): Conv2d(
    (conv): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(3, 160, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (init_conv_3): Conv2d(
    (conv): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(3, 160, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (pyramid_residual_1): Sequential(
    (residual_denseblock_1): DenseResidualBlock(
      (dense_block): _DenseBlock(
        (denselayer1): _DenseLayer(
          (conv_1): Conv2d(160, 160, kernel_size=(1, 1), stride=(1, 1))
          (relu_2): ReLU(inplace)
          (conv_2): Conv2d(
            (conv): Sequential(
              (0): ReflectionPad2d((1, 1, 1, 1))
              (1): Conv2d(160, 40, kernel_size=(3, 3), stride=(1, 1))
            )
          )
        )
        (denselayer2): _DenseLayer

In [16]:
params = checkpoint['params']

dataset = Dataset(
    'test',
    get_filenames('lr/', IMG_EXTENSIONS),
    [],
    2,
    input_size=None,
    mean=params['train']['dataset']['mean'],
    stddev=params['train']['dataset']['stddev'],
    downscale=False
)

data_loader = DataLoader(dataset, batch_size=1)

In [17]:
mean = params['train']['dataset']['mean']
stddev = params['train']['dataset']['stddev']

#### Super resolve images

In [18]:
with torch.no_grad():
    for iid, dt in enumerate(data_loader):
        input = dt['input']
        if not args['cpu']:
            input = input.cuda()
        
        start_time = time.time()
        output = progansr(input, args['scale']).cpu() + dt['bicubic']
        end_time = time.time()
        
        sr_img = tensor2im(output, mean, stddev)

        fn = osp.join('hr/', osp.basename(dt['input_fn'][0]))
        print(osp.basename(dt['input_fn'][0]))
        print('Time taken :', end_time - start_time, '\n')
        io.imsave(fn, sr_img)

119082.jpg
Time taken : 2.8227767944335938 

42012.jpg
Time taken : 2.6782026290893555 

0853x4.png
Time taken : 2.99727725982666 



### SRCNN

In [19]:
class SRCNN(torch.nn.Module):
    def __init__(self, num_channels, base_filter, upscale_factor=2):
        super(SRCNN, self).__init__()

        self.layers = torch.nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=base_filter // 2, out_channels=num_channels * (upscale_factor ** 2), kernel_size=5, stride=1, padding=2, bias=True),
            nn.PixelShuffle(upscale_factor)
        )

    def forward(self, x):
        out = self.layers(x)
        return out

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()


In [20]:
srcnn = SRCNN(1, 64)
srcnn.load_state_dict(torch.load("srcnn_x2.pth"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [21]:
srcnn = srcnn.to('cpu' if args['cpu'] else 'cuda')
srcnn.eval()

SRCNN(
  (layers): Sequential(
    (0): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU(inplace)
    (2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU(inplace)
    (4): Conv2d(32, 4, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): PixelShuffle(upscale_factor=2)
  )
)

#### Super resolve images

In [22]:
lr_imgs = glob.glob(args['inp_images'])
with torch.no_grad():
    for lr_img in lr_imgs:
        base = os.path.splitext(os.path.basename(lr_img))[0]
        lr_img = Image.open(lr_img).convert('YCbCr')
        y, cb, cr = lr_img.split()

        dt = (ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
        dt = dt.to('cpu' if args['cpu'] else 'cuda')
        
        start_time = time.time()
        out = srcnn(dt)
        end_time = time.time()
        
        print(base)
        print('Time taken :', end_time - start_time, '\n')
        
        out = out.cpu()
        out_img_y = out.data[0].numpy()
        out_img_y *= 255.0
        out_img_y = out_img_y.clip(0, 255)
        out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

        out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
        out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
        out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

        out_img.save('hr/{:s}_SRCNN.png'.format(base))

0853x4
Time taken : 0.0006754398345947266 

119082
Time taken : 0.0004901885986328125 

42012
Time taken : 0.0004978179931640625 



### FSRCNN

In [23]:
class FSRCNN(torch.nn.Module):
    def __init__(self, num_channels, upscale_factor, d=64, s=12, m=4):
        super(FSRCNN, self).__init__()

        self.first_part = nn.Sequential(nn.Conv2d(in_channels=num_channels, out_channels=d, kernel_size=5, stride=1, padding=2),
                                        nn.PReLU())

        self.layers = []
        self.layers.append(nn.Sequential(nn.Conv2d(in_channels=d, out_channels=s, kernel_size=1, stride=1, padding=0),
                                         nn.PReLU()))
        for _ in range(m):
            self.layers.append(nn.Conv2d(in_channels=s, out_channels=s, kernel_size=3, stride=1, padding=1))
        self.layers.append(nn.PReLU())
        self.layers.append(nn.Sequential(nn.Conv2d(in_channels=s, out_channels=d, kernel_size=1, stride=1, padding=0),
                                         nn.PReLU()))

        self.mid_part = torch.nn.Sequential(*self.layers)

        # Deconvolution
        self.last_part = nn.ConvTranspose2d(in_channels=d, out_channels=num_channels, kernel_size=9, stride=upscale_factor, padding=3, output_padding=1)

    def forward(self, x):
        out = self.first_part(x)
        out = self.mid_part(out)
        out = self.last_part(out)
        return out

    def weight_init(self, mean=0.0, std=0.02):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(mean, std)
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.0001)
                if m.bias is not None:
                    m.bias.data.zero_()


In [24]:
fsrcnn = FSRCNN(1, 4)
fsrcnn.load_state_dict(torch.load("fsrcnn_x4.pth"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [25]:
fsrcnn = fsrcnn.to('cpu' if args['cpu'] else 'cuda')
fsrcnn.eval()

FSRCNN(
  (first_part): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): PReLU(num_parameters=1)
  )
  (mid_part): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 12, kernel_size=(1, 1), stride=(1, 1))
      (1): PReLU(num_parameters=1)
    )
    (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PReLU(num_parameters=1)
    (6): Sequential(
      (0): Conv2d(12, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): PReLU(num_parameters=1)
    )
  )
  (last_part): ConvTranspose2d(64, 1, kernel_size=(9, 9), stride=(4, 4), padding=(3, 3), output_padding=(1, 1))
)

#### Super resolve images

In [26]:
lr_imgs = glob.glob(args['inp_images'])
with torch.no_grad():
    for lr_img in lr_imgs:
        base = os.path.splitext(os.path.basename(lr_img))[0]
        lr_img = Image.open(lr_img).convert('YCbCr')
        y, cb, cr = lr_img.split()

        dt = (ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
        dt = dt.to('cpu' if args['cpu'] else 'cuda')
        
        start_time = time.time()
        out = fsrcnn(dt)
        end_time = time.time()
        
        print(base)
        print('Time taken :', end_time - start_time, '\n')

        out = out.cpu()
        out_img_y = out.data[0].numpy()
        out_img_y *= 255.0
        out_img_y = out_img_y.clip(0, 255)
        out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

        out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
        out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
        out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

        out_img.save('hr/{:s}_FSRCNN.png'.format(base))

0853x4
Time taken : 0.2557401657104492 

119082
Time taken : 0.0009124279022216797 

42012
Time taken : 0.0009877681732177734 



SRGAN

In [27]:
import torch.nn.functional as F

In [28]:
def swish(x):
    return x * F.sigmoid(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, kernel, out_channels, stride):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = swish(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1)
        self.shuffler = nn.PixelShuffle(2)

    def forward(self, x):
        return swish(self.shuffler(self.conv(x)))


class SRGANGenerator(nn.Module):
    def __init__(self, n_residual_blocks, upsample_factor, num_channel=1, base_filter=64):
        super(SRGANGenerator, self).__init__()
        self.n_residual_blocks = n_residual_blocks
        self.upsample_factor = upsample_factor

        self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=9, stride=1, padding=4)

        for i in range(self.n_residual_blocks):
            self.add_module('residual_block' + str(i + 1), ResidualBlock(in_channels=base_filter, out_channels=base_filter, kernel=3, stride=1))

        self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(base_filter)

        for i in range(self.upsample_factor // 2):
            self.add_module('upsample' + str(i + 1), UpsampleBlock(base_filter))

        self.conv3 = nn.Conv2d(base_filter, num_channel, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = swish(self.conv1(x))

        y = x.clone()
        for i in range(self.n_residual_blocks):
            y = self.__getattr__('residual_block' + str(i + 1))(y)

        x = self.bn2(self.conv2(y)) + x

        for i in range(self.upsample_factor // 2):
            x = self.__getattr__('upsample' + str(i + 1))(x)

        return self.conv3(x)

    def weight_init(self, mean=0.0, std=0.02):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()


In [29]:
srgan = SRGANGenerator(16, 4, )
srgan.load_state_dict(torch.load("srgan_x4.pth"))
srgan = srgan.to('cpu' if args['cpu'] else 'cuda')
srgan.eval()

SRGANGenerator(
  (conv1): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (residual_block1): ResidualBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (residual_block2): ResidualBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (residual_block3): ResidualBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e

In [31]:
lr_imgs = glob.glob(args['inp_images'])
with torch.no_grad():
    for lr_img in lr_imgs:
        base = os.path.splitext(os.path.basename(lr_img))[0]
        lr_img = Image.open(lr_img).convert('YCbCr')
        y, cb, cr = lr_img.split()

        dt = (ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
        dt = dt.to('cpu' if args['cpu'] else 'cuda')
        
        start_time = time.time()
        out = srgan(dt)
        end_time = time.time()
        
        print(base)
        print('Time taken :', end_time - start_time, '\n')

        out_img_y = out.cpu().data[0].numpy()
        out_img_y *= 255.0
        out_img_y = out_img_y.clip(0, 255)
        out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

        out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
        out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
        out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

        out_img.save('hr/{:s}_SRGAN.png'.format(base))

0853x4
Time taken : 0.007898569107055664 

119082
Time taken : 0.006199359893798828 

42012
Time taken : 0.006172657012939453 



EDSR

In [32]:
import math

In [33]:
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

class BasicBlock(nn.Sequential):
    def __init__(
        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
        bn=True, act=nn.ReLU(True)):

        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)

        super(BasicBlock, self).__init__(*m)

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res

class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)


class EDSR(nn.Module):
    def __init__(self, conv=default_conv):
        super(EDSR, self).__init__()

        n_resblocks = 32
        n_feats = 256
        kernel_size = 3 
        scale = 2
        act = nn.ReLU(True)
        self.sub_mean = MeanShift(255)
        self.add_mean = MeanShift(255, sign=1)

        # define head module
        m_head = [conv(3, n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=0.1
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, 3, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x)

        return x 


In [34]:
edsr = EDSR().cuda()

In [35]:
edsr.load_state_dict(torch.load('edsr_x2-0edfb8a3.pt'))
edsr.eval()

EDSR(
  (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (head): Sequential(
    (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResBlock(
      (body): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): ResBlock(
      (body): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (2): ResBlock(
      (body): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
   

In [38]:
lr_imgs = glob.glob(args['inp_images'])
with torch.no_grad():
    for lr_img in lr_imgs:
        base = os.path.splitext(os.path.basename(lr_img))[0]
        y = Image.open(lr_img).convert('RGB')
    
        dt = (ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
        dt = dt.to('cpu' if args['cpu'] else 'cuda')
        
        start_time = time.time()
        hr_img = edsr(dt)
        end_time = time.time()
        
        print(base)
        print('Time taken :', end_time - start_time, '\n')
        
        out = hr_img.cpu()
        out_img = out.data[0].numpy()
        out_img *= 255.0
        out_img = out_img.clip(0, 255)
        out_img = np.transpose(out_img[[0, 1, 2], :, :], (1, 2, 0))
        out_img = Image.fromarray(np.uint8(out_img), 'RGB')
        out_img.save('hr/{:s}_EDSR.png'.format(base))

0853x4
Time taken : 0.008395195007324219 

119082
Time taken : 0.00790858268737793 

42012
Time taken : 0.008928775787353516 

