# Fetching Dataset from Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import shutil
shutil.copy("/content/drive/<enter drive path>", "<enter path to save file on colab>")

#Importing Libraries

In [None]:
import os
import time
import datetime
import tensorflow as tf
import matplotlib.pyplot as plt
from numpy.random import randint
from tensorflow.keras import Input
from numpy import load, zeros, ones
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Conv2D, UpSampling2D, LeakyReLU, Activation
from tensorflow.keras.layers import Concatenate, Dropout, BatchNormalization, LeakyReLU

# Creating Functions

## Defining Discriminator

In [None]:
def define_discriminator(image_shape):
	init = RandomNormal(stddev=0.02)
 
	in_src_image = Input(shape=image_shape)
	in_target_image = Input(shape=image_shape)
 
	merged = Concatenate()([in_src_image, in_target_image])
 
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
 
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
 
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
 
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
 
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
 
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
 
	model = Model([in_src_image, in_target_image], patch_out)
 
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model

## Defining Generator

In [None]:
def define_generator(image_shape = (256, 256, 3)):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)

    g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
    g = BatchNormalization()(g, training=True)
    g3 = LeakyReLU(alpha=0.2)(g)

    g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g3)
    g = BatchNormalization()(g, training=True)
    g2 = LeakyReLU(alpha=0.2)(g)

    g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g2)
    g = BatchNormalization()(g, training=True)
    g1 = LeakyReLU(alpha=0.2)(g)

    for _ in range(6):
        g = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(g1)
        g = BatchNormalization()(g, training=True)
        g = LeakyReLU(alpha=0.2)(g)

        g = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(g)
        g = BatchNormalization()(g, training=True)

        g1 = Concatenate()([g, g1])

    g = UpSampling2D((2, 2))(g1)
    g = Conv2D(128, (1, 1), kernel_initializer=init)(g)
    g = Dropout(0.5)(g, training=True)
    g = Concatenate()([g, g2])
    g = BatchNormalization()(g, training=True)
    g = LeakyReLU(alpha=0.2)(g)

    g = UpSampling2D((2, 2))(g)
    g = Conv2D(64, (1, 1), kernel_initializer=init)(g)
    g = Dropout(0.5)(g, training=True)
    g = Concatenate()([g, g3])
    g = BatchNormalization()(g, training=True)
    g = LeakyReLU(alpha=0.2)(g)

    g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
    g = BatchNormalization()(g, training=True)
    out_image = Activation('tanh')(g)

    model = Model(in_image, out_image)
    return model

## Initializing GAN training

In [None]:
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	for layer in d_model.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model

## Loading Real Samples

In [None]:
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X2, X1]

## Generating Real Fake Samples

In [None]:
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

## Generating Fake Samples

In [None]:
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

## Summarizing Training and Saving Model

In [None]:
def summarize_performance(step, g_model, d_model, dataset, n_samples=3):
    # select a sample of input images
    [X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
    # generate a batch of fake samples
    X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
    # scale all pixels from [-1,1] to [0,1]
    X_realA = (X_realA + 1) / 2.0
    X_realB = (X_realB + 1) / 2.0
    X_fakeB = (X_fakeB + 1) / 2.0
    # plot real source images
    plt.figure(figsize=(14, 14))
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + i)
        plt.axis('off')
        plt.title('Low-Light')
        plt.imshow(X_realA[i])
    # plot generated target image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples + i)
        plt.axis('off')
        plt.title('Generated')
        plt.imshow(X_fakeB[i])
    # plot real target image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples*2 + i)
        plt.axis('off')
        plt.title('Ground Truth')
        plt.imshow(X_realB[i])
    # save plot to file
    filename1 = step_output + 'plot_%06d.png' % (step+1)
    plt.savefig(filename1)
    plt.close()
    # save the generator model
    filename2 = model_output + 'gen_model_%06d.h5' % (step+1)
    g_model.save(filename2)
    # save the discriminator model
    filename3 = model_output + 'disc_model_%06d.h5' % (step+1)
    d_model.save(filename3)
    print('[.] Saved Step : %s' % (filename1))
    print('[.] Saved Model: %s' % (filename2))
    print('[.] Saved Model: %s' % (filename3))

## Training Function

In [None]:
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=12):
    # determine the output square shape of the discriminator
    n_patch = d_model.output_shape[1]
    # unpack dataset
    trainA, trainB = dataset
    # calculate the number of batches per training epoch
    bat_per_epo = int(len(trainA) / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    print("[!] Number of steps {}".format(n_steps))
    print("[!] Saves model/step output at every {}".format(bat_per_epo * 1))
    # manually enumerate epochs
    for i in range(n_steps):
        start = time.time()
        # select a batch of real samples
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
        # generate a batch of fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        # update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        # update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        # update the generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
        # summarize performance
        time_taken = time.time() - start
        print(
            '[*] %06d, d1[%.3f] d2[%.3f] g[%06.3f] ---> time[%.2f], time_left[%.08s]'
                %
            (i+1, d_loss1, d_loss2, g_loss, time_taken, str(datetime.timedelta(seconds=((time_taken) * (n_steps - (i + 1))))).split('.')[0].zfill(8))
        )
        # summarize model performance
        if (i+1) % (bat_per_epo * 1) == 0:
            summarize_performance(i, g_model, d_model, dataset)

# Main Function

## Loading Dataset

In [None]:
dataset = load_real_samples('<enter path to save file on colab>')
print('Loaded', dataset[0].shape, dataset[1].shape)
image_shape = dataset[0].shape[1:]

## Defining Models

In [None]:
d_model = define_discriminator(image_shape)
g_model = define_generator(image_shape)
gan_model = define_gan(g_model, d_model, image_shape)

## Creating model Directory and Calling Train Function

In [None]:
dir = '<enter path to save model on colab>'
fileName = 'Enhancement Model'
step_output = dir + fileName + "/Step Output/"
model_output = dir + fileName + "/Model Output/"
if fileName not in os.listdir(dir):
    os.mkdir(dir + fileName)
    os.mkdir(step_output)
    os.mkdir(model_output)

train(d_model, g_model, gan_model, dataset, batch=12)