In [2]:
import math
from math import sqrt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import random
import torch.optim as optim
import torchvision.utils as vutils


In [3]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!unzip 'datasets.zip'


# Define Model 

## EDSR

In [5]:
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)


class MeanShift(nn.Conv2d):
    def __init__(
      self, rgb_range,
      rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


class ResBlock(nn.Module):
    def __init__(
      self, conv, n_feats, kernel_size,
      bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)


class EDSR(nn.Module):
    def __init__(self, n_resblocks, n_feats, scale, res_scale,
                 pretrained=False):
        super(EDSR, self).__init__()
        self.scale = scale

        kernel_size = 3
        n_colors = 3
        rgb_range = 255
        conv = default_conv
        act = nn.ReLU(True)
        self.sub_mean = MeanShift(rgb_range)
        self.add_mean = MeanShift(rgb_range, sign=1)

        # define head module
        m_head = [conv(n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x, scale=None):
        if scale is not None and scale != self.scale:
            raise ValueError(f"Network scale is {self.scale}, not {scale}")
        x = self.sub_mean(255 * x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x) / 255

        return x


def edsr_r16f64(scale, pretrained=False):
    return EDSR(16, 64, scale, 1.0, pretrained)


def edsr_r32f256(scale, pretrained=False):
    return EDSR(32, 256, scale, 0.1, pretrained)


def edsr_baseline(scale, pretrained=False):
    return edsr_r16f64(scale, pretrained)


def edsr(scale, pretrained=False):
    return edsr_r32f256(scale, pretrained)


# Prepare for training

In [6]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")


In [7]:
model = edsr_baseline(scale=3, pretrained=False).to(device)


# Inference

In [8]:
model.load_state_dict(torch.load('SR_model.pkl'))


<All keys matched successfully>

In [9]:
dataroot = '/content/datasets/testing_lr_images'
image_filenames = [x for x in sorted(os.listdir(dataroot))]

trans_tensor = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

trans_img = transforms.ToPILImage()


In [10]:
def quantize(img, rgb_range=1):
  pixel_range = 255 / rgb_range
  return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)


In [11]:
model.eval()

with torch.no_grad():
  for f in image_filenames:
    path = os.path.join(dataroot, f)
    out_name = f[:-4] + '_pred.png'
    img = Image.open(path)
    (w, h) = img.size
    img = trans_tensor(img)
    pred = model(img.unsqueeze(0).to(device))
    pred = quantize(pred)
    output_path = os.path.join('./', out_name)
    vutils.save_image(pred.clone(), output_path, normalize=True)
