In [1]:
import math
from math import exp
import os
from os import listdir
from os.path import join
import numpy as np
import torch
from torch import nn
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import torchvision.utils as utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torch.utils.data.dataset import Dataset

import pandas as pd

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = torch.device('cuda:1')

In [3]:
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()
    def forward(self, x):
        return self.prelu(self.pixel_shuffle(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels=64):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):        
        return self.bn2(self.conv2(self.prelu(self.bn1(self.conv1(x))))) + x


In [4]:
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU())
        self.b2 = nn.Sequential(*[ResidualBlock(64) for _ in range(16)])
        self.b3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        self.b4 = nn.Sequential(*[UpsampleBLock(64, 2) for _ in range(upsample_block_num)])
        self.tail = nn.Conv2d(64, 3, kernel_size=9, padding=4)

    def forward(self, x):
        start = self.b1(x)
        end = self.b4(self.b3(self.b2(start)) + start)
        return self.tail(end)


In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ch = [64, 64, 128, 128, 256, 256, 512, 512]
        body = []
        for i in range(1, len(ch)):
            body.extend([nn.Conv2d(ch[i-1], ch[i], kernel_size=3, stride=2, padding=1),
                         nn.BatchNorm2d(ch[i]),
                         nn.LeakyReLU(0.2)])
        self.start = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2))    
        self.body = nn.Sequential(*body)
        self.tail = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1),
                                  nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1))

    def forward(self, x):
        return torch.sigmoid(self.tail(self.body(self.start(x))).view(x.size(0)))


In [7]:
class GLoss(nn.Module):
    def __init__(self):
        super(GLoss, self).__init__()
        vgg = models.vgg16(pretrained=True)
        vgg_part = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in vgg_part.parameters():
            param.requires_grad = False
        self.vgg = vgg_part
        self.mse_loss = nn.MSELoss()

    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(self.vgg(out_images), self.vgg(target_images))
        image_loss = self.mse_loss(out_images, target_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss

In [8]:
def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, i) for i in listdir(dataset_dir) if i.endswith('.jpg')][:90000]
        self.upsc = upscale_factor
        self.crop_size = crop_size - (crop_size % upscale_factor)
        
    def __getitem__(self, index):
        img = Image.open(self.image_filenames[index]).convert(mode='RGB')
        hr_image = RandomCrop(self.crop_size)(img)
        lr_image = Resize(self.crop_size // self.upsc, interpolation=Image.BICUBIC)(hr_image)
        try:
            return ToTensor()(lr_image), ToTensor()(hr_image)
        except:
            print(self.image_filenames[index])
            
    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.crop_size = crop_size - (crop_size % upscale_factor)
        self.image_filenames = [join(dataset_dir, i) for i in listdir(dataset_dir) if i.endswith('.jpg')][90000:]
        

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index]).convert(mode='RGB')
        w, h = hr_image.size
        crop_size = min(w, h) - (min(w, h) % self.upscale_factor)
        hr_image = CenterCrop(self.crop_size)(hr_image)
        lr_image = Resize(self.crop_size // self.upscale_factor, interpolation=Image.BICUBIC)(hr_image)
        hr_restore_img = Resize(self.crop_size, interpolation=Image.BICUBIC)(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)

In [9]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

In [10]:
params = {'crop_size': 96, 'upscale_factor': 4, 'num_epochs':40, 'batch_size': 64}
crop_size = params['crop_size']
upscale = params['upscale_factor']
num_epochs = params['num_epochs']
batch_size = params['batch_size']

train_set = TrainDatasetFromFolder('/mnt/storage-500g/datasets/VGDataset/VG_100K_vkz', crop_size=crop_size, upscale_factor=upscale)
val_set = ValDatasetFromFolder('/mnt/storage-500g/datasets/VGDataset/VG_100K_vkz', crop_size=crop_size, upscale_factor=upscale)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=batch_size, shuffle=False)


G = Generator(upscale)
D = Discriminator()
G.load_state_dict(torch.load('epochs/G_best.pth'))
D.load_state_dict(torch.load('epochs/D_best.pth'))
G = G.to(device)
D = D.to(device)
GLoss = GLoss().to(device)


optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.9, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.9, 0.999))
schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=params['num_epochs'] // 2, gamma=0.1)
schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=params['num_epochs'] // 2, gamma=0.1)

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

n_iter = 0

