In [1]:
import numpy as np
import pandas as pd
import cv2
import csv
import os
import random
from matplotlib import pyplot as plt
import keras

from keras.layers import Flatten
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Conv2D, MaxPooling2D, AveragePooling2D, Convolution2D, BatchNormalization, Activation, Reshape, UpSampling2D, MaxPool2D
from IPython.display import SVG
from PIL import Image
import scipy
from glob import glob

from keras.layers.merge import concatenate

from __future__ import print_function, division
import scipy

from keras import regularizers
from keras import initializers

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import sys

import keras.backend as K

from keras.utils.layer_utils import convert_all_kernels_in_model
from keras.utils.data_utils import get_file

Using TensorFlow backend.


In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#! unzip "drive/My Drive/DIV2K_train_HR.zip"

In [0]:
# Reading the image labes
path = "DIV2K_train_HR"          # Give the path of the testing file you are using here
train_labels = os.listdir( path )


In [0]:
# Reading images based on labels  and rescaling the images
train_img = []
# Rescaling the image
w = 224
h = 224
for i in range(len(train_labels)):
  #img = cv2.imread(os.path.join('DIV2K_train_HR',train_labels[i]),1)
  img = Image.open(os.path.join('DIV2K_train_HR',train_labels[i]))
  imgg = img.resize((w, h), Image.BICUBIC) 
  p = os.path.join('drive/My Drive/cropped/',(str(i)+'.jpg'))
  imgg.save(p,quality = 100)
  train_img.append(imgg)

In [0]:
# Reading the image labes
path = "drive/My Drive/cropped"
rescaled_train_labels = os.listdir( path )


In [0]:
# Reading the rescaled images into a list
rescaled_train = []
for i in range(len(rescaled_train_labels)):
  img = cv2.imread(os.path.join('drive/My Drive/cropped/',rescaled_train_labels[i]),1)
  rescaled_train.append(img)

In [0]:
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        path = glob('drive/My Drive/%s/*' % (self.dataset_name))

        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = scipy.misc.imresize(img, self.img_res)
            img_lr = scipy.misc.imresize(img, (low_h, low_w))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.

        return imgs_hr, imgs_lr


    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [0]:
class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        self.lr_height = 64                 # Low resolution height
        self.lr_width = 64                  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        self.hr_height = self.lr_height*4   # High resolution height
        self.hr_width = self.lr_width*4     # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)


        optimizer = Adam(0.0002, 0.5)

        # We use a pre-trained VGG19 model to extract image features from the high resolution
        # and the generated high resolution images and minimize the mse between them
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Configure data loader
        self.dataset_name = 'cropped'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))

        # Calculate output shape of D (PatchGAN)
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
        fake_features = self.vgg(fake_hr)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        """
        Builds a pre-trained VGG19 model that outputs image features extracted at the
        third block of the model
        """
        vgg = VGG19(weights="imagenet")
        # Set outputs to outputs of last conv. layer in block 3
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
        vgg.outputs = [vgg.layers[9].output]

        img = Input(shape=self.hr_shape)

        # Extract image features
        img_features = vgg(img)

        return Model(img, img_features)

    def build_generator(self):

        def inception_block(layer_input, filters):
            """Inception block described in report"""
            d1 = Conv2D(filters, kernel_size=1, strides=1, padding='same', activation='relu')(layer_input)
            d1 = Conv2D(filters, kernel_size=3, strides=1, padding='same', activation='relu')(d1)
            d2 = Conv2D(filters, kernel_size=1, strides=1, padding='same', activation='relu')(layer_input)
            d2 = Conv2D(filters, kernel_size=5, strides=1, padding='same', activation='relu')(d2)
            d3 = MaxPooling2D((3,3), (1,1), padding='same')(layer_input)
            d3 = Conv2D(filters, kernel_size=1, strides=1, padding='same', activation='relu')(d3)
            
            d = keras.layers.concatenate([d1, d2, d3], axis = 3)
            return d

        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)

        # Pre-inception block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # Propogate through inception blocks
        ci = inception_block(c1, self.gf)

        # Post-inception block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(ci)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # Upsampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        for epoch in range(epochs):

            # ----------------------
            #  Train Discriminator
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Train Generator
            # ------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("%d time: %s" % (epoch, elapsed_time))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        # Save generated images and the high resolution originals
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()

        # Save low resolution images for comparison
        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()

if __name__ == '__main__':
    gan = SRGAN()
    gan.train(epochs=5000, batch_size=1, sample_interval=50)

Instructions for updating:
Colocations handled automatically by placer.
Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.


Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


0 time: 0:00:34.820688
1 time: 0:01:03.922073
2 time: 0:01:25.162588
3 time: 0:01:46.589195
4 time: 0:02:08.944501
5 time: 0:02:30.272577
6 time: 0:02:51.508253
7 time: 0:03:12.883134
8 time: 0:03:33.927574
9 time: 0:03:54.664986
10 time: 0:04:16.566019
11 time: 0:04:37.151341
12 time: 0:04:58.183738
13 time: 0:05:18.723584
14 time: 0:05:39.779084
15 time: 0:06:00.661360
16 time: 0:06:22.039606
17 time: 0:06:42.988789
18 time: 0:07:04.008025
19 time: 0:07:24.828802
20 time: 0:07:45.809628
21 time: 0:08:06.960369
22 time: 0:08:28.168847
23 time: 0:08:49.330305
24 time: 0:09:10.237442
25 time: 0:09:31.340385
26 time: 0:09:52.336146
27 time: 0:10:13.304160
28 time: 0:10:33.695357
29 time: 0:10:54.363069
30 time: 0:11:14.773747
31 time: 0:11:35.739949
32 time: 0:11:56.586008
33 time: 0:12:17.225732
34 time: 0:12:38.412879
35 time: 0:12:59.187650
36 time: 0:13:19.798985
37 time: 0:13:41.066428
38 time: 0:14:01.569368
39 time: 0:14:22.561442
40 time: 0:14:43.342129
41 time: 0:15:04.197549
42

In [0]:
1200