In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
! kaggle datasets download -d ktaebum/anime-sketch-colorization-pair #download the dateset

In [None]:
! unzip anime-sketch-colorization-pair.zip

In [None]:
! mkdir data_2

! mkdir data_2/train
! mkdir data_2/train/Images
! mkdir data_2/train/Sketch

!mkdir data_2/validation
!mkdir data_2/validation/Images
!mkdir data_2/validation/Sketch

In [None]:
%tensorflow_version 1.x

In [None]:
"""
Importing the required libraries.
"""
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm
import cv2
import os
import random

In [None]:


# In[10]:


"""
Displaying the sample sketch and color images.
"""
for file in glob.glob('data/train/*.png')[:5]:
    f, a = plt.subplots(1,2, figsize=(10,5))
    a = a.flatten()
    
    img = Image.open(file).convert('RGB')
    a[0].imshow(img.crop((0, 0, 512,512))); a[0].axis('off');
    a[1].imshow(img.crop((512, 0, 1024, 512))); a[1].axis('off');
    
    plt.show()
    print(file)


# In[ ]:



"""
Preprocessing and saving the training data to corresponding directory. 
"""
for idx, file in tqdm(enumerate(glob.glob('data/train/*.png'))):
    img = Image.open(file).convert('RGB')
    
    img.crop((0, 0, 512,512)).save('data_2/train/Images/{}.png'.format(idx))
    img.crop((512, 0, 1024, 512)).save('data_2/train/Sketch/{}.png'.format(idx))



"""
Preprocessing and saving the validation/test data to corresponding directory. 
"""
for idx, file in tqdm(enumerate(glob.glob('data/val/*.png'))):
    img = Image.open(file).convert('RGB')
    
    img.crop((0, 0, 512,512)).save('data_2/validation/Images/{}.png'.format(idx))
    img.crop((512, 0, 1024, 512)).save('data_2/validation/Sketch/{}.png'.format(idx))

In [None]:
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.activations import tanh, sigmoid
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import concatenate, ReLU, GlobalAveragePooling2D, Input, Dense, Dropout, Flatten, LeakyReLU, Conv2D, BatchNormalization, Conv2DTranspose
import tensorflow as tf

In [None]:
img_paths = glob.glob('data_2/train/Images/*.png')
sketch_paths = glob.glob('data_2/train/Sketch/*.png')

img_paths.sort()
sketch_paths.sort()

img_paths = np.array(img_paths)# [ids]
sketch_paths = np.array(sketch_paths)# [ids]
def generate_samples(sketch_paths, img_paths, n_samples):
  """
  A function to load black-and-white sketches and colored images.
  The function that loads the black-and-white sketches and corresponding colored images 
  from the given paths for training the GAN.
  Parameters:
    sketch_paths(numpy.array): The paths to the black-and-white sketches i.e input images.
    image_paths(numpy.array): The paths to the colored images i.e target images.
    n_samples(int): The # samples to load for training process.
  Returns:
    X_sketches(numpy.array): The loaded black-and-white sketches.
    X_images(numpy.array): The loaded colored images.
  """

  idxs = np.random.randint(0, 14223, n_samples)
  # print(idxs)
  X_sketches = []
  X_images = []
  
  for sket, img in zip(sketch_paths[idxs], img_paths[idxs]):
    X_sketches.append(np.array(Image.open(sket).convert('RGB')))
    X_images.append(np.array(Image.open(img).convert('RGB')))
  
  # Normalizing the values to be between [-1, 1].
  X_sketches = (np.array(X_sketches, dtype='float32')-127.5)/127.5
  X_images = (np.array(X_images, dtype='float32')-127.5)/127.5
	
  return  X_images, X_sketches

In [None]:
if K.image_data_format == "channels_first":
  input_dim = 3,512,512
else:
  input_dim = 512,512,3

In [None]:
adam = Adam(lr = 0.0002, beta_1 = 0.5)

In [None]:
input_1 = Input(input_dim)


