In [None]:
import tensorflow as tf

"""# Input Image Dimension"""

# Change the dimension below if you want to change the input dimension
# if you are changing the dimension than be carefull to tweak the model
# architecture in model.py

HEIGHT = 256
WIDTH  = 256

"""# Model"""

# Don't Change kernel size if you are not sure about what it is
# basically this is the Size of Convolution kernels
# In the original implementation of 'Let There be Colour' kernel size was kept
# as (3,3) so I have left it as it was in original

KERNEL_SIZE = (3,3) 

# I tried using multiple kinds of activation functions but found that sigmoid
# turns out to be a better choice initially I started with tanh activation than
# noticed that output images are becoming more red also since tanh has a range
# of -1 to 1 it is not good for final output layer if your input images are 
# scaled between 0 to 1
  
ACTIVATION_FUNCTION = 'sigmoid'

# Loss Functions that I tried were MSE , MASE, and MAE. I found that MSE is 
# better that or equal to all other in results So I have kept it below although
# MASE will also produce similar results 

LOSS_FUNCTION = tf.keras.losses.MeanSquaredError()

LEARNING_RATE = 1e-3

"""# Paths"""

TRAIN_DIR_PATH   = "/kaggle/input/ukraine-images-2023"
TEST_DIR_PATH    = ""
VAL_DIR_PATH     = ""
SAVE_MODEL_PATH  = ""
SAVE_OUTPUT_PATH = ""
LOAD_MODEL_PATH  = ""
SAVE_CSV_PATH    = ""
LOAD_CSV_PATH    = ""

"""# other """

# NUMBER_OF_TRAINING_EXAMPLES = 500
NUMBER_OF_TRAINING_EXAMPLES = 20
NUMBER_OF_TEST_EXAMPLES     = 40
NUMBER_OF_VAL_EXAMPLES      = 50
# NUMBER_OF_EPOCHS            = 100
NUMBER_OF_EPOCHS            = 1
STEPS_PER_EPOCHS            = 100
VALIDATION_STEPS            = 20
BATCH_SIZE                  = 50

In [26]:

import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img


def loadImagesToArray(dir_path, num_of_img=-1, search_inside =False):
  """
  dir_path : path of directory from which images will be imported
  num_of_imgs (Integer): number of images to be imported from the directory if not 
              given than all images will be imported 
  search_inside (boolean, default : False) : If true all images inside that directory
              along with the images in subdirectory will be added to output array
  """
  images = []
  count = -1
  if search_inside==False:
      for filename in os.listdir(dir_path):
          count+=1
          if(count==num_of_img):
              break
          images.append(img_to_array(load_img(dir_path+os.sep+filename).resize((256, 256), resample=3)))
  if search_inside==True:
      for root,dirs,files in os.walk(dir_path):
        for filename in files:
            count+=1
            if(count==num_of_img):
                break
            images.append(img_to_array(load_img(root+os.sep+filename).resize((256, 256), resample=3)))
  # print([img.shape for img in images])
  return np.array(images,dtype=float)/255.0

def DataGenerator():
    DataGen = ImageDataGenerator(        
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)
    return DataGen

def RGB2GRAY(img,add_channel_dim=False):
  conv_matrix = np.array([0.212671 ,0.715160,0.072169])
  gray_img = img @ conv_matrix
  if add_channel_dim==True:
    return gray_img.reshape(np.array([*list(gray_img.shape),1]))
  else:
    return gray_img

def RGB2ab(img,use_skimage=True):
  """
  Refrences
  * https://en.wikipedia.org/wiki/Lab_color_space
  * https://github.com/scikit-image/scikit-image/blob/main/skimage/color/colorconv.py#L990-L1050
  """
  if use_skimage==False:
    def finv(cie):
      cond = cie > 0.008856
      cie[cond] = np.cbrt(cie[cond])
      cie[~cond] = 7.787 * cie[~cond] + 16. / 116.
      return cie     

    conv_matrix =np.array( [[0.412453, 0.357580, 0.180423],
            [0.212671, 0.715160, 0.072169],
            [0.019334, 0.119193, 0.950227]])
    CIE = np.matmul(img,conv_matrix.T)
    CIE[0] = CIE[0]/0.95047
    CIE[2] = CIE[2]/1.08883
    CIE = finv(CIE)
    x, y, z = CIE[..., 0], CIE[..., 1], CIE[..., 2]
    a =  (500*(x-y)+127)/255.0
    b =  (200*(y-z)+127)/255.0
    return np.concatenate([x[..., np.newaxis] for x in [a, b]], axis=-1)
  else:
    Lab = rgb2lab(img)
    a = (Lab[...,1]+127)/255.0
    b = (Lab[...,2]+127)/255.0
    return np.concatenate([x[..., np.newaxis] for x in [a, b]], axis=-1)

def Lab2RGB(gray,ab):
  """
    Parameters
    ----------
    gray : nd array
        lumminnance component of a image.
    ab : TYPE
        a and b componenets of a CIE L*a*b image.

    Returns
    -------
    ndarray with R G B components.
  """
  ab = ab*255.0 -127
  gray = gray*100
  Lab =np.concatenate([x[..., np.newaxis] for x in [gray[...,0], ab[...,0],ab[...,1]]], axis=-1)
  return lab2rgb(Lab)

