In [0]:
from google.colab import drive
drive.mount('/content/drive/')

In [0]:
import time
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from imageio import imread
from PIL import Image
import os
from functools import partial
from sklearn.model_selection import train_test_split
import random
import collections

#Go to the right directory
os.chdir('/content/drive/My Drive/V&P')
print('Current directory %s ' % os.getcwd())

In [0]:
def conv2d(x, filter, kernel, stride, padding):
    return tf.layers.conv2d(inputs=x, filters=filter, kernel_size=kernel, strides=stride, padding=padding)

def conv2dTranspose(x, filter, kernel, stride, padding):
    return tf.layers.conv2d_transpose(inputs=x, filters=filter, kernel_size=kernel, strides=stride, padding=padding, use_bias=False)

def batchNormalization(x):
    return tf.layers.batch_normalization(inputs=x, axis=3, momentum=0.9, epsilon=1e-5)

def instanceNormalization(x):
    return tf.contrib.layers.instance_norm(inputs=x, center=True, scale=True)

def relu(x):
    return tf.nn.relu(features=x)

def leakyRelu(x):
    return tf.nn.leaky_relu(features=x)

def tanh(x):
    return tf.math.tanh(x=x)

def sigmoid(x):
    return tf.math.sigmoid(x=x)

def add(x,y):
    return tf.add(x,y)

In [0]:
def residual_block(x):
    """
    Residual block
    """
    res = tf.pad(x, [ [0, 0], [1, 1], [1, 1], [0, 0] ], "REFLECT")
    res = relu( batchNormalization( conv2d(res, 256, 3, 1, "valid") ) )
    res = tf.pad(res, [ [0, 0], [1, 1], [1, 1], [0, 0] ], "REFLECT" )
    res = batchNormalization( conv2d(res, 256, 3, 1, "valid") )
    return add(res, x)

In [0]:
def build_generator(image, A_or_B):
    """
    Create a generator network using the hyperparameter values defined below
    """
    residual_blocks = 9
    with tf.variable_scope(A_or_B+'_generator', reuse=tf.AUTO_REUSE):
        input = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
    
        # First Convolution block
        h_conv1 = relu( instanceNormalization( conv2d(input, 64, 7, 1, 'valid') ) )
        
        # 2nd Convolution block
        h_conv2 = relu( instanceNormalization( conv2d(h_conv1, 128, 3, 2, 'same') ) )

        # 3rd Convolution block
        h_conv3 = relu( instanceNormalization( conv2d(h_conv2, 256, 3, 2,'same') ) )
        
        # 9 Residual blocks
        for i in range(1, residual_blocks+1):
            residual = residual_block(h_conv3)
    
        # Upsampling blocks
        # 1st Upsampling block
        h_conv4 = relu( instanceNormalization( conv2dTranspose(residual, 128, 3, 2, 'same') ) )
    
        # 2nd Upsampling block
        h_conv5 = relu( instanceNormalization( conv2dTranspose(h_conv4, 64, 3, 2, 'same') ) )

        # Last Convolution layer
        output = tanh( conv2d(h_conv5, 3, 7, 1, 'same') )
    
        return output  

In [0]:
def build_discriminator(image, A_or_B):
    """
    Create a discriminator network using the hyperparameter values defined below
    """
    #Add some gaussian noise to the discriminator
    image = image + tf.random_normal(shape=tf.shape(image), mean=0.0, stddev=0.1, dtype=tf.float32)

    with tf.variable_scope(A_or_B + '_discriminator', reuse=tf.AUTO_REUSE ):
        # 1st Convolutional block
        h_conv1 = leakyRelu(conv2d(image, 64, 4, 2, "same"))

        # 3 Hidden Convolution blocks
        hidden_conv1 = leakyRelu(instanceNormalization(conv2d(h_conv1, 128, 4, 2, "same")))
        hidden_conv2 = leakyRelu(instanceNormalization(conv2d(hidden_conv1, 256, 4, 2, "same")))
        hidden_conv3 = leakyRelu(instanceNormalization(conv2d(hidden_conv2, 512, 4, 2, "same")))
    
        # Last Convolution layer
        output = conv2d(hidden_conv3, 1, 4, 1, "same")
        
        return output

