In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import tensorflow as tf; tf.compat.v1.disable_eager_execution()
from keras import backend as K
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, Flatten, Lambda, Reshape
from keras.models import Model
from keras.losses import binary_crossentropy

In [None]:
import os
from tqdm import tqdm
import cv2
from google.colab.patches import cv2_imshow
from tensorflow.keras.utils import img_to_array
from keras import regularizers
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import keras
from keras.optimizers import rmsprop_v2
import keras
from keras.layers import Conv2DTranspose, ConvLSTM2D, BatchNormalization, TimeDistributed, Conv2D
from keras.models import Sequential, load_model
from keras.layers import LayerNormalization
from skimage import color
from skimage.transform import resize, rotate

In [None]:
SIZE = 256
# /content/drive/MyDrive/CUHK Dataset/CUHK_testing_cropped_photos
image_path = '/content/drive/MyDrive/CUHK Dataset/CUHK_training_cropped_photos/'
img_array = []

sketch_path = '/content/drive/MyDrive/CUHK Dataset/CUHK_training_cropped_sketches/'
sketch_array = []

test_image_path = '/content/drive/MyDrive/CUHK Dataset/CUHK_testing_cropped_photos/'
test_img_array = []

test_sketch_path = '/content/drive/MyDrive/CUHK Dataset/CUHK_testing_cropped_sketches/'
test_sketch_array = []

# Image and their corresponding file names in a sorted manner based on their names
image_file = sorted(os.listdir(image_path))
sketch_file = sorted(os.listdir(sketch_path))
test_image_file = sorted(os.listdir(test_image_path))
test_sketch_file = sorted(os.listdir(test_sketch_path))


In [None]:
def image_preprocessing(file_name, img_path, size):
  storage_array = []
  for img_file in tqdm(file_name):
    x = img_path + img_file
    img = (cv2.resize(cv2.imread(x,cv2.IMREAD_COLOR).astype('float32'),(SIZE,SIZE)))/255.0
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    storage_array.append(img)
    
  # Returning storage array where we have stored all our pre-processed images
  return np.array(storage_array)
  

In [None]:
img_array = image_preprocessing(image_file, image_path, SIZE)
sketch_array = image_preprocessing(sketch_file, sketch_path, SIZE)
test_img_array = image_preprocessing(test_image_file, test_image_path, SIZE)
test_sketch_array = image_preprocessing(test_sketch_file, test_sketch_path, SIZE)

In [None]:
# Print the number of colored and sketch images present
print("Total number of Training Colored images:",len(img_array))
print("Total number of Training sketch images:",len(sketch_array))
print("Total number of Testing Colored images:",len(test_img_array))
print("Total number of Testing sketch images:",len(test_sketch_array))

In [None]:
# Converting the image arrays into numpy for easy processing 
img_array_n = np.array(img_array)
sketch_array_n = np.array(sketch_array)

test_img_array_n = np.array(test_img_array)
test_sketch_array_n = np.array(test_sketch_array)

# Printing the shapes of the image 
print("The shape of the train colored image array is:", img_array_n.shape)
print("The shape of the train sketched image array is:", sketch_array_n.shape)

print("The shape of the test colored image array is:", test_img_array_n.shape)
print("The shape of the test sketched image array is:", test_sketch_array_n.shape)

In [None]:
# Plotting the training images to see
some_photos = np.concatenate([i for i in img_array_n[:5]],axis=1)
some_sketches = np.concatenate([i for i in sketch_array_n[:5]],axis=1)
plt.figure(figsize=(20,10))
plt.imshow(np.concatenate([some_photos,some_sketches]))
plt.axis("OFF")
plt.show()

In [None]:
# Plotting the testing images to see
some_photos = np.concatenate([i for i in test_img_array_n[:10]],axis=1)
some_sketches = np.concatenate([i for i in test_sketch_array_n[:10]],axis=1)
plt.figure(figsize=(20,10))
plt.imshow(np.concatenate([some_photos,some_sketches]))
plt.axis("OFF")
plt.show()

In [None]:
(X_train, y_train), (X_test, y_test) = (img_array_n, sketch_array_n), (test_img_array_n, test_sketch_array_n)

In [None]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

## Construction of Encoder

In [None]:
img_height   = X_train.shape[1]    # 256
img_width    = X_train.shape[2]    # 256
num_channels = X_train.shape[3]    # 3

input_shape =  (img_height, img_width, num_channels)   # (256, 256, 3) 
latent_dim = 16   # Dimension of the latent space

In [None]:
encoder_input = Input(shape=input_shape)

encoder_conv = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_input)

