In [74]:
from google.cloud import storage
import tempfile
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras import layers, Model
from sklearn.model_selection import train_test_split

import numpy as np
from keras import Model
from keras.layers import Conv2D, PReLU,BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add
from tqdm import tqdm

### Get the 100 HR images from Kaggle100-original

In [32]:
BUCKET_NAME = 'srgan-wagon-project'
STORAGE_LOCATION_hr = 'datasets/kaggle100-original/HR'

In [69]:
def get_images_gcp(prefix):

    client = storage.Client()

    bucket = client.bucket(BUCKET_NAME)

    blobs = bucket.list_blobs(prefix)
    images = []

    for blob in blobs:
        _, temp_local_filename = tempfile.mkstemp()

        # Download file from bucket.
        blob.download_to_filename(temp_local_filename)
        img = cv2.imread(temp_local_filename)
        images.append(img)
        os.remove(temp_local_filename)
    return np.array(images)

In [70]:
images_hr=get_images_gcp(prefix=STORAGE_LOCATION_hr)

In [71]:
images_hr.shape

(100, 384, 384, 3)

In [1]:
plt.imshow(images_hr[0]);


KeyboardInterrupt



### LR images

In [None]:
BUCKET_NAME = 'srgan-wagon-project'
STORAGE_LOCATION_LR = 'datasets/kaggle100-original/LR'

In [None]:
images_lr=get_images_gcp(prefix=STORAGE_LOCATION_LR)

### Create X and Y and Scale

In [None]:
X = images_lr/ 255
y = images_hr/ 255

### Train/Test split

<!-- #Split to train and test
X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=0.33,
                                                    random_state=42) -->

### Create Batches

In [None]:
#Create a list of images for LR and HR in batches from which a batch of images
#would be fetched during training.
batch_size = 1
X_batches = []
y_batches = []
for it in range(int(y.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    y_batches.append(y[start_idx:end_idx])
    X_batches.append(X[start_idx:end_idx])

In [None]:
y_shape = (y.shape[1], y.shape[2], y.shape[3])
X_shape = (X.shape[1], X.shape[2], X.shape[3])

X_ip = Input(shape=X_shape)
y_ip = Input(shape=y_shape)

### Create model

In [None]:
def res_block(ip):

    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1,2])(res_model)

    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)

    return add([ip,res_model])

def upscale_block(ip):

    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    up_model = PReLU(shared_axes=[1,2])(up_model)

    return up_model

In [None]:
#Generator model
def create_gen(gen_ip, num_res_block):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    layers = PReLU(shared_axes=[1,2])(layers)

    temp = layers

    for i in range(num_res_block):
        layers = res_block(layers)

    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])

    layers = upscale_block(layers)
    layers = upscale_block(layers)

    op = Conv2D(3, (9,9), padding="same")(layers)

    return Model(inputs=gen_ip, outputs=op)

In [None]:
#Descriminator block that will be used to construct the discriminator
def discriminator_block(ip, filters, strides=1, bn=True):

    disc_model = Conv2D(filters, (3,3), strides = strides, padding="same")(ip)

    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)

    disc_model = LeakyReLU( alpha=0.2 )(disc_model)

    return disc_model

In [None]:
#Descriminartor, as described in the original paper
def create_disc(disc_ip):

    df = 64

    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)

    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

In [None]:
from keras.applications import VGG19

def build_vgg(y_shape):

    vgg = VGG19(weights="imagenet",include_top=False, input_shape=y_shape)

    return Model(inputs=vgg.inputs, outputs=vgg.layers[10].output)


In [None]:
#Combined model
def create_comb(gen_model, disc_model, vgg, X_ip, y_ip):
    gen_img = gen_model(X_ip)

    gen_features = vgg(gen_img)

    disc_model.trainable = False
    validity = disc_model(gen_img)

    return Model(inputs=[X_ip, y_ip], outputs=[validity, gen_features])

In [None]:
generator = create_gen(X_ip, num_res_block = 16)
generator.summary()

discriminator = create_disc(y_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])
discriminator.summary()

vgg = build_vgg((128,128,3))
print(vgg.summary())
vgg.trainable = False

gan_model = create_comb(generator, discriminator, vgg, X_ip, y_ip)

### Train

In [None]:
epochs = 5
#Enumerate training over epochs
for e in range(epochs):

    fake_label = np.zeros((batch_size, 1)) # Assign a label of 0 to all fake (generated images)
    real_label = np.ones((batch_size,1)) # Assign a label of 1 to all real images.

    #Create empty lists to populate gen and disc losses.
    g_losses = []
    d_losses = []

    #Enumerate training over batches.
    for b in tqdm(range(len(y_batches))):
        X_imgs = X_batches[b] #Fetch a batch of LR images for training
        y_imgs = y_batches[b] #Fetch a batch of HR images for training

        fake_imgs = generator.predict_on_batch(X_imgs) #Fake images

        #First, train the discriminator on fake and real HR images.
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(fake_imgs, fake_label)
        d_loss_real = discriminator.train_on_batch(y_imgs, real_label)

        #Now, train the generator by fixing discriminator as non-trainable
        discriminator.trainable = False

        #Average the discriminator loss, just for reporting purposes.
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)

        #Extract VGG features, to be used towards calculating loss
        image_features = vgg.predict(y_imgs)

        #Train the generator via GAN.
        #Remember that we have 2 losses, adversarial loss and content (VGG) loss
        g_loss, _, _ = gan_model.train_on_batch([X_imgs, y_imgs], [real_label, image_features])

        #Save losses to a list so we can average and report.
        d_losses.append(d_loss)
        g_losses.append(g_loss)

    #Convert the list of losses to an array to make it easy to average
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)

    #Calculate the average losses for generator and discriminator
    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)

    #Report the progress during training.
    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)

    if (e+1) % 10 == 0: #Change the frequency for model saving, if needed
        #Save the generator after every n epochs (Usually 10 epochs)
        generator.save("gen_e_"+ str(e+1) +".h5")