In [None]:
from os import listdir, makedirs
from os.path import join, isfile, isdir, splitext
from PIL import Image
import skimage.color as color
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf

In [None]:
from utils.imgCap import load_images_list
from zipfile import ZipFile

# Extract Data
if not os.path.exists('./Flickr_Data.zip'):
    raise Exception('Dataset not found. Please read instructions above this cell and download dataset.')

if not os.path.exists('./Flickr_Data'):
    print("Extracting data ...")
    ZipFile('./Flickr_Data.zip', 'r').extractall('./')

#Files with names of corresponding images
train_image_list_path = './Flickr_Data/Flickr8k_text/Flickr_8k.trainImages.txt'
test_image_list_path = './Flickr_Data/Flickr8k_text/Flickr_8k.testImages.txt'

train_image_list = load_images_list(train_image_list_path)
test_image_list = load_images_list(test_image_list_path)

print('Total train images:',len(train_image_list))
print('Total test images:', len(test_image_list))

In [None]:
images_path = './Flickr_Data/Flickr8k_Dataset'
filename='667626_18933d713e.jpg'

In [None]:
img=Image.open(join(images_path,filename))
f=plt.figure()
f.add_subplot(1,2,1)
plt.imshow(img)
img=img.resize((299,299))
f.add_subplot(1,2,2)
plt.imshow(img)
plt.show()

# Resize all Images to 299x299

In [None]:
from utils.resize import ImageResizer
images_path = './Flickr_Data/Flickr8k_Dataset'
resized_dir='./Flickr_Data/resized'
size=(299,299)
resizer=ImageResizer(source_dir=images_path,
                    dest_dir=resized_dir).resize_all(size=size)

In [None]:
dev_list = './Flickr_Data/Flickr8k_text/Flickr_8k.devImages.txt'
test_list = './Flickr_Data/Flickr8k_text/Flickr_8k.testImages.txt'
train_list = './Flickr_Data/Flickr8k_text/Flickr_8k.trainImages.txt'

# Convert to Incep-Resnet-v2 records

In [None]:
from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

In [None]:
from utils.prep import prep_for_inception, RGB_to_lab

train1=join(resized_dir,'2903617548_d3e38d7f88.jpg')
img1=Image.open(train1)
f=plt.figure(figsize=(10,10))
f.add_subplot(1,5,1)
plt.imshow(img1)
incep_img=prep_for_inception(np.asarray(img1))
f.add_subplot(1,5,2)
plt.imshow(incep_img[0])
l_img,ab_img=RGB_to_lab(np.asarray(img1))
f.add_subplot(1,5,3)
plt.imshow(l_img[0],cmap='gray')
f.add_subplot(1,5,4)
plt.imshow(ab_img[0,:,:,0]*127)
f.add_subplot(1,5,5)
plt.imshow(ab_img[0,:,:,1]*127)

In [None]:
from utils.tfrecord_writer import tfrecordwriter
dev_list = './Flickr_Data/Flickr8k_text/Flickr_8k.devImages.txt'
resized_dir = './Flickr_Data/resized'
img_list = dev_list
record_path = './tfrecords/'
file_name='dev.tfrecords'
tfrecordwriter(resized_dir, img_list, record_path, file_name)

In [None]:
from utils.tfrecord_writer import tfrecordwriter
test_list = './Flickr_Data/Flickr8k_text/Flickr_8k.testImages.txt'
resized_dir = './Flickr_Data/resized'
img_list = test_list
record_path = './tfrecords/'
file_name='test.tfrecords'
tfrecordwriter(resized_dir, img_list, record_path, file_name)

In [None]:
from utils.tfrecord_writer import tfrecordwriter
train_list = './Flickr_Data/Flickr8k_text/Flickr_8k.trainImages.txt'
resized_dir = './Flickr_Data/resized'
img_list = train_list
record_path = './tfrecords/'
file_name='train.tfrecords'
tfrecordwriter(resized_dir, img_list, record_path, file_name)

# Read From tfrecord

In [None]:
from utils.tfrecord_reader import batch_reader
record_path = './tfrecords/'
record_file = 'train.tfrecords'
batch_size = 1
train_batch_generator = batch_reader(batch_size, record_path, record_file)

In [None]:
train_batch=next(train_batch_generator)
plt.imshow(train_batch[0][0][0],cmap='gray')

# Train model

In [None]:
from utils.tfrecord_reader import batch_reader
record_path = './tfrecords/'
record_file = 'train.tfrecords'
batch_size = 1
generator=batch_reader(batch_size, record_path, record_file)

In [None]:
from utils.network import deep_color
deep_color.fit(generator, epochs=10, steps_per_epoch=30)

# Get test case

In [None]:
record_path = './tfrecords/'
test_file = 'test.tfrecords'
batch_size = 1
test_generator=batch_reader(batch_size, record_path, test_file)

In [None]:
test_img=next(test_generator)
test_l, test_emb=test_img[0]
test_ab=test_img[1]

In [None]:
test_pred=deep_color.predict(test_img[0])

In [None]:
test_result=tf.concat([test_l,test_pred], 3)
test_result=color.lab2rgb(test_result)
test_truth=tf.concat([test_l,test_ab], 3)
test_truth=color.lab2rgb(test_truth)

In [None]:
f=plt.figure()
f.add_subplot(1,3,1)
plt.imshow(test_l[0],cmap='gray')
f.add_subplot(1,3,2)
plt.imshow(test_result[0])
f.add_subplot(1,3,3)
plt.imshow(test_truth[0])
plt.show()