for epoch in range(1, num_epochs + 1):
    train_bar = tqdm(train_loader)
    epoch_res = {'batch_size': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    G.train()
    D.train()
    for n, (data, real_img) in enumerate(train_bar):
        if data is None:
            continue
        epoch_res['batch_size'] += batch_size
        data = data.to(device)
        real_img = real_img.to(device)
        fake_img = G(data)

        D.zero_grad()
        real_out = D(real_img).mean()
        fake_out = D(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        optimizerD.step()
        G.zero_grad()
        g_loss = GLoss(Variable(fake_out.detach(), requires_grad=True),
                       Variable(fake_img.detach(), requires_grad=True),
                       Variable(real_img.detach(), requires_grad=True))
        g_loss.backward()

        fake_img = G(data)
        fake_out = D(fake_img).mean()

        optimizerG.step()

        epoch_res['g_loss'] += g_loss.item()
        epoch_res['d_loss'] += d_loss.item()
        epoch_res['d_score'] += real_out.item()
        epoch_res['g_score'] += fake_out.item()
        n_iter += 1
        if n % 1000 == 0:
            mul = batch_size / epoch_res['batch_size']
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, epoch_res['d_loss'] * mul, epoch_res['g_loss'] * mul,
                epoch_res['d_score'] * mul, epoch_res['g_score'] * mul))

    G.eval()

    with torch.no_grad():
        val_bar = tqdm(val_loader)
        val_res = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_size': 0}
        for n, (lr, hr_restore, hr) in enumerate(val_bar):
            val_res['batch_size'] += batch_size
            lr = lr.to(device)
            hr = hr.to(device)
            sr = G(lr)

            batch_mse = ((sr - hr) ** 2).data.mean()
            val_res['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            val_res['ssims'] += batch_ssim * batch_size
            val_res['psnr'] = 10 * math.log10((hr.max()**2) / (val_res['mse'] / val_res['batch_size']))
            val_res['ssim'] = val_res['ssims'] / val_res['batch_size']
            if n % 150 == 0:
                val_bar.set_description(desc='PSNR: %.4f dB SSIM: %.4f' % (val_res['psnr'], val_res['ssim']))
    torch.save(G.state_dict(), 'epochs/G_epoch_%d.pth' % (epoch))
    torch.save(D.state_dict(), 'epochs/D_epoch_%d.pth' % (epoch))

    results['d_loss'].append(epoch_res['d_loss'] / epoch_res['batch_size'])
    results['g_loss'].append(epoch_res['g_loss'] / epoch_res['batch_size'])
    results['d_score'].append(epoch_res['d_score'] / epoch_res['batch_size'])
    results['g_score'].append(epoch_res['g_score'] / epoch_res['batch_size'])
    results['psnr'].append(val_res['psnr'])
    results['ssim'].append(val_res['ssim'])
    
    if epoch != 0:
        out_path = 'statistics/'
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_path + 'srf_' + str(upscale) + '_train_results.csv', index_label='Epoch')

[1/40] Loss_D: 1.0000 Loss_G: 0.0087 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 1407/1407 [09:38<00:00,  2.43it/s]
PSNR: 20.8755 dB SSIM: 0.6449: 100%|██████████| 283/283 [00:20<00:00, 14.14it/s]
[2/40] Loss_D: 1.0000 Loss_G: 0.0087 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 1407/1407 [09:40<00:00,  2.42it/s]
PSNR: 20.9046 dB SSIM: 0.6468: 100%|██████████| 283/283 [00:20<00:00, 13.98it/s]
[3/40] Loss_D: 1.0000 Loss_G: 0.0087 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 1407/1407 [09:40<00:00,  2.43it/s]
PSNR: 20.8827 dB SSIM: 0.6449: 100%|██████████| 283/283 [00:19<00:00, 14.20it/s]
[4/40] Loss_D: 1.0000 Loss_G: 0.0087 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 1407/1407 [09:40<00:00,  2.43it/s]
PSNR: 20.8643 dB SSIM: 0.6447: 100%|██████████| 283/283 [00:19<00:00, 14.25it/s]
[5/40] Loss_D: 1.0000 Loss_G: 0.0086 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 1407/1407 [09:40<00:00,  2.43it/s]
PSNR: 20.9151 dB SSIM: 0.6475: 100%|██████████| 283/283 [00:19<00:00, 14.27it/s]