generator_1 = Conv2D(16, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(input_1)
activation_1 = LeakyReLU(0.2)(generator_1)

generator_2 = Conv2D(32, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_1)
batch_2 = BatchNormalization()(generator_2)
activation_2 = LeakyReLU(0.2)(batch_2)

generator_3 = Conv2D(64, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_2)
batch_3 = BatchNormalization()(generator_3)
activation_3 = LeakyReLU(0.2)(batch_3)


generator_4 = Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_3)
batch_4 = BatchNormalization()(generator_4)
activation_4 = LeakyReLU(0.2)(batch_4)

generator_5 = Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_4)
batch_5 = BatchNormalization()(generator_5)
activation_5 = LeakyReLU(0.2)(batch_5)

generator_6 = Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_5)
batch_6 = BatchNormalization()(generator_6)
activation_6 = LeakyReLU(0.2)(batch_6)

generator_7 = Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_6)
batch_7 = BatchNormalization()(generator_7)
activation_7 = LeakyReLU(0.2)(batch_7)

generator_8 = Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_7)
batch_8 = BatchNormalization()(generator_8)
activation_8 = LeakyReLU(0.2)(batch_8)

##Decorder

generator_9 = Conv2DTranspose(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(activation_8)
batch_9 = BatchNormalization()(generator_9)
activation_9 = LeakyReLU(0.2)(batch_9)

generator_10 = Conv2DTranspose(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([activation_9,generator_7]))
batch_10 = BatchNormalization()(generator_10)
activation_10 = LeakyReLU(0.2)(batch_10)
dropout_1 = Dropout(0.5)(activation_10)


generator_11 = Conv2DTranspose(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([dropout_1, generator_6]))
batch_11 = BatchNormalization()(generator_11)
activation_11 = LeakyReLU(0.2)(batch_11)
dropout_2 = Dropout(0.5)(activation_11)



generator_12 = Conv2DTranspose(128, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([dropout_2,generator_5]))
batch_12 = BatchNormalization()(generator_12)
activation_12 = LeakyReLU(0.2)(batch_12)
dropout_3 = Dropout(0.5)(activation_12)


generator_13 = Conv2DTranspose(64, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([dropout_3, generator_4]))
batch_13 = BatchNormalization()(generator_13)
activation_13 = LeakyReLU(0.2)(batch_13)

generator_14 = Conv2DTranspose(32, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([activation_13, generator_3]))
batch_14 = BatchNormalization()(generator_14)
activation_14 = LeakyReLU(0.2)(batch_14)

generator_15 = Conv2DTranspose(16, 4, strides=(2,2), padding='same',kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([activation_14, generator_2]) )
batch_15 = BatchNormalization()(generator_15)
activation_15 = LeakyReLU(0.2)(batch_15)

generator_16 = Conv2DTranspose(3, 4, padding='same', strides=(2,2),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02))(concatenate([activation_15, generator_1]))
activation_16 = tanh(generator_16)


generator = Model(inputs = input_1, outputs = activation_16)

generator.summary()

In [None]:
inp1 = Input(shape = input_dim) # sketch input
inp2 = Input(shape = input_dim) # colored input

inp = concatenate([inp1,inp2])

discriminator_1 = Conv2D(16, 4, strides=(2,2),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='same')(inp)
dactivation_1 = LeakyReLU(0.2)(discriminator_1)

discriminator_2 = Conv2D(32, 4, strides=(2,2),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='same')(dactivation_1)
dbatch_1 = BatchNormalization()(discriminator_2)
dactivation_2 = LeakyReLU(0.2)(dbatch_1)

discriminator_3 = Conv2D(64, 4, strides=(2,2),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='same')(dactivation_2)
dbatch_2 = BatchNormalization()(discriminator_3)
dactivation_3 = LeakyReLU(0.2)(dbatch_2)

discriminator_4 = Conv2D(128, 4, strides=(2,2),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='same')(dactivation_3)
dbatch_3 = BatchNormalization()(discriminator_4)
dactivation_4 = LeakyReLU(0.2)(dbatch_3)

