In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!unzip -q drive/My\ Drive/dataset/image/anime.zip

In [0]:
import os, random
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
import numpy as np
from PIL import Image, ImageFilter

In [0]:
class FSRCNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FSRCNN, self).__init__()
        self.first_part = self.__first()
        self.mid_part = self.__mid()
        self.last_part = self.__last()
        self.smooth_part = self.__smooth()
    
    def __first(self):
        first = nn.Sequential()
        first.add_module('first_conv1', nn.Conv2d(1, 32, kernel_size=3, padding=1))
        first.add_module('first_prelu1', nn.PReLU())
        first.add_module('first_conv2', nn.Conv2d(32, 32, kernel_size=3, padding=1))
        first.add_module('first_prelu2', nn.PReLU())
        first.add_module('first_conv3', nn.Conv2d(32, 64, kernel_size=3, padding=1))
        first.add_module('first_prelu3', nn.PReLU())
        first.add_module('first_conv4', nn.Conv2d(64, 64, kernel_size=3, padding=1))
        first.add_module('first_prelu4', nn.PReLU())
        for m in first.modules():
            if type(m) is nn.Conv2d:
                nn.init.kaiming_normal_(m.weight)
        return first
    
    def __mid(self):
        mid = nn.Sequential()
        mid.add_module('mid_conv1', nn.Conv2d(64, 16, kernel_size=1))
        mid.add_module('mid_prelu1', nn.PReLU())
        for i in range(4):
            mid.add_module(f'mid_conv{i+2}', nn.Conv2d(16, 16, kernel_size=3, padding=1))
        mid.add_module('mid_prelu2', nn.PReLU())
        mid.add_module('mid_conv6', nn.Conv2d(16, 64, kernel_size=1))
        mid.add_module('mid_prelu3', nn.PReLU())
        for m in mid.modules():
            if type(m) is nn.Conv2d:
                nn.init.kaiming_normal_(m.weight)
        return mid
    
    def __last(self):
        last = nn.ConvTranspose2d(64, 3, kernel_size=5, padding=2, stride=2, output_padding=1)
        nn.init.kaiming_normal_(last.weight)
        return last
    
    def __smooth(self):
        smooth = nn.Conv2d(3, 1, kernel_size=5, padding=2)
        nn.init.kaiming_normal_(smooth.weight)
        return smooth
    
    def forward(self, x):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        x = self.smooth_part(x)
        return x

In [0]:
class SuperResolutionDataset(Dataset):
    def __init__(self, path, blur=False):
        self.path = path
        self.n_samples = len(list(os.walk(path))[0][2])
        self.preprocess = tv.transforms.Compose([
            tv.transforms.ToPILImage(),
            tv.transforms.RandomCrop(224),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.RandomVerticalFlip()
        ])
        self.downscale = tv.transforms.Resize(112)
        self.totensor = tv.transforms.ToTensor()
        self.blur = blur
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        I = cv2.imread(os.path.join(self.path, f'img_{idx}.png'))
        y = cv2.split(cv2.cvtColor(I, cv2.COLOR_BGR2YCrCb))[0]
        I = self.preprocess(y)
        image = self.downscale(I)
        if self.blur and random.random() > .5:
            image = image.filter(ImageFilter.GaussianBlur(radius=1))
        I = self.totensor(I)
        image = self.totensor(image)
        return {"low":image.cuda(), "high":I.cuda()}

In [0]:
def find_ckpt(path):
    epo = 0
    ckpt_list = list(os.walk(path))[0][2]
    for ckpt_file in ckpt_list:
        if ckpt_file.endswith('.pt'):
            ep = int(ckpt_file[:-3].split('-')[1])
            if ep > epo:
                epo = ep
    return f'sfsrcnn-{epo}.pt'