In [0]:
def load_images(data_dirA, data_dirB):
    count = 0
    percentage = 10
    
    if not data_dirA == data_dirB:
        imagesA = glob(data_dirA+ '/*.*')
        imagesB = glob(data_dirB+ '/*.*')
        value = 33
    else:
        imagesA = glob(data_dirA+ '/trainA/*.*')
        imagesB = glob(data_dirA+ '/trainB/*.*')
        value = 107

    allImagesA = []
    allImagesB = []

    for index, filename in enumerate(imagesB):
        if count % value == 0:
            print("Stored %d%% images" %percentage)
            percentage += 10
        imgA = imread(imagesA[index], pilmode='RGB')
        imgB = imread(filename, pilmode='RGB')
        
        imgA = np.array(Image.fromarray(imgA).resize((256, 256)) )
        imgB = np.array(Image.fromarray(imgB).resize((256, 256)) )

        if np.random.random() > 0.5:
            imgA = np.fliplr(imgA)
            imgB = np.fliplr(imgB)

        allImagesA.append(imgA)
        allImagesB.append(imgB)
        count += 1

    # Normalize images
    allImagesA = np.array(allImagesA) / 127.5 - 1.
    allImagesB = np.array(allImagesB) / 127.5 - 1.

    return allImagesA, allImagesB

In [0]:
def load_images_test(data_dirA, data_dirB):
    count = 0
    percentage = 10
    
    imagesA = glob(data_dirA+ '/testA/*.*')
    imagesB = glob(data_dirA+ '/testB/*.*')
    print(len(imagesA))
    print(len(imagesB))
    allImagesA = []
    allImagesB = []

    for index, filename in enumerate(imagesB):
        imgA = imread(imagesA[index], pilmode='RGB')
        imgB = imread(filename, pilmode='RGB')
        
        imgA = np.array(Image.fromarray(imgA).resize((256, 256)) )
        imgB = np.array(Image.fromarray(imgB).resize((256, 256)) )

        if np.random.random() > 0.5:
            imgA = np.fliplr(imgA)
            imgB = np.fliplr(imgB)

        allImagesA.append(imgA)
        allImagesB.append(imgB)
        count += 1

    # Normalize images
    allImagesA = np.array(allImagesA) / 127.5 - 1.
    allImagesB = np.array(allImagesB) / 127.5 - 1.

    return allImagesA, allImagesB

In [0]:
def load_test_batch(data_dir, batch_size):
    imagesA = glob(data_dir + '/testA/*.*')
    imagesB = glob(data_dir + '/testB/*.*')

    imagesA = np.random.choice(imagesA, batch_size)
    imagesB = np.random.choice(imagesB, batch_size)

    allA = []
    allB = []

    for index, filename in enumerate(imagesA):
        # Load images and resize images
        imgA = np.array(Image.fromarray(imread(filename, pilmode='RGB')).resize((256, 256)))
        imgB = np.array(Image.fromarray(imread(imagesB[index], pilmode='RGB')).resize((256, 256)))

        allA.append(imgA)
        allB.append(imgB)

    return np.array(allA) / 127.5 - 1.0, np.array(allB) / 127.5 - 1.0

In [0]:
def save_images(originalA, generatedB, reconstructedA, originalB, generatedA, reconstructedB, path, predicting):
    """
    Save images
    """
    if predicting:
        #Disable interactive plotting when predicting
        plt.ioff()
    
    fig = plt.figure(figsize=(7,7))
    ax = fig.add_subplot(2, 3, 1)
    ax.imshow(originalA)
    ax.axis("off")
    ax.set_title("Original")
    if predicting:
        #Close fig
        plt.close(fig)

    ax = fig.add_subplot(2, 3, 2)
    ax.imshow(generatedB)
    ax.axis("off")
    ax.set_title("Generated")

    ax = fig.add_subplot(2, 3, 3)
    ax.imshow(reconstructedA)
    ax.axis("off")
    ax.set_title("Reconstructed")

    ax = fig.add_subplot(2, 3, 4)
    ax.imshow(originalB)
    ax.axis("off")
    ax.set_title("Original")

    ax = fig.add_subplot(2, 3, 5)
    ax.imshow(generatedA)
    ax.axis("off")
    ax.set_title("Generated")

    ax = fig.add_subplot(2, 3, 6)
    ax.imshow(reconstructedB)
    ax.axis("off")
    ax.set_title("Reconstructed")

    plt.savefig(path)


