In [1]:
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [4]:
class SRGAN:
    def __init__(self):
        self.gen = self.generator()
        self.dis = self.discriminator()
        self.vgg = self.vgg19()
        self.model_generator = self.build_generator(self.gen, self.dis, self.vgg)

    def generator(self):
        inputs = keras.layers.Input(shape=(270, 480, 3))
        cnn = keras.layers.Conv2D(64, 9, padding='same', activation='relu')(inputs)

        cnn = keras.layers.Conv2D(64, 3, padding='same')(cnn)
        cnn = keras.layers.BatchNormalization()(cnn)
        cnn = tf.nn.relu(cnn)

        cnn_first = cnn
        for index in range(16):
            cnn_ori = cnn
            cnn = keras.layers.Conv2D(64, 3, padding='same')(cnn)
            cnn = keras.layers.BatchNormalization()(cnn)
            cnn = tf.nn.relu(cnn)

            cnn = keras.layers.Conv2D(64, 3, padding='same')(cnn)
            cnn = keras.layers.BatchNormalization()(cnn)
            cnn = cnn + cnn_ori

        cnn = keras.layers.Conv2D(64, 3, padding='same')(cnn)
        cnn = keras.layers.BatchNormalization()(cnn)
        cnn = cnn_first + cnn

        for index in range(2):
            cnn = keras.layers.Conv2D(1, 3, padding='same')(cnn)
            cnn = keras.layers.UpSampling2D(size=2)(cnn)
            cnn = tf.nn.relu(cnn)

        cnn = keras.layers.Conv2D(3, 9, padding='same')(cnn)
        outputs = tf.nn.tanh(cnn)

        model = keras.models.Model(inputs=[inputs], outputs=[outputs])
        model.compile(optimizer=tf.optimizers.Adam(1e-1), loss='mse', metrics=['mse'])
        return model

    def discriminator(self):
        inputs = keras.layers.Input(shape=(270 * 4, 480 * 4, 3))

        cnn = keras.layers.Conv2D(64, 3, strides=1, padding='same')(inputs)
        cnn = tf.nn.relu(cnn)

        cnn = keras.layers.Conv2D(64, 3, strides=2, padding='same')(cnn)
        cnn = keras.layers.BatchNormalization()(cnn)
        cnn = tf.nn.relu(cnn)

        for filters in [128, 256, 512]:
            for strides in [1, 2]:
                cnn = keras.layers.Conv2D(filters, 3, strides=strides, padding='same')(cnn)
                cnn = keras.layers.BatchNormalization()(cnn)
                cnn = tf.nn.relu(cnn)
        cnn = keras.layers.Flatten()(cnn)
        dense = keras.layers.Dense(1024)(cnn)
        dense = tf.nn.relu(dense)
        dense = keras.layers.Dense(1)(dense)

        outputs = tf.nn.sigmoid(dense)

        model = keras.models.Model(inputs=[inputs], outputs=[outputs])
        model.compile(optimizer=tf.optimizers.Adam(1e-1, loss='binary_crossentropy', metrics=['accuracy']))
        return model

    def vgg19(self):
        vgg19 = keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=(270 * 4, 480 * 4, 3))
        vgg19.trainable = False
        for l in vgg19.layers:
            l.trainable = False
        model = keras.Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)
        model.compile(optimizer=tf.optimizers.Adam(1e-1), loss='mse', metrics=["mse"])
        return model

    def build_generator(self, gen, dis, vgg):
        discriminator = dis
        discriminator.trainable = False

        vgg = vgg
        vgg.trainable = False

        generator = gen
        discriminator_prob = discriminator(generator.outputs)
        vgg_features = vgg(generator.outputs)

        model = keras.Model(inputs=[generator.inputs], outputs=[generator.outputs, vgg_features, discriminator_prob])

        model.compile(optimizer=tf.optimizers.Adam(1e-1), loss=['mse', 'mse', 'binary_crossentropy'],
                      loss_weights=[1., 2e-6, 1e-3])
        return model

    def model_train(self, img_l, img_h):
        batch_size = img_l.shape[0]
        label_real = np.array([[0]] * batch_size)
        label_fake = np.array([[1]] * batch_size)
        self.dis.trainable = False
        self.vgg.trainable = False

        img_GT_vgg_features = self.vgg.predict(img_h)

        self.model_generator.train_on_batch(img_l, [img_h, img_GT_vgg_features, label_real])

        self.dis.trainable = True

        img_pred = self.model_generator.predict(img_l)
        self.dis.train_on_batch(img_pred, label_fake)
        self.dis.train_on_batch(img_h, label_real)

    def model_pred(self, img_l):
        img_super, vgg_features, prob_fake_or_real = self.model_generator.predict(img_l)
        return img_super

    def model_evaluate(self, img_l, img_GT):
        batch_size = img_l.shape[0]
        label_real = np.array([[0]] * batch_size)
        img_GT_vgg_features = self.vgg.predict(img_GT)
        generator_loss, vgg_features_content_loss, discriminator_loss = self.model_generator.evaluate(img_l, [img_GT,
                                                                                                              img_GT_vgg_features,
                                                                                                              label_real])
        return generator_loss, vgg_features_content_loss, discriminator_loss

In [5]:
srgan = SRGAN()
import os
img_path = os.path.join('youku', 'IMG_LOW', 'Youku_00000_l', '034.bmp')
img_l = cv2.imread(img_path)
img_path = os.path.join('youku', 'IMG_HIGH', 'Youku_00000_h_GT', '034.bmp')
img_h = cv2.imread(img_path)
x = np.array([img_l, img_l])
y = np.array([img_h, img_h])
srgan.model_train(x, y)
srgan.model_evaluate(x, y)
pic_super = srgan.model_pred(x)
cv2.imwrite("youku/GEN/srgan_00.bmp", pic_super[0])

ResourceExhaustedError: OOM when allocating tensor with shape[4177920,1024] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu [Op:Mul]