In [5]:
import os
import matplotlib.pyplot as plt

from data import DIV2K
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer

%matplotlib inline

In [6]:
weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)

os.makedirs(weights_dir, exist_ok=True)

Loading Datasets

In [3]:
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

In [4]:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)

Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip
  3907584/246914039 [..............................] - ETA: 1:30

KeyboardInterrupt: 

pre-training

In [None]:
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=1000000, 
                  evaluate_every=1000, 
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

GAN

In [16]:
gan_generator = generator()
gan_generator.load_weights(weights_file('pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)

KeyboardInterrupt: 

In [None]:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

implementation

In [7]:
pre_generator = generator()
gan_generator = generator()

pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))

In [None]:
import sys
import cv2
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
from model.srgan import resolve_single
import numpy as np
import tensorflow.compat.v1 as tf

class Ui_MainWindow(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(Ui_MainWindow, self).__init__(parent)

        self.set_ui()
        self.slot_init()

        self.__flag_work = 0

    def set_ui(self):
        self.setWindowTitle(u'Final Project')
        self.setFixedSize(960, 720)

        self.__layout_main = QtWidgets.QHBoxLayout()
        self.__layout_fun_button = QtWidgets.QVBoxLayout()
        self.__layout_show = QtWidgets.QVBoxLayout()
        self.__layout_text = QtWidgets.QVBoxLayout()

        self.button_open = QtWidgets.QPushButton(u'Open a picture to do super resolution')
        self.button_open.move(30, 105)
        self.button_open.resize(self.button_open.sizeHint())

        self.button_close = QtWidgets.QPushButton(u'Exit')
        self.button_close.resize(self.button_close.sizeHint())

        self.label_show = QtWidgets.QLabel()
        self.label_show.setFixedSize(641, 481)
        self.label_show.setAutoFillBackground(False)


        self.__layout_fun_button.addWidget(self.button_open)
        self.__layout_fun_button.addWidget(self.button_close)

        self.__layout_show.addWidget(self.label_show)

        self.__layout_main.addLayout(self.__layout_fun_button)
        self.__layout_main.addLayout(self.__layout_show)
        self.__layout_main.addLayout(self.__layout_text)

        self.setLayout(self.__layout_main)

    def slot_init(self):
        self.button_open.clicked.connect(self.open_click)
        self.button_close.clicked.connect(self.close)

    def open_click(self):
        fname, _ = QFileDialog.getOpenFileName(self, 'Choose an image', 'img/', 'Image files(*.jpg *.gif *.png)')
        imgm = cv2.imread(fname)
        imgtensor = resolve_single(pre_generator, imgm)
        with tf.Session() as sess:
            img = imgtensor.numpy()
        
        res = cv2.resize(img, (128, 96))
        imgread = cv2.cvtColor(res, cv2.COLOR_BGR2RGB) 
        img2 = resolve_single(gan_generator, imgread)

        image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
                             QtGui.QImage.Format_RGB888)


        img_out = QtGui.QPixmap(image)

        self.label_show.setPixmap(img_out)

    def closeEvent(self, event):
        ok = QtWidgets.QPushButton()
        cancel = QtWidgets.QPushButton()
        msg = QtWidgets.QMessageBox(QtWidgets.QMessageBox.Warning, u'Close', u'Whether to close')
        msg.addButton(ok, QtWidgets.QMessageBox.ActionRole)
        msg.addButton(cancel, QtWidgets.QMessageBox.RejectRole)
        ok.setText(u'yes')
        cancel.setText(u'no')
        if msg.exec_() == QtWidgets.QMessageBox.RejectRole:
            event.ignore()
        else:
            self.label_show.clear()
            event.accept()

if __name__ == '__main__':
    App = QApplication(sys.argv)
    win = Ui_MainWindow()
    win.show()
    sys.exit(App.exec_())

In [None]:
#If there is something wrong with GUI, please try this to test the SRGAN
from model.srgan import resolve_single

import numpy as np
from utils import load_image, plot_sample


def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
resolve_and_plot('demo/0869x4-crop.png')