In [0]:
def add_summary(writer, name, value, global_step):
    summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=value)])
    writer.add_summary(summary, global_step=global_step)

In [0]:
def CycleGAN():
    
    generatorA2B = partial(build_generator, A_or_B='AtoB')
    generatorB2A = partial(build_generator, A_or_B='BtoA')
    discriminatorA = partial(build_discriminator, A_or_B='A')
    discriminatorB = partial(build_discriminator, A_or_B='B')
    
    real_imageA = tf.placeholder("float", shape=[None, 256, 256, 3], name="Real_Image_A")    
    real_imageB = tf.placeholder("float", shape=[None, 256, 256, 3], name="Real_Image_B")
    
    fake_imageA = generatorB2A(real_imageB)
    fake_imageB = generatorA2B(real_imageA)
    reconstructedA = generatorB2A(fake_imageB)
    reconstructedB = generatorA2B(fake_imageA)
    
    probIsRealA = discriminatorA(real_imageA)
    probIsFakeA = discriminatorA(fake_imageA)

    probIsRealB = discriminatorB(real_imageB)
    probIsFakeB = discriminatorB(fake_imageB)
    
    with tf.variable_scope('cyclic_loss'):
        g_loss_a_to_b = tf.losses.mean_squared_error(labels=probIsFakeB, predictions=tf.ones_like(probIsFakeA))
        g_loss_b_to_a = tf.losses.mean_squared_error(labels=probIsFakeA, predictions=tf.ones_like(probIsFakeA))
        cyc_loss_a = tf.losses.absolute_difference(real_imageA, reconstructedA)
        cyc_loss_b = tf.losses.absolute_difference(real_imageB, reconstructedB)       
        g_total_loss = g_loss_a_to_b + g_loss_b_to_a + cyc_loss_a * 10.0 + cyc_loss_b * 10.0

    with tf.variable_scope("discriminator_A_loss"):
        da_loss_real = tf.losses.mean_squared_error(labels=probIsRealA, predictions=tf.ones_like(probIsRealA))
        da_loss_b_to_a_fake = tf.losses.mean_squared_error(labels=probIsFakeA, predictions=tf.zeros_like(probIsFakeA))
        da_total_loss = da_loss_real + da_loss_b_to_a_fake

    with tf.variable_scope("discriminator_B_loss"):
        db_loss_real = tf.losses.mean_squared_error(labels=probIsRealB, predictions=tf.ones_like(probIsRealB))
        db_loss_a_to_b_sample = tf.losses.mean_squared_error(labels=probIsFakeB, predictions=tf.zeros_like(probIsFakeB))
        db_total_loss = db_loss_real + db_loss_a_to_b_sample
    
    with tf.variable_scope("train"):     
        tvars = tf.trainable_variables()
        dA_vars = [var for var in tvars if 'A_discriminator' in var.name]
        dB_vars = [var for var in tvars if 'B_discriminator' in var.name]
        g_vars = [var for var in tvars if 'AtoB_generator' in var.name or 'BtoA_generator' in var.name]

        adam = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)
        trainerDA = adam.minimize(da_total_loss, var_list=dA_vars)
        trainerDB = adam.minimize(db_total_loss, var_list=dB_vars)
        trainerG = adam.minimize(g_total_loss, var_list=g_vars)
    
    return real_imageA, real_imageB, fake_imageA, fake_imageB, reconstructedA, reconstructedB, da_total_loss, db_total_loss, g_total_loss, trainerDA, trainerDB, trainerG