discriminator_5 = Conv2D(128, 2, strides=(1,1),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='valid')(dactivation_4)
dbatch_4 = BatchNormalization()(discriminator_5)
dactivation_5 = LeakyReLU(0.2)(dbatch_4)

discriminator_6 = Conv2D(1, 2, strides=(1,1),kernel_initializer=tf.keras.initializers.truncated_normal(stddev=.02), padding='valid')(dactivation_5)
dactivation_6 = sigmoid(discriminator_6)

average_layer = GlobalAveragePooling2D()(dactivation_6)

discriminator = Model(inputs = [inp1,inp2], outputs = average_layer)


discriminator.compile(loss='binary_crossentropy',
              optimizer=adam,
              metrics=['accuracy'])

discriminator.summary()

In [None]:
vgg = VGG16(weights='imagenet')
vgg_net1 = Model(inputs=vgg.input, outputs=ReLU()(vgg.get_layer('block2_conv2').output))
vgg_net2 = Model(inputs=vgg.input, outputs=ReLU()(vgg.get_layer('block2_conv2').output))

In [None]:
vgg_net1.summary()

In [None]:
def totalVariation_loss(y, g):
  """
  A loss for smoothness and to remove noise from the output image.
  Custom loss for getting similar colors used in the training data. 
  Parameters:
    y (Tensor): The target images to be generated.
    g (Tensor): The output images by generator.
  
  Returns:
    function: The reference to the loss function of prototype that 
      keras requires.
  """
  import tensorflow.keras.backend as K
  
  def finalTVLoss(y_true, y_pred):
    return K.abs( K.sqrt( K.sum(K.square(g[:, 1:, :, :] - g[:, :-1, :, :])) +\
                          K.sum(K.square(g[:, :, 1:, :] - g[:, :, :-1, :])) ) )
  
  return finalTVLoss

In [None]:
def featureLevel_loss(y, g):
  """
  A loss for features extracted from 4th layer of VGG16.
  Custom loss for extracting high level features of the target 
  colored and generated colored images.
  Parameters:
    y (Tensor): The target images to be generated.
    g (Tensor): The output images by the generator.
  
  Returns:
    function: The reference to the loss function of prototype 
      that keras requires.
  """
  import tensorflow.keras.backend as K
  
  def finalFLoss(y_true, y_pred):
    return K.mean( K.sqrt( K.sum( K.square( y - g ) ) ) )
  
  return finalFLoss

In [None]:
def pixelLevel_loss(y, g):
  """
  A loss for getting proper images by comparing each pixel.
  Custom loss for Pixel2Pixel level translation so that colors don't 
  come out the edges of generated images.
  Parameters:
    y (Tensor): The real target images to be generated.
    g (Tensor): The output images by the generator.
  
  Returns:
    function: The reference to the loss function of the prototype 
      that keras requires.
  """
  import tensorflow.keras.backend as K
  
  def finalPLLoss(y_true, y_pred):
    return K.mean( K.abs( y - g ) )
  
  return finalPLLoss

In [None]:
discriminator.trainable = False

ganInput = Input(input_dim)
x = generator([ganInput])

ganOutput = discriminator([ganInput,x])

color_inp = Input(input_dim)

pixelLevelLoss = pixelLevel_loss(color_inp, x) # pixel loss
  

totalVariationLoss = totalVariation_loss(color_inp, x)#total variation loss
 
net1_outp = vgg_net1([tf.image.resize(color_inp, (224, 224), tf.image.ResizeMethod.BILINEAR)]) # feature loss
net2_outp = vgg_net2([tf.image.resize(x, (224, 224), tf.image.ResizeMethod.BILINEAR)])

featureLevelLoss = featureLevel_loss(net1_outp,net2_outp)



gan = Model(inputs=[ganInput, color_inp], outputs=ganOutput)

