In [None]:
from collections import defaultdict
from PIL import Image
import numpy as np
import os
import skimage.color as color
import PIL
from matplotlib import pyplot as plt

## Load List of Images

In [None]:
from utils.imgCap import load_images_list
from zipfile import ZipFile

# Extract Data
if not os.path.exists('./Flickr_Data.zip'):
    raise Exception('Dataset not found. Please read instructions above this cell and download dataset.')

if not os.path.exists('./Flickr_Data'):
    print("Extracting data ...")
    ZipFile('./Flickr_Data.zip', 'r').extractall('./')

#Files with names of corresponding images
train_image_list_path = './Flickr_Data/Flickr8k_text/Flickr_8k.trainImages.txt'
test_image_list_path = './Flickr_Data/Flickr8k_text/Flickr_8k.testImages.txt'

train_image_list = load_images_list(train_image_list_path)
test_image_list = load_images_list(test_image_list_path)

print('Total train images:',len(train_image_list))
print('Total test images:', len(test_image_list))

## Method to Convert Image from RGB to LaB

In [None]:
image = Image.open('./Flickr_Data/Flickr8k_Dataset/2903617548_d3e38d7f88.jpg', mode='r')
f = plt.figure()
f.add_subplot(1,2, 1)
plt.imshow(image)
f.add_subplot(1,2, 2)
lab_image = color.rgb2lab(np.asarray(image))
plt.imshow(lab_image[:,:,0],cmap="gray")
plt.show(block=True)

## Encode Image for Resnet Embedding

In [None]:
# Resize the original images to 299*299
# Then, convert them to LaB images
# "L" is the input for the later model and "aB" is the ground truth for model output

train_data = np.zeros([600, 299, 299, 3])
test_data = np.zeros([100, 299, 299, 3])

images_path = './Flickr_Data/Flickr8k_Dataset'
i = 0
for image_name in train_image_list:
        path = images_path + "/" + image_name
        image = Image.open(path, mode='r')
        image = image.resize((299,299))
        x = color.rgb2lab(np.asarray(image))
        x = x/255
        x = x.reshape((1,299,299,3))
        train_data[i,:,:,:] = x
        i += 1
        if i>=600:
            break

i = 0
for image_name in test_image_list:
        path = images_path + "/" + image_name
        image = Image.open(path, mode='r')
        image = image.resize((299,299))
        x = color.rgb2lab(np.asarray(image))
        x = x/255
        x = x.reshape((1,299,299,3))
        test_data[i,:,:,:] = x
        i += 1
        if i>=100:
            break

In [None]:
f = plt.figure()
f.add_subplot(1,2, 1)
plt.imshow(train_data[1,:,:,0],cmap="gray")
f.add_subplot(1,2, 2)
plt.imshow(color.lab2rgb(train_data[1]*255))
plt.show(block=True)

In [None]:
f = plt.figure()
f.add_subplot(1,2, 1)
plt.imshow(test_data[1,:,:,0],cmap="gray")
f.add_subplot(1,2, 2)
plt.imshow(color.lab2rgb(test_data[1]*255))
plt.show(block=True)

## Resnet Embedding and Stacking

In [None]:
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras import Model
pre_trained_model = InceptionResNetV2(weights='imagenet')
feature_extractor = Model(inputs=pre_trained_model.input,outputs=pre_trained_model.layers[-2].output)

In [None]:
train_embeddings = np.zeros([600, 299, 299, 3],dtype='float32')
train_embeddings[:,:,:,0] = train_embeddings[:,:,:,1] = train_embeddings[:,:,:,2] = train_data[:,:,:,0]

test_embeddings = np.zeros([100, 299, 299, 3],dtype='float32')
test_embeddings[:,:,:,0] = test_embeddings[:,:,:,1] = test_embeddings[:,:,:,2] = test_data[:,:,:,0]

train_emb = feature_extractor.predict(train_embeddings)
test_emb = feature_extractor.predict(test_embeddings)

In [None]:
if not os.path.exists('./encoded_images'):
    os.mkdir('./encoded_images')
np.save('./encoded_images/train_emb.npy',train_emb)
np.save('./encoded_images/test_emb.npy',test_emb)

