In [1]:
import sys
import os
import glob
import numpy as np
import torch
from PIL import Image

In [2]:
data_name_list = sorted(glob.glob('../T91/*.png'))

In [3]:
class GetTrainData(object):
    def __init__(self, data_path_list):
        self.data_path_list = data_path_list
    def get_data(self):
        X, Y = [], []
        for i in range(len(self.data_path_list)):
            img = Image.open(self.data_path_list[i]).convert('L')
            (hight, width) = img.size
            if((width<128)|(hight<128)):
                continue
            img_array = np.array(img).astype(np.uint8)
            for w in range(0, width-128+1, 64):
                for h in range(0, hight-128+1, 64):
                    batch_img_array = img_array[w:w+128, h:h+128]
                    Y.append(batch_img_array.reshape(128,128,1))
                    batch_img = Image.fromarray(batch_img_array)
                    batch_img = batch_img.resize((64, 64), Image.BICUBIC)
                    batch_img = np.array(batch_img).astype(np.uint8)
                    X.append(batch_img.reshape(64,64,1))
        X = np.array(X)
        Y = np.array(Y)
        return X/255.0, Y/255.0

In [4]:
X, Y = GetTrainData(data_name_list).get_data()

In [5]:
print(f'LR image size = {X.shape}')
print(f'HR image size = {Y.shape}')
X = X.transpose(0, 3, 1, 2)
Y = Y.transpose(0, 3, 1, 2)
print(f'LR image size = {X.shape}')
print(f'HR image size = {Y.shape}')

LR image size = (501, 64, 64, 1)
HR image size = (501, 128, 128, 1)
LR image size = (501, 1, 64, 64)
HR image size = (501, 1, 128, 128)


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

In [7]:
import math
class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, bn = False, act = False, bias = True):
        m = []
        if(scale&(scale-1))==0:
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4*n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: 
                    m.append(nn.BatchNorm2d(n_feat))
                if act:
                    m.append(act())
        elif scale==3:
            m.append(conv(n_feat, 9*n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feat))
            if act:
                m.append(act())
        else:
            raise NotImplementedError
        
        super(Upsampler, self).__init__(*m)

In [8]:
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x*y

In [9]:
class RCAB(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, bias = True, bn = False, act = nn.ReLU(True), res_scale=1):
        super(RCAB, self).__init__()
        
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias = bias))
            if bn:
                modules_body.append(nn.BatchNorm2d(n_feat))
            if i==0:
                modules_body.append(act)
        modules_body.append(CALayer(n_feat, reduction))
        self.body = nn.Sequential(*modules_body)
        self.res_scale = res_scale
    def forward(self, x):
        res = self.body(x)
        res += x
        return res

In [10]:
class ResidualGroup(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
        super().__init__()
        modules_body = []
        modules_body = [
            RCAB(conv, n_feat, kernel_size, reduction, bias = True, bn =False, act = nn.ReLU(True), res_scale = 1) for _ in range(n_resblocks) 
        ]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)
        
    def forward(self, x):
        res = self.body(x)
        res += x
        return res