def save_ckpt(path, model, optimizer, scheduler, epoch, step, last_step):
    torch.save({
        'epoch': epoch, 
        'step': step, 
        'last_step': last_step, 
        'model_state_dict': model.state_dict(), 
        'optimizer_state_dict': optimizer.state_dict(), 
        'scheduler_state_dict': scheduler.state_dict(), 
    }, os.path.join(path, f'sfsrcnn-{epoch}.pt'))

def load_ckpt(path, model, optimizer=None, scheduler=None, epoch=0):
    if epoch > 0:
        ckpt_file = f'sfsrcnn-{epoch}.pt'
    else:
        ckpt_file = find_ckpt(path)
    state = torch.load(os.path.join(path, ckpt_file))
    model.load_state_dict(state['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(state['optimizer_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(state['scheduler_state_dict'])
    return state['epoch'], state['step'], state['last_step']

In [0]:
class SuperImage:
    def __run(self, model, image):
        model.eval()
        image = image/255.
        model.cuda()
        image = image.cuda()
        outp = model(image)
        outp = torch.clamp(outp, 0., 1.)
        outp = outp.detach().cpu().numpy().squeeze()
        return (outp*255.).astype(np.uint8)
    
    def scale2x(self, model, image):
        if isinstance(image, str):
            image = cv2.imread(image)
        y, cr, cb = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb))
        h, w = y.shape
        cr = cv2.resize(cr, (2*w,2*h), cv2.INTER_LANCZOS4)
        cb = cv2.resize(cb, (2*w,2*h), cv2.INTER_LANCZOS4)
        y = torch.from_numpy(y[np.newaxis,np.newaxis,:,:]).type(torch.FloatTensor)
        outp = self.__run(model, y)
        image2x = np.stack((outp, cr, cb), axis=2)
        image2x = cv2.cvtColor(image2x, cv2.COLOR_YCrCb2BGR)
        return image2x

In [0]:
def training_loop(model, dataloader, path, lr1=1e-3, lr2=1e-4, epoch=100, resume=-1, load_optim=True):
    model.cuda()
    mse = nn.MSELoss().cuda()
    optimizer = optim.Adam([
        {'params': model.first_part.parameters(), 'lr': lr1},
        {'params': model.mid_part.parameters(), 'lr': lr1},
        {'params': model.last_part.parameters(), 'lr': lr2},
        {'params': model.smooth_part.parameters(), 'lr': lr2}
    ])
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 75, gamma=.1)
    epo, batch = 0, 0
    last_batch = 0
    if resume >= 0:
        if load_optim:
            epo, batch, last_batch = load_ckpt(path, model, optimizer, lr_scheduler, epoch=resume)
        else:
            epo, batch, last_batch = load_ckpt(path, model, epoch=resume)
        print(f'restart after epoch {epo}')
    solver = SuperImage()
    print(f'model is training: {model.training}')
    for ep in range(epo, epoch):
        epoch_loss = 0.
        lrs = [params['lr'] for params in optimizer.state_dict()['param_groups']]
        for sample in dataloader:
            inp = sample["low"]
            gt = sample["high"]
            model.train()
            outp = model(inp)
            outp = torch.clamp(outp, 0., 1.)
            loss = mse(gt, outp)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss
            batch += 1
        lr_scheduler.step()
        print(f"Ep.{ep+1} - loss: {epoch_loss/(batch-last_batch):.6f} - lr: {lrs}")
        if (ep+1) % 10 == 0:
            save_ckpt(path, model, optimizer, lr_scheduler, ep+1, batch, last_batch)
            print(f'Ep.{ep+1} - model saved')
        last_batch = batch
        cv2.imwrite(f'demo_sakura.png', solver.scale2x(model, 'demo_half.png'))
        torch.cuda.empty_cache()

In [0]:
dataset = SuperResolutionDataset('data')
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [0]:
fsrcnn = FSRCNN(1, 1)

In [0]:
training_loop(fsrcnn, dataloader, 'drive/My Drive/checkpoints/FSRCNN')