In [0]:
log_dir =  "logs2/"
checkpoint_save_path =  "checkpoint2/model.ckpt"
meta_graph_path = 'model2/model.ckpt.meta'
complete_model_path = 'model2/model.ckpt'

In [0]:
dataset_path = "data/summer2winter_yosemite"
imagesA, imagesB = load_images(dataset_path, dataset_path)

In [0]:
def batch_generator(A,B,batch_size):
    for start in range(0, len(A), batch_size):
        end = start + batch_size
        yield A[start:end], B[start:end]

In [0]:
epochs = 500
BATCH_SIZE = 32
n_iterations = int(np.ceil(len(imagesA)/1))

tf.reset_default_graph()
g = tf.Graph()
with g.as_default():  
    real_imageA, real_imageB, fake_imageA, fake_imageB, reconstructedA, reconstructedB, da_total_loss, db_total_loss, g_total_loss, trainerDA, trainerDB, trainerG = CycleGAN()
    saver = tf.train.Saver() 
    
with tf.Session(graph=g) as sess:
    if tf.train.latest_checkpoint('checkpoint2/'):
        print("Checkpoint present. Restoring model.")
        saver.restore(sess, tf.train.latest_checkpoint('checkpoint2/'))
    else:
        print("Model not present. Initializing variables.")
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer()) 
    print("\nStarting training...")    
    train_writer = tf.summary.FileWriter(log_dir, sess.graph)
    try:
        for epoch in range(81, epochs):
            print("\nEpoch", epoch + 1)
            epoch_da_loss, epoch_db_loss, epoch_g_loss = 0., 0., 0.
            mb = 0
            start = time.perf_counter()
            print("======="*10)
            #for imageA, imageB in batch_generator(imagesA, imagesB, BATCH_SIZE):
            for imageA, imageB in zip(imagesA, imagesB):
                imageA = np.reshape(imageA, [1, 256,256,3])
                imageB = np.reshape(imageB, [1, 256, 256, 3])
                mb += 1
                #Retrieve the output of the two generators (A and B)
                A_Fake, B_Fake = sess.run([fake_imageA, fake_imageB], feed_dict={real_imageA:imageA, real_imageB:imageB})
                #Retrieve cyclic loss
                g_loss_val, _ = sess.run([g_total_loss, trainerG], feed_dict={real_imageA:imageA, real_imageB:imageB})
                epoch_g_loss += g_loss_val
                #Retrieve loss of discriminator A
                if epoch%2==0:
                    d_A_loss_val, _ = sess.run([da_total_loss, trainerDA], feed_dict={real_imageA:imageA, real_imageB:imageB})
                    epoch_da_loss += d_A_loss_val
                    #Retrieve loss of discriminator B
                    d_B_loss_val, _ = sess.run([db_total_loss, trainerDB], feed_dict={real_imageA:imageA, real_imageB:imageB})
                    epoch_db_loss += d_B_loss_val 
            elapsed = time.perf_counter() - start
            print('Elapsed %.3f seconds. \n' % elapsed)
            print("Cyclic Loss: {:.4f}\tDiscriminator A Loss: {:.4f}\tDiscriminator B Loss: {:.4f} ".format(epoch_g_loss/mb, epoch_da_loss/mb, epoch_db_loss/mb), end="\r")
            epoch_g_loss /= n_iterations
            epoch_da_loss /= n_iterations
            epoch_db_loss /= n_iterations
            add_summary(train_writer, "epoch_g_loss", epoch_g_loss, epoch)
            add_summary(train_writer, "epoch_da_loss", epoch_da_loss, epoch)
            add_summary(train_writer, "epoch_db_loss", epoch_db_loss, epoch)
            print("\n")
            print()
            print("======="*10)
            if epoch%10 == 0:
                # Save a checkpoint
                save_path = saver.save(sess, checkpoint_save_path)
                
                batchA, batchB = load_test_batch(data_dir=dataset_path, batch_size=2)
                # Try model so far
                #deqA = collections.deque(testA)
                #deqB = collections.deque(testB)
                #sampleA = np.asarray(random.sample(deqA, 2))
                #sampleB = np.asarray(random.sample(deqB, 2))
                inputA = g.get_tensor_by_name("Real_Image_A:0")
                inputB = g.get_tensor_by_name("Real_Image_B:0")
                # Get the generator tensors and their output 
                fake_imageB = g.get_tensor_by_name('AtoB_generator/Tanh:0')
                fake_imageA = g.get_tensor_by_name('BtoA_generator/Tanh:0')
                fakeB, fakeA = sess.run([fake_imageB, fake_imageA], feed_dict= {real_imageA:batchA, real_imageB:batchB})
                # Get reconstructed images
                reconstructedA = g.get_tensor_by_name("BtoA_generator_1/Tanh:0")
                reconstructedB = g.get_tensor_by_name("AtoB_generator_1/Tanh:0")
                reconstructed_imageA, reconstructed_imageB = sess.run([reconstructedA, reconstructedB], feed_dict={real_imageA:batchA, real_imageB:batchB})
                # Shpw the generated and reconstructed images
                for i in range(0,2):
                    save_images(originalA=batchA[i], generatedB=fakeB[i], reconstructedA=reconstructed_imageA[i], originalB=batchB[i], generatedA=fakeA[i], reconstructedB=reconstructed_imageB[i], path="results/gen_{}_{}".format(epoch, i), False)
    except KeyboardInterrupt:
        print("Keyboard interruption. Saving")
        save_path = saver.save(sess, complete_model_path)
        train_writer.close()
    save_path = saver.save(sess, complete_model_path)
    train_writer.close()        


