In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/Colab Notebooks/ADL/2/SRGAN_pytorch/

In [None]:
pip install torch==1.4.0 torchvision==0.5.0

#Install Packages

In [None]:
from os import listdir
from os.path import join
from PIL import Image
import torch
from torch import nn
from torchvision.models.vgg import vgg16
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
import math
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim

#Definations

In [16]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])


def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])



class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)


In [17]:
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = ResidualBlock(64)
        self.block8 = ResidualBlock(64)
        self.block9 = ResidualBlock(64)
        self.block10 = ResidualBlock(64)
        self.block11 = ResidualBlock(64)
        self.block12 = ResidualBlock(64)
        self.block13 = ResidualBlock(64)
        self.block14 = ResidualBlock(64)
        self.block15 = ResidualBlock(64)
        self.block16 = ResidualBlock(64)
        self.block17 = ResidualBlock(64)
        self.block18 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block19 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block19.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block19 = nn.Sequential(*block19)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block7)
        block9 = self.block9(block8)
        block10 = self.block10(block9)
        block11 = self.block11(block10)
        block12 = self.block12(block11)
        block13 = self.block13(block12)
        block14 = self.block14(block13)
        block15 = self.block15(block14)
        block16 = self.block16(block15)
        block17 = self.block17(block16)
        block18 = self.block18(block17)

        block19 = self.block19(block1 + block18)

        return (torch.tanh(block19) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x


In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        out_labels = torch.log(out_labels)
        adversarial_loss = torch.mean(- out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        return image_loss + 0.001 * adversarial_loss + 0.06 * perception_loss
        # return image_loss  

#SRNET

In [None]:
CROP_SIZE = 96
UPSCALE_FACTOR = 4
NUM_EPOCHS = 1000
load_model=False
gen_path='Pretrained Generator model Path'

train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)

netG = Generator(UPSCALE_FACTOR)
generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    generator_criterion.cuda()

if load_model:
    netG.load_state_dict(torch.load(gen_path))

optimizerG = optim.Adam(netG.parameters(),lr=0.0001)

results = {'g_loss': []}

for epoch in range(1, NUM_EPOCHS + 1):
    print(f'Epoch : {epoch}')
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'g_loss': 0}

    netG.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        real_img = Variable(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)

        #Update G network
        netG.zero_grad()
        g_loss = generator_criterion(fake_img, real_img)
        print(f'G Loss : {g_loss}')
        g_loss.backward()
        optimizerG.step()

        running_results['g_loss'] += g_loss.item() * batch_size

    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])

    if epoch %20 == 0 and epoch != 0:
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d_SRNet.pth' % (UPSCALE_FACTOR, epoch))


#SRGAN

In [None]:
CROP_SIZE = 96
UPSCALE_FACTOR = 4
NUM_EPOCHS = 1000
load_model=True
gen_path='Pretrained Generator model Path'
dec_path='Pretrained Discriminator model Path'

train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)

netG = Generator(UPSCALE_FACTOR)
netD = Discriminator()

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

if oad_model:
    print('Loading previous model')
    netG.load_state_dict(torch.load(gen_path))
    # netD.load_state_dict(torch.load(dis_path))

optimizerG = optim.Adam(netG.parameters(),lr=0.0001)
optimizerD = optim.Adam(netD.parameters(),lr=0.00001)

results = {'d_loss': [], 'g_loss': [], 'r_score': [], 'f_score': []}

for epoch in range(1, NUM_EPOCHS + 1):
    print(f'Epoch : {epoch}')
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'r_score':0,'f_score':0}

    netG.train()
    netD.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        real_img = Variable(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)
        real_out = torch.log(netD(real_img)).mean()
        fake_dout = netD(fake_img)
        fake_out = torch.log(fake_dout).mean()
        d_loss = - real_out - fake_out
        #Update D network
        if epoch%5==0:
          netD.zero_grad()
          d_loss.backward(retain_graph=True)
          optimizerD.step()

        #Update G network
        g_loss = generator_criterion(fake_dout, fake_img, real_img)
        # print(f'G Loss : {g_loss}')
        netG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['r_score'] += real_out.item() * batch_size
        running_results['f_score'] += fake_out.item() * batch_size

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['r_score'] / running_results['batch_sizes'],
            running_results['f_score'] / running_results['batch_sizes']))

    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['r_score'].append(running_results['r_score'] / running_results['batch_sizes'])
    results['f_score'].append(running_results['f_score'] / running_results['batch_sizes'])

    if epoch %10 == 0 and epoch != 0:
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d_SRGAN.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d_SRGAN.pth' % (UPSCALE_FACTOR, epoch))
        out_path = 'statistics/'
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_real': results['r_score'],
                  'Score_fake': results['f_score'] },index=range(1, (epoch + 1)))
        data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results_SRGAN.csv', index_label='Epoch')


#Convert LR to SR

In [None]:
UPSCALE_FACTOR=4
netG = Generator(UPSCALE_FACTOR)

if torch.cuda.is_available():
    netG.cuda()

netG.load_state_dict(torch.load('Pretrained Generator model Path'))

IMAGE_NAME='0863x4.png'
image=Image.open(IMAGE_NAME)
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)

image = image.cuda()

out = netG(image)
out_img = ToPILImage()(out[0].data.cpu())
out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_SRGAN_' + IMAGE_NAME)

#Compute PSNR and  SSIM for Validation set

In [None]:
UPSCALE_FACTOR=4
netG = Generator(UPSCALE_FACTOR)

if torch.cuda.is_available():
    netG.cuda()

netG.load_state_dict(torch.load('Pretrained Generator model Path))

netG.eval()
out_path = 'updatd_training_results/SRF_' + str(UPSCALE_FACTOR) + '/'

val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

with torch.no_grad():
    val_bar = tqdm(val_loader)
    valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
    val_images = []
    for val_lr, val_hr_restore, val_hr in val_bar:
        batch_size = val_lr.size(0)
        valing_results['batch_sizes'] += batch_size
        lr = val_lr
        hr = val_hr
        if torch.cuda.is_available():
            lr = lr.cuda()
            hr = hr.cuda()
        sr = netG(lr)

        batch_mse = ((sr - hr) ** 2).data.mean()
        valing_results['mse'] += batch_mse * batch_size
        batch_ssim = pytorch_ssim.ssim(sr, hr).item()
        valing_results['ssims'] += batch_ssim * batch_size
        valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
        valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
        val_bar.set_description(desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                valing_results['psnr'], valing_results['ssim']))
