In [1]:
import sys
import os
import glob
import numpy as np
import torch
from PIL import Image

In [2]:
data_name_list = sorted(glob.glob('../T91/*.png'))

In [3]:
class GetTrainData(object):
    def __init__(self, data_path_list):
        self.data_path_list = data_path_list
    def get_data(self):
        X, Y = [], []
        for i in range(len(self.data_path_list)):
            img = Image.open(self.data_path_list[i]).convert('L')
            (hight, width) = img.size
            if((width<128)|(hight<128)):
                continue
            img_array = np.array(img).astype(np.uint8)
            for w in range(0, width-128+1, 64):
                for h in range(0, hight-128+1, 64):
                    batch_img_array = img_array[w:w+128, h:h+128]
                    Y.append(batch_img_array.reshape(128,128,1))
                    batch_img = Image.fromarray(batch_img_array)
                    batch_img = batch_img.resize((64, 64), Image.BICUBIC)
                    batch_img = np.array(batch_img).astype(np.uint8)
                    X.append(batch_img.reshape(64,64,1))
        X = np.array(X)
        Y = np.array(Y)
        return X/255.0, Y/255.0

In [4]:
X, Y = GetTrainData(data_name_list).get_data()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

In [None]:
import math
class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, bn = False, act = False, bias = True):
        m = []
        if(scale&(scale-1))==0:
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4*n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: 
                    m.append(nn.BatchNorm2d(n_feat))
                if act:
                    m.append(act())
        elif scale==3:
            m.append(conv(n_feat, 9*n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feat))
            if act:
                m.append(act())
        else:
            raise NotImplementedError
        
        super(Upsampler, self).__init__(*m)

In [None]:
class ResidualGroup(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
        super().__init__()
        modules_body = []
        modules_body = [
            RCAB(conv, n_feat, kernel_size, reduction, bias = True, bn =False, act = nn.ReLU(True), res_scale = 1) for _ in range(n_resblocks) 
        ]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)
        
    def forward(self, x):
        res = self.body(x)
        res += x
        return res

In [None]:
def default_conv(in_channels, out_channels, kernel_size, bias = True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding = (kernel_size//2), bias = bias)

class RCAN(nn.Module):
    def __init__(self, n_colors, n_resgroups, n_resblocks, n_feats, reduction,  scale, res_scale, conv = default_conv):
        super().__init__()
        
        kernel_size = 3
        act = nn.ReLU(True)
        
        modules_head = [conv(n_colors, n_feats, kernel_size)]
        modules_body = [
            ResidualGroup(conv, n_feats, kernel_size, reduction, act = act, res_scale = res_scale, n_resblocks = n_resblocks) for _ in range(n_groups)
        ]
        
        modules_body.append(conv(n_feats, n_feats, kernel_size))
        
        modules_tail = [
            Upsampler(conv, scale, n_feats, act = False),
            conv(n_feats, n_colors, kernel_size)
        ]
        
        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
        
    def forward(self, x):
        x = self.head(x)
        
        res = self.body(x)
        res += x
        
        x = self.tail(res)
        return x