## Imports

In [None]:
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from model import resolve_single
from data import DIV2K
from utils import load_image
from PIL import Image
import numpy as np

%matplotlib inline

### Downloading and Importing LPIPS

In [None]:
!pip install lpips
import lpips
import torch

### Importing Simple XESRGAN Model

In [None]:
from model.sxesrgan import generator as sxesrgen
from model.sxesrgan import discriminator as sxesrdisc
from sxesrgantrain import SXESrganTrainer 
from sxesrgantrain import SXESrganGeneratorTrainer 

## Loading Datasets

In [None]:
div2k_train = DIV2K(scale=4, hr_size = 128, subset='train', downgrade='bicubic', images_dir='/content/gdrive/My Drive/super-resolution/.div2k/images', caches_dir='/content/gdrive/My Drive/super-resolution/.div2k/caches')
div2k_valid = DIV2K(scale=4, hr_size = 128,  subset='valid', downgrade='bicubic', images_dir='/content/gdrive/My Drive/super-resolution/.div2k/images', caches_dir='/content/gdrive/My Drive/super-resolution/.div2k/caches')

In [None]:
train_ds = div2k_train.dataset(batch_size=8, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=8, random_transform=True, repeat_count=1)

## Setting Weights Directory

In [None]:
weights_dir = 'weights/sxesrgan_lpips'
weights_file = lambda filename: os.path.join(weights_dir, filename)
os.makedirs(weights_dir, exist_ok=True)

## Training SXESRGAN

### Pre-Trainer 

In [None]:
pre_trainer = SXESrganGeneratorTrainer(model=sxesrgen(), checkpoint_dir=f'./ckpt_sxesr/pre_generator')

pre_trainer.train(train_ds,
                  valid_ds,
                  steps=2000, 
                  evaluate_every=50)

In [None]:
pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

### GAN Trainer

In [None]:
gan_generator = sxesrgen()
gan_discriminator = sxesrdisc(hr_size=128)
gan_generator.load_weights(weights_file('gan_generator.h5'))

In [None]:
gan_trainer = SXESrganTrainer(generator=gan_generator, discriminator=gan_discriminator, checkpoint_dir = './ckpt/sxesrgan_lpips', disc_type = 'ragan', loss_type = 'lpips')
gan_trainer.train(train_ds.take(200).repeat(None), evaluate_every=50, steps=4050) 

In [None]:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

### Network Interpolation

In [None]:
sxesrpre_generator = sxesrgen()
sxesrgan_generator = sxesrgen()


In [None]:
weights_dir = 'weights/sxesrgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)
sxesrpre_generator.load_weights(weights_file('pre_generator.h5'))

In [None]:
weights_dir = 'weights/sxesrgan_lpips'
weights_file = lambda filename: os.path.join(weights_dir, filename)
sxesrgan_generator.load_weights(weights_file('gan_generator.h5'))


In [None]:
ALPHA = 0.5
sxesrgan = sxesrgen()

vars_psnr = [v.numpy() for v in sxesrpre_generator.trainable_variables]
vars_esrgan = [v.numpy() for v in sxesrgan_generator.trainable_variables]

for i, var in enumerate(sxesrgan.trainable_variables):
  var.assign((1 - ALPHA) * vars_psnr[i] + ALPHA * vars_esrgan[i])

sxesrgan.save_weights(weights_file('gan_interp.h5'))

## Demo

In [None]:
pre_generator = generator()
gan_generator = generator()

pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))

In [None]:
from model import resolve_single
from utils import load_image

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
lr_image_path = 'demo/0869x4-crop.png'
resolve_and_plot(lr_image_path)