In [None]:
train_emb = np.load('./encoded_images/train_emb.npy')
test_emb = np.load('./encoded_images/test_emb.npy')
train_emb = train_emb.reshape([600,1,1,1536])
test_emb = test_emb.reshape([100,1,1,1536])

In [None]:
train_embeddings = np.zeros([600,28,28,1536],dtype='float32')
for i in range(600):
    for j in range(28):
        train_embeddings[i,j,j,:] = train_emb[i,0,0,:]
        
train_embeddings.shape

In [None]:
test_embeddings = np.zeros([100,28,28,1536],dtype='float32')
for i in range(100):
    for j in range(28):
        test_embeddings[i,j,j,:] = test_emb[i,0,0,:]
        
test_embeddings.shape

## prepare Data for Model

In [None]:
train_data = np.zeros([600, 224, 224, 3])
test_data = np.zeros([100, 224, 224, 3])

images_path = './Flickr_Data/Flickr8k_Dataset'
i = 0
for image_name in train_image_list:
        path = images_path + "/" + image_name
        image = Image.open(path, mode='r')
        image = image.resize((224,224))
        x = color.rgb2lab(np.asarray(image))
        x = x/255
        x = x.reshape((1,224,224,3))
        train_data[i,:,:,:] = x
        i += 1
        if i>=600:
            break

i = 0
for image_name in test_image_list:
        path = images_path + "/" + image_name
        image = Image.open(path, mode='r')
        image = image.resize((224,224))
        x = color.rgb2lab(np.asarray(image))
        x = x/255
        x = x.reshape((1,224,224,3))
        test_data[i,:,:,:] = x
        i += 1
        if i>=100:
            break

In [None]:
np.save('./encoded_images/train_data.npy',train_data)
np.save('./encoded_images/test_data.npy',test_data)

## Define Model

In [None]:
import datetime
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, AveragePooling2D, MaxPooling2D, Concatenate,UpSampling2D
from tensorflow.keras import Model
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.layers import Dense, Dropout, Input, LSTM, Embedding, Add, Bidirectional, Concatenate, RepeatVector, GRU


start = Input(shape=(224,224,1))
encoder = Conv2D(64, (3, 3), activation='relu', padding='same', strides=(2,2))(start)
encoder = Conv2D(128, (3, 3), activation='relu', padding='same', strides=(1,1))(encoder)
encoder = Conv2D(128, (3, 3), activation='relu', padding='same', strides=(2,2))(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same', strides=(1,1))(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same', strides=(2,2))(encoder)
encoder = Conv2D(512, (3, 3), activation='relu', padding='same', strides=(1,1))(encoder)
encoder = Conv2D(512, (3, 3), activation='relu', padding='same', strides=(1,1))(encoder)
encoder = Conv2D(256, (3, 3), activation='relu', padding='same', strides=(1,1))(encoder)

feature_extractor = Input(shape=(28,28,1536))
fusion = Concatenate()([encoder,feature_extractor])
fusion = Conv2D(256, (1, 1), activation='relu', padding='same', strides=1)(fusion)

decoder = Conv2D(128, (3, 3), activation="relu", padding="same")(fusion)
decoder = UpSampling2D((2, 2))(decoder)
decoder = Conv2D(64, (3, 3), activation="relu", padding="same")(decoder)
decoder = Conv2D(64, (3, 3), activation="relu", padding="same")(decoder)
decoder = UpSampling2D((2, 2))(decoder)
decoder = Conv2D(32, (3, 3), activation="relu", padding="same")(decoder)
decoder = Conv2D(2, (3, 3), activation="tanh", padding="same")(decoder)
decoder = UpSampling2D((2, 2))(decoder)

deep_color = Model([start,feature_extractor],decoder)

deep_color.compile(optimizer='Adam',loss='mse',metrics=['accuracy'])

deep_color.summary()
#Train the neural network
#model.fit(x=X, y=Y, batch_size=1, epochs=3000)
#print(model.evaluate(X, Y, batch_size=1))

## Train Model

In [None]:
train_data = np.load('./encoded_images/train_data.npy')

In [None]:
train_input = train_data[:,:,:,0].reshape([600,224,224,1])
train_input.shape

In [None]:
train_embeddings.shape

In [None]:
train_output = train_data[:,:,:,1:]
train_output.shape

In [None]:
deep_color.fit([train_input,train_embeddings],train_output, epochs=20,batch_size=60)