encoder_conv = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder_conv = Conv2D(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.BatchNormalization()(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder_conv = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder_conv = Conv2D(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.BatchNormalization()(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder_conv = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.BatchNormalization()(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder_conv = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', activation='relu')(encoder_conv)
encoder_conv = keras.layers.BatchNormalization()(encoder_conv)
encoder_conv = keras.layers.LeakyReLU()(encoder_conv)

encoder = Flatten()(encoder_conv)

mu = Dense(latent_dim)(encoder)
sigma = Dense(latent_dim)(encoder)

## To determine the values in the latent space layer

In [None]:
def compute_latent(x):
    mu, sigma = x
    batch = K.shape(mu)[0]
    dim = K.int_shape(mu)[1]
    eps = K.random_normal(shape=(batch,dim))
    return mu + K.exp(sigma/2)*eps

## Reparameterization 

In [None]:
latent_space = Lambda(compute_latent, output_shape=(latent_dim,))([mu, sigma])

In [None]:
conv_shape = K.int_shape(encoder_conv)

In [None]:
conv_shape

## Construction of Decoder

In [None]:
decoder_input = Input(shape=(latent_dim,))

decoder = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)

decoder = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(decoder)

decoder_conv = Conv2DTranspose(filters=512, kernel_size=3, strides=2, padding='same', activation='relu')(decoder)
decoder_conv = keras.layers.Dropout(0.1)(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=256, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)
decoder_conv = keras.layers.Dropout(0.1)(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)
decoder_conv = keras.layers.LeakyReLU()(decoder_conv)

decoder_conv = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(decoder_conv)

decoder_conv =  Conv2DTranspose(filters=num_channels, kernel_size=3, padding='same', activation='sigmoid')(decoder_conv)

## Connecting the encoder and decoder

In [None]:
encoder = Model(encoder_input, latent_space)
decoder = Model(decoder_input, decoder_conv)

## The output of vae model is the output of decoder in which its input is taken from the output of encoder

In [None]:
vae = Model(encoder_input, decoder(encoder(encoder_input)))

## Summary of autoencoder

In [None]:
vae.summary()

## Summary of encoder

In [None]:
encoder.summary()

## Summary of decoder

In [None]:
decoder.summary()

## Defining loss function

In [None]:
def kl_reconstruction_loss(true, pred):    # Reconstruction loss
    reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height    
    
    # KL divergence loss
    kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5    
    
    # Total loss = 50% rec + 50% KL divergence loss
    return K.mean(reconstruction_loss + kl_loss)

## Compiling the model

In [None]:
vae.compile(optimizer='adam', loss=kl_reconstruction_loss)

## Training the model

In [None]:
history = vae.fit(x=X_train, y=y_train, epochs=100, batch_size=16, validation_data=(X_test,y_test))

## Loss value of both train and test data

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['Training loss', 'Validation loss'])


## Displaying latent space

In [None]:
encoded = encoder.predict(X_train)

## Decoding data points in latent space

In [None]:
def display_image_sequence(x_start,y_start,x_end,y_end,no_of_imgs):
    x_axis = np.linspace(x_start,x_end,no_of_imgs)
    y_axis = np.linspace(y_start,y_end,no_of_imgs)
    
    x_axis = x_axis[:, np.newaxis]
    y_axis = y_axis[:, np.newaxis]
    
    new_points = np.hstack((x_axis, y_axis))

    print(new_points.shape)

    new_images = decoder.predict(new_points)
    print(new_images.shape)
    
    new_images = new_images.reshape(new_images.shape[0], new_images.shape[1], new_images.shape[2], new_images.shape[3])
    
    # Display some images
    fig, axes = plt.subplots(ncols=no_of_imgs, sharex=False, sharey=True, figsize=(20, 7))
    counter = 0
    for i in range(no_of_imgs):
        axes[counter].imshow(new_images[i], cmap='gray')
        axes[counter].get_xaxis().set_visible(False)
        axes[counter].get_yaxis().set_visible(False)
        counter += 1
    plt.show()

In [None]:

some_photos = []
sketch_photos = []
for i in range(9, 90, 7):
  some_photos.append(y_test[i])
  sketch_photos.append(vae.predict(X_test)[i])


X = np.concatenate(some_photos, axis=1)
Y = np.concatenate(sketch_photos, axis=1)

plt.figure(figsize=(20,10))
plt.imshow(np.concatenate([X,Y]))
plt.axis("OFF")

plt.show()

In [None]:
for i in range(80,90):
  img_no = i
  fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(5, 3))
  axes[0].imshow(X_test[img_no])
  axes[1].imshow(y_test[img_no])
  axes[2].imshow(vae.predict(X_test)[img_no])
  
  fig.tight_layout()