# **EVALUATION**

In [0]:
log_dir =  "logs2/"
checkpoint_save_path =  "checkpoint2/model.ckpt"
meta_graph_path = 'model2/model.ckpt.meta'
complete_model_path = 'model2/model.ckpt'
dataset_path = "data/summer2winter_yosemite"
test_imagesA, test_imagesB = load_images_test(dataset_path, dataset_path)

In [0]:
i = 0
with tf.Session() as sess:
            # Restore variables from disk.    
            saver = tf.train.import_meta_graph(meta_graph_path)
            graph = tf.get_default_graph()
            saver.restore(sess, complete_model_path)
            print("Model restored \n")
            for imageA, imageB in zip(test_imagesA, test_imagesB):
                i += 1
                imageA = np.reshape(imageA, [1, 256,256,3])
                imageB = np.reshape(imageB, [1, 256, 256, 3])
                # Get a batch of test data
                real_imageA = graph.get_tensor_by_name("Real_Image_A:0")
                real_imageB = graph.get_tensor_by_name("Real_Image_B:0")
                # Get the generator tensors and their output 
                fake_imageB = graph.get_tensor_by_name('AtoB_generator/Tanh:0')
                fake_imageA = graph.get_tensor_by_name('BtoA_generator/Tanh:0')
                fakeB, fakeA = sess.run([fake_imageB, fake_imageA], feed_dict= {real_imageA:imageA, real_imageB:imageB})
                # Get reconstructed images
                reconstructedA = graph.get_tensor_by_name("BtoA_generator_1/Tanh:0")
                reconstructedB = graph.get_tensor_by_name("AtoB_generator_1/Tanh:0")
                reconstructed_imageA, reconstructed_imageB = sess.run([reconstructedA, reconstructedB], feed_dict={real_imageA:imageA, real_imageB:imageB})
                save_images(originalA=np.squeeze(imageA), generatedB=np.squeeze(fakeB), reconstructedA=np.squeeze(reconstructed_imageA), originalB=np.squeeze(imageB), generatedA=np.squeeze(fakeA), reconstructedB=np.squeeze(reconstructed_imageB), path="test_results/image_{}".format(i), True)
    

In [0]:
%load_ext tensorboard
%tensorboard --logdir logs/