In [11]:
def default_conv(in_channels, out_channels, kernel_size, bias = True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding = (kernel_size//2), bias = bias)

class Model(nn.Module):
    def __init__(self, n_colors, n_resgroups, n_resblocks, n_feats, reduction,  scale, res_scale, conv = default_conv):
        super().__init__()
        
        kernel_size = 3
        act = nn.ReLU(True)
        
        modules_head = [conv(n_colors, n_feats, kernel_size)]
        modules_body = [
            ResidualGroup(conv, n_feats, kernel_size, reduction, act = act, res_scale = res_scale, n_resblocks = n_resblocks) for _ in range(n_resgroups)
        ]
        
        modules_body.append(conv(n_feats, n_feats, kernel_size))
        
        modules_tail = [
            Upsampler(conv, scale, n_feats, act = False),
            conv(n_feats, n_colors, kernel_size)
        ]
        
        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
        
    def forward(self, x):
        x = self.head(x)
        
        res = self.body(x)
        res += x
        
        x = self.tail(res)
        return x

In [12]:
import torch.utils.data
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, x, y=None):
        self.data = x
        if y is not None:
            self.label = y
        else:
            self.label = None

        self.datanum = x.shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = torch.from_numpy(self.data[idx])
        if self.label is not None:
            out_label = torch.from_numpy(self.label[idx])
            return out_data, out_label
        else:
            return out_data

In [13]:
import time

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
            
            
class RCAN(object):
    def __init__(self, MAX_EPOCH = 100, BATCH_SIZE = 32, lr = 0.00001, upscale=2, d=48, s=12, m=2):
        self.MAX_EPOCH = MAX_EPOCH
        self.BATCH_SIZE = BATCH_SIZE
        self.lr = lr
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.model = Model(n_colors=1, n_resgroups=2, n_resblocks=2, n_feats=64, reduction=2,  scale=2, res_scale=1, conv = default_conv)
        self.model.apply(weights_init)
        self.model = self.model.to(self.device)
        
        
        
        self.optimizer = optim.Adam(self.model.parameters(), lr = self.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)
        self.criterion = nn.MSELoss()
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size = 100, gamma = 0.01)
    
    def fit(self, X, Y):
        dataset = Mydatasets(X, Y)
        self.loss_list = []
        for epoch in range(self.MAX_EPOCH):
            loader = torch.utils.data.DataLoader(dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers = 2)
            running_loss = 0.0
            for batch_idx, (inputs, labels) in enumerate(loader):
                inputs, labels = Variable(inputs.float()).to(self.device), Variable(labels.float()).to(self.device)
                
                #zero the parameter gradients
                self.optimizer.zero_grad()
                
                #forward
                outputs = self.model(inputs)
                
                #backward+optimize
                loss = self.criterion(outputs, labels)
                #loss = gradient_sensitive_loss(outputs, labels)
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()
            print(f'Epoch[{epoch+1}/{self.MAX_EPOCH}]  loss : {running_loss/len(loader)}')
            self.loss_list.append(running_loss/len(loader))
            self.scheduler.step()
        print(f'Finish training...')
        
    def transform(self, X):
        self.model.eval()
        dataset = Mydatasets(X)
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
        transformed_X = []
        total_time = 0.0
        for batch_idx, inputs in enumerate(loader):
            inputs = Variable(inputs.float()).to(self.device)
            start = time.time()
            outputs = self.model(inputs).squeeze(0).cpu().detach().numpy()
            end = time.time()
            total_time+=(end-start)
            transformed_X.append(outputs)
        print(f'Average transform time = {total_time/len(loader)}')
        return np.array(transformed_X)
    
    def save_model(self, path=None):
        if path is None:
            path = './model.pth'
        torch.save(self.model.state_dict(), path)
        print("SAVE MODEL SUCCESS!!")

In [14]:
rcan = RCAN(MAX_EPOCH = 200, BATCH_SIZE = 16, lr = 0.0001,)

In [None]:
rcan.fit(X, Y)

Epoch[1/200]  loss : 19.704859424382448
Epoch[2/200]  loss : 1.5385871464386582
Epoch[3/200]  loss : 0.8685214258730412
Epoch[4/200]  loss : 0.6110734576359391
Epoch[5/200]  loss : 0.44842894468456507
Epoch[6/200]  loss : 0.34336212556809187
Epoch[7/200]  loss : 0.27877776604145765
Epoch[8/200]  loss : 0.23810729943215847
Epoch[9/200]  loss : 0.20813189866021276
Epoch[10/200]  loss : 0.18108301609754562
Epoch[11/200]  loss : 0.1616389281116426
Epoch[12/200]  loss : 0.14675945392809808
Epoch[13/200]  loss : 0.1315316583495587
Epoch[14/200]  loss : 0.1209555109962821
Epoch[15/200]  loss : 0.1108209565281868
Epoch[16/200]  loss : 0.10345213324762881
Epoch[17/200]  loss : 0.09636175679042935
Epoch[18/200]  loss : 0.0892292937496677
Epoch[19/200]  loss : 0.08593877137172967
Epoch[20/200]  loss : 0.08151182380970567
Epoch[21/200]  loss : 0.07571697351522744
Epoch[22/200]  loss : 0.07201872032601386
Epoch[23/200]  loss : 0.0686193557921797
Epoch[24/200]  loss : 0.06499593600165099
Epoch[25/20

In [None]:
import matplotlib.pyplot as plt
transformed_X = rcan.transform(X)
SR_image = transformed_X.transpose(0, 2, 3, 1)
HR_image = Y.transpose(0, 2, 3, 1)
for i in range(10):
    plt.figure()
    plt.gray()
    plt.imshow(HR_image[i])
    plt.figure()
    plt.gray()
    plt.imshow(SR_image[i])