gan.compile(loss=lambda y_true, y_pred : tf.keras.losses.binary_crossentropy(y_true, y_pred) + 
                                             100 * pixelLevelLoss(y_true, y_pred) + 
                                             0.0001 * totalVariationLoss(y_true, y_pred) + 
                                             0.01 * featureLevelLoss(y_true, y_pred), optimizer=adam)

gan.summary()

In [None]:
def generate_and_plot():
  num_examples = 5
  _,fake_image = generate_samples(sketch_paths, img_paths, 8)

  generated_images = generator.predict(fake_image)
  generated_images = (generated_images + 1) / 2.0


  plt.figure(figsize=(10,10))
  for i in range(num_examples):
    plt.subplot(2, 5, i+1)
    plt.imshow(generated_images[i])
    plt.axis("off")
  
  plt.show()

In [None]:
# d_total_loss = []
def train(epochs = 1, batch_size = 8):
  m = 14224 #number of training samples
  batch_count = m//batch_size #1778

  generate_and_plot()

  for e in range(epochs):
    print(f"Epoch: {e}")
    for j in tqdm(range(batch_count)):
      ##train Discriminator
      if not j%2:
        X_real_imgs, X_real_skets = generate_samples(sketch_paths, img_paths, 8)

        y_real = np.zeros((batch_size,1))
        y_real[:] = 0.9
  
        d_loss_1, _ = discriminator.train_on_batch([X_real_skets, X_real_imgs], y_real)
      
      if not j%3:

        _, X_fake_skets = generate_samples(sketch_paths, img_paths, 8)

        y_fake = np.zeros((batch_size,1))

        generated_images = generator.predict(X_fake_skets) 

        d_loss_2, _ = discriminator.train_on_batch([X_fake_skets, generated_images], y_fake)

    ##Gan training
      X_gan_imgs, X_gan_skets  = generate_samples(sketch_paths, img_paths, 8)
      
      y_gan = np.ones((batch_size,1))
      gan_loss = gan.train_on_batch([X_gan_skets, X_gan_imgs], y_gan)
    
    
    if e%5==0:
      print(d_loss_1,d_loss_2, gan_loss)
      model_save_name = f'{e}_model.h5' ##saving model after 5 epochs
      path = f"/content/gdrive/My Drive/{model_save_name}" 
      gan.save(path)
      generate_and_plot()

In [None]:
train(35)

In [None]:
img_paths = glob.glob('data_2/validation/Images/*.png')
sketch_paths = glob.glob('data_2/validation/Sketch/*.png')

img_paths.sort()
sketch_paths.sort()

img_paths = np.array(img_paths)# [ids]
sketch_paths = np.array(sketch_paths)# [ids]

def generate_validate(sketch_paths, img_paths, n_samples):
  idxs = np.random.randint(0,3545 , n_samples)
  # print(idxs)
  X_sketches = []
  X_images = []
  
  for sket, img in zip(sketch_paths[idxs], img_paths[idxs]):
    X_sketches.append(np.array(Image.open(sket).convert('RGB')))
    X_images.append(np.array(Image.open(img).convert('RGB')))
  
  # Normalizing the values to be between [-1, 1].
  X_sketches = (np.array(X_sketches, dtype='float32')-127.5)/127.5
  X_images = (np.array(X_images, dtype='float32')-127.5)/127.5
	
  return  X_images, X_sketches


In [None]:
def generate_and_plot_validate():
 
  real_image,fake_image = generate_validate(sketch_paths, img_paths, 8)

  generated_images = generator.predict(fake_image)
  generated_images = (generated_images + 1) / 2.0
  real_image = (real_image+1)/2.0
  fake_image = (fake_image+1)/2.0

  f, a = plt.subplots(8, 3, figsize=(12,60)); a = a.flatten()
  idx = 0
 
  for color,gen,sket in zip(real_image,generated_images,fake_image):
    a[idx].imshow(color);a[idx].axis('off')
    a[idx+1].imshow(gen); a[idx+1].axis('off')
    a[idx+2].imshow(sket); a[idx+2].axis('off')
    idx +=3
  plt.show()

In [None]:
generate_and_plot_validate()