def compare_results(img_gt,img_in,img_out,save_results=False,save_as=""):
  """
    Parameters
    ----------
    img_gt : ndarray with RGB components 
        Original Required image model is expected to produce this as ouput.
    img_in : grayscaled ndarray
        image used as input to the model.
    img_out : nd array with RGB componets.
        The ouput from the model.
    save_results : boolean, optional
        If True matplotlib.plt will be used to save model. The default is False.
    save_as : String, optional
        Output file name along with path. The default is "".

    Returns
    -------
    None.

  """
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img_gt)
  ax1.set_title('Ground Truth')
  ax2.imshow(img_in,cmap='gray')
  ax2.set_title('Input')
  ax3.imshow(img_out)
  ax3.set_title('Output')
  axes = [ax1,ax2,ax3]
  for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
  plt.show()
  if save_results==True:
    path = save_as+'.svg'
    fig.savefig(path,dpi=300)

def BatchGenerator(data,imgDataGen,batch_size=64):
  for batch in imgDataGen.flow(data, batch_size=batch_size):
    yield RGB2GRAY(batch,True), RGB2ab(batch)


    

In [27]:


import tensorflow as tf
# import config
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, RepeatVector,Reshape,Dense, Flatten, Input, Concatenate


def build_model(ks=(3,3),act='sigmoid',learning_rate=1e-2):
    
  # Input Layer
  input_lvl = Input(shape = (HEIGHT,WIDTH,1))
  
  # Initial Shared Network of Low - Level Features
  low_lvl = Conv2D(64 ,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(input_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(128,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(128,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(256,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(256,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)

  # Path one for  Mid-Level Features Network
  mid_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl)
  mid_lvl = layers.BatchNormalization()(mid_lvl)
  mid_lvl = Conv2D(256,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(mid_lvl)
  mid_lvl = layers.BatchNormalization()(mid_lvl)

  # Path two for Global Features Network
  global_lvl = Conv2D(512,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Flatten()(global_lvl) 
  global_lvl = Dense(1024,activation=act)(global_lvl)
  global_lvl = Dense(512 ,activation=act)(global_lvl)
  global_lvl = Dense(256 ,activation=act)(global_lvl)
  

  # Fusing the output of above two paths
  fusion_lvl = RepeatVector(mid_lvl.shape[1] * mid_lvl.shape[1])(global_lvl) 
  fusion_lvl = Reshape(([mid_lvl.shape[1],mid_lvl.shape[1]  , 256]))(fusion_lvl)
  fusion_lvl = Concatenate( axis=3)([mid_lvl, fusion_lvl]) 
  fusion_lvl = Conv2D(256, kernel_size=ks,strides =(1, 1), activation=act,padding='SAME')(fusion_lvl)

  # Colorization Network
  # Instead of UpSampling Layers I am using 2D Convolutional Transpose or deconv for upscaling the images 
  color_lvl = Conv2DTranspose(128,kernel_size = ks,strides = (1,1),padding='SAME',activation=act)(fusion_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(64,kernel_size = ks,strides = (2,2),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(64,kernel_size = ks,strides = (1,1),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(32,kernel_size = ks,strides = (2,2),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  
  # I added the below mentioned two lines when I trained the model for 100 X 100 sized images
  # Ignore if you are using 256 X 256
  # color_lvl = Conv2D(32,kernel_size = ks,strides = (1,1),padding='VALID',activation=act)(color_lvl)
  # color_lvl = layers.BatchNormalization()(color_lvl)

  # Output Layer
  output_lvl = Conv2DTranspose(2,kernel_size=ks,strides=(2,2),padding='SAME',activation='sigmoid')(color_lvl)


  # Model Parameters
  model = Model(inputs = input_lvl, outputs = output_lvl)
  optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  model.compile(
      loss = LOSS_FUNCTION,
      optimizer = optimizer,
      metrics = ['accuracy',
          tf.keras.metrics.CosineSimilarity()
          ])
  return model

In [28]:

import os
# import config
# import tools
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
# import modelArchitecture as net


train_data = loadImagesToArray(
    dir_path   = TRAIN_DIR_PATH,
    num_of_img = NUMBER_OF_TRAINING_EXAMPLES,
    search_inside = True)

val_data = loadImagesToArray(
    dir_path   = VAL_DIR_PATH,
    num_of_img = NUMBER_OF_VAL_EXAMPLES,
    search_inside = True)


imgDataGen = DataGenerator()
valimgDataGen = DataGenerator()

"""# Training """
colorize_model = build_model(
    ks = KERNEL_SIZE,
    act=ACTIVATION_FUNCTION,
    learning_rate=LEARNING_RATE)

# Load Previously trained model
try:
    colorize_model.load_weights(LOAD_MODEL_PATH)
except:
    pass

log = tf.keras.callbacks.CSVLogger(SAVE_CSV_PATH,append=True, separator=',')
callbacks = [log]

print(train_data.shape)

history = colorize_model.fit(
    BatchGenerator(train_data, imgDataGen,BATCH_SIZE),
    validation_data = BatchGenerator(val_data, valimgDataGen),
    validation_steps=VALIDATION_STEPS,
    steps_per_epoch =STEPS_PER_EPOCHS,
    epochs=NUMBER_OF_EPOCHS,
    callbacks=callbacks)


#Plotting and saving the history
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.show()


colorize_model.save(SAVE_MODEL_PATH)

"""# Testing """
test_images = loadImagesToArray(
    dir_path = TEST_DIR_PATH,
    num_of_img = NUMBER_OF_TEST_EXAMPLES,
    search_inside=True)

gray = RGB2GRAY(test_images,True)
gray2 = RGB2GRAY(test_images)
pred = colorize_model.predict(gray)

for i in range(NUMBER_OF_TEST_EXAMPLES):
    output = Lab2RGB(gray[i],pred[i])
    path =  SAVE_OUTPUT_PATH+os.sep+"img_"+str(i)
    compare_results(test_images[i],gray2[i],output.reshape(test_images[i].shape),save_results=True,save_as=path)

(20, 256, 256, 3)

ValueError: ('Input data in `NumpyArrayIterator` should have rank 4. You passed an array with shape', (0,))