In [1]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
if gpus:
  # 텐서플로가 첫 번째 GPU만 사용하도록 제한
    try:
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    except RuntimeError as e:
    # 프로그램 시작시에 접근 가능한 장치가 설정되어야만 합니다
        print(e)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]


In [1]:
from __future__ import print_function, division
import scipy

import tensorflow as tf
from tensorflow import keras 

from tensorflow.keras.applications import VGG19
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential, Model

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from tensorflow.keras.layers import PReLU, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D

import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

import tensorflow.keras.backend as K

In [2]:
# from tensorflow import keras

tf.random.set_seed(2)

In [3]:
class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        self.lr_height = 64                 # Low resolution height
        self.lr_width = 64                  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        self.hr_height = self.lr_height*4   # High resolution height
        self.hr_width = self.lr_width*4     # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)

        # Number of residual blocks in the generator
        self.n_residual_blocks = 16

        optimizer = Adam(0.0002, 0.5)

        # We use a pre-trained VGG19 model to extract image features from the high resolution
        # and the generated high resolution images and minimize the mse between them
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Configure data loader
        self.dataset_name = 'img_align_celeba'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))
        
        # Calculate output shape of D (PatchGAN)
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
        fake_features = self.vgg(fake_hr)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        """
        Builds a pre-trained VGG19 model that outputs image features extracted at the
        third block of the model
        """
        vgg19 = VGG19(weights="imagenet", include_top=False, input_tensor=Input(shape=self.hr_shape))
        # Set outputs to outputs of last conv. layer in block 3
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
#         vgg.outputs = [vgg.layers[9].output]
        
#         vgg.summary()
#         img = Input(shape=self.hr_shape)
#         print(img.shape)
        
        # Extract image features
#         img_features = vgg(img)

#         return Model(img, img_features)
        return Model(inputs=vgg19.input, outputs=vgg19.get_layer('block3_conv4').output)

    def build_generator(self):

        def residual_block(layer_input, filters):
            """Residual block described in paper"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)

        # Pre-residual block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # Propogate through residual blocks
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)

        # Post-residual block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # Upsampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        for epoch in range(epochs):

            # ----------------------
            #  Train Discriminator
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
            
            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)
            
            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Train Generator
            # ------------------
            
            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            g_loss,_,_ = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
            
            
            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            if epoch % 10 == 0:
                print ("%d time: %s g_loss: %s d_loss: %s" % (epoch, elapsed_time, g_loss, d_loss[0]))
                
#             If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
            
            if epoch == 3000:
                self.generator.save('./saved_model/teacher_epoch3000')
            
    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        # Save generated images and the high resolution originals
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()

        # Save low resolution images for comparison
        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()

In [None]:
gan = SRGAN()
gan.train(epochs=4000, batch_size=1, sample_interval=50)

0 time: 0:00:09.022362 g_loss: 391.3881530761719 d_loss: 0.3115745559334755
10 time: 0:00:22.582775 g_loss: 143.034423828125 d_loss: 0.036898551508784294
20 time: 0:00:34.996710 g_loss: 151.3246307373047 d_loss: 0.11423621326684952
30 time: 0:00:47.345715 g_loss: 206.63092041015625 d_loss: 0.0080098154139705
40 time: 0:00:59.735214 g_loss: 107.7813949584961 d_loss: 0.02149177179671824
50 time: 0:01:12.103327 g_loss: 96.79265594482422 d_loss: 0.44683945178985596
60 time: 0:01:25.404711 g_loss: 88.2862319946289 d_loss: 0.03172209765762091
70 time: 0:01:37.733550 g_loss: 87.76997375488281 d_loss: 0.29922191936930176
80 time: 0:01:50.026630 g_loss: 76.18598175048828 d_loss: 0.035597119480371475
90 time: 0:02:02.315358 g_loss: 67.90010070800781 d_loss: 0.0013631187612190843
100 time: 0:02:14.607920 g_loss: 63.84009552001953 d_loss: 0.04158050729893148
110 time: 0:02:27.849296 g_loss: 79.8370132446289 d_loss: 0.01905744755640626
120 time: 0:02:40.091827 g_loss: 111.9331283569336 d_loss: 0.00

1030 time: 0:21:25.595704 g_loss: 17.102270126342773 d_loss: 0.0005492285272339359
1040 time: 0:21:37.707357 g_loss: 37.70618438720703 d_loss: 0.007794791570631787
1050 time: 0:21:49.783192 g_loss: 20.927274703979492 d_loss: 0.3296257547335699
1060 time: 0:22:02.827873 g_loss: 12.160476684570312 d_loss: 0.0001967746575246565
1070 time: 0:22:14.935056 g_loss: 16.992748260498047 d_loss: 0.005221899375101202
1080 time: 0:22:27.084902 g_loss: 31.913658142089844 d_loss: 0.0019283852889202535
1090 time: 0:22:39.173964 g_loss: 24.314308166503906 d_loss: 0.08906803795616725
1100 time: 0:22:51.265717 g_loss: 14.11486530303955 d_loss: 0.0012529421946965158
1110 time: 0:23:04.262268 g_loss: 32.973960876464844 d_loss: 0.0010189711465500295
1120 time: 0:23:16.330789 g_loss: 19.38916015625 d_loss: 0.05257016736504738
1130 time: 0:23:28.448324 g_loss: 34.52748489379883 d_loss: 0.00018491544324206188
1140 time: 0:23:40.477677 g_loss: 21.88515853881836 d_loss: 0.00022743943554814905
1150 time: 0:23:52.

2040 time: 0:42:10.400305 g_loss: 26.091428756713867 d_loss: 0.04544722006539814
2050 time: 0:42:22.543933 g_loss: 21.452810287475586 d_loss: 0.00012579606845974922
2060 time: 0:42:35.630767 g_loss: 16.85350799560547 d_loss: 0.00018598911083245184
2070 time: 0:42:47.748754 g_loss: 27.578283309936523 d_loss: 0.0006790529587306082
2080 time: 0:42:59.896348 g_loss: 29.074953079223633 d_loss: 0.00010011155609390698
2090 time: 0:43:11.996436 g_loss: 20.804988861083984 d_loss: 0.007706864969804883
2100 time: 0:43:24.090588 g_loss: 32.778377532958984 d_loss: 0.00011132494182675146
2110 time: 0:43:37.250811 g_loss: 17.914865493774414 d_loss: 0.0006666344415862113
2120 time: 0:43:49.348163 g_loss: 28.24638557434082 d_loss: 0.0003240539226680994
2130 time: 0:44:01.449234 g_loss: 8.326123237609863 d_loss: 0.1367810462070338
2140 time: 0:44:13.532745 g_loss: 10.642097473144531 d_loss: 0.001668968383455649
2150 time: 0:44:25.625862 g_loss: 31.04475975036621 d_loss: 0.0004169304847891908
2160 time: 

In [None]:
model_name = 'first_baseline_epoch4000_model'
gan.generator.save('./saved_model/%s' % model_name)

In [4]:
gan = SRGAN()
gan.discriminator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 256, 256, 64)      1792      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256, 256, 64)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 64)      36928     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128, 128, 64)      0         
_________________________________________________________________
batch_normalization (BatchNo (None, 128, 128, 64)      256       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 128, 128, 128)     7385